xtensor.hpp 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983
  1. /***************************************************************************
  2. * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
  3. * Copyright (c) QuantStack *
  4. * *
  5. * Distributed under the terms of the BSD 3-Clause License. *
  6. * *
  7. * The full license is in the file LICENSE, distributed with this software. *
  8. ****************************************************************************/
  9. #ifndef XTENSOR_TENSOR_HPP
  10. #define XTENSOR_TENSOR_HPP
  11. #include <algorithm>
  12. #include <array>
  13. #include <cstddef>
  14. #include <utility>
  15. #include <vector>
  16. #include "xbuffer_adaptor.hpp"
  17. #include "xcontainer.hpp"
  18. #include "xsemantic.hpp"
  19. namespace xt
  20. {
  21. /***********************
  22. * xtensor declaration *
  23. ***********************/
  24. namespace extension
  25. {
  26. template <class EC, std::size_t N, layout_type L, class Tag>
  27. struct xtensor_container_base;
  28. template <class EC, std::size_t N, layout_type L>
  29. struct xtensor_container_base<EC, N, L, xtensor_expression_tag>
  30. {
  31. using type = xtensor_empty_base;
  32. };
  33. template <class EC, std::size_t N, layout_type L, class Tag>
  34. using xtensor_container_base_t = typename xtensor_container_base<EC, N, L, Tag>::type;
  35. }
  36. template <class EC, std::size_t N, layout_type L, class Tag>
  37. struct xcontainer_inner_types<xtensor_container<EC, N, L, Tag>>
  38. {
  39. using storage_type = EC;
  40. using reference = inner_reference_t<storage_type>;
  41. using const_reference = typename storage_type::const_reference;
  42. using size_type = typename storage_type::size_type;
  43. using shape_type = std::array<typename storage_type::size_type, N>;
  44. using strides_type = get_strides_t<shape_type>;
  45. using backstrides_type = get_strides_t<shape_type>;
  46. using inner_shape_type = shape_type;
  47. using inner_strides_type = strides_type;
  48. using inner_backstrides_type = backstrides_type;
  49. using temporary_type = xtensor_container<EC, N, L, Tag>;
  50. static constexpr layout_type layout = L;
  51. };
  52. template <class EC, std::size_t N, layout_type L, class Tag>
  53. struct xiterable_inner_types<xtensor_container<EC, N, L, Tag>>
  54. : xcontainer_iterable_types<xtensor_container<EC, N, L, Tag>>
  55. {
  56. };
  57. /**
  58. * @class xtensor_container
  59. * @brief Dense multidimensional container with tensor semantic and fixed
  60. * dimension.
  61. *
  62. * The xtensor_container class implements a dense multidimensional container
  63. * with tensor semantics and fixed dimension
  64. *
  65. * @tparam EC The type of the container holding the elements.
  66. * @tparam N The dimension of the container.
  67. * @tparam L The layout_type of the tensor.
  68. * @tparam Tag The expression tag.
  69. * @sa xtensor, xstrided_container, xcontainer
  70. */
  71. template <class EC, std::size_t N, layout_type L, class Tag>
  72. class xtensor_container : public xstrided_container<xtensor_container<EC, N, L, Tag>>,
  73. public xcontainer_semantic<xtensor_container<EC, N, L, Tag>>,
  74. public extension::xtensor_container_base_t<EC, N, L, Tag>
  75. {
  76. public:
  77. using self_type = xtensor_container<EC, N, L, Tag>;
  78. using base_type = xstrided_container<self_type>;
  79. using semantic_base = xcontainer_semantic<self_type>;
  80. using extension_base = extension::xtensor_container_base_t<EC, N, L, Tag>;
  81. using storage_type = typename base_type::storage_type;
  82. using allocator_type = typename base_type::allocator_type;
  83. using value_type = typename base_type::value_type;
  84. using reference = typename base_type::reference;
  85. using const_reference = typename base_type::const_reference;
  86. using pointer = typename base_type::pointer;
  87. using const_pointer = typename base_type::const_pointer;
  88. using shape_type = typename base_type::shape_type;
  89. using inner_shape_type = typename base_type::inner_shape_type;
  90. using strides_type = typename base_type::strides_type;
  91. using backstrides_type = typename base_type::backstrides_type;
  92. using inner_backstrides_type = typename base_type::inner_backstrides_type;
  93. using inner_strides_type = typename base_type::inner_strides_type;
  94. using temporary_type = typename semantic_base::temporary_type;
  95. using expression_tag = Tag;
  96. static constexpr std::size_t rank = N;
  97. xtensor_container();
  98. xtensor_container(nested_initializer_list_t<value_type, N> t);
  99. explicit xtensor_container(const shape_type& shape, layout_type l = L);
  100. explicit xtensor_container(const shape_type& shape, const_reference value, layout_type l = L);
  101. explicit xtensor_container(const shape_type& shape, const strides_type& strides);
  102. explicit xtensor_container(const shape_type& shape, const strides_type& strides, const_reference value);
  103. explicit xtensor_container(storage_type&& storage, inner_shape_type&& shape, inner_strides_type&& strides);
  104. template <class S = shape_type>
  105. static xtensor_container from_shape(S&& s);
  106. ~xtensor_container() = default;
  107. xtensor_container(const xtensor_container&) = default;
  108. xtensor_container& operator=(const xtensor_container&) = default;
  109. xtensor_container(xtensor_container&&) = default;
  110. xtensor_container& operator=(xtensor_container&&) = default;
  111. template <class SC>
  112. explicit xtensor_container(xarray_container<EC, L, SC, Tag>&&);
  113. template <class SC>
  114. xtensor_container& operator=(xarray_container<EC, L, SC, Tag>&&);
  115. template <class E>
  116. xtensor_container(const xexpression<E>& e);
  117. template <class E>
  118. xtensor_container& operator=(const xexpression<E>& e);
  119. private:
  120. storage_type m_storage;
  121. storage_type& storage_impl() noexcept;
  122. const storage_type& storage_impl() const noexcept;
  123. friend class xcontainer<xtensor_container<EC, N, L, Tag>>;
  124. };
  125. /*****************************************
  126. * xtensor_container_adaptor declaration *
  127. *****************************************/
  128. namespace extension
  129. {
  130. template <class EC, std::size_t N, layout_type L, class Tag>
  131. struct xtensor_adaptor_base;
  132. template <class EC, std::size_t N, layout_type L>
  133. struct xtensor_adaptor_base<EC, N, L, xtensor_expression_tag>
  134. {
  135. using type = xtensor_empty_base;
  136. };
  137. template <class EC, std::size_t N, layout_type L, class Tag>
  138. using xtensor_adaptor_base_t = typename xtensor_adaptor_base<EC, N, L, Tag>::type;
  139. }
  140. template <class EC, std::size_t N, layout_type L, class Tag>
  141. struct xcontainer_inner_types<xtensor_adaptor<EC, N, L, Tag>>
  142. {
  143. using storage_type = std::remove_reference_t<EC>;
  144. using reference = inner_reference_t<storage_type>;
  145. using const_reference = typename storage_type::const_reference;
  146. using size_type = typename storage_type::size_type;
  147. using shape_type = std::array<typename storage_type::size_type, N>;
  148. using strides_type = get_strides_t<shape_type>;
  149. using backstrides_type = get_strides_t<shape_type>;
  150. using inner_shape_type = shape_type;
  151. using inner_strides_type = strides_type;
  152. using inner_backstrides_type = backstrides_type;
  153. using temporary_type = xtensor_container<temporary_container_t<storage_type>, N, L, Tag>;
  154. static constexpr layout_type layout = L;
  155. };
  156. template <class EC, std::size_t N, layout_type L, class Tag>
  157. struct xiterable_inner_types<xtensor_adaptor<EC, N, L, Tag>>
  158. : xcontainer_iterable_types<xtensor_adaptor<EC, N, L, Tag>>
  159. {
  160. };
  161. /**
  162. * @class xtensor_adaptor
  163. * @brief Dense multidimensional container adaptor with tensor
  164. * semantics and fixed dimension.
  165. *
  166. * The xtensor_adaptor class implements a dense multidimensional
  167. * container adaptor with tensor semantics and fixed dimension. It
  168. * is used to provide a multidimensional container semantic and a
  169. * tensor semantic to stl-like containers.
  170. *
  171. * @tparam EC The closure for the container type to adapt.
  172. * @tparam N The dimension of the adaptor.
  173. * @tparam L The layout_type of the adaptor.
  174. * @tparam Tag The expression tag.
  175. * @sa xstrided_container, xcontainer
  176. */
  177. template <class EC, std::size_t N, layout_type L, class Tag>
  178. class xtensor_adaptor : public xstrided_container<xtensor_adaptor<EC, N, L, Tag>>,
  179. public xcontainer_semantic<xtensor_adaptor<EC, N, L, Tag>>,
  180. public extension::xtensor_adaptor_base_t<EC, N, L, Tag>
  181. {
  182. public:
  183. using container_closure_type = EC;
  184. using self_type = xtensor_adaptor<EC, N, L, Tag>;
  185. using base_type = xstrided_container<self_type>;
  186. using semantic_base = xcontainer_semantic<self_type>;
  187. using extension_base = extension::xtensor_adaptor_base_t<EC, N, L, Tag>;
  188. using storage_type = typename base_type::storage_type;
  189. using allocator_type = typename base_type::allocator_type;
  190. using shape_type = typename base_type::shape_type;
  191. using strides_type = typename base_type::strides_type;
  192. using backstrides_type = typename base_type::backstrides_type;
  193. using temporary_type = typename semantic_base::temporary_type;
  194. using expression_tag = Tag;
  195. static constexpr std::size_t rank = N;
  196. xtensor_adaptor(storage_type&& storage);
  197. xtensor_adaptor(const storage_type& storage);
  198. template <class D>
  199. xtensor_adaptor(D&& storage, const shape_type& shape, layout_type l = L);
  200. template <class D>
  201. xtensor_adaptor(D&& storage, const shape_type& shape, const strides_type& strides);
  202. ~xtensor_adaptor() = default;
  203. xtensor_adaptor(const xtensor_adaptor&) = default;
  204. xtensor_adaptor& operator=(const xtensor_adaptor&);
  205. xtensor_adaptor(xtensor_adaptor&&) = default;
  206. xtensor_adaptor& operator=(xtensor_adaptor&&);
  207. xtensor_adaptor& operator=(temporary_type&&);
  208. template <class E>
  209. xtensor_adaptor& operator=(const xexpression<E>& e);
  210. template <class P, class S>
  211. void reset_buffer(P&& pointer, S&& size);
  212. private:
  213. container_closure_type m_storage;
  214. storage_type& storage_impl() noexcept;
  215. const storage_type& storage_impl() const noexcept;
  216. friend class xcontainer<xtensor_adaptor<EC, N, L, Tag>>;
  217. };
  218. /****************************
  219. * xtensor_view declaration *
  220. ****************************/
  221. template <class EC, std::size_t N, layout_type L, class Tag>
  222. class xtensor_view;
  223. namespace extension
  224. {
  225. template <class EC, std::size_t N, layout_type L, class Tag>
  226. struct xtensor_view_base;
  227. template <class EC, std::size_t N, layout_type L>
  228. struct xtensor_view_base<EC, N, L, xtensor_expression_tag>
  229. {
  230. using type = xtensor_empty_base;
  231. };
  232. template <class EC, std::size_t N, layout_type L, class Tag>
  233. using xtensor_view_base_t = typename xtensor_view_base<EC, N, L, Tag>::type;
  234. }
  235. template <class EC, std::size_t N, layout_type L, class Tag>
  236. struct xcontainer_inner_types<xtensor_view<EC, N, L, Tag>>
  237. {
  238. using storage_type = std::remove_reference_t<EC>;
  239. using reference = inner_reference_t<storage_type>;
  240. using const_reference = typename storage_type::const_reference;
  241. using size_type = typename storage_type::size_type;
  242. using shape_type = std::array<typename storage_type::size_type, N>;
  243. using strides_type = get_strides_t<shape_type>;
  244. using backstrides_type = get_strides_t<shape_type>;
  245. using inner_shape_type = shape_type;
  246. using inner_strides_type = strides_type;
  247. using inner_backstrides_type = backstrides_type;
  248. using temporary_type = xtensor_container<temporary_container_t<storage_type>, N, L, Tag>;
  249. static constexpr layout_type layout = L;
  250. };
  251. template <class EC, std::size_t N, layout_type L, class Tag>
  252. struct xiterable_inner_types<xtensor_view<EC, N, L, Tag>>
  253. : xcontainer_iterable_types<xtensor_view<EC, N, L, Tag>>
  254. {
  255. };
  256. /**
  257. * @class xtensor_view
  258. * @brief Dense multidimensional container adaptor with view
  259. * semantics and fixed dimension.
  260. *
  261. * The xtensor_view class implements a dense multidimensional
  262. * container adaptor with viewsemantics and fixed dimension. It
  263. * is used to provide a multidimensional container semantic and a
  264. * view semantic to stl-like containers.
  265. *
  266. * @tparam EC The closure for the container type to adapt.
  267. * @tparam N The dimension of the view.
  268. * @tparam L The layout_type of the view.
  269. * @tparam Tag The expression tag.
  270. * @sa xstrided_container, xcontainer
  271. */
  272. template <class EC, std::size_t N, layout_type L, class Tag>
  273. class xtensor_view : public xstrided_container<xtensor_view<EC, N, L, Tag>>,
  274. public xview_semantic<xtensor_view<EC, N, L, Tag>>,
  275. public extension::xtensor_view_base_t<EC, N, L, Tag>
  276. {
  277. public:
  278. using container_closure_type = EC;
  279. using self_type = xtensor_view<EC, N, L, Tag>;
  280. using base_type = xstrided_container<self_type>;
  281. using semantic_base = xview_semantic<self_type>;
  282. using extension_base = extension::xtensor_adaptor_base_t<EC, N, L, Tag>;
  283. using storage_type = typename base_type::storage_type;
  284. using allocator_type = typename base_type::allocator_type;
  285. using shape_type = typename base_type::shape_type;
  286. using strides_type = typename base_type::strides_type;
  287. using backstrides_type = typename base_type::backstrides_type;
  288. using temporary_type = typename semantic_base::temporary_type;
  289. using expression_tag = Tag;
  290. xtensor_view(storage_type&& storage);
  291. xtensor_view(const storage_type& storage);
  292. template <class D>
  293. xtensor_view(D&& storage, const shape_type& shape, layout_type l = L);
  294. template <class D>
  295. xtensor_view(D&& storage, const shape_type& shape, const strides_type& strides);
  296. ~xtensor_view() = default;
  297. xtensor_view(const xtensor_view&) = default;
  298. xtensor_view& operator=(const xtensor_view&);
  299. xtensor_view(xtensor_view&&) = default;
  300. xtensor_view& operator=(xtensor_view&&);
  301. template <class E>
  302. self_type& operator=(const xexpression<E>& e);
  303. template <class E>
  304. disable_xexpression<E, self_type>& operator=(const E& e);
  305. private:
  306. container_closure_type m_storage;
  307. storage_type& storage_impl() noexcept;
  308. const storage_type& storage_impl() const noexcept;
  309. void assign_temporary_impl(temporary_type&& tmp);
  310. friend class xcontainer<xtensor_view<EC, N, L, Tag>>;
  311. friend class xview_semantic<xtensor_view<EC, N, L, Tag>>;
  312. };
  313. namespace detail
  314. {
  315. template <class V>
  316. struct tensor_view_simd_helper
  317. {
  318. using valid_return_type = detail::has_simd_interface_impl<V, typename V::value_type>;
  319. using valid_reference = std::is_lvalue_reference<typename V::reference>;
  320. static constexpr bool value = valid_return_type::value && valid_reference::value;
  321. using type = std::integral_constant<bool, value>;
  322. };
  323. }
  324. // xtensor_view can be used on pseudo containers, i.e. containers
  325. // whose access operator does not return a reference. Since it
  326. // is not possible to take the address f a temporary, the load_simd
  327. // method implementation leads to a compilation error.
  328. template <class EC, std::size_t N, layout_type L, class Tag>
  329. struct has_simd_interface<xtensor_view<EC, N, L, Tag>>
  330. : detail::tensor_view_simd_helper<xtensor_view<EC, N, L, Tag>>::type
  331. {
  332. };
  333. /************************************
  334. * xtensor_container implementation *
  335. ************************************/
  336. /**
  337. * @name Constructors
  338. */
  339. //@{
  340. /**
  341. * Allocates an uninitialized xtensor_container that holds 0 elements.
  342. */
  343. template <class EC, std::size_t N, layout_type L, class Tag>
  344. inline xtensor_container<EC, N, L, Tag>::xtensor_container()
  345. : base_type()
  346. , m_storage(N == 0 ? 1 : 0, value_type())
  347. {
  348. }
  349. /**
  350. * Allocates an xtensor_container with nested initializer lists.
  351. */
  352. template <class EC, std::size_t N, layout_type L, class Tag>
  353. inline xtensor_container<EC, N, L, Tag>::xtensor_container(nested_initializer_list_t<value_type, N> t)
  354. : base_type()
  355. {
  356. base_type::resize(xt::shape<shape_type>(t), true);
  357. constexpr auto tmp = layout_type::row_major;
  358. L == tmp ? nested_copy(m_storage.begin(), t) : nested_copy(this->template begin<tmp>(), t);
  359. }
  360. /**
  361. * Allocates an uninitialized xtensor_container with the specified shape and
  362. * layout_type.
  363. * @param shape the shape of the xtensor_container
  364. * @param l the layout_type of the xtensor_container
  365. */
  366. template <class EC, std::size_t N, layout_type L, class Tag>
  367. inline xtensor_container<EC, N, L, Tag>::xtensor_container(const shape_type& shape, layout_type l)
  368. : base_type()
  369. {
  370. base_type::resize(shape, l);
  371. }
  372. /**
  373. * Allocates an xtensor_container with the specified shape and layout_type. Elements
  374. * are initialized to the specified value.
  375. * @param shape the shape of the xtensor_container
  376. * @param value the value of the elements
  377. * @param l the layout_type of the xtensor_container
  378. */
  379. template <class EC, std::size_t N, layout_type L, class Tag>
  380. inline xtensor_container<EC, N, L, Tag>::xtensor_container(
  381. const shape_type& shape,
  382. const_reference value,
  383. layout_type l
  384. )
  385. : base_type()
  386. {
  387. base_type::resize(shape, l);
  388. std::fill(m_storage.begin(), m_storage.end(), value);
  389. }
  390. /**
  391. * Allocates an uninitialized xtensor_container with the specified shape and strides.
  392. * @param shape the shape of the xtensor_container
  393. * @param strides the strides of the xtensor_container
  394. */
  395. template <class EC, std::size_t N, layout_type L, class Tag>
  396. inline xtensor_container<EC, N, L, Tag>::xtensor_container(const shape_type& shape, const strides_type& strides)
  397. : base_type()
  398. {
  399. base_type::resize(shape, strides);
  400. }
  401. /**
  402. * Allocates an uninitialized xtensor_container with the specified shape and strides.
  403. * Elements are initialized to the specified value.
  404. * @param shape the shape of the xtensor_container
  405. * @param strides the strides of the xtensor_container
  406. * @param value the value of the elements
  407. */
  408. template <class EC, std::size_t N, layout_type L, class Tag>
  409. inline xtensor_container<EC, N, L, Tag>::xtensor_container(
  410. const shape_type& shape,
  411. const strides_type& strides,
  412. const_reference value
  413. )
  414. : base_type()
  415. {
  416. base_type::resize(shape, strides);
  417. std::fill(m_storage.begin(), m_storage.end(), value);
  418. }
  419. /**
  420. * Allocates an xtensor_container by moving specified data, shape and strides
  421. *
  422. * @param storage the data for the xtensor_container
  423. * @param shape the shape of the xtensor_container
  424. * @param strides the strides of the xtensor_container
  425. */
  426. template <class EC, std::size_t N, layout_type L, class Tag>
  427. inline xtensor_container<EC, N, L, Tag>::xtensor_container(
  428. storage_type&& storage,
  429. inner_shape_type&& shape,
  430. inner_strides_type&& strides
  431. )
  432. : base_type(std::move(shape), std::move(strides))
  433. , m_storage(std::move(storage))
  434. {
  435. }
  436. template <class EC, std::size_t N, layout_type L, class Tag>
  437. template <class SC>
  438. inline xtensor_container<EC, N, L, Tag>::xtensor_container(xarray_container<EC, L, SC, Tag>&& rhs)
  439. : base_type(
  440. xtl::forward_sequence<inner_shape_type, decltype(rhs.shape())>(rhs.shape()),
  441. xtl::forward_sequence<inner_strides_type, decltype(rhs.strides())>(rhs.strides()),
  442. xtl::forward_sequence<inner_backstrides_type, decltype(rhs.backstrides())>(rhs.backstrides()),
  443. std::move(rhs.layout())
  444. )
  445. , m_storage(std::move(rhs.storage()))
  446. {
  447. }
  448. template <class EC, std::size_t N, layout_type L, class Tag>
  449. template <class SC>
  450. inline xtensor_container<EC, N, L, Tag>&
  451. xtensor_container<EC, N, L, Tag>::operator=(xarray_container<EC, L, SC, Tag>&& rhs)
  452. {
  453. XTENSOR_ASSERT_MSG(N == rhs.dimension(), "Cannot change dimension of xtensor.");
  454. std::copy(rhs.shape().begin(), rhs.shape().end(), this->shape_impl().begin());
  455. std::copy(rhs.strides().cbegin(), rhs.strides().cend(), this->strides_impl().begin());
  456. std::copy(rhs.backstrides().cbegin(), rhs.backstrides().cend(), this->backstrides_impl().begin());
  457. this->mutable_layout() = std::move(rhs.layout());
  458. m_storage = std::move(std::move(rhs.storage()));
  459. return *this;
  460. }
  461. template <class EC, std::size_t N, layout_type L, class Tag>
  462. template <class S>
  463. inline xtensor_container<EC, N, L, Tag> xtensor_container<EC, N, L, Tag>::from_shape(S&& s)
  464. {
  465. XTENSOR_ASSERT_MSG(s.size() == N, "Cannot change dimension of xtensor.");
  466. shape_type shape = xtl::forward_sequence<shape_type, S>(s);
  467. return self_type(shape);
  468. }
  469. //@}
  470. /**
  471. * @name Extended copy semantic
  472. */
  473. //@{
  474. /**
  475. * The extended copy constructor.
  476. */
  477. template <class EC, std::size_t N, layout_type L, class Tag>
  478. template <class E>
  479. inline xtensor_container<EC, N, L, Tag>::xtensor_container(const xexpression<E>& e)
  480. : base_type()
  481. {
  482. XTENSOR_ASSERT_MSG(N == e.derived_cast().dimension(), "Cannot change dimension of xtensor.");
  483. // Avoids uninitialized data because of (m_shape == shape) condition
  484. // in resize (called by assign), which is always true when dimension() == 0.
  485. if (e.derived_cast().dimension() == 0)
  486. {
  487. detail::resize_data_container(m_storage, std::size_t(1));
  488. }
  489. semantic_base::assign(e);
  490. }
  491. /**
  492. * The extended assignment operator.
  493. */
  494. template <class EC, std::size_t N, layout_type L, class Tag>
  495. template <class E>
  496. inline auto xtensor_container<EC, N, L, Tag>::operator=(const xexpression<E>& e) -> self_type&
  497. {
  498. return semantic_base::operator=(e);
  499. }
  500. //@}
  501. template <class EC, std::size_t N, layout_type L, class Tag>
  502. inline auto xtensor_container<EC, N, L, Tag>::storage_impl() noexcept -> storage_type&
  503. {
  504. return m_storage;
  505. }
  506. template <class EC, std::size_t N, layout_type L, class Tag>
  507. inline auto xtensor_container<EC, N, L, Tag>::storage_impl() const noexcept -> const storage_type&
  508. {
  509. return m_storage;
  510. }
  511. /**********************************
  512. * xtensor_adaptor implementation *
  513. **********************************/
  514. /**
  515. * @name Constructors
  516. */
  517. //@{
  518. /**
  519. * Constructs an xtensor_adaptor of the given stl-like container.
  520. * @param storage the container to adapt
  521. */
  522. template <class EC, std::size_t N, layout_type L, class Tag>
  523. inline xtensor_adaptor<EC, N, L, Tag>::xtensor_adaptor(storage_type&& storage)
  524. : base_type()
  525. , m_storage(std::move(storage))
  526. {
  527. }
  528. /**
  529. * Constructs an xtensor_adaptor of the given stl-like container.
  530. * @param storage the container to adapt
  531. */
  532. template <class EC, std::size_t N, layout_type L, class Tag>
  533. inline xtensor_adaptor<EC, N, L, Tag>::xtensor_adaptor(const storage_type& storage)
  534. : base_type()
  535. , m_storage(storage)
  536. {
  537. }
  538. /**
  539. * Constructs an xtensor_adaptor of the given stl-like container,
  540. * with the specified shape and layout_type.
  541. * @param storage the container to adapt
  542. * @param shape the shape of the xtensor_adaptor
  543. * @param l the layout_type of the xtensor_adaptor
  544. */
  545. template <class EC, std::size_t N, layout_type L, class Tag>
  546. template <class D>
  547. inline xtensor_adaptor<EC, N, L, Tag>::xtensor_adaptor(D&& storage, const shape_type& shape, layout_type l)
  548. : base_type()
  549. , m_storage(std::forward<D>(storage))
  550. {
  551. base_type::resize(shape, l);
  552. }
  553. /**
  554. * Constructs an xtensor_adaptor of the given stl-like container,
  555. * with the specified shape and strides.
  556. * @param storage the container to adapt
  557. * @param shape the shape of the xtensor_adaptor
  558. * @param strides the strides of the xtensor_adaptor
  559. */
  560. template <class EC, std::size_t N, layout_type L, class Tag>
  561. template <class D>
  562. inline xtensor_adaptor<EC, N, L, Tag>::xtensor_adaptor(
  563. D&& storage,
  564. const shape_type& shape,
  565. const strides_type& strides
  566. )
  567. : base_type()
  568. , m_storage(std::forward<D>(storage))
  569. {
  570. base_type::resize(shape, strides);
  571. }
  572. //@}
  573. template <class EC, std::size_t N, layout_type L, class Tag>
  574. inline auto xtensor_adaptor<EC, N, L, Tag>::operator=(const xtensor_adaptor& rhs) -> self_type&
  575. {
  576. base_type::operator=(rhs);
  577. m_storage = rhs.m_storage;
  578. return *this;
  579. }
  580. template <class EC, std::size_t N, layout_type L, class Tag>
  581. inline auto xtensor_adaptor<EC, N, L, Tag>::operator=(xtensor_adaptor&& rhs) -> self_type&
  582. {
  583. base_type::operator=(std::move(rhs));
  584. m_storage = rhs.m_storage;
  585. return *this;
  586. }
  587. template <class EC, std::size_t N, layout_type L, class Tag>
  588. inline auto xtensor_adaptor<EC, N, L, Tag>::operator=(temporary_type&& rhs) -> self_type&
  589. {
  590. base_type::shape_impl() = std::move(const_cast<shape_type&>(rhs.shape()));
  591. base_type::strides_impl() = std::move(const_cast<strides_type&>(rhs.strides()));
  592. base_type::backstrides_impl() = std::move(const_cast<backstrides_type&>(rhs.backstrides()));
  593. m_storage = std::move(rhs.storage());
  594. return *this;
  595. }
  596. /**
  597. * @name Extended copy semantic
  598. */
  599. //@{
  600. /**
  601. * The extended assignment operator.
  602. */
  603. template <class EC, std::size_t N, layout_type L, class Tag>
  604. template <class E>
  605. inline auto xtensor_adaptor<EC, N, L, Tag>::operator=(const xexpression<E>& e) -> self_type&
  606. {
  607. return semantic_base::operator=(e);
  608. }
  609. //@}
  610. template <class EC, std::size_t N, layout_type L, class Tag>
  611. inline auto xtensor_adaptor<EC, N, L, Tag>::storage_impl() noexcept -> storage_type&
  612. {
  613. return m_storage;
  614. }
  615. template <class EC, std::size_t N, layout_type L, class Tag>
  616. inline auto xtensor_adaptor<EC, N, L, Tag>::storage_impl() const noexcept -> const storage_type&
  617. {
  618. return m_storage;
  619. }
  620. template <class EC, std::size_t N, layout_type L, class Tag>
  621. template <class P, class S>
  622. inline void xtensor_adaptor<EC, N, L, Tag>::reset_buffer(P&& pointer, S&& size)
  623. {
  624. return m_storage.reset_data(std::forward<P>(pointer), std::forward<S>(size));
  625. }
  626. /*******************************
  627. * xtensor_view implementation *
  628. *******************************/
  629. /**
  630. * @name Constructors
  631. */
  632. //@{
  633. /**
  634. * Constructs an xtensor_view of the given stl-like container.
  635. * @param storage the container to adapt
  636. */
  637. template <class EC, std::size_t N, layout_type L, class Tag>
  638. inline xtensor_view<EC, N, L, Tag>::xtensor_view(storage_type&& storage)
  639. : base_type()
  640. , m_storage(std::move(storage))
  641. {
  642. }
  643. /**
  644. * Constructs an xtensor_view of the given stl-like container.
  645. * @param storage the container to adapt
  646. */
  647. template <class EC, std::size_t N, layout_type L, class Tag>
  648. inline xtensor_view<EC, N, L, Tag>::xtensor_view(const storage_type& storage)
  649. : base_type()
  650. , m_storage(storage)
  651. {
  652. }
  653. /**
  654. * Constructs an xtensor_view of the given stl-like container,
  655. * with the specified shape and layout_type.
  656. * @param storage the container to adapt
  657. * @param shape the shape of the xtensor_view
  658. * @param l the layout_type of the xtensor_view
  659. */
  660. template <class EC, std::size_t N, layout_type L, class Tag>
  661. template <class D>
  662. inline xtensor_view<EC, N, L, Tag>::xtensor_view(D&& storage, const shape_type& shape, layout_type l)
  663. : base_type()
  664. , m_storage(std::forward<D>(storage))
  665. {
  666. base_type::resize(shape, l);
  667. }
  668. /**
  669. * Constructs an xtensor_view of the given stl-like container,
  670. * with the specified shape and strides.
  671. * @param storage the container to adapt
  672. * @param shape the shape of the xtensor_view
  673. * @param strides the strides of the xtensor_view
  674. */
  675. template <class EC, std::size_t N, layout_type L, class Tag>
  676. template <class D>
  677. inline xtensor_view<EC, N, L, Tag>::xtensor_view(D&& storage, const shape_type& shape, const strides_type& strides)
  678. : base_type()
  679. , m_storage(std::forward<D>(storage))
  680. {
  681. base_type::resize(shape, strides);
  682. }
  683. //@}
  684. template <class EC, std::size_t N, layout_type L, class Tag>
  685. inline auto xtensor_view<EC, N, L, Tag>::operator=(const xtensor_view& rhs) -> self_type&
  686. {
  687. base_type::operator=(rhs);
  688. m_storage = rhs.m_storage;
  689. return *this;
  690. }
  691. template <class EC, std::size_t N, layout_type L, class Tag>
  692. inline auto xtensor_view<EC, N, L, Tag>::operator=(xtensor_view&& rhs) -> self_type&
  693. {
  694. base_type::operator=(std::move(rhs));
  695. m_storage = rhs.m_storage;
  696. return *this;
  697. }
  698. /**
  699. * @name Extended copy semantic
  700. */
  701. //@{
  702. /**
  703. * The extended assignment operator.
  704. */
  705. template <class EC, std::size_t N, layout_type L, class Tag>
  706. template <class E>
  707. inline auto xtensor_view<EC, N, L, Tag>::operator=(const xexpression<E>& e) -> self_type&
  708. {
  709. return semantic_base::operator=(e);
  710. }
  711. //@}
  712. template <class EC, std::size_t N, layout_type L, class Tag>
  713. template <class E>
  714. inline auto xtensor_view<EC, N, L, Tag>::operator=(const E& e) -> disable_xexpression<E, self_type>&
  715. {
  716. std::fill(m_storage.begin(), m_storage.end(), e);
  717. return *this;
  718. }
  719. template <class EC, std::size_t N, layout_type L, class Tag>
  720. inline auto xtensor_view<EC, N, L, Tag>::storage_impl() noexcept -> storage_type&
  721. {
  722. return m_storage;
  723. }
  724. template <class EC, std::size_t N, layout_type L, class Tag>
  725. inline auto xtensor_view<EC, N, L, Tag>::storage_impl() const noexcept -> const storage_type&
  726. {
  727. return m_storage;
  728. }
  729. template <class EC, std::size_t N, layout_type L, class Tag>
  730. inline void xtensor_view<EC, N, L, Tag>::assign_temporary_impl(temporary_type&& tmp)
  731. {
  732. std::copy(tmp.cbegin(), tmp.cend(), m_storage.begin());
  733. }
  734. /**
  735. * Converts ``std::vector<index_type>`` (returned e.g. from ``xt::argwhere``) to ``xtensor``.
  736. *
  737. * @param idx vector of indices
  738. *
  739. * @return ``xt::xtensor<typename index_type::value_type, 2>`` (e.g. ``xt::xtensor<size_t, 2>``)
  740. */
  741. template <class T>
  742. inline auto from_indices(const std::vector<T>& idx)
  743. {
  744. using return_type = xtensor<typename T::value_type, 2>;
  745. using size_type = typename return_type::size_type;
  746. if (idx.size() == 0)
  747. {
  748. return return_type::from_shape({size_type(0), size_type(0)});
  749. }
  750. return_type out = return_type::from_shape({idx.size(), idx[0].size()});
  751. for (size_type i = 0; i < out.shape()[0]; ++i)
  752. {
  753. for (size_type j = 0; j < out.shape()[1]; ++j)
  754. {
  755. out(i, j) = idx[i][j];
  756. }
  757. }
  758. return out;
  759. }
  760. /**
  761. * Converts ``std::vector<index_type>`` (returned e.g. from ``xt::argwhere``) to a flattened
  762. * ``xtensor``.
  763. *
  764. * @param idx a vector of indices
  765. *
  766. * @return ``xt::xtensor<typename index_type::value_type, 1>`` (e.g. ``xt::xtensor<size_t, 1>``)
  767. */
  768. template <class T>
  769. inline auto flatten_indices(const std::vector<T>& idx)
  770. {
  771. auto n = idx.size();
  772. if (n != 0)
  773. {
  774. n *= idx[0].size();
  775. }
  776. using return_type = xtensor<typename T::value_type, 1>;
  777. return_type out = return_type::from_shape({n});
  778. auto iter = out.begin();
  779. for_each(
  780. idx.begin(),
  781. idx.end(),
  782. [&iter](const auto& t)
  783. {
  784. iter = std::copy(t.cbegin(), t.cend(), iter);
  785. }
  786. );
  787. return out;
  788. }
  789. struct ravel_vector_tag;
  790. struct ravel_tensor_tag;
  791. namespace detail
  792. {
  793. template <class C, class Tag>
  794. struct ravel_return_type;
  795. template <class C>
  796. struct ravel_return_type<C, ravel_vector_tag>
  797. {
  798. using index_type = typename C::value_type;
  799. using value_type = typename index_type::value_type;
  800. using type = std::vector<value_type>;
  801. template <class T>
  802. static std::vector<value_type> init(T n)
  803. {
  804. return std::vector<value_type>(n);
  805. }
  806. };
  807. template <class C>
  808. struct ravel_return_type<C, ravel_tensor_tag>
  809. {
  810. using index_type = typename C::value_type;
  811. using value_type = typename index_type::value_type;
  812. using type = xt::xtensor<value_type, 1>;
  813. template <class T>
  814. static xt::xtensor<value_type, 1> init(T n)
  815. {
  816. return xtensor<value_type, 1>::from_shape({n});
  817. }
  818. };
  819. }
  820. template <class C, class Tag>
  821. using ravel_return_type_t = typename detail::ravel_return_type<C, Tag>::type;
  822. /**
  823. * Converts ``std::vector<index_type>`` (returned e.g. from ``xt::argwhere``) to ``xtensor``
  824. * whereby the indices are ravelled. For 1-d input there is no conversion.
  825. *
  826. * @param idx vector of indices
  827. * @param shape the shape of the original array
  828. * @param l the layout type (row-major or column-major)
  829. *
  830. * @return ``xt::xtensor<typename index_type::value_type, 1>`` (e.g. ``xt::xtensor<size_t, 1>``)
  831. */
  832. template <class Tag = ravel_tensor_tag, class C, class S>
  833. ravel_return_type_t<C, Tag>
  834. ravel_indices(const C& idx, const S& shape, layout_type l = layout_type::row_major)
  835. {
  836. using return_type = typename detail::ravel_return_type<C, Tag>::type;
  837. using value_type = typename detail::ravel_return_type<C, Tag>::value_type;
  838. using strides_type = get_strides_t<S>;
  839. strides_type strides = xtl::make_sequence<strides_type>(shape.size(), 0);
  840. compute_strides(shape, l, strides);
  841. return_type out = detail::ravel_return_type<C, Tag>::init(idx.size());
  842. auto out_iter = out.begin();
  843. auto idx_iter = idx.begin();
  844. for (; out_iter != out.end(); ++out_iter, ++idx_iter)
  845. {
  846. *out_iter = element_offset<value_type>(strides, (*idx_iter).cbegin(), (*idx_iter).cend());
  847. }
  848. return out;
  849. }
  850. }
  851. #endif