xchunked_view.hpp 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. /***************************************************************************
  2. * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
  3. * Copyright (c) QuantStack *
  4. * *
  5. * Distributed under the terms of the BSD 3-Clause License. *
  6. * *
  7. * The full license is in the file LICENSE, distributed with this software. *
  8. ****************************************************************************/
  9. #ifndef XTENSOR_CHUNKED_VIEW_HPP
  10. #define XTENSOR_CHUNKED_VIEW_HPP
  11. #include <xtl/xsequence.hpp>
  12. #include "xchunked_array.hpp"
  13. #include "xnoalias.hpp"
  14. #include "xstorage.hpp"
  15. #include "xstrided_view.hpp"
  16. namespace xt
  17. {
  18. template <class E>
  19. struct is_chunked_t : detail::chunk_helper<E>::is_chunked
  20. {
  21. };
  22. /*****************
  23. * xchunked_view *
  24. *****************/
  25. template <class E>
  26. class xchunk_iterator;
  27. template <class E>
  28. class xchunked_view
  29. {
  30. public:
  31. using self_type = xchunked_view<E>;
  32. using expression_type = std::decay_t<E>;
  33. using value_type = typename expression_type::value_type;
  34. using reference = typename expression_type::reference;
  35. using const_reference = typename expression_type::const_reference;
  36. using pointer = typename expression_type::pointer;
  37. using const_pointer = typename expression_type::const_pointer;
  38. using size_type = typename expression_type::size_type;
  39. using difference_type = typename expression_type::difference_type;
  40. using shape_type = svector<size_type>;
  41. using chunk_iterator = xchunk_iterator<self_type>;
  42. using const_chunk_iterator = xchunk_iterator<const self_type>;
  43. template <class OE, class S>
  44. xchunked_view(OE&& e, S&& chunk_shape);
  45. template <class OE>
  46. xchunked_view(OE&& e);
  47. void init();
  48. template <class OE>
  49. typename std::enable_if_t<!is_chunked_t<OE>::value, xchunked_view<E>&> operator=(const OE& e);
  50. template <class OE>
  51. typename std::enable_if_t<is_chunked_t<OE>::value, xchunked_view<E>&> operator=(const OE& e);
  52. size_type dimension() const noexcept;
  53. const shape_type& shape() const noexcept;
  54. const shape_type& chunk_shape() const noexcept;
  55. size_type grid_size() const noexcept;
  56. const shape_type& grid_shape() const noexcept;
  57. expression_type& expression() noexcept;
  58. const expression_type& expression() const noexcept;
  59. chunk_iterator chunk_begin();
  60. chunk_iterator chunk_end();
  61. const_chunk_iterator chunk_begin() const;
  62. const_chunk_iterator chunk_end() const;
  63. const_chunk_iterator chunk_cbegin() const;
  64. const_chunk_iterator chunk_cend() const;
  65. private:
  66. E m_expression;
  67. shape_type m_shape;
  68. shape_type m_chunk_shape;
  69. shape_type m_grid_shape;
  70. size_type m_chunk_nb;
  71. };
  72. template <class E, class S>
  73. xchunked_view<E> as_chunked(E&& e, S&& chunk_shape);
  74. /********************************
  75. * xchunked_view implementation *
  76. ********************************/
  77. template <class E>
  78. template <class OE, class S>
  79. inline xchunked_view<E>::xchunked_view(OE&& e, S&& chunk_shape)
  80. : m_expression(std::forward<OE>(e))
  81. , m_chunk_shape(xtl::forward_sequence<shape_type, S>(chunk_shape))
  82. {
  83. m_shape.resize(e.dimension());
  84. const auto& s = e.shape();
  85. std::copy(s.cbegin(), s.cend(), m_shape.begin());
  86. init();
  87. }
  88. template <class E>
  89. template <class OE>
  90. inline xchunked_view<E>::xchunked_view(OE&& e)
  91. : m_expression(std::forward<OE>(e))
  92. {
  93. m_shape.resize(e.dimension());
  94. const auto& s = e.shape();
  95. std::copy(s.cbegin(), s.cend(), m_shape.begin());
  96. }
  97. template <class E>
  98. void xchunked_view<E>::init()
  99. {
  100. // compute chunk number in each dimension
  101. m_grid_shape.resize(m_shape.size());
  102. std::transform(
  103. m_shape.cbegin(),
  104. m_shape.cend(),
  105. m_chunk_shape.cbegin(),
  106. m_grid_shape.begin(),
  107. [](auto s, auto cs)
  108. {
  109. std::size_t cn = s / cs;
  110. if (s % cs > 0)
  111. {
  112. cn++; // edge_chunk
  113. }
  114. return cn;
  115. }
  116. );
  117. m_chunk_nb = std::accumulate(
  118. std::begin(m_grid_shape),
  119. std::end(m_grid_shape),
  120. std::size_t(1),
  121. std::multiplies<>()
  122. );
  123. }
  124. template <class E>
  125. template <class OE>
  126. typename std::enable_if_t<!is_chunked_t<OE>::value, xchunked_view<E>&>
  127. xchunked_view<E>::operator=(const OE& e)
  128. {
  129. auto end = chunk_end();
  130. for (auto it = chunk_begin(); it != end; ++it)
  131. {
  132. auto el = *it;
  133. noalias(el) = strided_view(e, it.get_slice_vector());
  134. }
  135. return *this;
  136. }
  137. template <class E>
  138. template <class OE>
  139. typename std::enable_if_t<is_chunked_t<OE>::value, xchunked_view<E>&>
  140. xchunked_view<E>::operator=(const OE& e)
  141. {
  142. m_chunk_shape.resize(e.dimension());
  143. const auto& cs = e.chunk_shape();
  144. std::copy(cs.cbegin(), cs.cend(), m_chunk_shape.begin());
  145. init();
  146. auto it2 = e.chunks().begin();
  147. auto end1 = chunk_end();
  148. for (auto it1 = chunk_begin(); it1 != end1; ++it1, ++it2)
  149. {
  150. auto el1 = *it1;
  151. auto el2 = *it2;
  152. auto lhs_shape = el1.shape();
  153. if (lhs_shape != el2.shape())
  154. {
  155. xstrided_slice_vector esv(el2.dimension()); // element slice in edge chunk
  156. std::transform(
  157. lhs_shape.begin(),
  158. lhs_shape.end(),
  159. esv.begin(),
  160. [](auto size)
  161. {
  162. return range(0, size);
  163. }
  164. );
  165. noalias(el1) = strided_view(el2, esv);
  166. }
  167. else
  168. {
  169. noalias(el1) = el2;
  170. }
  171. }
  172. return *this;
  173. }
  174. template <class E>
  175. inline auto xchunked_view<E>::dimension() const noexcept -> size_type
  176. {
  177. return m_shape.size();
  178. }
  179. template <class E>
  180. inline auto xchunked_view<E>::shape() const noexcept -> const shape_type&
  181. {
  182. return m_shape;
  183. }
  184. template <class E>
  185. inline auto xchunked_view<E>::chunk_shape() const noexcept -> const shape_type&
  186. {
  187. return m_chunk_shape;
  188. }
  189. template <class E>
  190. inline auto xchunked_view<E>::grid_size() const noexcept -> size_type
  191. {
  192. return m_chunk_nb;
  193. }
  194. template <class E>
  195. inline auto xchunked_view<E>::grid_shape() const noexcept -> const shape_type&
  196. {
  197. return m_grid_shape;
  198. }
  199. template <class E>
  200. inline auto xchunked_view<E>::expression() noexcept -> expression_type&
  201. {
  202. return m_expression;
  203. }
  204. template <class E>
  205. inline auto xchunked_view<E>::expression() const noexcept -> const expression_type&
  206. {
  207. return m_expression;
  208. }
  209. template <class E>
  210. inline auto xchunked_view<E>::chunk_begin() -> chunk_iterator
  211. {
  212. shape_type chunk_index(m_shape.size(), size_type(0));
  213. return chunk_iterator(*this, std::move(chunk_index), 0u);
  214. }
  215. template <class E>
  216. inline auto xchunked_view<E>::chunk_end() -> chunk_iterator
  217. {
  218. return chunk_iterator(*this, shape_type(grid_shape()), grid_size());
  219. }
  220. template <class E>
  221. inline auto xchunked_view<E>::chunk_begin() const -> const_chunk_iterator
  222. {
  223. shape_type chunk_index(m_shape.size(), size_type(0));
  224. return const_chunk_iterator(*this, std::move(chunk_index), 0u);
  225. }
  226. template <class E>
  227. inline auto xchunked_view<E>::chunk_end() const -> const_chunk_iterator
  228. {
  229. return const_chunk_iterator(*this, shape_type(grid_shape()), grid_size());
  230. }
  231. template <class E>
  232. inline auto xchunked_view<E>::chunk_cbegin() const -> const_chunk_iterator
  233. {
  234. return chunk_begin();
  235. }
  236. template <class E>
  237. inline auto xchunked_view<E>::chunk_cend() const -> const_chunk_iterator
  238. {
  239. return chunk_end();
  240. }
  241. template <class E, class S>
  242. inline xchunked_view<E> as_chunked(E&& e, S&& chunk_shape)
  243. {
  244. return xchunked_view<E>(std::forward<E>(e), std::forward<S>(chunk_shape));
  245. }
  246. template <class E>
  247. inline xchunked_view<E> as_chunked(E&& e)
  248. {
  249. return xchunked_view<E>(std::forward<E>(e));
  250. }
  251. }
  252. #endif