xshape.hpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578
  1. /***************************************************************************
  2. * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
  3. * Copyright (c) QuantStack *
  4. * *
  5. * Distributed under the terms of the BSD 3-Clause License. *
  6. * *
  7. * The full license is in the file LICENSE, distributed with this software. *
  8. ****************************************************************************/
  9. #ifndef XTENSOR_XSHAPE_HPP
  10. #define XTENSOR_XSHAPE_HPP
  11. #include <algorithm>
  12. #include <cassert>
  13. #include <cstddef>
  14. #include <cstdlib>
  15. #include <cstring>
  16. #include <initializer_list>
  17. #include <iterator>
  18. #include <memory>
  19. #include "xlayout.hpp"
  20. #include "xstorage.hpp"
  21. #include "xtensor_forward.hpp"
  22. namespace xt
  23. {
  24. template <class T>
  25. using dynamic_shape = svector<T, 4>;
  26. template <class T, std::size_t N>
  27. using static_shape = std::array<T, N>;
  28. template <std::size_t... X>
  29. class fixed_shape;
  30. using xindex = dynamic_shape<std::size_t>;
  31. template <class S1, class S2>
  32. bool same_shape(const S1& s1, const S2& s2) noexcept;
  33. template <class U>
  34. struct initializer_dimension;
  35. template <class R, class T>
  36. constexpr R shape(T t);
  37. template <class R = std::size_t, class T, std::size_t N>
  38. xt::static_shape<R, N> shape(const T (&aList)[N]);
  39. template <class S>
  40. struct static_dimension;
  41. template <layout_type L, class S>
  42. struct select_layout;
  43. template <class... S>
  44. struct promote_shape;
  45. template <class... S>
  46. struct promote_strides;
  47. template <class S>
  48. struct index_from_shape;
  49. }
  50. namespace xtl
  51. {
  52. namespace detail
  53. {
  54. template <class S>
  55. struct sequence_builder;
  56. template <std::size_t... I>
  57. struct sequence_builder<xt::fixed_shape<I...>>
  58. {
  59. using sequence_type = xt::fixed_shape<I...>;
  60. using value_type = typename sequence_type::value_type;
  61. inline static sequence_type make(std::size_t /*size*/)
  62. {
  63. return sequence_type{};
  64. }
  65. inline static sequence_type make(std::size_t /*size*/, value_type /*v*/)
  66. {
  67. return sequence_type{};
  68. }
  69. };
  70. }
  71. }
  72. namespace xt
  73. {
  74. /**
  75. * @defgroup xt_xshape Support functions to get/check a shape array.
  76. */
  77. /**************
  78. * same_shape *
  79. **************/
  80. /**
  81. * Check if two objects have the same shape.
  82. *
  83. * @ingroup xt_xshape
  84. * @param s1 an array
  85. * @param s2 an array
  86. * @return bool
  87. */
  88. template <class S1, class S2>
  89. inline bool same_shape(const S1& s1, const S2& s2) noexcept
  90. {
  91. return s1.size() == s2.size() && std::equal(s1.begin(), s1.end(), s2.begin());
  92. }
  93. /*************
  94. * has_shape *
  95. *************/
  96. /**
  97. * Check if an object has a certain shape.
  98. *
  99. * @ingroup xt_xshape
  100. * @param a an array
  101. * @param shape the shape to test
  102. * @return bool
  103. */
  104. template <class E, class S>
  105. inline bool has_shape(const E& e, std::initializer_list<S> shape) noexcept
  106. {
  107. return e.shape().size() == shape.size()
  108. && std::equal(e.shape().cbegin(), e.shape().cend(), shape.begin());
  109. }
  110. /**
  111. * Check if an object has a certain shape.
  112. *
  113. * @ingroup has_shape
  114. * @param a an array
  115. * @param shape the shape to test
  116. * @return bool
  117. */
  118. template <class E, class S, class = typename std::enable_if_t<has_iterator_interface<S>::value>>
  119. inline bool has_shape(const E& e, const S& shape)
  120. {
  121. return e.shape().size() == shape.size()
  122. && std::equal(e.shape().cbegin(), e.shape().cend(), shape.begin());
  123. }
  124. /*************************
  125. * initializer_dimension *
  126. *************************/
  127. namespace detail
  128. {
  129. template <class U>
  130. struct initializer_depth_impl
  131. {
  132. static constexpr std::size_t value = 0;
  133. };
  134. template <class T>
  135. struct initializer_depth_impl<std::initializer_list<T>>
  136. {
  137. static constexpr std::size_t value = 1 + initializer_depth_impl<T>::value;
  138. };
  139. }
  140. template <class U>
  141. struct initializer_dimension
  142. {
  143. static constexpr std::size_t value = detail::initializer_depth_impl<U>::value;
  144. };
  145. /*********************
  146. * initializer_shape *
  147. *********************/
  148. namespace detail
  149. {
  150. template <std::size_t I>
  151. struct initializer_shape_impl
  152. {
  153. template <class T>
  154. static constexpr std::size_t value(T t)
  155. {
  156. return t.size() == 0 ? 0 : initializer_shape_impl<I - 1>::value(*t.begin());
  157. }
  158. };
  159. template <>
  160. struct initializer_shape_impl<0>
  161. {
  162. template <class T>
  163. static constexpr std::size_t value(T t)
  164. {
  165. return t.size();
  166. }
  167. };
  168. template <class R, class U, std::size_t... I>
  169. constexpr R initializer_shape(U t, std::index_sequence<I...>)
  170. {
  171. using size_type = typename R::value_type;
  172. return {size_type(initializer_shape_impl<I>::value(t))...};
  173. }
  174. }
  175. template <class R, class T>
  176. constexpr R shape(T t)
  177. {
  178. return detail::initializer_shape<R, decltype(t)>(
  179. t,
  180. std::make_index_sequence<initializer_dimension<decltype(t)>::value>()
  181. );
  182. }
  183. /** @brief Generate an xt::static_shape of the given size. */
  184. template <class R, class T, std::size_t N>
  185. xt::static_shape<R, N> shape(const T (&list)[N])
  186. {
  187. xt::static_shape<R, N> shape;
  188. std::copy(std::begin(list), std::end(list), std::begin(shape));
  189. return shape;
  190. }
  191. /********************
  192. * static_dimension *
  193. ********************/
  194. namespace detail
  195. {
  196. template <class T, class E = void>
  197. struct static_dimension_impl
  198. {
  199. static constexpr std::ptrdiff_t value = -1;
  200. };
  201. template <class T>
  202. struct static_dimension_impl<T, void_t<decltype(std::tuple_size<T>::value)>>
  203. {
  204. static constexpr std::ptrdiff_t value = static_cast<std::ptrdiff_t>(std::tuple_size<T>::value);
  205. };
  206. }
  207. template <class S>
  208. struct static_dimension
  209. {
  210. static constexpr std::ptrdiff_t value = detail::static_dimension_impl<S>::value;
  211. };
  212. /**
  213. * Compute a layout based on a layout and a shape type.
  214. *
  215. * The main functionality of this function is that it reduces vectors to
  216. * ``xt::layout_type::any`` so that assigning a row major 1D container to another
  217. * row_major container becomes free.
  218. *
  219. * @ingroup xt_xshape
  220. */
  221. template <layout_type L, class S>
  222. struct select_layout
  223. {
  224. static constexpr std::ptrdiff_t static_dimension = xt::static_dimension<S>::value;
  225. static constexpr bool is_any = static_dimension != -1 && static_dimension <= 1
  226. && L != layout_type::dynamic;
  227. static constexpr layout_type value = is_any ? layout_type::any : L;
  228. };
  229. /*************************************
  230. * promote_shape and promote_strides *
  231. *************************************/
  232. namespace detail
  233. {
  234. template <class T1, class T2>
  235. constexpr std::common_type_t<T1, T2> imax(const T1& a, const T2& b)
  236. {
  237. return a > b ? a : b;
  238. }
  239. // Variadic meta-function returning the maximal size of std::arrays.
  240. template <class... T>
  241. struct max_array_size;
  242. template <>
  243. struct max_array_size<>
  244. {
  245. static constexpr std::size_t value = 0;
  246. };
  247. template <class T, class... Ts>
  248. struct max_array_size<T, Ts...>
  249. : std::integral_constant<std::size_t, imax(std::tuple_size<T>::value, max_array_size<Ts...>::value)>
  250. {
  251. };
  252. // Broadcasting for fixed shapes
  253. template <std::size_t IDX, std::size_t... X>
  254. struct at
  255. {
  256. static constexpr std::size_t arr[sizeof...(X)] = {X...};
  257. static constexpr std::size_t value = (IDX < sizeof...(X)) ? arr[IDX] : 0;
  258. };
  259. template <class S1, class S2>
  260. struct broadcast_fixed_shape;
  261. template <class IX, class A, class B>
  262. struct broadcast_fixed_shape_impl;
  263. template <std::size_t IX, class A, class B>
  264. struct broadcast_fixed_shape_cmp_impl;
  265. template <std::size_t JX, std::size_t... I, std::size_t... J>
  266. struct broadcast_fixed_shape_cmp_impl<JX, fixed_shape<I...>, fixed_shape<J...>>
  267. {
  268. // We line the shapes up from the last index
  269. // IX may underflow, thus being a very large number
  270. static constexpr std::size_t IX = JX - (sizeof...(J) - sizeof...(I));
  271. // Out of bounds access gives value 0
  272. static constexpr std::size_t I_v = at<IX, I...>::value;
  273. static constexpr std::size_t J_v = at<JX, J...>::value;
  274. // we're statically checking if the broadcast shapes are either one on either of them or equal
  275. static_assert(!I_v || I_v == 1 || J_v == 1 || J_v == I_v, "broadcast shapes do not match.");
  276. static constexpr std::size_t ordinate = (I_v > J_v) ? I_v : J_v;
  277. static constexpr bool value = (I_v == J_v);
  278. };
  279. template <std::size_t... JX, std::size_t... I, std::size_t... J>
  280. struct broadcast_fixed_shape_impl<std::index_sequence<JX...>, fixed_shape<I...>, fixed_shape<J...>>
  281. {
  282. static_assert(sizeof...(J) >= sizeof...(I), "broadcast shapes do not match.");
  283. using type = xt::fixed_shape<
  284. broadcast_fixed_shape_cmp_impl<JX, fixed_shape<I...>, fixed_shape<J...>>::ordinate...>;
  285. static constexpr bool value = xtl::conjunction<
  286. broadcast_fixed_shape_cmp_impl<JX, fixed_shape<I...>, fixed_shape<J...>>...>::value;
  287. };
  288. /* broadcast_fixed_shape<fixed_shape<I...>, fixed_shape<J...>>
  289. * Just like a call to broadcast_shape(cont S1& input, S2& output),
  290. * except that the result shape is alised as type, and the returned
  291. * bool is the member value. Asserts on an illegal broadcast, including
  292. * the case where pack I is strictly longer than pack J. */
  293. template <std::size_t... I, std::size_t... J>
  294. struct broadcast_fixed_shape<fixed_shape<I...>, fixed_shape<J...>>
  295. : broadcast_fixed_shape_impl<std::make_index_sequence<sizeof...(J)>, fixed_shape<I...>, fixed_shape<J...>>
  296. {
  297. };
  298. // Simple is_array and only_array meta-functions
  299. template <class S>
  300. struct is_array
  301. {
  302. static constexpr bool value = false;
  303. };
  304. template <class T, std::size_t N>
  305. struct is_array<std::array<T, N>>
  306. {
  307. static constexpr bool value = true;
  308. };
  309. template <class S>
  310. struct is_fixed : std::false_type
  311. {
  312. };
  313. template <std::size_t... N>
  314. struct is_fixed<fixed_shape<N...>> : std::true_type
  315. {
  316. };
  317. template <class S>
  318. struct is_scalar_shape
  319. {
  320. static constexpr bool value = false;
  321. };
  322. template <class T>
  323. struct is_scalar_shape<std::array<T, 0>>
  324. {
  325. static constexpr bool value = true;
  326. };
  327. template <class... S>
  328. using only_array = xtl::conjunction<xtl::disjunction<is_array<S>, is_fixed<S>>...>;
  329. // test that at least one argument is a fixed shape. If yes, then either argument has to be fixed or
  330. // scalar
  331. template <class... S>
  332. using only_fixed = std::integral_constant<
  333. bool,
  334. xtl::disjunction<is_fixed<S>...>::value
  335. && xtl::conjunction<xtl::disjunction<is_fixed<S>, is_scalar_shape<S>>...>::value>;
  336. template <class... S>
  337. using all_fixed = xtl::conjunction<is_fixed<S>...>;
  338. // The promote_index meta-function returns std::vector<promoted_value_type> in the
  339. // general case and an array of the promoted value type and maximal size if all
  340. // arguments are of type std::array
  341. template <class... S>
  342. struct promote_array
  343. {
  344. using type = std::
  345. array<typename std::common_type<typename S::value_type...>::type, max_array_size<S...>::value>;
  346. };
  347. template <>
  348. struct promote_array<>
  349. {
  350. using type = std::array<std::size_t, 0>;
  351. };
  352. template <class S>
  353. struct filter_scalar
  354. {
  355. using type = S;
  356. };
  357. template <class T>
  358. struct filter_scalar<std::array<T, 0>>
  359. {
  360. using type = fixed_shape<1>;
  361. };
  362. template <class S>
  363. using filter_scalar_t = typename filter_scalar<S>::type;
  364. template <class... S>
  365. struct promote_fixed : promote_fixed<filter_scalar_t<S>...>
  366. {
  367. };
  368. template <std::size_t... I>
  369. struct promote_fixed<fixed_shape<I...>>
  370. {
  371. using type = fixed_shape<I...>;
  372. static constexpr bool value = true;
  373. };
  374. template <std::size_t... I, std::size_t... J, class... S>
  375. struct promote_fixed<fixed_shape<I...>, fixed_shape<J...>, S...>
  376. {
  377. private:
  378. using intermediate = std::conditional_t<
  379. (sizeof...(I) > sizeof...(J)),
  380. broadcast_fixed_shape<fixed_shape<J...>, fixed_shape<I...>>,
  381. broadcast_fixed_shape<fixed_shape<I...>, fixed_shape<J...>>>;
  382. using result = promote_fixed<typename intermediate::type, S...>;
  383. public:
  384. using type = typename result::type;
  385. static constexpr bool value = xtl::conjunction<intermediate, result>::value;
  386. };
  387. template <bool all_index, bool all_array, class... S>
  388. struct select_promote_index;
  389. template <class... S>
  390. struct select_promote_index<true, true, S...> : promote_fixed<S...>
  391. {
  392. };
  393. template <>
  394. struct select_promote_index<true, true>
  395. {
  396. // todo correct? used in xvectorize
  397. using type = dynamic_shape<std::size_t>;
  398. };
  399. template <class... S>
  400. struct select_promote_index<false, true, S...> : promote_array<S...>
  401. {
  402. };
  403. template <class... S>
  404. struct select_promote_index<false, false, S...>
  405. {
  406. using type = dynamic_shape<typename std::common_type<typename S::value_type...>::type>;
  407. };
  408. template <class... S>
  409. struct promote_index : select_promote_index<only_fixed<S...>::value, only_array<S...>::value, S...>
  410. {
  411. };
  412. template <class T>
  413. struct index_from_shape_impl
  414. {
  415. using type = T;
  416. };
  417. template <std::size_t... N>
  418. struct index_from_shape_impl<fixed_shape<N...>>
  419. {
  420. using type = std::array<std::size_t, sizeof...(N)>;
  421. };
  422. }
  423. template <class... S>
  424. struct promote_shape
  425. {
  426. using type = typename detail::promote_index<S...>::type;
  427. };
  428. /**
  429. * @ingroup xt_xshape
  430. */
  431. template <class... S>
  432. using promote_shape_t = typename promote_shape<S...>::type;
  433. template <class... S>
  434. struct promote_strides
  435. {
  436. using type = typename detail::promote_index<S...>::type;
  437. };
  438. /**
  439. * @ingroup xt_xshape
  440. */
  441. template <class... S>
  442. using promote_strides_t = typename promote_strides<S...>::type;
  443. template <class S>
  444. struct index_from_shape
  445. {
  446. using type = typename detail::index_from_shape_impl<S>::type;
  447. };
  448. /**
  449. * @ingroup xt_xshape
  450. */
  451. template <class S>
  452. using index_from_shape_t = typename index_from_shape<S>::type;
  453. /**********************
  454. * filter_fixed_shape *
  455. **********************/
  456. namespace detail
  457. {
  458. template <class S>
  459. struct filter_fixed_shape_impl
  460. {
  461. using type = S;
  462. };
  463. template <std::size_t... N>
  464. struct filter_fixed_shape_impl<fixed_shape<N...>>
  465. {
  466. using type = std::array<std::size_t, sizeof...(N)>;
  467. };
  468. }
  469. template <class S>
  470. struct filter_fixed_shape : detail::filter_fixed_shape_impl<S>
  471. {
  472. };
  473. /**
  474. * @ingroup xt_xshape
  475. */
  476. template <class S>
  477. using filter_fixed_shape_t = typename filter_fixed_shape<S>::type;
  478. }
  479. #endif