xchunked_assign.hpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. /***************************************************************************
  2. * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
  3. * Copyright (c) QuantStack *
  4. * *
  5. * Distributed under the terms of the BSD 3-Clause License. *
  6. * *
  7. * The full license is in the file LICENSE, distributed with this software. *
  8. ****************************************************************************/
  9. #ifndef XTENSOR_CHUNKED_ASSIGN_HPP
  10. #define XTENSOR_CHUNKED_ASSIGN_HPP
  11. #include "xnoalias.hpp"
  12. #include "xstrided_view.hpp"
  13. namespace xt
  14. {
  15. /*******************
  16. * xchunk_assigner *
  17. *******************/
  18. template <class T, class chunk_storage>
  19. class xchunked_assigner
  20. {
  21. public:
  22. using temporary_type = T;
  23. template <class E, class DST>
  24. void build_and_assign_temporary(const xexpression<E>& e, DST& dst);
  25. };
  26. /*********************************
  27. * xchunked_semantic declaration *
  28. *********************************/
  29. template <class D>
  30. class xchunked_semantic : public xsemantic_base<D>
  31. {
  32. public:
  33. using base_type = xsemantic_base<D>;
  34. using derived_type = D;
  35. using temporary_type = typename base_type::temporary_type;
  36. template <class E>
  37. derived_type& assign_xexpression(const xexpression<E>& e);
  38. template <class E>
  39. derived_type& computed_assign(const xexpression<E>& e);
  40. template <class E, class F>
  41. derived_type& scalar_computed_assign(const E& e, F&& f);
  42. protected:
  43. xchunked_semantic() = default;
  44. ~xchunked_semantic() = default;
  45. xchunked_semantic(const xchunked_semantic&) = default;
  46. xchunked_semantic& operator=(const xchunked_semantic&) = default;
  47. xchunked_semantic(xchunked_semantic&&) = default;
  48. xchunked_semantic& operator=(xchunked_semantic&&) = default;
  49. template <class E>
  50. derived_type& operator=(const xexpression<E>& e);
  51. private:
  52. template <class CS>
  53. xchunked_assigner<temporary_type, CS> get_assigner(const CS&) const;
  54. };
  55. /*******************
  56. * xchunk_iterator *
  57. *******************/
  58. template <class CS>
  59. class xchunked_array;
  60. template <class E>
  61. class xchunked_view;
  62. namespace detail
  63. {
  64. template <class T>
  65. struct is_xchunked_array : std::false_type
  66. {
  67. };
  68. template <class CS>
  69. struct is_xchunked_array<xchunked_array<CS>> : std::true_type
  70. {
  71. };
  72. template <class T>
  73. struct is_xchunked_view : std::false_type
  74. {
  75. };
  76. template <class E>
  77. struct is_xchunked_view<xchunked_view<E>> : std::true_type
  78. {
  79. };
  80. struct invalid_chunk_iterator
  81. {
  82. };
  83. template <class A>
  84. struct xchunk_iterator_array
  85. {
  86. using reference = decltype(*(std::declval<A>().chunks().begin()));
  87. inline decltype(auto) get_chunk(A& arr, typename A::size_type i, const xstrided_slice_vector&) const
  88. {
  89. using difference_type = typename A::difference_type;
  90. return *(arr.chunks().begin() + static_cast<difference_type>(i));
  91. }
  92. };
  93. template <class V>
  94. struct xchunk_iterator_view
  95. {
  96. using reference = decltype(xt::strided_view(
  97. std::declval<V>().expression(),
  98. std::declval<xstrided_slice_vector>()
  99. ));
  100. inline auto get_chunk(V& view, typename V::size_type, const xstrided_slice_vector& sv) const
  101. {
  102. return xt::strided_view(view.expression(), sv);
  103. }
  104. };
  105. template <class T>
  106. struct xchunk_iterator_base
  107. : std::conditional_t<
  108. is_xchunked_array<std::decay_t<T>>::value,
  109. xchunk_iterator_array<T>,
  110. std::conditional_t<is_xchunked_view<std::decay_t<T>>::value, xchunk_iterator_view<T>, invalid_chunk_iterator>>
  111. {
  112. };
  113. }
  114. template <class E>
  115. class xchunk_iterator : private detail::xchunk_iterator_base<E>
  116. {
  117. public:
  118. using base_type = detail::xchunk_iterator_base<E>;
  119. using self_type = xchunk_iterator<E>;
  120. using size_type = typename E::size_type;
  121. using shape_type = typename E::shape_type;
  122. using slice_vector = xstrided_slice_vector;
  123. using reference = typename base_type::reference;
  124. using value_type = std::remove_reference_t<reference>;
  125. using pointer = value_type*;
  126. using difference_type = typename E::difference_type;
  127. using iterator_category = std::forward_iterator_tag;
  128. xchunk_iterator() = default;
  129. xchunk_iterator(E& chunked_expression, shape_type&& chunk_index, size_type chunk_linear_index);
  130. self_type& operator++();
  131. self_type operator++(int);
  132. decltype(auto) operator*() const;
  133. bool operator==(const self_type& rhs) const;
  134. bool operator!=(const self_type& rhs) const;
  135. const shape_type& chunk_index() const;
  136. const slice_vector& get_slice_vector() const;
  137. slice_vector get_chunk_slice_vector() const;
  138. private:
  139. void fill_slice_vector(size_type index);
  140. E* p_chunked_expression;
  141. shape_type m_chunk_index;
  142. size_type m_chunk_linear_index;
  143. xstrided_slice_vector m_slice_vector;
  144. };
  145. /************************************
  146. * xchunked_semantic implementation *
  147. ************************************/
  148. template <class T, class CS>
  149. template <class E, class DST>
  150. inline void xchunked_assigner<T, CS>::build_and_assign_temporary(const xexpression<E>& e, DST& dst)
  151. {
  152. temporary_type tmp(e, CS(), dst.chunk_shape());
  153. dst = std::move(tmp);
  154. }
  155. template <class D>
  156. template <class E>
  157. inline auto xchunked_semantic<D>::assign_xexpression(const xexpression<E>& e) -> derived_type&
  158. {
  159. auto& d = this->derived_cast();
  160. const auto& chunk_shape = d.chunk_shape();
  161. size_t i = 0;
  162. auto it_end = d.chunk_end();
  163. for (auto it = d.chunk_begin(); it != it_end; ++it, ++i)
  164. {
  165. auto rhs = strided_view(e.derived_cast(), it.get_slice_vector());
  166. if (rhs.shape() != chunk_shape)
  167. {
  168. noalias(strided_view(*it, it.get_chunk_slice_vector())) = rhs;
  169. }
  170. else
  171. {
  172. noalias(*it) = rhs;
  173. }
  174. }
  175. return this->derived_cast();
  176. }
  177. template <class D>
  178. template <class E>
  179. inline auto xchunked_semantic<D>::computed_assign(const xexpression<E>& e) -> derived_type&
  180. {
  181. D& d = this->derived_cast();
  182. if (e.derived_cast().dimension() > d.dimension() || e.derived_cast().shape() > d.shape())
  183. {
  184. return operator=(e);
  185. }
  186. else
  187. {
  188. return assign_xexpression(e);
  189. }
  190. }
  191. template <class D>
  192. template <class E, class F>
  193. inline auto xchunked_semantic<D>::scalar_computed_assign(const E& e, F&& f) -> derived_type&
  194. {
  195. for (auto& c : this->derived_cast().chunks())
  196. {
  197. c.scalar_computed_assign(e, f);
  198. }
  199. return this->derived_cast();
  200. }
  201. template <class D>
  202. template <class E>
  203. inline auto xchunked_semantic<D>::operator=(const xexpression<E>& e) -> derived_type&
  204. {
  205. D& d = this->derived_cast();
  206. get_assigner(d.chunks()).build_and_assign_temporary(e, d);
  207. return d;
  208. }
  209. template <class D>
  210. template <class CS>
  211. inline auto xchunked_semantic<D>::get_assigner(const CS&) const -> xchunked_assigner<temporary_type, CS>
  212. {
  213. return xchunked_assigner<temporary_type, CS>();
  214. }
  215. /**********************************
  216. * xchunk_iterator implementation *
  217. **********************************/
  218. template <class E>
  219. inline xchunk_iterator<E>::xchunk_iterator(E& expression, shape_type&& chunk_index, size_type chunk_linear_index)
  220. : p_chunked_expression(&expression)
  221. , m_chunk_index(std::move(chunk_index))
  222. , m_chunk_linear_index(chunk_linear_index)
  223. , m_slice_vector(m_chunk_index.size())
  224. {
  225. for (size_type i = 0; i < m_chunk_index.size(); ++i)
  226. {
  227. fill_slice_vector(i);
  228. }
  229. }
  230. template <class E>
  231. inline xchunk_iterator<E>& xchunk_iterator<E>::operator++()
  232. {
  233. if (m_chunk_linear_index + 1u != p_chunked_expression->grid_size())
  234. {
  235. size_type i = p_chunked_expression->dimension();
  236. while (i != 0)
  237. {
  238. --i;
  239. if (m_chunk_index[i] + 1u == p_chunked_expression->grid_shape()[i])
  240. {
  241. m_chunk_index[i] = 0;
  242. fill_slice_vector(i);
  243. }
  244. else
  245. {
  246. m_chunk_index[i] += 1;
  247. fill_slice_vector(i);
  248. break;
  249. }
  250. }
  251. }
  252. m_chunk_linear_index++;
  253. return *this;
  254. }
  255. template <class E>
  256. inline xchunk_iterator<E> xchunk_iterator<E>::operator++(int)
  257. {
  258. xchunk_iterator<E> it = *this;
  259. ++(*this);
  260. return it;
  261. }
  262. template <class E>
  263. inline decltype(auto) xchunk_iterator<E>::operator*() const
  264. {
  265. return base_type::get_chunk(*p_chunked_expression, m_chunk_linear_index, m_slice_vector);
  266. }
  267. template <class E>
  268. inline bool xchunk_iterator<E>::operator==(const xchunk_iterator& other) const
  269. {
  270. return m_chunk_linear_index == other.m_chunk_linear_index;
  271. }
  272. template <class E>
  273. inline bool xchunk_iterator<E>::operator!=(const xchunk_iterator& other) const
  274. {
  275. return !(*this == other);
  276. }
  277. template <class E>
  278. inline auto xchunk_iterator<E>::get_slice_vector() const -> const slice_vector&
  279. {
  280. return m_slice_vector;
  281. }
  282. template <class E>
  283. auto xchunk_iterator<E>::chunk_index() const -> const shape_type&
  284. {
  285. return m_chunk_index;
  286. }
  287. template <class E>
  288. inline auto xchunk_iterator<E>::get_chunk_slice_vector() const -> slice_vector
  289. {
  290. slice_vector slices(m_chunk_index.size());
  291. for (size_type i = 0; i < m_chunk_index.size(); ++i)
  292. {
  293. size_type chunk_shape = p_chunked_expression->chunk_shape()[i];
  294. size_type end = std::min(
  295. chunk_shape,
  296. p_chunked_expression->shape()[i] - m_chunk_index[i] * chunk_shape
  297. );
  298. slices[i] = range(0u, end);
  299. }
  300. return slices;
  301. }
  302. template <class E>
  303. inline void xchunk_iterator<E>::fill_slice_vector(size_type i)
  304. {
  305. size_type range_start = m_chunk_index[i] * p_chunked_expression->chunk_shape()[i];
  306. size_type range_end = std::min(
  307. (m_chunk_index[i] + 1) * p_chunked_expression->chunk_shape()[i],
  308. p_chunked_expression->shape()[i]
  309. );
  310. m_slice_vector[i] = range(range_start, range_end);
  311. }
  312. }
  313. #endif