xaxis_slice_iterator.hpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. /***************************************************************************
  2. * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
  3. * Copyright (c) QuantStack *
  4. * *
  5. * Distributed under the terms of the BSD 3-Clause License. *
  6. * *
  7. * The full license is in the file LICENSE, distributed with this software. *
  8. ****************************************************************************/
  9. #ifndef XTENSOR_AXIS_SLICE_ITERATOR_HPP
  10. #define XTENSOR_AXIS_SLICE_ITERATOR_HPP
  11. #include "xstrided_view.hpp"
  12. namespace xt
  13. {
  14. /**
  15. * @class xaxis_slice_iterator
  16. * @brief Class for iteration over one-dimensional slices
  17. *
  18. * The xaxis_slice_iterator iterates over one-dimensional slices
  19. * oriented along the specified axis
  20. *
  21. * @tparam CT the closure type of the \ref xexpression
  22. */
  23. template <class CT>
  24. class xaxis_slice_iterator
  25. {
  26. public:
  27. using self_type = xaxis_slice_iterator<CT>;
  28. using xexpression_type = std::decay_t<CT>;
  29. using size_type = typename xexpression_type::size_type;
  30. using difference_type = typename xexpression_type::difference_type;
  31. using shape_type = typename xexpression_type::shape_type;
  32. using strides_type = typename xexpression_type::strides_type;
  33. using value_type = xstrided_view<CT, shape_type>;
  34. using reference = std::remove_reference_t<apply_cv_t<CT, value_type>>;
  35. using pointer = xtl::xclosure_pointer<std::remove_reference_t<apply_cv_t<CT, value_type>>>;
  36. using iterator_category = std::forward_iterator_tag;
  37. template <class CTA>
  38. xaxis_slice_iterator(CTA&& e, size_type axis);
  39. template <class CTA>
  40. xaxis_slice_iterator(CTA&& e, size_type axis, size_type index, size_type offset);
  41. self_type& operator++();
  42. self_type operator++(int);
  43. reference operator*() const;
  44. pointer operator->() const;
  45. bool equal(const self_type& rhs) const;
  46. private:
  47. using storing_type = xtl::ptr_closure_type_t<CT>;
  48. mutable storing_type p_expression;
  49. size_type m_index;
  50. size_type m_offset;
  51. size_type m_axis_stride;
  52. size_type m_lower_shape;
  53. size_type m_upper_shape;
  54. size_type m_iter_size;
  55. bool m_is_target_axis;
  56. value_type m_sv;
  57. template <class T, class CTA>
  58. std::enable_if_t<std::is_pointer<T>::value, T> get_storage_init(CTA&& e) const;
  59. template <class T, class CTA>
  60. std::enable_if_t<!std::is_pointer<T>::value, T> get_storage_init(CTA&& e) const;
  61. };
  62. template <class CT>
  63. bool operator==(const xaxis_slice_iterator<CT>& lhs, const xaxis_slice_iterator<CT>& rhs);
  64. template <class CT>
  65. bool operator!=(const xaxis_slice_iterator<CT>& lhs, const xaxis_slice_iterator<CT>& rhs);
  66. template <class E>
  67. auto xaxis_slice_begin(E&& e);
  68. template <class E>
  69. auto xaxis_slice_begin(E&& e, typename std::decay_t<E>::size_type axis);
  70. template <class E>
  71. auto xaxis_slice_end(E&& e);
  72. template <class E>
  73. auto xaxis_slice_end(E&& e, typename std::decay_t<E>::size_type axis);
  74. /***************************************
  75. * xaxis_slice_iterator implementation *
  76. ***************************************/
  77. template <class CT>
  78. template <class T, class CTA>
  79. inline std::enable_if_t<std::is_pointer<T>::value, T>
  80. xaxis_slice_iterator<CT>::get_storage_init(CTA&& e) const
  81. {
  82. return &e;
  83. }
  84. template <class CT>
  85. template <class T, class CTA>
  86. inline std::enable_if_t<!std::is_pointer<T>::value, T>
  87. xaxis_slice_iterator<CT>::get_storage_init(CTA&& e) const
  88. {
  89. return e;
  90. }
  91. /**
  92. * @name Constructors
  93. */
  94. //@{
  95. /**
  96. * Constructs an xaxis_slice_iterator
  97. *
  98. * @param e the expression to iterate over
  99. * @param axis the axis to iterate over taking one dimensional slices
  100. */
  101. template <class CT>
  102. template <class CTA>
  103. inline xaxis_slice_iterator<CT>::xaxis_slice_iterator(CTA&& e, size_type axis)
  104. : xaxis_slice_iterator(std::forward<CTA>(e), axis, 0, e.data_offset())
  105. {
  106. }
  107. /**
  108. * Constructs an xaxis_slice_iterator starting at specified index and offset
  109. *
  110. * @param e the expression to iterate over
  111. * @param axis the axis to iterate over taking one dimensional slices
  112. * @param index the starting index for the iterator
  113. * @param offset the starting offset for the iterator
  114. */
  115. template <class CT>
  116. template <class CTA>
  117. inline xaxis_slice_iterator<CT>::xaxis_slice_iterator(CTA&& e, size_type axis, size_type index, size_type offset)
  118. : p_expression(get_storage_init<storing_type>(std::forward<CTA>(e)))
  119. , m_index(index)
  120. , m_offset(offset)
  121. , m_axis_stride(static_cast<size_type>(e.strides()[axis]) * (e.shape()[axis] - 1u))
  122. , m_lower_shape(0)
  123. , m_upper_shape(0)
  124. , m_iter_size(0)
  125. , m_is_target_axis(false)
  126. , m_sv(strided_view(
  127. std::forward<CT>(e),
  128. std::forward<shape_type>({e.shape()[axis]}),
  129. std::forward<strides_type>({e.strides()[axis]}),
  130. offset,
  131. e.layout()
  132. ))
  133. {
  134. if (e.layout() == layout_type::row_major)
  135. {
  136. m_is_target_axis = axis == e.dimension() - 1;
  137. m_lower_shape = std::accumulate(
  138. e.shape().begin() + axis + 1,
  139. e.shape().end(),
  140. size_t(1),
  141. std::multiplies<>()
  142. );
  143. m_iter_size = std::accumulate(e.shape().begin() + 1, e.shape().end(), size_t(1), std::multiplies<>());
  144. }
  145. else
  146. {
  147. m_is_target_axis = axis == 0;
  148. m_lower_shape = std::accumulate(
  149. e.shape().begin(),
  150. e.shape().begin() + axis,
  151. size_t(1),
  152. std::multiplies<>()
  153. );
  154. m_iter_size = std::accumulate(e.shape().begin(), e.shape().end() - 1, size_t(1), std::multiplies<>());
  155. }
  156. m_upper_shape = m_lower_shape + m_axis_stride;
  157. }
  158. //@}
  159. /**
  160. * @name Increment
  161. */
  162. //@{
  163. /**
  164. * Increments the iterator to the next position and returns it.
  165. */
  166. template <class CT>
  167. inline auto xaxis_slice_iterator<CT>::operator++() -> self_type&
  168. {
  169. ++m_index;
  170. ++m_offset;
  171. auto index_compare = (m_offset % m_iter_size);
  172. if (m_is_target_axis || (m_upper_shape >= index_compare && index_compare >= m_lower_shape))
  173. {
  174. m_offset += m_axis_stride;
  175. }
  176. m_sv.set_offset(m_offset);
  177. return *this;
  178. }
  179. /**
  180. * Makes a copy of the iterator, increments it to the next
  181. * position, and returns the copy.
  182. */
  183. template <class CT>
  184. inline auto xaxis_slice_iterator<CT>::operator++(int) -> self_type
  185. {
  186. self_type tmp(*this);
  187. ++(*this);
  188. return tmp;
  189. }
  190. //@}
  191. /**
  192. * @name Reference
  193. */
  194. //@{
  195. /**
  196. * Returns the strided view at the current iteration position
  197. *
  198. * @return a strided_view
  199. */
  200. template <class CT>
  201. inline auto xaxis_slice_iterator<CT>::operator*() const -> reference
  202. {
  203. return m_sv;
  204. }
  205. /**
  206. * Returns a pointer to the strided view at the current iteration position
  207. *
  208. * @return a pointer to a strided_view
  209. */
  210. template <class CT>
  211. inline auto xaxis_slice_iterator<CT>::operator->() const -> pointer
  212. {
  213. return xtl::closure_pointer(operator*());
  214. }
  215. //@}
  216. /*
  217. * @name Comparisons
  218. */
  219. //@{
  220. /**
  221. * Checks equality of the xaxis_slice_iterator and \c rhs.
  222. *
  223. * @return true if the iterators are equivalent, false otherwise
  224. */
  225. template <class CT>
  226. inline bool xaxis_slice_iterator<CT>::equal(const self_type& rhs) const
  227. {
  228. return p_expression == rhs.p_expression && m_index == rhs.m_index;
  229. }
  230. /**
  231. * Checks equality of the iterators.
  232. *
  233. * @return true if the iterators are equivalent, false otherwise
  234. */
  235. template <class CT>
  236. inline bool operator==(const xaxis_slice_iterator<CT>& lhs, const xaxis_slice_iterator<CT>& rhs)
  237. {
  238. return lhs.equal(rhs);
  239. }
  240. /**
  241. * Checks inequality of the iterators
  242. * @return true if the iterators are different, true otherwise
  243. */
  244. template <class CT>
  245. inline bool operator!=(const xaxis_slice_iterator<CT>& lhs, const xaxis_slice_iterator<CT>& rhs)
  246. {
  247. return !(lhs == rhs);
  248. }
  249. //@}
  250. /**
  251. * @name Iterators
  252. */
  253. //@{
  254. /**
  255. * Returns an iterator to the first element of the expression for axis 0
  256. *
  257. * @param e the expession to iterate over
  258. * @return an instance of xaxis_slice_iterator
  259. */
  260. template <class E>
  261. inline auto axis_slice_begin(E&& e)
  262. {
  263. using return_type = xaxis_slice_iterator<xtl::closure_type_t<E>>;
  264. return return_type(std::forward<E>(e), 0);
  265. }
  266. /**
  267. * Returns an iterator to the first element of the expression for the specified axis
  268. *
  269. * @param e the expession to iterate over
  270. * @param axis the axis to iterate over
  271. * @return an instance of xaxis_slice_iterator
  272. */
  273. template <class E>
  274. inline auto axis_slice_begin(E&& e, typename std::decay_t<E>::size_type axis)
  275. {
  276. using return_type = xaxis_slice_iterator<xtl::closure_type_t<E>>;
  277. return return_type(std::forward<E>(e), axis, 0, e.data_offset());
  278. }
  279. /**
  280. * Returns an iterator to the element following the last element of
  281. * the expression for axis 0
  282. *
  283. * @param e the expession to iterate over
  284. * @return an instance of xaxis_slice_iterator
  285. */
  286. template <class E>
  287. inline auto axis_slice_end(E&& e)
  288. {
  289. using return_type = xaxis_slice_iterator<xtl::closure_type_t<E>>;
  290. return return_type(
  291. std::forward<E>(e),
  292. 0,
  293. std::accumulate(e.shape().begin() + 1, e.shape().end(), size_t(1), std::multiplies<>()),
  294. e.size()
  295. );
  296. }
  297. /**
  298. * Returns an iterator to the element following the last element of
  299. * the expression for the specified axis
  300. *
  301. * @param e the expression to iterate over
  302. * @param axis the axis to iterate over
  303. * @return an instance of xaxis_slice_iterator
  304. */
  305. template <class E>
  306. inline auto axis_slice_end(E&& e, typename std::decay_t<E>::size_type axis)
  307. {
  308. using return_type = xaxis_slice_iterator<xtl::closure_type_t<E>>;
  309. auto index_sum = std::accumulate(
  310. e.shape().begin(),
  311. e.shape().begin() + axis,
  312. size_t(1),
  313. std::multiplies<>()
  314. );
  315. return return_type(
  316. std::forward<E>(e),
  317. axis,
  318. std::accumulate(e.shape().begin() + axis + 1, e.shape().end(), index_sum, std::multiplies<>()),
  319. e.size() + axis
  320. );
  321. }
  322. //@}
  323. }
  324. #endif