xiterator.hpp 36 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303
  1. /***************************************************************************
  2. * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
  3. * Copyright (c) QuantStack *
  4. * *
  5. * Distributed under the terms of the BSD 3-Clause License. *
  6. * *
  7. * The full license is in the file LICENSE, distributed with this software. *
  8. ****************************************************************************/
  9. #ifndef XTENSOR_ITERATOR_HPP
  10. #define XTENSOR_ITERATOR_HPP
  11. #include <algorithm>
  12. #include <array>
  13. #include <cstddef>
  14. #include <iterator>
  15. #include <numeric>
  16. #include <vector>
  17. #include <xtl/xcompare.hpp>
  18. #include <xtl/xiterator_base.hpp>
  19. #include <xtl/xmeta_utils.hpp>
  20. #include <xtl/xsequence.hpp>
  21. #include "xexception.hpp"
  22. #include "xlayout.hpp"
  23. #include "xshape.hpp"
  24. #include "xutils.hpp"
  25. namespace xt
  26. {
  27. /***********************
  28. * iterator meta utils *
  29. ***********************/
  30. template <class CT>
  31. class xscalar;
  32. template <bool is_const, class CT>
  33. class xscalar_stepper;
  34. namespace detail
  35. {
  36. template <class C>
  37. struct get_stepper_iterator_impl
  38. {
  39. using type = typename C::container_iterator;
  40. };
  41. template <class C>
  42. struct get_stepper_iterator_impl<const C>
  43. {
  44. using type = typename C::const_container_iterator;
  45. };
  46. template <class CT>
  47. struct get_stepper_iterator_impl<xscalar<CT>>
  48. {
  49. using type = typename xscalar<CT>::dummy_iterator;
  50. };
  51. template <class CT>
  52. struct get_stepper_iterator_impl<const xscalar<CT>>
  53. {
  54. using type = typename xscalar<CT>::const_dummy_iterator;
  55. };
  56. }
  57. template <class C>
  58. using get_stepper_iterator = typename detail::get_stepper_iterator_impl<C>::type;
  59. /********************************
  60. * xindex_type_t implementation *
  61. ********************************/
  62. namespace detail
  63. {
  64. template <class ST>
  65. struct index_type_impl
  66. {
  67. using type = dynamic_shape<typename ST::value_type>;
  68. };
  69. template <class V, std::size_t L>
  70. struct index_type_impl<std::array<V, L>>
  71. {
  72. using type = std::array<V, L>;
  73. };
  74. template <std::size_t... I>
  75. struct index_type_impl<fixed_shape<I...>>
  76. {
  77. using type = std::array<std::size_t, sizeof...(I)>;
  78. };
  79. }
  80. template <class C>
  81. using xindex_type_t = typename detail::index_type_impl<C>::type;
  82. /************
  83. * xstepper *
  84. ************/
  85. template <class C>
  86. class xstepper
  87. {
  88. public:
  89. using storage_type = C;
  90. using subiterator_type = get_stepper_iterator<C>;
  91. using subiterator_traits = std::iterator_traits<subiterator_type>;
  92. using value_type = typename subiterator_traits::value_type;
  93. using reference = typename subiterator_traits::reference;
  94. using pointer = typename subiterator_traits::pointer;
  95. using difference_type = typename subiterator_traits::difference_type;
  96. using size_type = typename storage_type::size_type;
  97. using shape_type = typename storage_type::shape_type;
  98. using simd_value_type = xt_simd::simd_type<value_type>;
  99. template <class requested_type>
  100. using simd_return_type = xt_simd::simd_return_type<value_type, requested_type>;
  101. xstepper() = default;
  102. xstepper(storage_type* c, subiterator_type it, size_type offset) noexcept;
  103. reference operator*() const;
  104. void step(size_type dim, size_type n = 1);
  105. void step_back(size_type dim, size_type n = 1);
  106. void reset(size_type dim);
  107. void reset_back(size_type dim);
  108. void to_begin();
  109. void to_end(layout_type l);
  110. template <class T>
  111. simd_return_type<T> step_simd();
  112. void step_leading();
  113. template <class R>
  114. void store_simd(const R& vec);
  115. private:
  116. storage_type* p_c;
  117. subiterator_type m_it;
  118. size_type m_offset;
  119. };
  120. template <layout_type L>
  121. struct stepper_tools
  122. {
  123. // For performance reasons, increment_stepper and decrement_stepper are
  124. // specialized for the case where n=1, which underlies operator++ and
  125. // operator-- on xiterators.
  126. template <class S, class IT, class ST>
  127. static void increment_stepper(S& stepper, IT& index, const ST& shape);
  128. template <class S, class IT, class ST>
  129. static void decrement_stepper(S& stepper, IT& index, const ST& shape);
  130. template <class S, class IT, class ST>
  131. static void increment_stepper(S& stepper, IT& index, const ST& shape, typename S::size_type n);
  132. template <class S, class IT, class ST>
  133. static void decrement_stepper(S& stepper, IT& index, const ST& shape, typename S::size_type n);
  134. };
  135. /********************
  136. * xindexed_stepper *
  137. ********************/
  138. template <class E, bool is_const>
  139. class xindexed_stepper
  140. {
  141. public:
  142. using self_type = xindexed_stepper<E, is_const>;
  143. using xexpression_type = std::conditional_t<is_const, const E, E>;
  144. using value_type = typename xexpression_type::value_type;
  145. using reference = std::
  146. conditional_t<is_const, typename xexpression_type::const_reference, typename xexpression_type::reference>;
  147. using pointer = std::
  148. conditional_t<is_const, typename xexpression_type::const_pointer, typename xexpression_type::pointer>;
  149. using size_type = typename xexpression_type::size_type;
  150. using difference_type = typename xexpression_type::difference_type;
  151. using shape_type = typename xexpression_type::shape_type;
  152. using index_type = xindex_type_t<shape_type>;
  153. xindexed_stepper() = default;
  154. xindexed_stepper(xexpression_type* e, size_type offset, bool end = false) noexcept;
  155. reference operator*() const;
  156. void step(size_type dim, size_type n = 1);
  157. void step_back(size_type dim, size_type n = 1);
  158. void reset(size_type dim);
  159. void reset_back(size_type dim);
  160. void to_begin();
  161. void to_end(layout_type l);
  162. private:
  163. xexpression_type* p_e;
  164. index_type m_index;
  165. size_type m_offset;
  166. };
  167. template <class T>
  168. struct is_indexed_stepper
  169. {
  170. static const bool value = false;
  171. };
  172. template <class T, bool B>
  173. struct is_indexed_stepper<xindexed_stepper<T, B>>
  174. {
  175. static const bool value = true;
  176. };
  177. template <class T, class R = T>
  178. struct enable_indexed_stepper : std::enable_if<is_indexed_stepper<T>::value, R>
  179. {
  180. };
  181. template <class T, class R = T>
  182. using enable_indexed_stepper_t = typename enable_indexed_stepper<T, R>::type;
  183. template <class T, class R = T>
  184. struct disable_indexed_stepper : std::enable_if<!is_indexed_stepper<T>::value, R>
  185. {
  186. };
  187. template <class T, class R = T>
  188. using disable_indexed_stepper_t = typename disable_indexed_stepper<T, R>::type;
  189. /*************
  190. * xiterator *
  191. *************/
  192. namespace detail
  193. {
  194. template <class S>
  195. class shape_storage
  196. {
  197. public:
  198. using shape_type = S;
  199. using param_type = const S&;
  200. shape_storage() = default;
  201. shape_storage(param_type shape);
  202. const S& shape() const;
  203. private:
  204. S m_shape;
  205. };
  206. template <class S>
  207. class shape_storage<S*>
  208. {
  209. public:
  210. using shape_type = S;
  211. using param_type = const S*;
  212. shape_storage(param_type shape = 0);
  213. const S& shape() const;
  214. private:
  215. const S* p_shape;
  216. };
  217. template <layout_type L>
  218. struct LAYOUT_FORBIDEN_FOR_XITERATOR;
  219. }
  220. template <class St, class S, layout_type L>
  221. class xiterator : public xtl::xrandom_access_iterator_base<
  222. xiterator<St, S, L>,
  223. typename St::value_type,
  224. typename St::difference_type,
  225. typename St::pointer,
  226. typename St::reference>,
  227. private detail::shape_storage<S>
  228. {
  229. public:
  230. using self_type = xiterator<St, S, L>;
  231. using stepper_type = St;
  232. using value_type = typename stepper_type::value_type;
  233. using reference = typename stepper_type::reference;
  234. using pointer = typename stepper_type::pointer;
  235. using difference_type = typename stepper_type::difference_type;
  236. using size_type = typename stepper_type::size_type;
  237. using iterator_category = std::random_access_iterator_tag;
  238. using private_base = detail::shape_storage<S>;
  239. using shape_type = typename private_base::shape_type;
  240. using shape_param_type = typename private_base::param_type;
  241. using index_type = xindex_type_t<shape_type>;
  242. xiterator() = default;
  243. // end_index means either reverse_iterator && !end or !reverse_iterator && end
  244. xiterator(St st, shape_param_type shape, bool end_index);
  245. self_type& operator++();
  246. self_type& operator--();
  247. self_type& operator+=(difference_type n);
  248. self_type& operator-=(difference_type n);
  249. difference_type operator-(const self_type& rhs) const;
  250. reference operator*() const;
  251. pointer operator->() const;
  252. bool equal(const xiterator& rhs) const;
  253. bool less_than(const xiterator& rhs) const;
  254. private:
  255. stepper_type m_st;
  256. index_type m_index;
  257. difference_type m_linear_index;
  258. using checking_type = typename detail::LAYOUT_FORBIDEN_FOR_XITERATOR<L>::type;
  259. };
  260. template <class St, class S, layout_type L>
  261. bool operator==(const xiterator<St, S, L>& lhs, const xiterator<St, S, L>& rhs);
  262. template <class St, class S, layout_type L>
  263. bool operator<(const xiterator<St, S, L>& lhs, const xiterator<St, S, L>& rhs);
  264. template <class St, class S, layout_type L>
  265. struct is_contiguous_container<xiterator<St, S, L>> : std::false_type
  266. {
  267. };
  268. /*********************
  269. * xbounded_iterator *
  270. *********************/
  271. template <class It, class BIt>
  272. class xbounded_iterator : public xtl::xrandom_access_iterator_base<
  273. xbounded_iterator<It, BIt>,
  274. typename std::iterator_traits<It>::value_type,
  275. typename std::iterator_traits<It>::difference_type,
  276. typename std::iterator_traits<It>::pointer,
  277. typename std::iterator_traits<It>::reference>
  278. {
  279. public:
  280. using self_type = xbounded_iterator<It, BIt>;
  281. using subiterator_type = It;
  282. using bound_iterator_type = BIt;
  283. using value_type = typename std::iterator_traits<It>::value_type;
  284. using reference = typename std::iterator_traits<It>::reference;
  285. using pointer = typename std::iterator_traits<It>::pointer;
  286. using difference_type = typename std::iterator_traits<It>::difference_type;
  287. using iterator_category = std::random_access_iterator_tag;
  288. xbounded_iterator() = default;
  289. xbounded_iterator(It it, BIt bound_it);
  290. self_type& operator++();
  291. self_type& operator--();
  292. self_type& operator+=(difference_type n);
  293. self_type& operator-=(difference_type n);
  294. difference_type operator-(const self_type& rhs) const;
  295. value_type operator*() const;
  296. bool equal(const self_type& rhs) const;
  297. bool less_than(const self_type& rhs) const;
  298. private:
  299. subiterator_type m_it;
  300. bound_iterator_type m_bound_it;
  301. };
  302. template <class It, class BIt>
  303. bool operator==(const xbounded_iterator<It, BIt>& lhs, const xbounded_iterator<It, BIt>& rhs);
  304. template <class It, class BIt>
  305. bool operator<(const xbounded_iterator<It, BIt>& lhs, const xbounded_iterator<It, BIt>& rhs);
  306. /*****************************
  307. * linear_begin / linear_end *
  308. *****************************/
  309. namespace detail
  310. {
  311. template <class C, class = void_t<>>
  312. struct has_linear_iterator : std::false_type
  313. {
  314. };
  315. template <class C>
  316. struct has_linear_iterator<C, void_t<decltype(std::declval<C>().linear_cbegin())>> : std::true_type
  317. {
  318. };
  319. }
  320. template <class C>
  321. XTENSOR_CONSTEXPR_RETURN auto linear_begin(C& c) noexcept
  322. {
  323. return xtl::mpl::static_if<detail::has_linear_iterator<C>::value>(
  324. [&](auto self)
  325. {
  326. return self(c).linear_begin();
  327. },
  328. /*else*/
  329. [&](auto self)
  330. {
  331. return self(c).begin();
  332. }
  333. );
  334. }
  335. template <class C>
  336. XTENSOR_CONSTEXPR_RETURN auto linear_end(C& c) noexcept
  337. {
  338. return xtl::mpl::static_if<detail::has_linear_iterator<C>::value>(
  339. [&](auto self)
  340. {
  341. return self(c).linear_end();
  342. },
  343. /*else*/
  344. [&](auto self)
  345. {
  346. return self(c).end();
  347. }
  348. );
  349. }
  350. template <class C>
  351. XTENSOR_CONSTEXPR_RETURN auto linear_begin(const C& c) noexcept
  352. {
  353. return xtl::mpl::static_if<detail::has_linear_iterator<C>::value>(
  354. [&](auto self)
  355. {
  356. return self(c).linear_cbegin();
  357. },
  358. /*else*/
  359. [&](auto self)
  360. {
  361. return self(c).cbegin();
  362. }
  363. );
  364. }
  365. template <class C>
  366. XTENSOR_CONSTEXPR_RETURN auto linear_end(const C& c) noexcept
  367. {
  368. return xtl::mpl::static_if<detail::has_linear_iterator<C>::value>(
  369. [&](auto self)
  370. {
  371. return self(c).linear_cend();
  372. },
  373. /*else*/
  374. [&](auto self)
  375. {
  376. return self(c).cend();
  377. }
  378. );
  379. }
  380. /***************************
  381. * xstepper implementation *
  382. ***************************/
  383. template <class C>
  384. inline xstepper<C>::xstepper(storage_type* c, subiterator_type it, size_type offset) noexcept
  385. : p_c(c)
  386. , m_it(it)
  387. , m_offset(offset)
  388. {
  389. }
  390. template <class C>
  391. inline auto xstepper<C>::operator*() const -> reference
  392. {
  393. return *m_it;
  394. }
  395. template <class C>
  396. inline void xstepper<C>::step(size_type dim, size_type n)
  397. {
  398. if (dim >= m_offset)
  399. {
  400. using strides_value_type = typename std::decay_t<decltype(p_c->strides())>::value_type;
  401. m_it += difference_type(static_cast<strides_value_type>(n) * p_c->strides()[dim - m_offset]);
  402. }
  403. }
  404. template <class C>
  405. inline void xstepper<C>::step_back(size_type dim, size_type n)
  406. {
  407. if (dim >= m_offset)
  408. {
  409. using strides_value_type = typename std::decay_t<decltype(p_c->strides())>::value_type;
  410. m_it -= difference_type(static_cast<strides_value_type>(n) * p_c->strides()[dim - m_offset]);
  411. }
  412. }
  413. template <class C>
  414. inline void xstepper<C>::reset(size_type dim)
  415. {
  416. if (dim >= m_offset)
  417. {
  418. m_it -= difference_type(p_c->backstrides()[dim - m_offset]);
  419. }
  420. }
  421. template <class C>
  422. inline void xstepper<C>::reset_back(size_type dim)
  423. {
  424. if (dim >= m_offset)
  425. {
  426. m_it += difference_type(p_c->backstrides()[dim - m_offset]);
  427. }
  428. }
  429. template <class C>
  430. inline void xstepper<C>::to_begin()
  431. {
  432. m_it = p_c->data_xbegin();
  433. }
  434. template <class C>
  435. inline void xstepper<C>::to_end(layout_type l)
  436. {
  437. m_it = p_c->data_xend(l, m_offset);
  438. }
  439. namespace detail
  440. {
  441. template <class It>
  442. struct step_simd_invoker
  443. {
  444. template <class R>
  445. static R apply(const It& it)
  446. {
  447. R reg;
  448. return reg.load_unaligned(&(*it));
  449. // return reg;
  450. }
  451. };
  452. template <bool is_const, class T, class S, layout_type L>
  453. struct step_simd_invoker<xiterator<xscalar_stepper<is_const, T>, S, L>>
  454. {
  455. template <class R>
  456. static R apply(const xiterator<xscalar_stepper<is_const, T>, S, L>& it)
  457. {
  458. return R(*it);
  459. }
  460. };
  461. }
  462. template <class C>
  463. template <class T>
  464. inline auto xstepper<C>::step_simd() -> simd_return_type<T>
  465. {
  466. using simd_type = simd_return_type<T>;
  467. simd_type reg = detail::step_simd_invoker<subiterator_type>::template apply<simd_type>(m_it);
  468. m_it += xt_simd::revert_simd_traits<simd_type>::size;
  469. return reg;
  470. }
  471. template <class C>
  472. template <class R>
  473. inline void xstepper<C>::store_simd(const R& vec)
  474. {
  475. vec.store_unaligned(&(*m_it));
  476. m_it += xt_simd::revert_simd_traits<R>::size;
  477. ;
  478. }
  479. template <class C>
  480. void xstepper<C>::step_leading()
  481. {
  482. ++m_it;
  483. }
  484. template <>
  485. template <class S, class IT, class ST>
  486. void stepper_tools<layout_type::row_major>::increment_stepper(S& stepper, IT& index, const ST& shape)
  487. {
  488. using size_type = typename S::size_type;
  489. const size_type size = index.size();
  490. size_type i = size;
  491. while (i != 0)
  492. {
  493. --i;
  494. if (index[i] != shape[i] - 1)
  495. {
  496. ++index[i];
  497. stepper.step(i);
  498. return;
  499. }
  500. else
  501. {
  502. index[i] = 0;
  503. if (i != 0)
  504. {
  505. stepper.reset(i);
  506. }
  507. }
  508. }
  509. if (i == 0)
  510. {
  511. if (size != size_type(0))
  512. {
  513. std::transform(
  514. shape.cbegin(),
  515. shape.cend() - 1,
  516. index.begin(),
  517. [](const auto& v)
  518. {
  519. return v - 1;
  520. }
  521. );
  522. index[size - 1] = shape[size - 1];
  523. }
  524. stepper.to_end(layout_type::row_major);
  525. }
  526. }
  527. template <>
  528. template <class S, class IT, class ST>
  529. void stepper_tools<layout_type::row_major>::increment_stepper(
  530. S& stepper,
  531. IT& index,
  532. const ST& shape,
  533. typename S::size_type n
  534. )
  535. {
  536. using size_type = typename S::size_type;
  537. const size_type size = index.size();
  538. const size_type leading_i = size - 1;
  539. size_type i = size;
  540. while (i != 0 && n != 0)
  541. {
  542. --i;
  543. size_type inc = (i == leading_i) ? n : 1;
  544. if (xtl::cmp_less(index[i] + inc, shape[i]))
  545. {
  546. index[i] += inc;
  547. stepper.step(i, inc);
  548. n -= inc;
  549. if (i != leading_i || index.size() == 1)
  550. {
  551. i = index.size();
  552. }
  553. }
  554. else
  555. {
  556. if (i == leading_i)
  557. {
  558. size_type off = shape[i] - index[i] - 1;
  559. stepper.step(i, off);
  560. n -= off;
  561. }
  562. index[i] = 0;
  563. if (i != 0)
  564. {
  565. stepper.reset(i);
  566. }
  567. }
  568. }
  569. if (i == 0 && n != 0)
  570. {
  571. if (size != size_type(0))
  572. {
  573. std::transform(
  574. shape.cbegin(),
  575. shape.cend() - 1,
  576. index.begin(),
  577. [](const auto& v)
  578. {
  579. return v - 1;
  580. }
  581. );
  582. index[leading_i] = shape[leading_i];
  583. }
  584. stepper.to_end(layout_type::row_major);
  585. }
  586. }
  587. template <>
  588. template <class S, class IT, class ST>
  589. void stepper_tools<layout_type::row_major>::decrement_stepper(S& stepper, IT& index, const ST& shape)
  590. {
  591. using size_type = typename S::size_type;
  592. size_type i = index.size();
  593. while (i != 0)
  594. {
  595. --i;
  596. if (index[i] != 0)
  597. {
  598. --index[i];
  599. stepper.step_back(i);
  600. return;
  601. }
  602. else
  603. {
  604. index[i] = shape[i] - 1;
  605. if (i != 0)
  606. {
  607. stepper.reset_back(i);
  608. }
  609. }
  610. }
  611. if (i == 0)
  612. {
  613. stepper.to_begin();
  614. }
  615. }
  616. template <>
  617. template <class S, class IT, class ST>
  618. void stepper_tools<layout_type::row_major>::decrement_stepper(
  619. S& stepper,
  620. IT& index,
  621. const ST& shape,
  622. typename S::size_type n
  623. )
  624. {
  625. using size_type = typename S::size_type;
  626. size_type i = index.size();
  627. size_type leading_i = index.size() - 1;
  628. while (i != 0 && n != 0)
  629. {
  630. --i;
  631. size_type inc = (i == leading_i) ? n : 1;
  632. if (xtl::cmp_greater_equal(index[i], inc))
  633. {
  634. index[i] -= inc;
  635. stepper.step_back(i, inc);
  636. n -= inc;
  637. if (i != leading_i || index.size() == 1)
  638. {
  639. i = index.size();
  640. }
  641. }
  642. else
  643. {
  644. if (i == leading_i)
  645. {
  646. size_type off = index[i];
  647. stepper.step_back(i, off);
  648. n -= off;
  649. }
  650. index[i] = shape[i] - 1;
  651. if (i != 0)
  652. {
  653. stepper.reset_back(i);
  654. }
  655. }
  656. }
  657. if (i == 0 && n != 0)
  658. {
  659. stepper.to_begin();
  660. }
  661. }
  662. template <>
  663. template <class S, class IT, class ST>
  664. void stepper_tools<layout_type::column_major>::increment_stepper(S& stepper, IT& index, const ST& shape)
  665. {
  666. using size_type = typename S::size_type;
  667. const size_type size = index.size();
  668. size_type i = 0;
  669. while (i != size)
  670. {
  671. if (index[i] != shape[i] - 1)
  672. {
  673. ++index[i];
  674. stepper.step(i);
  675. return;
  676. }
  677. else
  678. {
  679. index[i] = 0;
  680. if (i != size - 1)
  681. {
  682. stepper.reset(i);
  683. }
  684. }
  685. ++i;
  686. }
  687. if (i == size)
  688. {
  689. if (size != size_type(0))
  690. {
  691. std::transform(
  692. shape.cbegin() + 1,
  693. shape.cend(),
  694. index.begin() + 1,
  695. [](const auto& v)
  696. {
  697. return v - 1;
  698. }
  699. );
  700. index[0] = shape[0];
  701. }
  702. stepper.to_end(layout_type::column_major);
  703. }
  704. }
  705. template <>
  706. template <class S, class IT, class ST>
  707. void stepper_tools<layout_type::column_major>::increment_stepper(
  708. S& stepper,
  709. IT& index,
  710. const ST& shape,
  711. typename S::size_type n
  712. )
  713. {
  714. using size_type = typename S::size_type;
  715. const size_type size = index.size();
  716. const size_type leading_i = 0;
  717. size_type i = 0;
  718. while (i != size && n != 0)
  719. {
  720. size_type inc = (i == leading_i) ? n : 1;
  721. if (index[i] + inc < shape[i])
  722. {
  723. index[i] += inc;
  724. stepper.step(i, inc);
  725. n -= inc;
  726. if (i != leading_i || size == 1)
  727. {
  728. i = 0;
  729. continue;
  730. }
  731. }
  732. else
  733. {
  734. if (i == leading_i)
  735. {
  736. size_type off = shape[i] - index[i] - 1;
  737. stepper.step(i, off);
  738. n -= off;
  739. }
  740. index[i] = 0;
  741. if (i != size - 1)
  742. {
  743. stepper.reset(i);
  744. }
  745. }
  746. ++i;
  747. }
  748. if (i == size && n != 0)
  749. {
  750. if (size != size_type(0))
  751. {
  752. std::transform(
  753. shape.cbegin() + 1,
  754. shape.cend(),
  755. index.begin() + 1,
  756. [](const auto& v)
  757. {
  758. return v - 1;
  759. }
  760. );
  761. index[leading_i] = shape[leading_i];
  762. }
  763. stepper.to_end(layout_type::column_major);
  764. }
  765. }
  766. template <>
  767. template <class S, class IT, class ST>
  768. void stepper_tools<layout_type::column_major>::decrement_stepper(S& stepper, IT& index, const ST& shape)
  769. {
  770. using size_type = typename S::size_type;
  771. size_type size = index.size();
  772. size_type i = 0;
  773. while (i != size)
  774. {
  775. if (index[i] != 0)
  776. {
  777. --index[i];
  778. stepper.step_back(i);
  779. return;
  780. }
  781. else
  782. {
  783. index[i] = shape[i] - 1;
  784. if (i != size - 1)
  785. {
  786. stepper.reset_back(i);
  787. }
  788. }
  789. ++i;
  790. }
  791. if (i == size)
  792. {
  793. stepper.to_begin();
  794. }
  795. }
  796. template <>
  797. template <class S, class IT, class ST>
  798. void stepper_tools<layout_type::column_major>::decrement_stepper(
  799. S& stepper,
  800. IT& index,
  801. const ST& shape,
  802. typename S::size_type n
  803. )
  804. {
  805. using size_type = typename S::size_type;
  806. size_type size = index.size();
  807. size_type i = 0;
  808. size_type leading_i = 0;
  809. while (i != size && n != 0)
  810. {
  811. size_type inc = (i == leading_i) ? n : 1;
  812. if (index[i] >= inc)
  813. {
  814. index[i] -= inc;
  815. stepper.step_back(i, inc);
  816. n -= inc;
  817. if (i != leading_i || index.size() == 1)
  818. {
  819. i = 0;
  820. continue;
  821. }
  822. }
  823. else
  824. {
  825. if (i == leading_i)
  826. {
  827. size_type off = index[i];
  828. stepper.step_back(i, off);
  829. n -= off;
  830. }
  831. index[i] = shape[i] - 1;
  832. if (i != size - 1)
  833. {
  834. stepper.reset_back(i);
  835. }
  836. }
  837. ++i;
  838. }
  839. if (i == size && n != 0)
  840. {
  841. stepper.to_begin();
  842. }
  843. }
  844. /***********************************
  845. * xindexed_stepper implementation *
  846. ***********************************/
  847. template <class C, bool is_const>
  848. inline xindexed_stepper<C, is_const>::xindexed_stepper(xexpression_type* e, size_type offset, bool end) noexcept
  849. : p_e(e)
  850. , m_index(xtl::make_sequence<index_type>(e->shape().size(), size_type(0)))
  851. , m_offset(offset)
  852. {
  853. if (end)
  854. {
  855. // Note: the layout here doesn't matter (unused) but using default traversal looks more "correct".
  856. to_end(XTENSOR_DEFAULT_TRAVERSAL);
  857. }
  858. }
  859. template <class C, bool is_const>
  860. inline auto xindexed_stepper<C, is_const>::operator*() const -> reference
  861. {
  862. return p_e->element(m_index.cbegin(), m_index.cend());
  863. }
  864. template <class C, bool is_const>
  865. inline void xindexed_stepper<C, is_const>::step(size_type dim, size_type n)
  866. {
  867. if (dim >= m_offset)
  868. {
  869. m_index[dim - m_offset] += static_cast<typename index_type::value_type>(n);
  870. }
  871. }
  872. template <class C, bool is_const>
  873. inline void xindexed_stepper<C, is_const>::step_back(size_type dim, size_type n)
  874. {
  875. if (dim >= m_offset)
  876. {
  877. m_index[dim - m_offset] -= static_cast<typename index_type::value_type>(n);
  878. }
  879. }
  880. template <class C, bool is_const>
  881. inline void xindexed_stepper<C, is_const>::reset(size_type dim)
  882. {
  883. if (dim >= m_offset)
  884. {
  885. m_index[dim - m_offset] = 0;
  886. }
  887. }
  888. template <class C, bool is_const>
  889. inline void xindexed_stepper<C, is_const>::reset_back(size_type dim)
  890. {
  891. if (dim >= m_offset)
  892. {
  893. m_index[dim - m_offset] = p_e->shape()[dim - m_offset] - 1;
  894. }
  895. }
  896. template <class C, bool is_const>
  897. inline void xindexed_stepper<C, is_const>::to_begin()
  898. {
  899. std::fill(m_index.begin(), m_index.end(), size_type(0));
  900. }
  901. template <class C, bool is_const>
  902. inline void xindexed_stepper<C, is_const>::to_end(layout_type l)
  903. {
  904. const auto& shape = p_e->shape();
  905. std::transform(
  906. shape.cbegin(),
  907. shape.cend(),
  908. m_index.begin(),
  909. [](const auto& v)
  910. {
  911. return v - 1;
  912. }
  913. );
  914. size_type l_dim = (l == layout_type::row_major) ? shape.size() - 1 : 0;
  915. m_index[l_dim] = shape[l_dim];
  916. }
  917. /****************************
  918. * xiterator implementation *
  919. ****************************/
  920. namespace detail
  921. {
  922. template <class S>
  923. inline shape_storage<S>::shape_storage(param_type shape)
  924. : m_shape(shape)
  925. {
  926. }
  927. template <class S>
  928. inline const S& shape_storage<S>::shape() const
  929. {
  930. return m_shape;
  931. }
  932. template <class S>
  933. inline shape_storage<S*>::shape_storage(param_type shape)
  934. : p_shape(shape)
  935. {
  936. }
  937. template <class S>
  938. inline const S& shape_storage<S*>::shape() const
  939. {
  940. return *p_shape;
  941. }
  942. template <>
  943. struct LAYOUT_FORBIDEN_FOR_XITERATOR<layout_type::row_major>
  944. {
  945. using type = int;
  946. };
  947. template <>
  948. struct LAYOUT_FORBIDEN_FOR_XITERATOR<layout_type::column_major>
  949. {
  950. using type = int;
  951. };
  952. }
  953. template <class St, class S, layout_type L>
  954. inline xiterator<St, S, L>::xiterator(St st, shape_param_type shape, bool end_index)
  955. : private_base(shape)
  956. , m_st(st)
  957. , m_index(
  958. end_index ? xtl::forward_sequence<index_type, const shape_type&>(this->shape())
  959. : xtl::make_sequence<index_type>(this->shape().size(), size_type(0))
  960. )
  961. , m_linear_index(0)
  962. {
  963. // end_index means either reverse_iterator && !end or !reverse_iterator && end
  964. if (end_index)
  965. {
  966. if (m_index.size() != size_type(0))
  967. {
  968. auto iter_begin = (L == layout_type::row_major) ? m_index.begin() : m_index.begin() + 1;
  969. auto iter_end = (L == layout_type::row_major) ? m_index.end() - 1 : m_index.end();
  970. std::transform(
  971. iter_begin,
  972. iter_end,
  973. iter_begin,
  974. [](const auto& v)
  975. {
  976. return v - 1;
  977. }
  978. );
  979. }
  980. m_linear_index = difference_type(std::accumulate(
  981. this->shape().cbegin(),
  982. this->shape().cend(),
  983. size_type(1),
  984. std::multiplies<size_type>()
  985. ));
  986. }
  987. }
  988. template <class St, class S, layout_type L>
  989. inline auto xiterator<St, S, L>::operator++() -> self_type&
  990. {
  991. stepper_tools<L>::increment_stepper(m_st, m_index, this->shape());
  992. ++m_linear_index;
  993. return *this;
  994. }
  995. template <class St, class S, layout_type L>
  996. inline auto xiterator<St, S, L>::operator--() -> self_type&
  997. {
  998. stepper_tools<L>::decrement_stepper(m_st, m_index, this->shape());
  999. --m_linear_index;
  1000. return *this;
  1001. }
  1002. template <class St, class S, layout_type L>
  1003. inline auto xiterator<St, S, L>::operator+=(difference_type n) -> self_type&
  1004. {
  1005. if (n >= 0)
  1006. {
  1007. stepper_tools<L>::increment_stepper(m_st, m_index, this->shape(), static_cast<size_type>(n));
  1008. }
  1009. else
  1010. {
  1011. stepper_tools<L>::decrement_stepper(m_st, m_index, this->shape(), static_cast<size_type>(-n));
  1012. }
  1013. m_linear_index += n;
  1014. return *this;
  1015. }
  1016. template <class St, class S, layout_type L>
  1017. inline auto xiterator<St, S, L>::operator-=(difference_type n) -> self_type&
  1018. {
  1019. if (n >= 0)
  1020. {
  1021. stepper_tools<L>::decrement_stepper(m_st, m_index, this->shape(), static_cast<size_type>(n));
  1022. }
  1023. else
  1024. {
  1025. stepper_tools<L>::increment_stepper(m_st, m_index, this->shape(), static_cast<size_type>(-n));
  1026. }
  1027. m_linear_index -= n;
  1028. return *this;
  1029. }
  1030. template <class St, class S, layout_type L>
  1031. inline auto xiterator<St, S, L>::operator-(const self_type& rhs) const -> difference_type
  1032. {
  1033. return m_linear_index - rhs.m_linear_index;
  1034. }
  1035. template <class St, class S, layout_type L>
  1036. inline auto xiterator<St, S, L>::operator*() const -> reference
  1037. {
  1038. return *m_st;
  1039. }
  1040. template <class St, class S, layout_type L>
  1041. inline auto xiterator<St, S, L>::operator->() const -> pointer
  1042. {
  1043. return &(*m_st);
  1044. }
  1045. template <class St, class S, layout_type L>
  1046. inline bool xiterator<St, S, L>::equal(const xiterator& rhs) const
  1047. {
  1048. XTENSOR_ASSERT(this->shape() == rhs.shape());
  1049. return m_linear_index == rhs.m_linear_index;
  1050. }
  1051. template <class St, class S, layout_type L>
  1052. inline bool xiterator<St, S, L>::less_than(const xiterator& rhs) const
  1053. {
  1054. XTENSOR_ASSERT(this->shape() == rhs.shape());
  1055. return m_linear_index < rhs.m_linear_index;
  1056. }
  1057. template <class St, class S, layout_type L>
  1058. inline bool operator==(const xiterator<St, S, L>& lhs, const xiterator<St, S, L>& rhs)
  1059. {
  1060. return lhs.equal(rhs);
  1061. }
  1062. template <class St, class S, layout_type L>
  1063. bool operator<(const xiterator<St, S, L>& lhs, const xiterator<St, S, L>& rhs)
  1064. {
  1065. return lhs.less_than(rhs);
  1066. }
  1067. /************************************
  1068. * xbounded_iterator implementation *
  1069. ************************************/
  1070. template <class It, class BIt>
  1071. xbounded_iterator<It, BIt>::xbounded_iterator(It it, BIt bound_it)
  1072. : m_it(it)
  1073. , m_bound_it(bound_it)
  1074. {
  1075. }
  1076. template <class It, class BIt>
  1077. inline auto xbounded_iterator<It, BIt>::operator++() -> self_type&
  1078. {
  1079. ++m_it;
  1080. ++m_bound_it;
  1081. return *this;
  1082. }
  1083. template <class It, class BIt>
  1084. inline auto xbounded_iterator<It, BIt>::operator--() -> self_type&
  1085. {
  1086. --m_it;
  1087. --m_bound_it;
  1088. return *this;
  1089. }
  1090. template <class It, class BIt>
  1091. inline auto xbounded_iterator<It, BIt>::operator+=(difference_type n) -> self_type&
  1092. {
  1093. m_it += n;
  1094. m_bound_it += n;
  1095. return *this;
  1096. }
  1097. template <class It, class BIt>
  1098. inline auto xbounded_iterator<It, BIt>::operator-=(difference_type n) -> self_type&
  1099. {
  1100. m_it -= n;
  1101. m_bound_it -= n;
  1102. return *this;
  1103. }
  1104. template <class It, class BIt>
  1105. inline auto xbounded_iterator<It, BIt>::operator-(const self_type& rhs) const -> difference_type
  1106. {
  1107. return m_it - rhs.m_it;
  1108. }
  1109. template <class It, class BIt>
  1110. inline auto xbounded_iterator<It, BIt>::operator*() const -> value_type
  1111. {
  1112. using type = decltype(*m_bound_it);
  1113. return (static_cast<type>(*m_it) < *m_bound_it) ? *m_it : static_cast<value_type>((*m_bound_it) - 1);
  1114. }
  1115. template <class It, class BIt>
  1116. inline bool xbounded_iterator<It, BIt>::equal(const self_type& rhs) const
  1117. {
  1118. return m_it == rhs.m_it && m_bound_it == rhs.m_bound_it;
  1119. }
  1120. template <class It, class BIt>
  1121. inline bool xbounded_iterator<It, BIt>::less_than(const self_type& rhs) const
  1122. {
  1123. return m_it < rhs.m_it;
  1124. }
  1125. template <class It, class BIt>
  1126. inline bool operator==(const xbounded_iterator<It, BIt>& lhs, const xbounded_iterator<It, BIt>& rhs)
  1127. {
  1128. return lhs.equal(rhs);
  1129. }
  1130. template <class It, class BIt>
  1131. inline bool operator<(const xbounded_iterator<It, BIt>& lhs, const xbounded_iterator<It, BIt>& rhs)
  1132. {
  1133. return lhs.less_than(rhs);
  1134. }
  1135. }
  1136. #endif