xaxis_iterator.hpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. /***************************************************************************
  2. * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
  3. * Copyright (c) QuantStack *
  4. * *
  5. * Distributed under the terms of the BSD 3-Clause License. *
  6. * *
  7. * The full license is in the file LICENSE, distributed with this software. *
  8. ****************************************************************************/
  9. #ifndef XTENSOR_AXIS_ITERATOR_HPP
  10. #define XTENSOR_AXIS_ITERATOR_HPP
  11. #include "xstrided_view.hpp"
  12. namespace xt
  13. {
  14. /******************
  15. * xaxis_iterator *
  16. ******************/
  17. /**
  18. * @class xaxis_iterator
  19. * @brief Class for iteration over (N-1)-dimensional slices, where
  20. * N is the dimension of the underlying expression
  21. *
  22. * If N is the number of dimensions of an expression, the xaxis_iterator
  23. * iterates over (N-1)-dimensional slices oriented along the specified axis.
  24. *
  25. * @tparam CT the closure type of the \ref xexpression
  26. */
  27. template <class CT>
  28. class xaxis_iterator
  29. {
  30. public:
  31. using self_type = xaxis_iterator<CT>;
  32. using xexpression_type = std::decay_t<CT>;
  33. using size_type = typename xexpression_type::size_type;
  34. using difference_type = typename xexpression_type::difference_type;
  35. using shape_type = typename xexpression_type::shape_type;
  36. using value_type = xstrided_view<CT, shape_type>;
  37. using reference = std::remove_reference_t<apply_cv_t<CT, value_type>>;
  38. using pointer = xtl::xclosure_pointer<std::remove_reference_t<apply_cv_t<CT, value_type>>>;
  39. using iterator_category = std::forward_iterator_tag;
  40. template <class CTA>
  41. xaxis_iterator(CTA&& e, size_type axis);
  42. template <class CTA>
  43. xaxis_iterator(CTA&& e, size_type axis, size_type index, size_type offset);
  44. self_type& operator++();
  45. self_type operator++(int);
  46. reference operator*() const;
  47. pointer operator->() const;
  48. bool equal(const self_type& rhs) const;
  49. private:
  50. using storing_type = xtl::ptr_closure_type_t<CT>;
  51. mutable storing_type p_expression;
  52. size_type m_index;
  53. size_type m_add_offset;
  54. value_type m_sv;
  55. template <class T, class CTA>
  56. std::enable_if_t<std::is_pointer<T>::value, T> get_storage_init(CTA&& e) const;
  57. template <class T, class CTA>
  58. std::enable_if_t<!std::is_pointer<T>::value, T> get_storage_init(CTA&& e) const;
  59. };
  60. template <class CT>
  61. bool operator==(const xaxis_iterator<CT>& lhs, const xaxis_iterator<CT>& rhs);
  62. template <class CT>
  63. bool operator!=(const xaxis_iterator<CT>& lhs, const xaxis_iterator<CT>& rhs);
  64. template <class E>
  65. auto axis_begin(E&& e);
  66. template <class E>
  67. auto axis_begin(E&& e, typename std::decay_t<E>::size_type axis);
  68. template <class E>
  69. auto axis_end(E&& e);
  70. template <class E>
  71. auto axis_end(E&& e, typename std::decay_t<E>::size_type axis);
  72. /*********************************
  73. * xaxis_iterator implementation *
  74. *********************************/
  75. namespace detail
  76. {
  77. template <class CT>
  78. auto derive_xstrided_view(
  79. CT&& e,
  80. typename std::decay_t<CT>::size_type axis,
  81. typename std::decay_t<CT>::size_type offset
  82. )
  83. {
  84. using xexpression_type = std::decay_t<CT>;
  85. using shape_type = typename xexpression_type::shape_type;
  86. using strides_type = typename xexpression_type::strides_type;
  87. const auto& e_shape = e.shape();
  88. shape_type shape(e_shape.size() - 1);
  89. auto nxt = std::copy(e_shape.cbegin(), e_shape.cbegin() + axis, shape.begin());
  90. std::copy(e_shape.cbegin() + axis + 1, e_shape.end(), nxt);
  91. const auto& e_strides = e.strides();
  92. strides_type strides(e_strides.size() - 1);
  93. auto nxt_strides = std::copy(e_strides.cbegin(), e_strides.cbegin() + axis, strides.begin());
  94. std::copy(e_strides.cbegin() + axis + 1, e_strides.end(), nxt_strides);
  95. return strided_view(std::forward<CT>(e), std::move(shape), std::move(strides), offset, e.layout());
  96. }
  97. }
  98. template <class CT>
  99. template <class T, class CTA>
  100. inline std::enable_if_t<std::is_pointer<T>::value, T> xaxis_iterator<CT>::get_storage_init(CTA&& e) const
  101. {
  102. return &e;
  103. }
  104. template <class CT>
  105. template <class T, class CTA>
  106. inline std::enable_if_t<!std::is_pointer<T>::value, T> xaxis_iterator<CT>::get_storage_init(CTA&& e) const
  107. {
  108. return e;
  109. }
  110. /**
  111. * @name Constructors
  112. */
  113. //@{
  114. /**
  115. * Constructs an xaxis_iterator
  116. *
  117. * @param e the expression to iterate over
  118. * @param axis the axis to iterate over taking N-1 dimensional slices
  119. */
  120. template <class CT>
  121. template <class CTA>
  122. inline xaxis_iterator<CT>::xaxis_iterator(CTA&& e, size_type axis)
  123. : xaxis_iterator(std::forward<CTA>(e), axis, 0, e.data_offset())
  124. {
  125. }
  126. /**
  127. * Constructs an xaxis_iterator starting at specified index and offset
  128. *
  129. * @param e the expression to iterate over
  130. * @param axis the axis to iterate over taking N-1 dimensional slices
  131. * @param index the starting index for the iterator
  132. * @param offset the starting offset for the iterator
  133. */
  134. template <class CT>
  135. template <class CTA>
  136. inline xaxis_iterator<CT>::xaxis_iterator(CTA&& e, size_type axis, size_type index, size_type offset)
  137. : p_expression(get_storage_init<storing_type>(std::forward<CTA>(e)))
  138. , m_index(index)
  139. , m_add_offset(static_cast<size_type>(e.strides()[axis]))
  140. , m_sv(detail::derive_xstrided_view<CTA>(std::forward<CTA>(e), axis, offset))
  141. {
  142. }
  143. //@}
  144. /**
  145. * @name Increment
  146. */
  147. //@{
  148. /**
  149. * Increments the iterator to the next position and returns it.
  150. */
  151. template <class CT>
  152. inline auto xaxis_iterator<CT>::operator++() -> self_type&
  153. {
  154. m_sv.set_offset(m_sv.data_offset() + m_add_offset);
  155. ++m_index;
  156. return *this;
  157. }
  158. /**
  159. * Makes a copy of the iterator, increments it to the next
  160. * position, and returns the copy.
  161. */
  162. template <class CT>
  163. inline auto xaxis_iterator<CT>::operator++(int) -> self_type
  164. {
  165. self_type tmp(*this);
  166. ++(*this);
  167. return tmp;
  168. }
  169. //@}
  170. /**
  171. * @name Reference
  172. */
  173. //@{
  174. /**
  175. * Returns the strided view at the current iteration position
  176. *
  177. * @return a strided_view
  178. */
  179. template <class CT>
  180. inline auto xaxis_iterator<CT>::operator*() const -> reference
  181. {
  182. return m_sv;
  183. }
  184. /**
  185. * Returns a pointer to the strided view at the current iteration position
  186. *
  187. * @return a pointer to a strided_view
  188. */
  189. template <class CT>
  190. inline auto xaxis_iterator<CT>::operator->() const -> pointer
  191. {
  192. return xtl::closure_pointer(operator*());
  193. }
  194. //@}
  195. /*
  196. * @name Comparisons
  197. */
  198. //@{
  199. /**
  200. * Checks equality of the xaxis_slice_iterator and \c rhs.
  201. *
  202. * @param
  203. * @return true if the iterators are equivalent, false otherwise
  204. */
  205. template <class CT>
  206. inline bool xaxis_iterator<CT>::equal(const self_type& rhs) const
  207. {
  208. return p_expression == rhs.p_expression && m_index == rhs.m_index
  209. && m_sv.data_offset() == rhs.m_sv.data_offset();
  210. }
  211. /**
  212. * Checks equality of the iterators.
  213. *
  214. * @return true if the iterators are equivalent, false otherwise
  215. */
  216. template <class CT>
  217. inline bool operator==(const xaxis_iterator<CT>& lhs, const xaxis_iterator<CT>& rhs)
  218. {
  219. return lhs.equal(rhs);
  220. }
  221. /**
  222. * Checks inequality of the iterators
  223. * @return true if the iterators are different, true otherwise
  224. */
  225. template <class CT>
  226. inline bool operator!=(const xaxis_iterator<CT>& lhs, const xaxis_iterator<CT>& rhs)
  227. {
  228. return !(lhs == rhs);
  229. }
  230. //@}
  231. /**
  232. * @name Iterators
  233. */
  234. //@{
  235. /**
  236. * Returns an iterator to the first element of the expression for axis 0
  237. *
  238. * @param e the expession to iterate over
  239. * @return an instance of xaxis_iterator
  240. */
  241. template <class E>
  242. inline auto axis_begin(E&& e)
  243. {
  244. using return_type = xaxis_iterator<xtl::closure_type_t<E>>;
  245. return return_type(std::forward<E>(e), 0);
  246. }
  247. /**
  248. * Returns an iterator to the first element of the expression for the specified axis
  249. *
  250. * @param e the expession to iterate over
  251. * @param axis the axis to iterate over
  252. * @return an instance of xaxis_iterator
  253. */
  254. template <class E>
  255. inline auto axis_begin(E&& e, typename std::decay_t<E>::size_type axis)
  256. {
  257. using return_type = xaxis_iterator<xtl::closure_type_t<E>>;
  258. return return_type(std::forward<E>(e), axis);
  259. }
  260. /**
  261. * Returns an iterator to the element following the last element of
  262. * the expression for axis 0
  263. *
  264. * @param e the expession to iterate over
  265. * @return an instance of xaxis_iterator
  266. */
  267. template <class E>
  268. inline auto axis_end(E&& e)
  269. {
  270. using size_type = typename std::decay_t<E>::size_type;
  271. using return_type = xaxis_iterator<xtl::closure_type_t<E>>;
  272. return return_type(
  273. std::forward<E>(e),
  274. 0,
  275. e.shape()[0],
  276. static_cast<size_type>(e.strides()[0]) * e.shape()[0]
  277. );
  278. }
  279. /**
  280. * Returns an iterator to the element following the last element of
  281. * the expression for the specified axis
  282. *
  283. * @param e the expression to iterate over
  284. * @param axis the axis to iterate over
  285. * @return an instance of xaxis_iterator
  286. */
  287. template <class E>
  288. inline auto axis_end(E&& e, typename std::decay_t<E>::size_type axis)
  289. {
  290. using size_type = typename std::decay_t<E>::size_type;
  291. using return_type = xaxis_iterator<xtl::closure_type_t<E>>;
  292. return return_type(
  293. std::forward<E>(e),
  294. axis,
  295. e.shape()[axis],
  296. static_cast<size_type>(e.strides()[axis]) * e.shape()[axis]
  297. );
  298. }
  299. //@}
  300. }
  301. #endif