xcontainer.hpp 41 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192
  1. /***************************************************************************
  2. * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
  3. * Copyright (c) QuantStack *
  4. * *
  5. * Distributed under the terms of the BSD 3-Clause License. *
  6. * *
  7. * The full license is in the file LICENSE, distributed with this software. *
  8. ****************************************************************************/
  9. #ifndef XTENSOR_CONTAINER_HPP
  10. #define XTENSOR_CONTAINER_HPP
  11. #include <algorithm>
  12. #include <functional>
  13. #include <memory>
  14. #include <numeric>
  15. #include <stdexcept>
  16. #include <xtl/xmeta_utils.hpp>
  17. #include <xtl/xsequence.hpp>
  18. #include "xaccessible.hpp"
  19. #include "xiterable.hpp"
  20. #include "xiterator.hpp"
  21. #include "xmath.hpp"
  22. #include "xoperation.hpp"
  23. #include "xstrides.hpp"
  24. #include "xtensor_config.hpp"
  25. #include "xtensor_forward.hpp"
  26. namespace xt
  27. {
  28. template <class D>
  29. struct xcontainer_iterable_types
  30. {
  31. using inner_shape_type = typename xcontainer_inner_types<D>::inner_shape_type;
  32. using stepper = xstepper<D>;
  33. using const_stepper = xstepper<const D>;
  34. };
  35. namespace detail
  36. {
  37. template <class T>
  38. struct allocator_type_impl
  39. {
  40. using type = typename T::allocator_type;
  41. };
  42. template <class T, std::size_t N>
  43. struct allocator_type_impl<std::array<T, N>>
  44. {
  45. using type = std::allocator<T>; // fake allocator for testing
  46. };
  47. }
  48. template <class T>
  49. using allocator_type_t = typename detail::allocator_type_impl<T>::type;
  50. /**
  51. * @class xcontainer
  52. * @brief Base class for dense multidimensional containers.
  53. *
  54. * The xcontainer class defines the interface for dense multidimensional
  55. * container classes. It does not embed any data container, this responsibility
  56. * is delegated to the inheriting classes.
  57. *
  58. * @tparam D The derived type, i.e. the inheriting class for which xcontainer
  59. * provides the interface.
  60. */
  61. template <class D>
  62. class xcontainer : public xcontiguous_iterable<D>,
  63. private xaccessible<D>
  64. {
  65. public:
  66. using derived_type = D;
  67. using inner_types = xcontainer_inner_types<D>;
  68. using storage_type = typename inner_types::storage_type;
  69. using allocator_type = allocator_type_t<std::decay_t<storage_type>>;
  70. using value_type = typename storage_type::value_type;
  71. using reference = typename inner_types::reference;
  72. using const_reference = typename inner_types::const_reference;
  73. using pointer = typename storage_type::pointer;
  74. using const_pointer = typename storage_type::const_pointer;
  75. using size_type = typename inner_types::size_type;
  76. using difference_type = typename storage_type::difference_type;
  77. using simd_value_type = xt_simd::simd_type<value_type>;
  78. using bool_load_type = xt::bool_load_type<value_type>;
  79. using shape_type = typename inner_types::shape_type;
  80. using strides_type = typename inner_types::strides_type;
  81. using backstrides_type = typename inner_types::backstrides_type;
  82. using inner_shape_type = typename inner_types::inner_shape_type;
  83. using inner_strides_type = typename inner_types::inner_strides_type;
  84. using inner_backstrides_type = typename inner_types::inner_backstrides_type;
  85. using iterable_base = xcontiguous_iterable<D>;
  86. using stepper = typename iterable_base::stepper;
  87. using const_stepper = typename iterable_base::const_stepper;
  88. using accessible_base = xaccessible<D>;
  89. static constexpr layout_type static_layout = inner_types::layout;
  90. static constexpr bool contiguous_layout = static_layout != layout_type::dynamic;
  91. using data_alignment = xt_simd::container_alignment_t<storage_type>;
  92. using simd_type = xt_simd::simd_type<value_type>;
  93. using linear_iterator = typename iterable_base::linear_iterator;
  94. using const_linear_iterator = typename iterable_base::const_linear_iterator;
  95. using reverse_linear_iterator = typename iterable_base::reverse_linear_iterator;
  96. using const_reverse_linear_iterator = typename iterable_base::const_reverse_linear_iterator;
  97. static_assert(static_layout != layout_type::any, "Container layout can never be layout_type::any!");
  98. size_type size() const noexcept;
  99. XTENSOR_CONSTEXPR_RETURN size_type dimension() const noexcept;
  100. XTENSOR_CONSTEXPR_RETURN const inner_shape_type& shape() const noexcept;
  101. XTENSOR_CONSTEXPR_RETURN const inner_strides_type& strides() const noexcept;
  102. XTENSOR_CONSTEXPR_RETURN const inner_backstrides_type& backstrides() const noexcept;
  103. template <class T>
  104. void fill(const T& value);
  105. template <class... Args>
  106. reference operator()(Args... args);
  107. template <class... Args>
  108. const_reference operator()(Args... args) const;
  109. template <class... Args>
  110. reference unchecked(Args... args);
  111. template <class... Args>
  112. const_reference unchecked(Args... args) const;
  113. using accessible_base::at;
  114. using accessible_base::shape;
  115. using accessible_base::operator[];
  116. using accessible_base::back;
  117. using accessible_base::front;
  118. using accessible_base::in_bounds;
  119. using accessible_base::periodic;
  120. template <class It>
  121. reference element(It first, It last);
  122. template <class It>
  123. const_reference element(It first, It last) const;
  124. storage_type& storage() noexcept;
  125. const storage_type& storage() const noexcept;
  126. pointer data() noexcept;
  127. const_pointer data() const noexcept;
  128. const size_type data_offset() const noexcept;
  129. template <class S>
  130. bool broadcast_shape(S& shape, bool reuse_cache = false) const;
  131. template <class S>
  132. bool has_linear_assign(const S& strides) const noexcept;
  133. template <class S>
  134. stepper stepper_begin(const S& shape) noexcept;
  135. template <class S>
  136. stepper stepper_end(const S& shape, layout_type l) noexcept;
  137. template <class S>
  138. const_stepper stepper_begin(const S& shape) const noexcept;
  139. template <class S>
  140. const_stepper stepper_end(const S& shape, layout_type l) const noexcept;
  141. reference data_element(size_type i);
  142. const_reference data_element(size_type i) const;
  143. reference flat(size_type i);
  144. const_reference flat(size_type i) const;
  145. template <class requested_type>
  146. using simd_return_type = xt_simd::simd_return_type<value_type, requested_type>;
  147. template <class align, class simd>
  148. void store_simd(size_type i, const simd& e);
  149. template <class align, class requested_type = value_type, std::size_t N = xt_simd::simd_traits<requested_type>::size>
  150. container_simd_return_type_t<storage_type, value_type, requested_type>
  151. /*simd_return_type<requested_type>*/ load_simd(size_type i) const;
  152. linear_iterator linear_begin() noexcept;
  153. linear_iterator linear_end() noexcept;
  154. const_linear_iterator linear_begin() const noexcept;
  155. const_linear_iterator linear_end() const noexcept;
  156. const_linear_iterator linear_cbegin() const noexcept;
  157. const_linear_iterator linear_cend() const noexcept;
  158. reverse_linear_iterator linear_rbegin() noexcept;
  159. reverse_linear_iterator linear_rend() noexcept;
  160. const_reverse_linear_iterator linear_rbegin() const noexcept;
  161. const_reverse_linear_iterator linear_rend() const noexcept;
  162. const_reverse_linear_iterator linear_crbegin() const noexcept;
  163. const_reverse_linear_iterator linear_crend() const noexcept;
  164. using container_iterator = linear_iterator;
  165. using const_container_iterator = const_linear_iterator;
  166. protected:
  167. xcontainer() = default;
  168. ~xcontainer() = default;
  169. xcontainer(const xcontainer&) = default;
  170. xcontainer& operator=(const xcontainer&) = default;
  171. xcontainer(xcontainer&&) = default;
  172. xcontainer& operator=(xcontainer&&) = default;
  173. container_iterator data_xbegin() noexcept;
  174. const_container_iterator data_xbegin() const noexcept;
  175. container_iterator data_xend(layout_type l, size_type offset) noexcept;
  176. const_container_iterator data_xend(layout_type l, size_type offset) const noexcept;
  177. protected:
  178. derived_type& derived_cast() & noexcept;
  179. const derived_type& derived_cast() const& noexcept;
  180. derived_type derived_cast() && noexcept;
  181. private:
  182. template <class It>
  183. It data_xend_impl(It begin, layout_type l, size_type offset) const noexcept;
  184. inner_shape_type& mutable_shape();
  185. inner_strides_type& mutable_strides();
  186. inner_backstrides_type& mutable_backstrides();
  187. template <class C>
  188. friend class xstepper;
  189. friend class xaccessible<D>;
  190. friend class xconst_accessible<D>;
  191. };
  192. /**
  193. * @class xstrided_container
  194. * @brief Partial implementation of xcontainer that embeds the strides and the shape
  195. *
  196. * The xstrided_container class is a partial implementation of the xcontainer interface
  197. * that embed the strides and the shape of the multidimensional container. It does
  198. * not embed the data container, this responsibility is delegated to the inheriting
  199. * classes.
  200. *
  201. * @tparam D The derived type, i.e. the inheriting class for which xstrided_container
  202. * provides the partial imlpementation of xcontainer.
  203. */
  204. template <class D>
  205. class xstrided_container : public xcontainer<D>
  206. {
  207. public:
  208. using base_type = xcontainer<D>;
  209. using storage_type = typename base_type::storage_type;
  210. using value_type = typename base_type::value_type;
  211. using reference = typename base_type::reference;
  212. using const_reference = typename base_type::const_reference;
  213. using pointer = typename base_type::pointer;
  214. using const_pointer = typename base_type::const_pointer;
  215. using size_type = typename base_type::size_type;
  216. using shape_type = typename base_type::shape_type;
  217. using strides_type = typename base_type::strides_type;
  218. using inner_shape_type = typename base_type::inner_shape_type;
  219. using inner_strides_type = typename base_type::inner_strides_type;
  220. using inner_backstrides_type = typename base_type::inner_backstrides_type;
  221. template <class S = shape_type>
  222. void resize(S&& shape, bool force = false);
  223. template <class S = shape_type>
  224. void resize(S&& shape, layout_type l);
  225. template <class S = shape_type>
  226. void resize(S&& shape, const strides_type& strides);
  227. template <class S = shape_type>
  228. auto& reshape(S&& shape, layout_type layout = base_type::static_layout) &;
  229. template <class T>
  230. auto& reshape(std::initializer_list<T> shape, layout_type layout = base_type::static_layout) &;
  231. layout_type layout() const noexcept;
  232. bool is_contiguous() const noexcept;
  233. protected:
  234. xstrided_container() noexcept;
  235. ~xstrided_container() = default;
  236. xstrided_container(const xstrided_container&) = default;
  237. xstrided_container& operator=(const xstrided_container&) = default;
  238. xstrided_container(xstrided_container&&) = default;
  239. xstrided_container& operator=(xstrided_container&&) = default;
  240. explicit xstrided_container(inner_shape_type&&, inner_strides_type&&) noexcept;
  241. explicit xstrided_container(inner_shape_type&&, inner_strides_type&&, inner_backstrides_type&&, layout_type&&) noexcept;
  242. inner_shape_type& shape_impl() noexcept;
  243. const inner_shape_type& shape_impl() const noexcept;
  244. inner_strides_type& strides_impl() noexcept;
  245. const inner_strides_type& strides_impl() const noexcept;
  246. inner_backstrides_type& backstrides_impl() noexcept;
  247. const inner_backstrides_type& backstrides_impl() const noexcept;
  248. template <class S = shape_type>
  249. void reshape_impl(S&& shape, std::true_type, layout_type layout = base_type::static_layout);
  250. template <class S = shape_type>
  251. void reshape_impl(S&& shape, std::false_type, layout_type layout = base_type::static_layout);
  252. layout_type& mutable_layout() noexcept;
  253. private:
  254. inner_shape_type m_shape;
  255. inner_strides_type m_strides;
  256. inner_backstrides_type m_backstrides;
  257. layout_type m_layout = base_type::static_layout;
  258. };
  259. /******************************
  260. * xcontainer implementation *
  261. ******************************/
  262. template <class D>
  263. template <class It>
  264. inline It xcontainer<D>::data_xend_impl(It begin, layout_type l, size_type offset) const noexcept
  265. {
  266. return strided_data_end(*this, begin, l, offset);
  267. }
  268. template <class D>
  269. inline auto xcontainer<D>::mutable_shape() -> inner_shape_type&
  270. {
  271. return derived_cast().shape_impl();
  272. }
  273. template <class D>
  274. inline auto xcontainer<D>::mutable_strides() -> inner_strides_type&
  275. {
  276. return derived_cast().strides_impl();
  277. }
  278. template <class D>
  279. inline auto xcontainer<D>::mutable_backstrides() -> inner_backstrides_type&
  280. {
  281. return derived_cast().backstrides_impl();
  282. }
  283. /**
  284. * @name Size and shape
  285. */
  286. //@{
  287. /**
  288. * Returns the number of element in the container.
  289. */
  290. template <class D>
  291. inline auto xcontainer<D>::size() const noexcept -> size_type
  292. {
  293. return contiguous_layout ? storage().size() : compute_size(shape());
  294. }
  295. /**
  296. * Returns the number of dimensions of the container.
  297. */
  298. template <class D>
  299. XTENSOR_CONSTEXPR_RETURN auto xcontainer<D>::dimension() const noexcept -> size_type
  300. {
  301. return shape().size();
  302. }
  303. /**
  304. * Returns the shape of the container.
  305. */
  306. template <class D>
  307. XTENSOR_CONSTEXPR_RETURN auto xcontainer<D>::shape() const noexcept -> const inner_shape_type&
  308. {
  309. return derived_cast().shape_impl();
  310. }
  311. /**
  312. * Returns the strides of the container.
  313. */
  314. template <class D>
  315. XTENSOR_CONSTEXPR_RETURN auto xcontainer<D>::strides() const noexcept -> const inner_strides_type&
  316. {
  317. return derived_cast().strides_impl();
  318. }
  319. /**
  320. * Returns the backstrides of the container.
  321. */
  322. template <class D>
  323. XTENSOR_CONSTEXPR_RETURN auto xcontainer<D>::backstrides() const noexcept -> const inner_backstrides_type&
  324. {
  325. return derived_cast().backstrides_impl();
  326. }
  327. //@}
  328. /**
  329. * @name Data
  330. */
  331. //@{
  332. /**
  333. * Fills the container with the given value.
  334. * @param value the value to fill the container with.
  335. */
  336. template <class D>
  337. template <class T>
  338. inline void xcontainer<D>::fill(const T& value)
  339. {
  340. if (contiguous_layout)
  341. {
  342. std::fill(this->linear_begin(), this->linear_end(), value);
  343. }
  344. else
  345. {
  346. std::fill(this->begin(), this->end(), value);
  347. }
  348. }
  349. /**
  350. * Returns a reference to the element at the specified position in the container.
  351. * @param args a list of indices specifying the position in the container. Indices
  352. * must be unsigned integers, the number of indices should be equal or greater than
  353. * the number of dimensions of the container.
  354. */
  355. template <class D>
  356. template <class... Args>
  357. inline auto xcontainer<D>::operator()(Args... args) -> reference
  358. {
  359. XTENSOR_TRY(check_index(shape(), args...));
  360. XTENSOR_CHECK_DIMENSION(shape(), args...);
  361. size_type index = xt::data_offset<size_type>(strides(), args...);
  362. return storage()[index];
  363. }
  364. /**
  365. * Returns a constant reference to the element at the specified position in the container.
  366. * @param args a list of indices specifying the position in the container. Indices
  367. * must be unsigned integers, the number of indices should be equal or greater than
  368. * the number of dimensions of the container.
  369. */
  370. template <class D>
  371. template <class... Args>
  372. inline auto xcontainer<D>::operator()(Args... args) const -> const_reference
  373. {
  374. XTENSOR_TRY(check_index(shape(), args...));
  375. XTENSOR_CHECK_DIMENSION(shape(), args...);
  376. size_type index = xt::data_offset<size_type>(strides(), args...);
  377. return storage()[index];
  378. }
  379. /**
  380. * Returns a reference to the element at the specified position in the container.
  381. * @param args a list of indices specifying the position in the container. Indices
  382. * must be unsigned integers, the number of indices must be equal to the number of
  383. * dimensions of the container, else the behavior is undefined.
  384. *
  385. * @warning This method is meant for performance, for expressions with a dynamic
  386. * number of dimensions (i.e. not known at compile time). Since it may have
  387. * undefined behavior (see parameters), operator() should be preferred whenever
  388. * it is possible.
  389. * @warning This method is NOT compatible with broadcasting, meaning the following
  390. * code has undefined behavior:
  391. * @code{.cpp}
  392. * xt::xarray<double> a = {{0, 1}, {2, 3}};
  393. * xt::xarray<double> b = {0, 1};
  394. * auto fd = a + b;
  395. * double res = fd.uncheked(0, 1);
  396. * @endcode
  397. */
  398. template <class D>
  399. template <class... Args>
  400. inline auto xcontainer<D>::unchecked(Args... args) -> reference
  401. {
  402. size_type index = xt::unchecked_data_offset<size_type, static_layout>(
  403. strides(),
  404. static_cast<std::ptrdiff_t>(args)...
  405. );
  406. return storage()[index];
  407. }
  408. /**
  409. * Returns a constant reference to the element at the specified position in the container.
  410. * @param args a list of indices specifying the position in the container. Indices
  411. * must be unsigned integers, the number of indices must be equal to the number of
  412. * dimensions of the container, else the behavior is undefined.
  413. *
  414. * @warning This method is meant for performance, for expressions with a dynamic
  415. * number of dimensions (i.e. not known at compile time). Since it may have
  416. * undefined behavior (see parameters), operator() should be preferred whenever
  417. * it is possible.
  418. * @warning This method is NOT compatible with broadcasting, meaning the following
  419. * code has undefined behavior:
  420. * @code{.cpp}
  421. * xt::xarray<double> a = {{0, 1}, {2, 3}};
  422. * xt::xarray<double> b = {0, 1};
  423. * auto fd = a + b;
  424. * double res = fd.uncheked(0, 1);
  425. * @endcode
  426. */
  427. template <class D>
  428. template <class... Args>
  429. inline auto xcontainer<D>::unchecked(Args... args) const -> const_reference
  430. {
  431. size_type index = xt::unchecked_data_offset<size_type, static_layout>(
  432. strides(),
  433. static_cast<std::ptrdiff_t>(args)...
  434. );
  435. return storage()[index];
  436. }
  437. /**
  438. * Returns a reference to the element at the specified position in the container.
  439. * @param first iterator starting the sequence of indices
  440. * @param last iterator ending the sequence of indices
  441. * The number of indices in the sequence should be equal to or greater
  442. * than the number of dimensions of the container.
  443. */
  444. template <class D>
  445. template <class It>
  446. inline auto xcontainer<D>::element(It first, It last) -> reference
  447. {
  448. XTENSOR_TRY(check_element_index(shape(), first, last));
  449. return storage()[element_offset<size_type>(strides(), first, last)];
  450. }
  451. /**
  452. * Returns a reference to the element at the specified position in the container.
  453. * @param first iterator starting the sequence of indices
  454. * @param last iterator ending the sequence of indices
  455. * The number of indices in the sequence should be equal to or greater
  456. * than the number of dimensions of the container.
  457. */
  458. template <class D>
  459. template <class It>
  460. inline auto xcontainer<D>::element(It first, It last) const -> const_reference
  461. {
  462. XTENSOR_TRY(check_element_index(shape(), first, last));
  463. return storage()[element_offset<size_type>(strides(), first, last)];
  464. }
  465. /**
  466. * Returns a reference to the buffer containing the elements of the container.
  467. */
  468. template <class D>
  469. inline auto xcontainer<D>::storage() noexcept -> storage_type&
  470. {
  471. return derived_cast().storage_impl();
  472. }
  473. /**
  474. * Returns a constant reference to the buffer containing the elements of the
  475. * container.
  476. */
  477. template <class D>
  478. inline auto xcontainer<D>::storage() const noexcept -> const storage_type&
  479. {
  480. return derived_cast().storage_impl();
  481. }
  482. /**
  483. * Returns a pointer to the underlying array serving as element storage. The pointer
  484. * is such that range [data(); data() + size()] is always a valid range, even if the
  485. * container is empty (data() is not is not dereferenceable in that case)
  486. */
  487. template <class D>
  488. inline auto xcontainer<D>::data() noexcept -> pointer
  489. {
  490. return storage().data();
  491. }
  492. /**
  493. * Returns a constant pointer to the underlying array serving as element storage. The pointer
  494. * is such that range [data(); data() + size()] is always a valid range, even if the
  495. * container is empty (data() is not is not dereferenceable in that case)
  496. */
  497. template <class D>
  498. inline auto xcontainer<D>::data() const noexcept -> const_pointer
  499. {
  500. return storage().data();
  501. }
  502. /**
  503. * Returns the offset to the first element in the container.
  504. */
  505. template <class D>
  506. inline auto xcontainer<D>::data_offset() const noexcept -> const size_type
  507. {
  508. return size_type(0);
  509. }
  510. //@}
  511. /**
  512. * @name Broadcasting
  513. */
  514. //@{
  515. /**
  516. * Broadcast the shape of the container to the specified parameter.
  517. * @param shape the result shape
  518. * @param reuse_cache parameter for internal optimization
  519. * @return a boolean indicating whether the broadcasting is trivial
  520. */
  521. template <class D>
  522. template <class S>
  523. inline bool xcontainer<D>::broadcast_shape(S& shape, bool) const
  524. {
  525. return xt::broadcast_shape(this->shape(), shape);
  526. }
  527. /**
  528. * Checks whether the xcontainer can be linearly assigned to an expression
  529. * with the specified strides.
  530. * @return a boolean indicating whether a linear assign is possible
  531. */
  532. template <class D>
  533. template <class S>
  534. inline bool xcontainer<D>::has_linear_assign(const S& str) const noexcept
  535. {
  536. return str.size() == strides().size() && std::equal(str.cbegin(), str.cend(), strides().begin());
  537. }
  538. //@}
  539. template <class D>
  540. inline auto xcontainer<D>::derived_cast() const& noexcept -> const derived_type&
  541. {
  542. return *static_cast<const derived_type*>(this);
  543. }
  544. template <class D>
  545. inline auto xcontainer<D>::derived_cast() && noexcept -> derived_type
  546. {
  547. return *static_cast<derived_type*>(this);
  548. }
  549. template <class D>
  550. inline auto xcontainer<D>::data_element(size_type i) -> reference
  551. {
  552. return storage()[i];
  553. }
  554. template <class D>
  555. inline auto xcontainer<D>::data_element(size_type i) const -> const_reference
  556. {
  557. return storage()[i];
  558. }
  559. /**
  560. * Returns a reference to the element at the specified position in the container
  561. * storage (as if it was one dimensional).
  562. * @param i index specifying the position in the storage.
  563. * Must be smaller than the number of elements in the container.
  564. */
  565. template <class D>
  566. inline auto xcontainer<D>::flat(size_type i) -> reference
  567. {
  568. XTENSOR_ASSERT(i < size());
  569. return storage()[i];
  570. }
  571. /**
  572. * Returns a constant reference to the element at the specified position in the container
  573. * storage (as if it was one dimensional).
  574. * @param i index specifying the position in the storage.
  575. * Must be smaller than the number of elements in the container.
  576. */
  577. template <class D>
  578. inline auto xcontainer<D>::flat(size_type i) const -> const_reference
  579. {
  580. XTENSOR_ASSERT(i < size());
  581. return storage()[i];
  582. }
  583. /***************
  584. * stepper api *
  585. ***************/
  586. template <class D>
  587. template <class S>
  588. inline auto xcontainer<D>::stepper_begin(const S& shape) noexcept -> stepper
  589. {
  590. size_type offset = shape.size() - dimension();
  591. return stepper(static_cast<derived_type*>(this), data_xbegin(), offset);
  592. }
  593. template <class D>
  594. template <class S>
  595. inline auto xcontainer<D>::stepper_end(const S& shape, layout_type l) noexcept -> stepper
  596. {
  597. size_type offset = shape.size() - dimension();
  598. return stepper(static_cast<derived_type*>(this), data_xend(l, offset), offset);
  599. }
  600. template <class D>
  601. template <class S>
  602. inline auto xcontainer<D>::stepper_begin(const S& shape) const noexcept -> const_stepper
  603. {
  604. size_type offset = shape.size() - dimension();
  605. return const_stepper(static_cast<const derived_type*>(this), data_xbegin(), offset);
  606. }
  607. template <class D>
  608. template <class S>
  609. inline auto xcontainer<D>::stepper_end(const S& shape, layout_type l) const noexcept -> const_stepper
  610. {
  611. size_type offset = shape.size() - dimension();
  612. return const_stepper(static_cast<const derived_type*>(this), data_xend(l, offset), offset);
  613. }
  614. template <class D>
  615. inline auto xcontainer<D>::data_xbegin() noexcept -> container_iterator
  616. {
  617. return storage().begin();
  618. }
  619. template <class D>
  620. inline auto xcontainer<D>::data_xbegin() const noexcept -> const_container_iterator
  621. {
  622. return storage().cbegin();
  623. }
  624. template <class D>
  625. inline auto xcontainer<D>::data_xend(layout_type l, size_type offset) noexcept -> container_iterator
  626. {
  627. return data_xend_impl(storage().begin(), l, offset);
  628. }
  629. template <class D>
  630. inline auto xcontainer<D>::data_xend(layout_type l, size_type offset) const noexcept
  631. -> const_container_iterator
  632. {
  633. return data_xend_impl(storage().cbegin(), l, offset);
  634. }
  635. template <class D>
  636. template <class alignment, class simd>
  637. inline void xcontainer<D>::store_simd(size_type i, const simd& e)
  638. {
  639. using align_mode = driven_align_mode_t<alignment, data_alignment>;
  640. xt_simd::store_as(std::addressof(storage()[i]), e, align_mode());
  641. }
  642. template <class D>
  643. template <class alignment, class requested_type, std::size_t N>
  644. inline auto xcontainer<D>::load_simd(size_type i) const
  645. -> container_simd_return_type_t<storage_type, value_type, requested_type>
  646. {
  647. using align_mode = driven_align_mode_t<alignment, data_alignment>;
  648. return xt_simd::load_as<requested_type>(std::addressof(storage()[i]), align_mode());
  649. }
  650. template <class D>
  651. inline auto xcontainer<D>::linear_begin() noexcept -> linear_iterator
  652. {
  653. return storage().begin();
  654. }
  655. template <class D>
  656. inline auto xcontainer<D>::linear_end() noexcept -> linear_iterator
  657. {
  658. return storage().end();
  659. }
  660. template <class D>
  661. inline auto xcontainer<D>::linear_begin() const noexcept -> const_linear_iterator
  662. {
  663. return storage().begin();
  664. }
  665. template <class D>
  666. inline auto xcontainer<D>::linear_end() const noexcept -> const_linear_iterator
  667. {
  668. return storage().cend();
  669. }
  670. template <class D>
  671. inline auto xcontainer<D>::linear_cbegin() const noexcept -> const_linear_iterator
  672. {
  673. return storage().cbegin();
  674. }
  675. template <class D>
  676. inline auto xcontainer<D>::linear_cend() const noexcept -> const_linear_iterator
  677. {
  678. return storage().cend();
  679. }
  680. template <class D>
  681. inline auto xcontainer<D>::linear_rbegin() noexcept -> reverse_linear_iterator
  682. {
  683. return storage().rbegin();
  684. }
  685. template <class D>
  686. inline auto xcontainer<D>::linear_rend() noexcept -> reverse_linear_iterator
  687. {
  688. return storage().rend();
  689. }
  690. template <class D>
  691. inline auto xcontainer<D>::linear_rbegin() const noexcept -> const_reverse_linear_iterator
  692. {
  693. return storage().rbegin();
  694. }
  695. template <class D>
  696. inline auto xcontainer<D>::linear_rend() const noexcept -> const_reverse_linear_iterator
  697. {
  698. return storage().rend();
  699. }
  700. template <class D>
  701. inline auto xcontainer<D>::linear_crbegin() const noexcept -> const_reverse_linear_iterator
  702. {
  703. return storage().crbegin();
  704. }
  705. template <class D>
  706. inline auto xcontainer<D>::linear_crend() const noexcept -> const_reverse_linear_iterator
  707. {
  708. return storage().crend();
  709. }
  710. template <class D>
  711. inline auto xcontainer<D>::derived_cast() & noexcept -> derived_type&
  712. {
  713. return *static_cast<derived_type*>(this);
  714. }
  715. /*************************************
  716. * xstrided_container implementation *
  717. *************************************/
  718. template <class D>
  719. inline xstrided_container<D>::xstrided_container() noexcept
  720. : base_type()
  721. {
  722. m_shape = xtl::make_sequence<inner_shape_type>(base_type::dimension(), 0);
  723. m_strides = xtl::make_sequence<inner_strides_type>(base_type::dimension(), 0);
  724. m_backstrides = xtl::make_sequence<inner_backstrides_type>(base_type::dimension(), 0);
  725. }
  726. template <class D>
  727. inline xstrided_container<D>::xstrided_container(inner_shape_type&& shape, inner_strides_type&& strides) noexcept
  728. : base_type()
  729. , m_shape(std::move(shape))
  730. , m_strides(std::move(strides))
  731. {
  732. m_backstrides = xtl::make_sequence<inner_backstrides_type>(m_shape.size(), 0);
  733. adapt_strides(m_shape, m_strides, m_backstrides);
  734. }
  735. template <class D>
  736. inline xstrided_container<D>::xstrided_container(
  737. inner_shape_type&& shape,
  738. inner_strides_type&& strides,
  739. inner_backstrides_type&& backstrides,
  740. layout_type&& layout
  741. ) noexcept
  742. : base_type()
  743. , m_shape(std::move(shape))
  744. , m_strides(std::move(strides))
  745. , m_backstrides(std::move(backstrides))
  746. , m_layout(std::move(layout))
  747. {
  748. }
  749. template <class D>
  750. inline auto xstrided_container<D>::shape_impl() noexcept -> inner_shape_type&
  751. {
  752. return m_shape;
  753. }
  754. template <class D>
  755. inline auto xstrided_container<D>::shape_impl() const noexcept -> const inner_shape_type&
  756. {
  757. return m_shape;
  758. }
  759. template <class D>
  760. inline auto xstrided_container<D>::strides_impl() noexcept -> inner_strides_type&
  761. {
  762. return m_strides;
  763. }
  764. template <class D>
  765. inline auto xstrided_container<D>::strides_impl() const noexcept -> const inner_strides_type&
  766. {
  767. return m_strides;
  768. }
  769. template <class D>
  770. inline auto xstrided_container<D>::backstrides_impl() noexcept -> inner_backstrides_type&
  771. {
  772. return m_backstrides;
  773. }
  774. template <class D>
  775. inline auto xstrided_container<D>::backstrides_impl() const noexcept -> const inner_backstrides_type&
  776. {
  777. return m_backstrides;
  778. }
  779. /**
  780. * Return the layout_type of the container
  781. * @return layout_type of the container
  782. */
  783. template <class D>
  784. inline layout_type xstrided_container<D>::layout() const noexcept
  785. {
  786. return m_layout;
  787. }
  788. template <class D>
  789. inline bool xstrided_container<D>::is_contiguous() const noexcept
  790. {
  791. using str_type = typename inner_strides_type::value_type;
  792. auto is_zero = [](auto i)
  793. {
  794. return i == 0;
  795. };
  796. if (!is_contiguous_container<storage_type>::value)
  797. {
  798. return false;
  799. }
  800. // We need to make sure the inner-most non-zero stride is one.
  801. // Trailing zero strides are ignored because they indicate bradcasted dimensions.
  802. if (m_layout == layout_type::row_major)
  803. {
  804. auto it = std::find_if_not(m_strides.rbegin(), m_strides.rend(), is_zero);
  805. // If the array has strides of zero, it is a constant, and therefore contiguous.
  806. return it == m_strides.rend() || *it == str_type(1);
  807. }
  808. else if (m_layout == layout_type::column_major)
  809. {
  810. auto it = std::find_if_not(m_strides.begin(), m_strides.end(), is_zero);
  811. // If the array has strides of zero, it is a constant, and therefore contiguous.
  812. return it == m_strides.end() || *it == str_type(1);
  813. }
  814. else
  815. {
  816. return m_strides.empty();
  817. }
  818. }
  819. namespace detail
  820. {
  821. template <class C, class S>
  822. inline void resize_data_container(C& c, S size)
  823. {
  824. xt::resize_container(c, size);
  825. }
  826. template <class C, class S>
  827. inline void resize_data_container(const C& c, S size)
  828. {
  829. (void) c; // remove unused parameter warning
  830. (void) size;
  831. XTENSOR_ASSERT_MSG(c.size() == size, "Trying to resize const data container with wrong size.");
  832. }
  833. template <class S, class T>
  834. constexpr bool check_resize_dimension(const S&, const T&)
  835. {
  836. return true;
  837. }
  838. template <class T, size_t N, class S>
  839. constexpr bool check_resize_dimension(const std::array<T, N>&, const S& s)
  840. {
  841. return N == s.size();
  842. }
  843. }
  844. /**
  845. * Resizes the container.
  846. * @warning Contrary to STL containers like std::vector, resize
  847. * does NOT preserve the container elements.
  848. * @param shape the new shape
  849. * @param force force reshaping, even if the shape stays the same (default: false)
  850. */
  851. template <class D>
  852. template <class S>
  853. inline void xstrided_container<D>::resize(S&& shape, bool force)
  854. {
  855. XTENSOR_ASSERT_MSG(
  856. detail::check_resize_dimension(m_shape, shape),
  857. "cannot change the number of dimensions of xtensor"
  858. )
  859. std::size_t dim = shape.size();
  860. if (m_shape.size() != dim || !std::equal(std::begin(shape), std::end(shape), std::begin(m_shape))
  861. || force)
  862. {
  863. if (D::static_layout == layout_type::dynamic && m_layout == layout_type::dynamic)
  864. {
  865. m_layout = XTENSOR_DEFAULT_LAYOUT; // fall back to default layout
  866. }
  867. m_shape = xtl::forward_sequence<shape_type, S>(shape);
  868. resize_container(m_strides, dim);
  869. resize_container(m_backstrides, dim);
  870. size_type data_size = compute_strides<D::static_layout>(m_shape, m_layout, m_strides, m_backstrides);
  871. detail::resize_data_container(this->storage(), data_size);
  872. }
  873. }
  874. /**
  875. * Resizes the container.
  876. * @warning Contrary to STL containers like std::vector, resize
  877. * does NOT preserve the container elements.
  878. * @param shape the new shape
  879. * @param l the new layout_type
  880. */
  881. template <class D>
  882. template <class S>
  883. inline void xstrided_container<D>::resize(S&& shape, layout_type l)
  884. {
  885. XTENSOR_ASSERT_MSG(
  886. detail::check_resize_dimension(m_shape, shape),
  887. "cannot change the number of dimensions of xtensor"
  888. )
  889. if (base_type::static_layout != layout_type::dynamic && l != base_type::static_layout)
  890. {
  891. XTENSOR_THROW(
  892. std::runtime_error,
  893. "Cannot change layout_type if template parameter not layout_type::dynamic."
  894. );
  895. }
  896. m_layout = l;
  897. resize(std::forward<S>(shape), true);
  898. }
  899. /**
  900. * Resizes the container.
  901. * @warning Contrary to STL containers like std::vector, resize
  902. * does NOT preserve the container elements.
  903. * @param shape the new shape
  904. * @param strides the new strides
  905. */
  906. template <class D>
  907. template <class S>
  908. inline void xstrided_container<D>::resize(S&& shape, const strides_type& strides)
  909. {
  910. XTENSOR_ASSERT_MSG(
  911. detail::check_resize_dimension(m_shape, shape),
  912. "cannot change the number of dimensions of xtensor"
  913. )
  914. if (base_type::static_layout != layout_type::dynamic)
  915. {
  916. XTENSOR_THROW(
  917. std::runtime_error,
  918. "Cannot resize with custom strides when layout() is != layout_type::dynamic."
  919. );
  920. }
  921. m_shape = xtl::forward_sequence<shape_type, S>(shape);
  922. m_strides = strides;
  923. resize_container(m_backstrides, m_strides.size());
  924. adapt_strides(m_shape, m_strides, m_backstrides);
  925. m_layout = layout_type::dynamic;
  926. detail::resize_data_container(this->storage(), compute_size(m_shape));
  927. }
  928. /**
  929. * Reshapes the container and keeps old elements. The `shape` argument can have one of its value
  930. * equal to `-1`, in this case the value is inferred from the number of elements in the container
  931. * and the remaining values in the `shape`.
  932. * @code{.cpp}
  933. * xt::xarray<int> a = { 1, 2, 3, 4, 5, 6, 7, 8 };
  934. * a.reshape({-1, 4});
  935. * //a.shape() is {2, 4}
  936. * @endcode
  937. * @param shape the new shape (has to have same number of elements as the original container)
  938. * @param layout the layout to compute the strides (defaults to static layout of the container,
  939. * or for a container with dynamic layout to XTENSOR_DEFAULT_LAYOUT)
  940. */
  941. template <class D>
  942. template <class S>
  943. inline auto& xstrided_container<D>::reshape(S&& shape, layout_type layout) &
  944. {
  945. reshape_impl(
  946. std::forward<S>(shape),
  947. xtl::is_signed<std::decay_t<typename std::decay_t<S>::value_type>>(),
  948. std::forward<layout_type>(layout)
  949. );
  950. return this->derived_cast();
  951. }
  952. template <class D>
  953. template <class T>
  954. inline auto& xstrided_container<D>::reshape(std::initializer_list<T> shape, layout_type layout) &
  955. {
  956. using sh_type = rebind_container_t<T, shape_type>;
  957. sh_type sh = xtl::make_sequence<sh_type>(shape.size());
  958. std::copy(shape.begin(), shape.end(), sh.begin());
  959. reshape_impl(std::move(sh), xtl::is_signed<T>(), std::forward<layout_type>(layout));
  960. return this->derived_cast();
  961. }
  962. template <class D>
  963. template <class S>
  964. inline void
  965. xstrided_container<D>::reshape_impl(S&& shape, std::false_type /* is unsigned */, layout_type layout)
  966. {
  967. if (compute_size(shape) != this->size())
  968. {
  969. XTENSOR_THROW(
  970. std::runtime_error,
  971. "Cannot reshape with incorrect number of elements. Do you mean to resize?"
  972. );
  973. }
  974. if (D::static_layout == layout_type::dynamic && layout == layout_type::dynamic)
  975. {
  976. layout = XTENSOR_DEFAULT_LAYOUT; // fall back to default layout
  977. }
  978. if (D::static_layout != layout_type::dynamic && layout != D::static_layout)
  979. {
  980. XTENSOR_THROW(std::runtime_error, "Cannot reshape with different layout if static layout != dynamic.");
  981. }
  982. m_layout = layout;
  983. m_shape = xtl::forward_sequence<shape_type, S>(shape);
  984. resize_container(m_strides, m_shape.size());
  985. resize_container(m_backstrides, m_shape.size());
  986. compute_strides<D::static_layout>(m_shape, m_layout, m_strides, m_backstrides);
  987. }
  988. template <class D>
  989. template <class S>
  990. inline void
  991. xstrided_container<D>::reshape_impl(S&& _shape, std::true_type /* is signed */, layout_type layout)
  992. {
  993. using tmp_value_type = typename std::decay_t<S>::value_type;
  994. auto new_size = compute_size(_shape);
  995. if (this->size() % new_size)
  996. {
  997. XTENSOR_THROW(std::runtime_error, "Negative axis size cannot be inferred. Shape mismatch.");
  998. }
  999. std::decay_t<S> shape = _shape;
  1000. tmp_value_type accumulator = 1;
  1001. std::size_t neg_idx = 0;
  1002. std::size_t i = 0;
  1003. for (auto it = shape.begin(); it != shape.end(); ++it, i++)
  1004. {
  1005. auto&& dim = *it;
  1006. if (dim < 0)
  1007. {
  1008. XTENSOR_ASSERT(dim == -1 && !neg_idx);
  1009. neg_idx = i;
  1010. }
  1011. accumulator *= dim;
  1012. }
  1013. if (accumulator < 0)
  1014. {
  1015. shape[neg_idx] = static_cast<tmp_value_type>(this->size()) / std::abs(accumulator);
  1016. }
  1017. else if (this->size() != new_size)
  1018. {
  1019. XTENSOR_THROW(
  1020. std::runtime_error,
  1021. "Cannot reshape with incorrect number of elements. Do you mean to resize?"
  1022. );
  1023. }
  1024. m_layout = layout;
  1025. m_shape = xtl::forward_sequence<shape_type, S>(shape);
  1026. resize_container(m_strides, m_shape.size());
  1027. resize_container(m_backstrides, m_shape.size());
  1028. compute_strides<D::static_layout>(m_shape, m_layout, m_strides, m_backstrides);
  1029. }
  1030. template <class D>
  1031. inline auto xstrided_container<D>::mutable_layout() noexcept -> layout_type&
  1032. {
  1033. return m_layout;
  1034. }
  1035. }
  1036. #endif