xstrides.hpp 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916
  1. /***************************************************************************
  2. * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
  3. * Copyright (c) QuantStack *
  4. * *
  5. * Distributed under the terms of the BSD 3-Clause License. *
  6. * *
  7. * The full license is in the file LICENSE, distributed with this software. *
  8. ****************************************************************************/
  9. #ifndef XTENSOR_STRIDES_HPP
  10. #define XTENSOR_STRIDES_HPP
  11. #include <cstddef>
  12. #include <functional>
  13. #include <limits>
  14. #include <numeric>
  15. #include <xtl/xsequence.hpp>
  16. #include "xexception.hpp"
  17. #include "xshape.hpp"
  18. #include "xtensor_config.hpp"
  19. #include "xtensor_forward.hpp"
  20. namespace xt
  21. {
  22. template <class shape_type>
  23. std::size_t compute_size(const shape_type& shape) noexcept;
  24. /**
  25. * @defgroup xt_xstrides Support functions swich between array indices and flat indices
  26. */
  27. /***************
  28. * data offset *
  29. ***************/
  30. template <class offset_type, class S>
  31. offset_type data_offset(const S& strides) noexcept;
  32. /**
  33. * @brief Return the flat index for an array index.
  34. *
  35. * Given ``m`` arguments, and dimension ``n``of the array (``n == strides.size()``).
  36. *
  37. * - If ``m == n``, the index is
  38. * ``strides[0] * index[0] + ... + strides[n - 1] * index[n - 1]``.
  39. *
  40. * - If ``m < n`` and the last argument is ``xt::missing`` the indices are zero-padded at
  41. * the end to match the dimension of the array. The index is then
  42. * ``strides[0] * index[0] + ... + strides[m - 1] * index[m - 1]``.
  43. *
  44. * - If ``m < n`` (and the last argument is not ``xt::missing``), the index is
  45. * ``strides[n - m - 1] * index[0] + ... + strides[n - 1] * index[m - 1]``.
  46. *
  47. * - If ``m > n``, then the first ``m - n`` arguments are ignored. The index is then
  48. * ``strides[0] * index[m - n] + ... + strides[n - 1] * index[m - 1]``.
  49. *
  50. * @ingroup xt_xstrides
  51. * @param strides Strides of the array.
  52. * @param args Array index.
  53. * @return The flat index.
  54. */
  55. template <class offset_type, class S, class Arg, class... Args>
  56. offset_type data_offset(const S& strides, Arg arg, Args... args) noexcept;
  57. template <class offset_type, layout_type L = layout_type::dynamic, class S, class... Args>
  58. offset_type unchecked_data_offset(const S& strides, Args... args) noexcept;
  59. template <class offset_type, class S, class It>
  60. offset_type element_offset(const S& strides, It first, It last) noexcept;
  61. /*******************
  62. * strides builder *
  63. *******************/
  64. /**
  65. * @brief Compute the strides given the shape and the layout of an array.
  66. *
  67. * @ingroup xt_xstrides
  68. * @param shape Shape of the array.
  69. * @param l Layout type, see xt::layout_type().
  70. * @param strides (output) Strides of the array.
  71. * @return The size: the product of the shape.
  72. */
  73. template <layout_type L = layout_type::dynamic, class shape_type, class strides_type>
  74. std::size_t compute_strides(const shape_type& shape, layout_type l, strides_type& strides);
  75. template <layout_type L = layout_type::dynamic, class shape_type, class strides_type, class backstrides_type>
  76. std::size_t
  77. compute_strides(const shape_type& shape, layout_type l, strides_type& strides, backstrides_type& backstrides);
  78. template <class shape_type, class strides_type>
  79. void adapt_strides(const shape_type& shape, strides_type& strides) noexcept;
  80. template <class shape_type, class strides_type, class backstrides_type>
  81. void adapt_strides(const shape_type& shape, strides_type& strides, backstrides_type& backstrides) noexcept;
  82. /*****************
  83. * unravel_index *
  84. *****************/
  85. template <class S>
  86. S unravel_from_strides(typename S::value_type index, const S& strides, layout_type l = layout_type::row_major);
  87. template <class S>
  88. get_strides_t<S>
  89. unravel_index(typename S::value_type index, const S& shape, layout_type l = layout_type::row_major);
  90. template <class S, class T>
  91. std::vector<get_strides_t<S>>
  92. unravel_indices(const T& indices, const S& shape, layout_type l = layout_type::row_major);
  93. /***********************
  94. * broadcast functions *
  95. ***********************/
  96. template <class S, class size_type>
  97. S uninitialized_shape(size_type size);
  98. template <class S1, class S2>
  99. bool broadcast_shape(const S1& input, S2& output);
  100. template <class S1, class S2>
  101. bool broadcastable(const S1& s1, S2& s2);
  102. /*************************
  103. * check strides overlap *
  104. *************************/
  105. template <layout_type L>
  106. struct check_strides_overlap;
  107. /**********************************
  108. * check bounds, without throwing *
  109. **********************************/
  110. /**
  111. * @brief Check if the index is within the bounds of the array.
  112. *
  113. * @param shape Shape of the array.
  114. * @param args Array index.
  115. * @return true If the index is within the bounds of the array.
  116. * @return false Otherwise.
  117. */
  118. template <class S, class... Args>
  119. bool in_bounds(const S& shape, Args&... args);
  120. /********************************
  121. * apply periodicity to indices *
  122. *******************************/
  123. /**
  124. * @brief Normalise an index of a periodic array.
  125. * For example if the shape is ``(3, 4)`` and the index is ``(3, -4)`` the result is ``(0, 0)``.
  126. *
  127. * @ingroup xt_xstrides
  128. * @param shape Shape of the array.
  129. * @param args (input/output) Array index.
  130. */
  131. template <class S, class... Args>
  132. void normalize_periodic(const S& shape, Args&... args);
  133. /********************************************
  134. * utility functions for strided containers *
  135. ********************************************/
  136. template <class C, class It, class size_type>
  137. It strided_data_end(const C& c, It begin, layout_type l, size_type offset)
  138. {
  139. using difference_type = typename std::iterator_traits<It>::difference_type;
  140. if (c.dimension() == 0)
  141. {
  142. ++begin;
  143. }
  144. else
  145. {
  146. for (std::size_t i = 0; i != c.dimension(); ++i)
  147. {
  148. begin += c.strides()[i] * difference_type(c.shape()[i] - 1);
  149. }
  150. if (l == layout_type::row_major)
  151. {
  152. begin += c.strides().back();
  153. }
  154. else
  155. {
  156. if (offset == 0)
  157. {
  158. begin += c.strides().front();
  159. }
  160. }
  161. }
  162. return begin;
  163. }
  164. /***********
  165. * strides *
  166. ***********/
  167. namespace detail
  168. {
  169. template <class return_type, class S, class T, class D>
  170. inline return_type compute_stride_impl(layout_type layout, const S& shape, T axis, D default_stride)
  171. {
  172. if (layout == layout_type::row_major)
  173. {
  174. return std::accumulate(
  175. shape.cbegin() + axis + 1,
  176. shape.cend(),
  177. static_cast<return_type>(1),
  178. std::multiplies<return_type>()
  179. );
  180. }
  181. if (layout == layout_type::column_major)
  182. {
  183. return std::accumulate(
  184. shape.cbegin(),
  185. shape.cbegin() + axis,
  186. static_cast<return_type>(1),
  187. std::multiplies<return_type>()
  188. );
  189. }
  190. return default_stride;
  191. }
  192. }
  193. /**
  194. * @brief Choose stride type
  195. * @ingroup xt_xstrides
  196. */
  197. enum class stride_type
  198. {
  199. internal = 0, ///< As used internally (with `stride(axis) == 0` if `shape(axis) == 1`)
  200. normal = 1, ///< Normal stride corresponding to storage.
  201. bytes = 2, ///< Normal stride in bytes.
  202. };
  203. /**
  204. * @brief Get strides of an object.
  205. *
  206. * @ingroup xt_xstrides
  207. * @param a an array
  208. * @return array
  209. */
  210. template <class E>
  211. inline auto strides(const E& e, stride_type type = stride_type::normal) noexcept
  212. {
  213. using strides_type = typename E::strides_type;
  214. using return_type = typename strides_type::value_type;
  215. strides_type ret = e.strides();
  216. auto shape = e.shape();
  217. if (type == stride_type::internal)
  218. {
  219. return ret;
  220. }
  221. for (std::size_t i = 0; i < ret.size(); ++i)
  222. {
  223. if (shape[i] == 1)
  224. {
  225. ret[i] = detail::compute_stride_impl<return_type>(e.layout(), shape, i, ret[i]);
  226. }
  227. }
  228. if (type == stride_type::bytes)
  229. {
  230. return_type f = static_cast<return_type>(sizeof(typename E::value_type));
  231. std::for_each(
  232. ret.begin(),
  233. ret.end(),
  234. [f](auto& c)
  235. {
  236. c *= f;
  237. }
  238. );
  239. }
  240. return ret;
  241. }
  242. /**
  243. * @brief Get stride of an object along an axis.
  244. *
  245. * @ingroup xt_xstrides
  246. * @param a an array
  247. * @return integer
  248. */
  249. template <class E>
  250. inline auto strides(const E& e, std::size_t axis, stride_type type = stride_type::normal) noexcept
  251. {
  252. using strides_type = typename E::strides_type;
  253. using return_type = typename strides_type::value_type;
  254. return_type ret = e.strides()[axis];
  255. if (type == stride_type::internal)
  256. {
  257. return ret;
  258. }
  259. if (ret == 0)
  260. {
  261. if (e.shape(axis) == 1)
  262. {
  263. ret = detail::compute_stride_impl<return_type>(e.layout(), e.shape(), axis, ret);
  264. }
  265. }
  266. if (type == stride_type::bytes)
  267. {
  268. return_type f = static_cast<return_type>(sizeof(typename E::value_type));
  269. ret *= f;
  270. }
  271. return ret;
  272. }
  273. /******************
  274. * Implementation *
  275. ******************/
  276. namespace detail
  277. {
  278. template <class shape_type>
  279. inline std::size_t compute_size_impl(const shape_type& shape, std::true_type /* is signed */)
  280. {
  281. using size_type = std::decay_t<typename shape_type::value_type>;
  282. return static_cast<std::size_t>(std::abs(
  283. std::accumulate(shape.cbegin(), shape.cend(), size_type(1), std::multiplies<size_type>())
  284. ));
  285. }
  286. template <class shape_type>
  287. inline std::size_t compute_size_impl(const shape_type& shape, std::false_type /* is not signed */)
  288. {
  289. using size_type = std::decay_t<typename shape_type::value_type>;
  290. return static_cast<std::size_t>(
  291. std::accumulate(shape.cbegin(), shape.cend(), size_type(1), std::multiplies<size_type>())
  292. );
  293. }
  294. }
  295. template <class shape_type>
  296. inline std::size_t compute_size(const shape_type& shape) noexcept
  297. {
  298. return detail::compute_size_impl(
  299. shape,
  300. xtl::is_signed<std::decay_t<typename std::decay_t<shape_type>::value_type>>()
  301. );
  302. }
  303. namespace detail
  304. {
  305. template <std::size_t dim, class S>
  306. inline auto raw_data_offset(const S&) noexcept
  307. {
  308. using strides_value_type = std::decay_t<decltype(std::declval<S>()[0])>;
  309. return strides_value_type(0);
  310. }
  311. template <std::size_t dim, class S>
  312. inline auto raw_data_offset(const S&, missing_type) noexcept
  313. {
  314. using strides_value_type = std::decay_t<decltype(std::declval<S>()[0])>;
  315. return strides_value_type(0);
  316. }
  317. template <std::size_t dim, class S, class Arg, class... Args>
  318. inline auto raw_data_offset(const S& strides, Arg arg, Args... args) noexcept
  319. {
  320. return static_cast<std::ptrdiff_t>(arg) * strides[dim] + raw_data_offset<dim + 1>(strides, args...);
  321. }
  322. template <layout_type L, std::ptrdiff_t static_dim>
  323. struct layout_data_offset
  324. {
  325. template <std::size_t dim, class S, class Arg, class... Args>
  326. inline static auto run(const S& strides, Arg arg, Args... args) noexcept
  327. {
  328. return raw_data_offset<dim>(strides, arg, args...);
  329. }
  330. };
  331. template <std::ptrdiff_t static_dim>
  332. struct layout_data_offset<layout_type::row_major, static_dim>
  333. {
  334. using self_type = layout_data_offset<layout_type::row_major, static_dim>;
  335. template <std::size_t dim, class S, class Arg>
  336. inline static auto run(const S& strides, Arg arg) noexcept
  337. {
  338. if (std::ptrdiff_t(dim) + 1 == static_dim)
  339. {
  340. return arg;
  341. }
  342. else
  343. {
  344. return arg * strides[dim];
  345. }
  346. }
  347. template <std::size_t dim, class S, class Arg, class... Args>
  348. inline static auto run(const S& strides, Arg arg, Args... args) noexcept
  349. {
  350. return arg * strides[dim] + self_type::template run<dim + 1>(strides, args...);
  351. }
  352. };
  353. template <std::ptrdiff_t static_dim>
  354. struct layout_data_offset<layout_type::column_major, static_dim>
  355. {
  356. using self_type = layout_data_offset<layout_type::column_major, static_dim>;
  357. template <std::size_t dim, class S, class Arg>
  358. inline static auto run(const S& strides, Arg arg) noexcept
  359. {
  360. if (dim == 0)
  361. {
  362. return arg;
  363. }
  364. else
  365. {
  366. return arg * strides[dim];
  367. }
  368. }
  369. template <std::size_t dim, class S, class Arg, class... Args>
  370. inline static auto run(const S& strides, Arg arg, Args... args) noexcept
  371. {
  372. if (dim == 0)
  373. {
  374. return arg + self_type::template run<dim + 1>(strides, args...);
  375. }
  376. else
  377. {
  378. return arg * strides[dim] + self_type::template run<dim + 1>(strides, args...);
  379. }
  380. }
  381. };
  382. }
  383. template <class offset_type, class S>
  384. inline offset_type data_offset(const S&) noexcept
  385. {
  386. return offset_type(0);
  387. }
  388. template <class offset_type, class S, class Arg, class... Args>
  389. inline offset_type data_offset(const S& strides, Arg arg, Args... args) noexcept
  390. {
  391. constexpr std::size_t nargs = sizeof...(Args) + 1;
  392. if (nargs == strides.size())
  393. {
  394. // Correct number of arguments: iterate
  395. return static_cast<offset_type>(detail::raw_data_offset<0>(strides, arg, args...));
  396. }
  397. else if (nargs > strides.size())
  398. {
  399. // Too many arguments: drop the first
  400. return data_offset<offset_type, S>(strides, args...);
  401. }
  402. else if (detail::last_type_is_missing<Args...>)
  403. {
  404. // Too few arguments & last argument xt::missing: postfix index with zeros
  405. return static_cast<offset_type>(detail::raw_data_offset<0>(strides, arg, args...));
  406. }
  407. else
  408. {
  409. // Too few arguments: right to left scalar product
  410. auto view = strides.cend() - nargs;
  411. return static_cast<offset_type>(detail::raw_data_offset<0>(view, arg, args...));
  412. }
  413. }
  414. template <class offset_type, layout_type L, class S, class... Args>
  415. inline offset_type unchecked_data_offset(const S& strides, Args... args) noexcept
  416. {
  417. return static_cast<offset_type>(
  418. detail::layout_data_offset<L, static_dimension<S>::value>::template run<0>(strides.cbegin(), args...)
  419. );
  420. }
  421. template <class offset_type, class S, class It>
  422. inline offset_type element_offset(const S& strides, It first, It last) noexcept
  423. {
  424. using difference_type = typename std::iterator_traits<It>::difference_type;
  425. auto size = static_cast<difference_type>(
  426. (std::min)(static_cast<typename S::size_type>(std::distance(first, last)), strides.size())
  427. );
  428. return std::inner_product(last - size, last, strides.cend() - size, offset_type(0));
  429. }
  430. namespace detail
  431. {
  432. template <class shape_type, class strides_type, class bs_ptr>
  433. inline void adapt_strides(
  434. const shape_type& shape,
  435. strides_type& strides,
  436. bs_ptr backstrides,
  437. typename strides_type::size_type i
  438. ) noexcept
  439. {
  440. if (shape[i] == 1)
  441. {
  442. strides[i] = 0;
  443. }
  444. (*backstrides)[i] = strides[i] * std::ptrdiff_t(shape[i] - 1);
  445. }
  446. template <class shape_type, class strides_type>
  447. inline void adapt_strides(
  448. const shape_type& shape,
  449. strides_type& strides,
  450. std::nullptr_t,
  451. typename strides_type::size_type i
  452. ) noexcept
  453. {
  454. if (shape[i] == 1)
  455. {
  456. strides[i] = 0;
  457. }
  458. }
  459. template <layout_type L, class shape_type, class strides_type, class bs_ptr>
  460. inline std::size_t
  461. compute_strides(const shape_type& shape, layout_type l, strides_type& strides, bs_ptr bs)
  462. {
  463. using strides_value_type = typename std::decay_t<strides_type>::value_type;
  464. strides_value_type data_size = 1;
  465. #if defined(_MSC_VER) && (1931 <= _MSC_VER)
  466. // Workaround MSVC compiler optimization bug, xtensor#2568
  467. if (0 == shape.size())
  468. {
  469. return static_cast<std::size_t>(data_size);
  470. }
  471. #endif
  472. if (L == layout_type::row_major || l == layout_type::row_major)
  473. {
  474. for (std::size_t i = shape.size(); i != 0; --i)
  475. {
  476. strides[i - 1] = data_size;
  477. data_size = strides[i - 1] * static_cast<strides_value_type>(shape[i - 1]);
  478. adapt_strides(shape, strides, bs, i - 1);
  479. }
  480. }
  481. else
  482. {
  483. for (std::size_t i = 0; i < shape.size(); ++i)
  484. {
  485. strides[i] = data_size;
  486. data_size = strides[i] * static_cast<strides_value_type>(shape[i]);
  487. adapt_strides(shape, strides, bs, i);
  488. }
  489. }
  490. return static_cast<std::size_t>(data_size);
  491. }
  492. }
  493. template <layout_type L, class shape_type, class strides_type>
  494. inline std::size_t compute_strides(const shape_type& shape, layout_type l, strides_type& strides)
  495. {
  496. return detail::compute_strides<L>(shape, l, strides, nullptr);
  497. }
  498. template <layout_type L, class shape_type, class strides_type, class backstrides_type>
  499. inline std::size_t
  500. compute_strides(const shape_type& shape, layout_type l, strides_type& strides, backstrides_type& backstrides)
  501. {
  502. return detail::compute_strides<L>(shape, l, strides, &backstrides);
  503. }
  504. template <class T1, class T2>
  505. inline bool
  506. stride_match_condition(const T1& stride, const T2& shape, const T1& data_size, bool zero_strides)
  507. {
  508. return (shape == T2(1) && stride == T1(0) && zero_strides) || (stride == data_size);
  509. }
  510. // zero_strides should be true when strides are set to 0 if the corresponding dimensions are 1
  511. template <class shape_type, class strides_type>
  512. inline bool
  513. do_strides_match(const shape_type& shape, const strides_type& strides, layout_type l, bool zero_strides)
  514. {
  515. using value_type = typename strides_type::value_type;
  516. value_type data_size = 1;
  517. if (l == layout_type::row_major)
  518. {
  519. for (std::size_t i = strides.size(); i != 0; --i)
  520. {
  521. if (!stride_match_condition(strides[i - 1], shape[i - 1], data_size, zero_strides))
  522. {
  523. return false;
  524. }
  525. data_size *= static_cast<value_type>(shape[i - 1]);
  526. }
  527. return true;
  528. }
  529. else if (l == layout_type::column_major)
  530. {
  531. for (std::size_t i = 0; i < strides.size(); ++i)
  532. {
  533. if (!stride_match_condition(strides[i], shape[i], data_size, zero_strides))
  534. {
  535. return false;
  536. }
  537. data_size *= static_cast<value_type>(shape[i]);
  538. }
  539. return true;
  540. }
  541. else
  542. {
  543. return false;
  544. }
  545. }
  546. template <class shape_type, class strides_type>
  547. inline void adapt_strides(const shape_type& shape, strides_type& strides) noexcept
  548. {
  549. for (typename shape_type::size_type i = 0; i < shape.size(); ++i)
  550. {
  551. detail::adapt_strides(shape, strides, nullptr, i);
  552. }
  553. }
  554. template <class shape_type, class strides_type, class backstrides_type>
  555. inline void
  556. adapt_strides(const shape_type& shape, strides_type& strides, backstrides_type& backstrides) noexcept
  557. {
  558. for (typename shape_type::size_type i = 0; i < shape.size(); ++i)
  559. {
  560. detail::adapt_strides(shape, strides, &backstrides, i);
  561. }
  562. }
  563. namespace detail
  564. {
  565. template <class S>
  566. inline S unravel_noexcept(typename S::value_type idx, const S& strides, layout_type l) noexcept
  567. {
  568. using value_type = typename S::value_type;
  569. using size_type = typename S::size_type;
  570. S result = xtl::make_sequence<S>(strides.size(), 0);
  571. if (l == layout_type::row_major)
  572. {
  573. for (size_type i = 0; i < strides.size(); ++i)
  574. {
  575. value_type str = strides[i];
  576. value_type quot = str != 0 ? idx / str : 0;
  577. idx = str != 0 ? idx % str : idx;
  578. result[i] = quot;
  579. }
  580. }
  581. else
  582. {
  583. for (size_type i = strides.size(); i != 0; --i)
  584. {
  585. value_type str = strides[i - 1];
  586. value_type quot = str != 0 ? idx / str : 0;
  587. idx = str != 0 ? idx % str : idx;
  588. result[i - 1] = quot;
  589. }
  590. }
  591. return result;
  592. }
  593. }
  594. template <class S>
  595. inline S unravel_from_strides(typename S::value_type index, const S& strides, layout_type l)
  596. {
  597. if (l != layout_type::row_major && l != layout_type::column_major)
  598. {
  599. XTENSOR_THROW(std::runtime_error, "unravel_index: dynamic layout not supported");
  600. }
  601. return detail::unravel_noexcept(index, strides, l);
  602. }
  603. template <class S, class T>
  604. inline get_value_type_t<T> ravel_from_strides(const T& index, const S& strides)
  605. {
  606. return element_offset<get_value_type_t<T>>(strides, index.begin(), index.end());
  607. }
  608. template <class S>
  609. inline get_strides_t<S> unravel_index(typename S::value_type index, const S& shape, layout_type l)
  610. {
  611. using strides_type = get_strides_t<S>;
  612. using strides_value_type = typename strides_type::value_type;
  613. strides_type strides = xtl::make_sequence<strides_type>(shape.size(), 0);
  614. compute_strides(shape, l, strides);
  615. return unravel_from_strides(static_cast<strides_value_type>(index), strides, l);
  616. }
  617. template <class S, class T>
  618. inline std::vector<get_strides_t<S>> unravel_indices(const T& idx, const S& shape, layout_type l)
  619. {
  620. using strides_type = get_strides_t<S>;
  621. using strides_value_type = typename strides_type::value_type;
  622. strides_type strides = xtl::make_sequence<strides_type>(shape.size(), 0);
  623. compute_strides(shape, l, strides);
  624. std::vector<get_strides_t<S>> out(idx.size());
  625. auto out_iter = out.begin();
  626. auto idx_iter = idx.begin();
  627. for (; out_iter != out.end(); ++out_iter, ++idx_iter)
  628. {
  629. *out_iter = unravel_from_strides(static_cast<strides_value_type>(*idx_iter), strides, l);
  630. }
  631. return out;
  632. }
  633. template <class S, class T>
  634. inline get_value_type_t<T> ravel_index(const T& index, const S& shape, layout_type l)
  635. {
  636. using strides_type = get_strides_t<S>;
  637. strides_type strides = xtl::make_sequence<strides_type>(shape.size(), 0);
  638. compute_strides(shape, l, strides);
  639. return ravel_from_strides(index, strides);
  640. }
  641. template <class S, class stype>
  642. inline S uninitialized_shape(stype size)
  643. {
  644. using value_type = typename S::value_type;
  645. using size_type = typename S::size_type;
  646. return xtl::make_sequence<S>(static_cast<size_type>(size), std::numeric_limits<value_type>::max());
  647. }
  648. template <class S1, class S2>
  649. inline bool broadcast_shape(const S1& input, S2& output)
  650. {
  651. bool trivial_broadcast = (input.size() == output.size());
  652. // Indices are faster than reverse iterators
  653. using value_type = typename S2::value_type;
  654. auto output_index = output.size();
  655. auto input_index = input.size();
  656. if (output_index < input_index)
  657. {
  658. throw_broadcast_error(output, input);
  659. }
  660. for (; input_index != 0; --input_index, --output_index)
  661. {
  662. // First case: output = (MAX, MAX, ...., MAX)
  663. // output is a new shape that has not been through
  664. // the broadcast process yet; broadcast is trivial
  665. if (output[output_index - 1] == std::numeric_limits<value_type>::max())
  666. {
  667. output[output_index - 1] = static_cast<value_type>(input[input_index - 1]);
  668. }
  669. // Second case: output has been initialized to 1. Broadcast is trivial
  670. // only if input is 1 to.
  671. else if (output[output_index - 1] == 1)
  672. {
  673. output[output_index - 1] = static_cast<value_type>(input[input_index - 1]);
  674. trivial_broadcast = trivial_broadcast && (input[input_index - 1] == 1);
  675. }
  676. // Third case: output has been initialized to something different from 1.
  677. // if input is 1, then the broadcast is not trivial
  678. else if (input[input_index - 1] == 1)
  679. {
  680. trivial_broadcast = false;
  681. }
  682. // Last case: input and output must have the same value, else
  683. // shape are not compatible and an exception is thrown
  684. else if (static_cast<value_type>(input[input_index - 1]) != output[output_index - 1])
  685. {
  686. throw_broadcast_error(output, input);
  687. }
  688. }
  689. return trivial_broadcast;
  690. }
  691. template <class S1, class S2>
  692. inline bool broadcastable(const S1& src_shape, const S2& dst_shape)
  693. {
  694. auto src_iter = src_shape.crbegin();
  695. auto dst_iter = dst_shape.crbegin();
  696. bool res = dst_shape.size() >= src_shape.size();
  697. for (; src_iter != src_shape.crend() && res; ++src_iter, ++dst_iter)
  698. {
  699. res = (static_cast<std::size_t>(*src_iter) == static_cast<std::size_t>(*dst_iter))
  700. || (*src_iter == 1);
  701. }
  702. return res;
  703. }
  704. template <>
  705. struct check_strides_overlap<layout_type::row_major>
  706. {
  707. template <class S1, class S2>
  708. static std::size_t get(const S1& s1, const S2& s2)
  709. {
  710. using value_type = typename S1::value_type;
  711. // Indices are faster than reverse iterators
  712. auto s1_index = s1.size();
  713. auto s2_index = s2.size();
  714. for (; s2_index != 0; --s1_index, --s2_index)
  715. {
  716. if (static_cast<value_type>(s1[s1_index - 1]) != static_cast<value_type>(s2[s2_index - 1]))
  717. {
  718. break;
  719. }
  720. }
  721. return s1_index;
  722. }
  723. };
  724. template <>
  725. struct check_strides_overlap<layout_type::column_major>
  726. {
  727. template <class S1, class S2>
  728. static std::size_t get(const S1& s1, const S2& s2)
  729. {
  730. // Indices are faster than reverse iterators
  731. using size_type = typename S1::size_type;
  732. using value_type = typename S1::value_type;
  733. size_type index = 0;
  734. // This check is necessary as column major "broadcasting" is still
  735. // performed in a row major fashion
  736. if (s1.size() != s2.size())
  737. {
  738. return 0;
  739. }
  740. auto size = s2.size();
  741. for (; index < size; ++index)
  742. {
  743. if (static_cast<value_type>(s1[index]) != static_cast<value_type>(s2[index]))
  744. {
  745. break;
  746. }
  747. }
  748. return index;
  749. }
  750. };
  751. namespace detail
  752. {
  753. template <class S, std::size_t dim>
  754. inline bool check_in_bounds_impl(const S&)
  755. {
  756. return true;
  757. }
  758. template <class S, std::size_t dim>
  759. inline bool check_in_bounds_impl(const S&, missing_type)
  760. {
  761. return true;
  762. }
  763. template <class S, std::size_t dim, class T, class... Args>
  764. inline bool check_in_bounds_impl(const S& shape, T& arg, Args&... args)
  765. {
  766. if (sizeof...(Args) + 1 > shape.size())
  767. {
  768. return check_in_bounds_impl<S, dim>(shape, args...);
  769. }
  770. else
  771. {
  772. return arg >= T(0) && arg < static_cast<T>(shape[dim])
  773. && check_in_bounds_impl<S, dim + 1>(shape, args...);
  774. }
  775. }
  776. }
  777. template <class S, class... Args>
  778. inline bool check_in_bounds(const S& shape, Args&... args)
  779. {
  780. return detail::check_in_bounds_impl<S, 0>(shape, args...);
  781. }
  782. namespace detail
  783. {
  784. template <class S, std::size_t dim>
  785. inline void normalize_periodic_impl(const S&)
  786. {
  787. }
  788. template <class S, std::size_t dim>
  789. inline void normalize_periodic_impl(const S&, missing_type)
  790. {
  791. }
  792. template <class S, std::size_t dim, class T, class... Args>
  793. inline void normalize_periodic_impl(const S& shape, T& arg, Args&... args)
  794. {
  795. if (sizeof...(Args) + 1 > shape.size())
  796. {
  797. normalize_periodic_impl<S, dim>(shape, args...);
  798. }
  799. else
  800. {
  801. T n = static_cast<T>(shape[dim]);
  802. arg = (n + (arg % n)) % n;
  803. normalize_periodic_impl<S, dim + 1>(shape, args...);
  804. }
  805. }
  806. }
  807. template <class S, class... Args>
  808. inline void normalize_periodic(const S& shape, Args&... args)
  809. {
  810. check_dimension(shape, args...);
  811. detail::normalize_periodic_impl<S, 0>(shape, args...);
  812. }
  813. }
  814. #endif