xmanipulation.hpp 37 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145
  1. /***************************************************************************
  2. * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
  3. * Copyright (c) QuantStack *
  4. * *
  5. * Distributed under the terms of the BSD 3-Clause License. *
  6. * *
  7. * The full license is in the file LICENSE, distributed with this software. *
  8. ****************************************************************************/
  9. #ifndef XTENSOR_MANIPULATION_HPP
  10. #define XTENSOR_MANIPULATION_HPP
  11. #include <algorithm>
  12. #include <utility>
  13. #include <xtl/xcompare.hpp>
  14. #include <xtl/xsequence.hpp>
  15. #include "xbuilder.hpp"
  16. #include "xexception.hpp"
  17. #include "xrepeat.hpp"
  18. #include "xstrided_view.hpp"
  19. #include "xtensor_config.hpp"
  20. #include "xutils.hpp"
  21. namespace xt
  22. {
  23. /**
  24. * @defgroup xt_xmanipulation
  25. */
  26. namespace check_policy
  27. {
  28. struct none
  29. {
  30. };
  31. struct full
  32. {
  33. };
  34. }
  35. template <class E>
  36. auto transpose(E&& e) noexcept;
  37. template <class E, class S, class Tag = check_policy::none>
  38. auto transpose(E&& e, S&& permutation, Tag check_policy = Tag());
  39. template <class E>
  40. auto swapaxes(E&& e, std::ptrdiff_t axis1, std::ptrdiff_t axis2);
  41. template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E>
  42. auto ravel(E&& e);
  43. template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E>
  44. auto flatten(E&& e);
  45. template <layout_type L, class T>
  46. auto flatnonzero(const T& arr);
  47. template <class E>
  48. auto trim_zeros(E&& e, const std::string& direction = "fb");
  49. template <class E>
  50. auto squeeze(E&& e);
  51. template <class E, class S, class Tag = check_policy::none, std::enable_if_t<!xtl::is_integral<S>::value, int> = 0>
  52. auto squeeze(E&& e, S&& axis, Tag check_policy = Tag());
  53. template <class E>
  54. auto expand_dims(E&& e, std::size_t axis);
  55. template <std::size_t N, class E>
  56. auto atleast_Nd(E&& e);
  57. template <class E>
  58. auto atleast_1d(E&& e);
  59. template <class E>
  60. auto atleast_2d(E&& e);
  61. template <class E>
  62. auto atleast_3d(E&& e);
  63. template <class E>
  64. auto split(E& e, std::size_t n, std::size_t axis = 0);
  65. template <class E>
  66. auto hsplit(E& e, std::size_t n);
  67. template <class E>
  68. auto vsplit(E& e, std::size_t n);
  69. template <class E>
  70. auto flip(E&& e);
  71. template <class E>
  72. auto flip(E&& e, std::size_t axis);
  73. template <std::ptrdiff_t N = 1, class E>
  74. auto rot90(E&& e, const std::array<std::ptrdiff_t, 2>& axes = {0, 1});
  75. template <class E>
  76. auto roll(E&& e, std::ptrdiff_t shift);
  77. template <class E>
  78. auto roll(E&& e, std::ptrdiff_t shift, std::ptrdiff_t axis);
  79. template <class E>
  80. auto repeat(E&& e, std::size_t repeats, std::size_t axis);
  81. template <class E>
  82. auto repeat(E&& e, const std::vector<std::size_t>& repeats, std::size_t axis);
  83. template <class E>
  84. auto repeat(E&& e, std::vector<std::size_t>&& repeats, std::size_t axis);
  85. /****************************
  86. * transpose implementation *
  87. ****************************/
  88. namespace detail
  89. {
  90. inline layout_type transpose_layout_noexcept(layout_type l) noexcept
  91. {
  92. layout_type result = l;
  93. if (l == layout_type::row_major)
  94. {
  95. result = layout_type::column_major;
  96. }
  97. else if (l == layout_type::column_major)
  98. {
  99. result = layout_type::row_major;
  100. }
  101. return result;
  102. }
  103. inline layout_type transpose_layout(layout_type l)
  104. {
  105. if (l != layout_type::row_major && l != layout_type::column_major)
  106. {
  107. XTENSOR_THROW(transpose_error, "cannot compute transposed layout of dynamic layout");
  108. }
  109. return transpose_layout_noexcept(l);
  110. }
  111. template <class E, class S>
  112. inline auto transpose_impl(E&& e, S&& permutation, check_policy::none)
  113. {
  114. if (sequence_size(permutation) != e.dimension())
  115. {
  116. XTENSOR_THROW(transpose_error, "Permutation does not have the same size as shape");
  117. }
  118. // permute stride and shape
  119. using shape_type = xindex_type_t<typename std::decay_t<E>::shape_type>;
  120. shape_type temp_shape;
  121. resize_container(temp_shape, e.shape().size());
  122. using strides_type = get_strides_t<shape_type>;
  123. strides_type temp_strides;
  124. resize_container(temp_strides, e.strides().size());
  125. using size_type = typename std::decay_t<E>::size_type;
  126. for (std::size_t i = 0; i < e.shape().size(); ++i)
  127. {
  128. if (std::size_t(permutation[i]) >= e.dimension())
  129. {
  130. XTENSOR_THROW(transpose_error, "Permutation contains wrong axis");
  131. }
  132. size_type perm = static_cast<size_type>(permutation[i]);
  133. temp_shape[i] = e.shape()[perm];
  134. temp_strides[i] = e.strides()[perm];
  135. }
  136. layout_type new_layout = layout_type::dynamic;
  137. if (std::is_sorted(std::begin(permutation), std::end(permutation)))
  138. {
  139. // keep old layout
  140. new_layout = e.layout();
  141. }
  142. else if (std::is_sorted(std::begin(permutation), std::end(permutation), std::greater<>()))
  143. {
  144. new_layout = transpose_layout_noexcept(e.layout());
  145. }
  146. return strided_view(
  147. std::forward<E>(e),
  148. std::move(temp_shape),
  149. std::move(temp_strides),
  150. get_offset<XTENSOR_DEFAULT_LAYOUT>(e),
  151. new_layout
  152. );
  153. }
  154. template <class E, class S>
  155. inline auto transpose_impl(E&& e, S&& permutation, check_policy::full)
  156. {
  157. // check if axis appears twice in permutation
  158. for (std::size_t i = 0; i < sequence_size(permutation); ++i)
  159. {
  160. for (std::size_t j = i + 1; j < sequence_size(permutation); ++j)
  161. {
  162. if (permutation[i] == permutation[j])
  163. {
  164. XTENSOR_THROW(transpose_error, "Permutation contains axis more than once");
  165. }
  166. }
  167. }
  168. return transpose_impl(std::forward<E>(e), std::forward<S>(permutation), check_policy::none());
  169. }
  170. template <class E, class S, class X, std::enable_if_t<has_data_interface<std::decay_t<E>>::value>* = nullptr>
  171. inline void compute_transposed_strides(E&& e, const S&, X& strides)
  172. {
  173. std::copy(e.strides().crbegin(), e.strides().crend(), strides.begin());
  174. }
  175. template <class E, class S, class X, std::enable_if_t<!has_data_interface<std::decay_t<E>>::value>* = nullptr>
  176. inline void compute_transposed_strides(E&&, const S& shape, X& strides)
  177. {
  178. // In the case where E does not have a data interface, the transposition
  179. // makes use of a flat storage adaptor that has layout XTENSOR_DEFAULT_TRAVERSAL
  180. // which should be the one inverted.
  181. layout_type l = transpose_layout(XTENSOR_DEFAULT_TRAVERSAL);
  182. compute_strides(shape, l, strides);
  183. }
  184. }
  185. /**
  186. * Returns a transpose view by reversing the dimensions of xexpression e
  187. *
  188. * @ingroup xt_xmanipulation
  189. * @param e the input expression
  190. */
  191. template <class E>
  192. inline auto transpose(E&& e) noexcept
  193. {
  194. using shape_type = xindex_type_t<typename std::decay_t<E>::shape_type>;
  195. shape_type shape;
  196. resize_container(shape, e.shape().size());
  197. std::copy(e.shape().crbegin(), e.shape().crend(), shape.begin());
  198. get_strides_t<shape_type> strides;
  199. resize_container(strides, e.shape().size());
  200. detail::compute_transposed_strides(e, shape, strides);
  201. layout_type new_layout = detail::transpose_layout_noexcept(e.layout());
  202. return strided_view(
  203. std::forward<E>(e),
  204. std::move(shape),
  205. std::move(strides),
  206. detail::get_offset<XTENSOR_DEFAULT_TRAVERSAL>(e),
  207. new_layout
  208. );
  209. }
  210. /**
  211. * Returns a transpose view by permuting the xexpression e with @p permutation.
  212. *
  213. * @ingroup xt_xmanipulation
  214. * @param e the input expression
  215. * @param permutation the sequence containing permutation
  216. * @param check_policy the check level (check_policy::full() or check_policy::none())
  217. * @tparam Tag selects the level of error checking on permutation vector defaults to check_policy::none.
  218. */
  219. template <class E, class S, class Tag>
  220. inline auto transpose(E&& e, S&& permutation, Tag check_policy)
  221. {
  222. return detail::transpose_impl(std::forward<E>(e), std::forward<S>(permutation), check_policy);
  223. }
  224. /// @cond DOXYGEN_INCLUDE_SFINAE
  225. template <class E, class I, std::size_t N, class Tag = check_policy::none>
  226. inline auto transpose(E&& e, const I (&permutation)[N], Tag check_policy = Tag())
  227. {
  228. return detail::transpose_impl(std::forward<E>(e), permutation, check_policy);
  229. }
  230. /// @endcond
  231. /*****************************
  232. * swapaxes implementation *
  233. *****************************/
  234. namespace detail
  235. {
  236. template <class S>
  237. inline S swapaxes_perm(std::size_t dim, std::ptrdiff_t axis1, std::ptrdiff_t axis2)
  238. {
  239. const std::size_t ax1 = normalize_axis(dim, axis1);
  240. const std::size_t ax2 = normalize_axis(dim, axis2);
  241. auto perm = xtl::make_sequence<S>(dim, 0);
  242. using id_t = typename S::value_type;
  243. std::iota(perm.begin(), perm.end(), id_t(0));
  244. perm[ax1] = ax2;
  245. perm[ax2] = ax1;
  246. return perm;
  247. }
  248. }
  249. /**
  250. * Return a new expression with two axes interchanged.
  251. *
  252. * The two axis parameter @p axis and @p axis2 are interchangable.
  253. *
  254. * @ingroup xt_xmanipulation
  255. * @param e The input expression
  256. * @param axis1 First axis to swap
  257. * @param axis2 Second axis to swap
  258. */
  259. template <class E>
  260. inline auto swapaxes(E&& e, std::ptrdiff_t axis1, std::ptrdiff_t axis2)
  261. {
  262. const auto dim = e.dimension();
  263. check_axis_in_dim(axis1, dim, "Parameter axis1");
  264. check_axis_in_dim(axis2, dim, "Parameter axis2");
  265. using strides_t = get_strides_t<typename std::decay_t<E>::shape_type>;
  266. return transpose(std::forward<E>(e), detail::swapaxes_perm<strides_t>(dim, axis1, axis2));
  267. }
  268. /*****************************
  269. * moveaxis implementation *
  270. *****************************/
  271. namespace detail
  272. {
  273. template <class S>
  274. inline S moveaxis_perm(std::size_t dim, std::ptrdiff_t src, std::ptrdiff_t dest)
  275. {
  276. using id_t = typename S::value_type;
  277. const std::size_t src_norm = normalize_axis(dim, src);
  278. const std::size_t dest_norm = normalize_axis(dim, dest);
  279. // Initializing to src_norm handles case where `dest == -1` and the loop
  280. // does not go check `perm_idx == dest_norm` a `dim+1`th time.
  281. auto perm = xtl::make_sequence<S>(dim, src_norm);
  282. id_t perm_idx = 0;
  283. for (id_t i = 0; xtl::cmp_less(i, dim); ++i)
  284. {
  285. if (xtl::cmp_equal(perm_idx, dest_norm))
  286. {
  287. perm[perm_idx] = src_norm;
  288. ++perm_idx;
  289. }
  290. if (xtl::cmp_not_equal(i, src_norm))
  291. {
  292. perm[perm_idx] = i;
  293. ++perm_idx;
  294. }
  295. }
  296. return perm;
  297. }
  298. }
  299. /**
  300. * Return a new expression with an axis move to a new position.
  301. *
  302. * @ingroup xt_xmanipulation
  303. * @param e The input expression
  304. * @param src Original position of the axis to move
  305. * @param dest Destination position for the original axis.
  306. */
  307. template <class E>
  308. inline auto moveaxis(E&& e, std::ptrdiff_t src, std::ptrdiff_t dest)
  309. {
  310. const auto dim = e.dimension();
  311. check_axis_in_dim(src, dim, "Parameter src");
  312. check_axis_in_dim(dest, dim, "Parameter dest");
  313. using strides_t = get_strides_t<typename std::decay_t<E>::shape_type>;
  314. return xt::transpose(std::forward<E>(e), detail::moveaxis_perm<strides_t>(e.dimension(), src, dest));
  315. }
  316. /************************************
  317. * ravel and flatten implementation *
  318. ************************************/
  319. namespace detail
  320. {
  321. template <class E, layout_type L>
  322. struct expression_iterator_getter
  323. {
  324. using iterator = decltype(std::declval<E>().template begin<L>());
  325. using const_iterator = decltype(std::declval<E>().template cbegin<L>());
  326. inline static iterator begin(E& e)
  327. {
  328. return e.template begin<L>();
  329. }
  330. inline static const_iterator cbegin(E& e)
  331. {
  332. return e.template cbegin<L>();
  333. }
  334. inline static auto size(E& e)
  335. {
  336. return e.size();
  337. }
  338. };
  339. }
  340. /**
  341. * Return a flatten view of the given expression. No copy is made.
  342. *
  343. * @ingroup xt_xmanipulation
  344. * @param e the input expression
  345. * @tparam L the layout used to read the elements of e.
  346. * If no parameter is specified, XTENSOR_DEFAULT_TRAVERSAL is used.
  347. * @tparam E the type of the expression
  348. */
  349. template <layout_type L, class E>
  350. inline auto ravel(E&& e)
  351. {
  352. using iterator = decltype(e.template begin<L>());
  353. using iterator_getter = detail::expression_iterator_getter<std::remove_reference_t<E>, L>;
  354. auto size = e.size();
  355. auto adaptor = make_xiterator_adaptor(std::forward<E>(e), iterator_getter());
  356. constexpr layout_type layout = std::is_pointer<iterator>::value ? L : layout_type::dynamic;
  357. using type = xtensor_view<decltype(adaptor), 1, layout, extension::get_expression_tag_t<E>>;
  358. return type(std::move(adaptor), {size});
  359. }
  360. /**
  361. * Return a flatten view of the given expression.
  362. *
  363. * No copy is made.
  364. * This method is equivalent to ravel and is provided for API sameness with NumPy.
  365. *
  366. * @ingroup xt_xmanipulation
  367. * @param e the input expression
  368. * @tparam L the layout used to read the elements of e.
  369. * If no parameter is specified, XTENSOR_DEFAULT_TRAVERSAL is used.
  370. * @tparam E the type of the expression
  371. * @sa ravel
  372. */
  373. template <layout_type L, class E>
  374. inline auto flatten(E&& e)
  375. {
  376. return ravel<L>(std::forward<E>(e));
  377. }
  378. /**
  379. * Return indices that are non-zero in the flattened version of arr.
  380. *
  381. * Equivalent to ``nonzero(ravel<layout_type>(arr))[0];``
  382. *
  383. * @param arr input array
  384. * @return indices that are non-zero in the flattened version of arr
  385. */
  386. template <layout_type L, class T>
  387. inline auto flatnonzero(const T& arr)
  388. {
  389. return nonzero(ravel<L>(arr))[0];
  390. }
  391. /*****************************
  392. * trim_zeros implementation *
  393. *****************************/
  394. /**
  395. * Trim zeros at beginning, end or both of 1D sequence.
  396. *
  397. * @ingroup xt_xmanipulation
  398. * @param e input xexpression
  399. * @param direction string of either 'f' for trim from beginning, 'b' for trim from end
  400. * or 'fb' (default) for both.
  401. * @return returns a view without zeros at the beginning and end
  402. */
  403. template <class E>
  404. inline auto trim_zeros(E&& e, const std::string& direction)
  405. {
  406. XTENSOR_ASSERT_MSG(e.dimension() == 1, "Dimension for trim_zeros has to be 1.");
  407. std::ptrdiff_t begin = 0, end = static_cast<std::ptrdiff_t>(e.size());
  408. auto find_fun = [](const auto& i)
  409. {
  410. return i != 0;
  411. };
  412. if (direction.find("f") != std::string::npos)
  413. {
  414. begin = std::find_if(e.cbegin(), e.cend(), find_fun) - e.cbegin();
  415. }
  416. if (direction.find("b") != std::string::npos && begin != end)
  417. {
  418. end -= std::find_if(e.crbegin(), e.crend(), find_fun) - e.crbegin();
  419. }
  420. return strided_view(std::forward<E>(e), {range(begin, end)});
  421. }
  422. /**************************
  423. * squeeze implementation *
  424. **************************/
  425. /**
  426. * Returns a squeeze view of the given expression.
  427. *
  428. * No copy is made. Squeezing an expression removes dimensions of extent 1.
  429. *
  430. * @ingroup xt_xmanipulation
  431. * @param e the input expression
  432. * @tparam E the type of the expression
  433. */
  434. template <class E>
  435. inline auto squeeze(E&& e)
  436. {
  437. dynamic_shape<std::size_t> new_shape;
  438. dynamic_shape<std::ptrdiff_t> new_strides;
  439. std::copy_if(
  440. e.shape().cbegin(),
  441. e.shape().cend(),
  442. std::back_inserter(new_shape),
  443. [](std::size_t i)
  444. {
  445. return i != 1;
  446. }
  447. );
  448. decltype(auto) old_strides = detail::get_strides<XTENSOR_DEFAULT_LAYOUT>(e);
  449. std::copy_if(
  450. old_strides.cbegin(),
  451. old_strides.cend(),
  452. std::back_inserter(new_strides),
  453. [](std::ptrdiff_t i)
  454. {
  455. return i != 0;
  456. }
  457. );
  458. return strided_view(std::forward<E>(e), std::move(new_shape), std::move(new_strides), 0, e.layout());
  459. }
  460. namespace detail
  461. {
  462. template <class E, class S>
  463. inline auto squeeze_impl(E&& e, S&& axis, check_policy::none)
  464. {
  465. std::size_t new_dim = e.dimension() - axis.size();
  466. dynamic_shape<std::size_t> new_shape(new_dim);
  467. dynamic_shape<std::ptrdiff_t> new_strides(new_dim);
  468. decltype(auto) old_strides = detail::get_strides<XTENSOR_DEFAULT_LAYOUT>(e);
  469. for (std::size_t i = 0, ix = 0; i < e.dimension(); ++i)
  470. {
  471. if (axis.cend() == std::find(axis.cbegin(), axis.cend(), i))
  472. {
  473. new_shape[ix] = e.shape()[i];
  474. new_strides[ix++] = old_strides[i];
  475. }
  476. }
  477. return strided_view(std::forward<E>(e), std::move(new_shape), std::move(new_strides), 0, e.layout());
  478. }
  479. template <class E, class S>
  480. inline auto squeeze_impl(E&& e, S&& axis, check_policy::full)
  481. {
  482. for (auto ix : axis)
  483. {
  484. if (static_cast<std::size_t>(ix) > e.dimension())
  485. {
  486. XTENSOR_THROW(std::runtime_error, "Axis argument to squeeze > dimension of expression");
  487. }
  488. if (e.shape()[static_cast<std::size_t>(ix)] != 1)
  489. {
  490. XTENSOR_THROW(std::runtime_error, "Trying to squeeze axis != 1");
  491. }
  492. }
  493. return squeeze_impl(std::forward<E>(e), std::forward<S>(axis), check_policy::none());
  494. }
  495. }
  496. /**
  497. * Remove single-dimensional entries from the shape of an xexpression
  498. *
  499. * @ingroup xt_xmanipulation
  500. * @param e input xexpression
  501. * @param axis integer or container of integers, select a subset of single-dimensional
  502. * entries of the shape.
  503. * @param check_policy select check_policy. With check_policy::full(), selecting an axis
  504. * which is greater than one will throw a runtime_error.
  505. */
  506. template <class E, class S, class Tag, std::enable_if_t<!xtl::is_integral<S>::value, int>>
  507. inline auto squeeze(E&& e, S&& axis, Tag check_policy)
  508. {
  509. return detail::squeeze_impl(std::forward<E>(e), std::forward<S>(axis), check_policy);
  510. }
  511. /// @cond DOXYGEN_INCLUDE_SFINAE
  512. template <class E, class I, std::size_t N, class Tag = check_policy::none>
  513. inline auto squeeze(E&& e, const I (&axis)[N], Tag check_policy = Tag())
  514. {
  515. using arr_t = std::array<I, N>;
  516. return detail::squeeze_impl(
  517. std::forward<E>(e),
  518. xtl::forward_sequence<arr_t, decltype(axis)>(axis),
  519. check_policy
  520. );
  521. }
  522. template <class E, class Tag = check_policy::none>
  523. inline auto squeeze(E&& e, std::size_t axis, Tag check_policy = Tag())
  524. {
  525. return squeeze(std::forward<E>(e), std::array<std::size_t, 1>{axis}, check_policy);
  526. }
  527. /// @endcond
  528. /******************************
  529. * expand_dims implementation *
  530. ******************************/
  531. /**
  532. * Expand the shape of an xexpression.
  533. *
  534. * Insert a new axis that will appear at the axis position in the expanded array shape.
  535. * This will return a ``strided_view`` with a ``xt::newaxis()`` at the indicated axis.
  536. *
  537. * @ingroup xt_xmanipulation
  538. * @param e input xexpression
  539. * @param axis axis to expand
  540. * @return returns a ``strided_view`` with expanded dimension
  541. */
  542. template <class E>
  543. inline auto expand_dims(E&& e, std::size_t axis)
  544. {
  545. xstrided_slice_vector sv(e.dimension() + 1, all());
  546. sv[axis] = newaxis();
  547. return strided_view(std::forward<E>(e), std::move(sv));
  548. }
  549. /*****************************
  550. * atleast_Nd implementation *
  551. *****************************/
  552. /**
  553. * Expand dimensions of xexpression to at least `N`
  554. *
  555. * This adds ``newaxis()`` slices to a ``strided_view`` until
  556. * the dimension of the view reaches at least `N`.
  557. * Note: dimensions are added equally at the beginning and the end.
  558. * For example, a 1-D array of shape (N,) becomes a view of shape (1, N, 1).
  559. *
  560. * @ingroup xt_xmanipulation
  561. * @param e input xexpression
  562. * @tparam N the number of requested dimensions
  563. * @return ``strided_view`` with expanded dimensions
  564. */
  565. template <std::size_t N, class E>
  566. inline auto atleast_Nd(E&& e)
  567. {
  568. xstrided_slice_vector sv((std::max)(e.dimension(), N), all());
  569. if (e.dimension() < N)
  570. {
  571. std::size_t i = 0;
  572. std::size_t end = static_cast<std::size_t>(std::round(double(N - e.dimension()) / double(N)));
  573. for (; i < end; ++i)
  574. {
  575. sv[i] = newaxis();
  576. }
  577. i += e.dimension();
  578. for (; i < N; ++i)
  579. {
  580. sv[i] = newaxis();
  581. }
  582. }
  583. return strided_view(std::forward<E>(e), std::move(sv));
  584. }
  585. /**
  586. * Expand to at least 1D
  587. *
  588. * @ingroup xt_xmanipulation
  589. * @sa atleast_Nd
  590. */
  591. template <class E>
  592. inline auto atleast_1d(E&& e)
  593. {
  594. return atleast_Nd<1>(std::forward<E>(e));
  595. }
  596. /**
  597. * Expand to at least 2D
  598. *
  599. * @ingroup xt_xmanipulation
  600. * @sa atleast_Nd
  601. */
  602. template <class E>
  603. inline auto atleast_2d(E&& e)
  604. {
  605. return atleast_Nd<2>(std::forward<E>(e));
  606. }
  607. /**
  608. * Expand to at least 3D
  609. *
  610. * @ingroup xt_xmanipulation
  611. * @sa atleast_Nd
  612. */
  613. template <class E>
  614. inline auto atleast_3d(E&& e)
  615. {
  616. return atleast_Nd<3>(std::forward<E>(e));
  617. }
  618. /************************
  619. * split implementation *
  620. ************************/
  621. /**
  622. * Split xexpression along axis into subexpressions
  623. *
  624. * This splits an xexpression along the axis in `n` equal parts and
  625. * returns a vector of ``strided_view``.
  626. * Calling split with axis > dimension of e or a `n` that does not result in
  627. * an equal division of the xexpression will throw a runtime_error.
  628. *
  629. * @ingroup xt_xmanipulation
  630. * @param e input xexpression
  631. * @param n number of elements to return
  632. * @param axis axis along which to split the expression
  633. */
  634. template <class E>
  635. inline auto split(E& e, std::size_t n, std::size_t axis)
  636. {
  637. if (axis >= e.dimension())
  638. {
  639. XTENSOR_THROW(std::runtime_error, "Split along axis > dimension.");
  640. }
  641. std::size_t ax_sz = e.shape()[axis];
  642. xstrided_slice_vector sv(e.dimension(), all());
  643. std::size_t step = ax_sz / n;
  644. std::size_t rest = ax_sz % n;
  645. if (rest)
  646. {
  647. XTENSOR_THROW(std::runtime_error, "Split does not result in equal division.");
  648. }
  649. std::vector<decltype(strided_view(e, sv))> result;
  650. for (std::size_t i = 0; i < n; ++i)
  651. {
  652. sv[axis] = range(i * step, (i + 1) * step);
  653. result.emplace_back(strided_view(e, sv));
  654. }
  655. return result;
  656. }
  657. /**
  658. * Split an xexpression into subexpressions horizontally (column-wise)
  659. *
  660. * This method is equivalent to ``split(e, n, 1)``.
  661. *
  662. * @ingroup xt_xmanipulation
  663. * @param e input xexpression
  664. * @param n number of elements to return
  665. */
  666. template <class E>
  667. inline auto hsplit(E& e, std::size_t n)
  668. {
  669. return split(e, n, std::size_t(1));
  670. }
  671. /**
  672. * Split an xexpression into subexpressions vertically (row-wise)
  673. *
  674. * This method is equivalent to ``split(e, n, 0)``.
  675. *
  676. * @ingroup xt_xmanipulation
  677. * @param e input xexpression
  678. * @param n number of elements to return
  679. */
  680. template <class E>
  681. inline auto vsplit(E& e, std::size_t n)
  682. {
  683. return split(e, n, std::size_t(0));
  684. }
  685. /***********************
  686. * flip implementation *
  687. ***********************/
  688. /**
  689. * Reverse the order of elements in an xexpression along every axis.
  690. *
  691. * @ingroup xt_xmanipulation
  692. * @param e the input xexpression
  693. * @return returns a view with the result of the flip.
  694. */
  695. template <class E>
  696. inline auto flip(E&& e)
  697. {
  698. using size_type = typename std::decay_t<E>::size_type;
  699. auto r = flip(e, 0);
  700. for (size_type d = 1; d < e.dimension(); ++d)
  701. {
  702. r = flip(r, d);
  703. }
  704. return r;
  705. }
  706. /**
  707. * Reverse the order of elements in an xexpression along the given axis.
  708. *
  709. * Note: A NumPy/Matlab style `flipud(arr)` is equivalent to `xt::flip(arr, 0)`,
  710. * `fliplr(arr)` to `xt::flip(arr, 1)`.
  711. *
  712. * @ingroup xt_xmanipulation
  713. * @param e the input xexpression
  714. * @param axis the axis along which elements should be reversed
  715. * @return returns a view with the result of the flip
  716. */
  717. template <class E>
  718. inline auto flip(E&& e, std::size_t axis)
  719. {
  720. using shape_type = xindex_type_t<typename std::decay_t<E>::shape_type>;
  721. shape_type shape;
  722. resize_container(shape, e.shape().size());
  723. std::copy(e.shape().cbegin(), e.shape().cend(), shape.begin());
  724. get_strides_t<shape_type> strides;
  725. decltype(auto) old_strides = detail::get_strides<XTENSOR_DEFAULT_LAYOUT>(e);
  726. resize_container(strides, old_strides.size());
  727. std::copy(old_strides.cbegin(), old_strides.cend(), strides.begin());
  728. strides[axis] *= -1;
  729. std::size_t offset = static_cast<std::size_t>(
  730. static_cast<std::ptrdiff_t>(e.data_offset())
  731. + old_strides[axis] * (static_cast<std::ptrdiff_t>(e.shape()[axis]) - 1)
  732. );
  733. return strided_view(std::forward<E>(e), std::move(shape), std::move(strides), offset);
  734. }
  735. /************************
  736. * rot90 implementation *
  737. ************************/
  738. namespace detail
  739. {
  740. template <std::ptrdiff_t N>
  741. struct rot90_impl;
  742. template <>
  743. struct rot90_impl<0>
  744. {
  745. template <class E>
  746. inline auto operator()(E&& e, const std::array<std::size_t, 2>& /*axes*/)
  747. {
  748. return std::forward<E>(e);
  749. }
  750. };
  751. template <>
  752. struct rot90_impl<1>
  753. {
  754. template <class E>
  755. inline auto operator()(E&& e, const std::array<std::size_t, 2>& axes)
  756. {
  757. using std::swap;
  758. dynamic_shape<std::ptrdiff_t> axes_list(e.shape().size());
  759. std::iota(axes_list.begin(), axes_list.end(), 0);
  760. swap(axes_list[axes[0]], axes_list[axes[1]]);
  761. return transpose(flip(std::forward<E>(e), axes[1]), std::move(axes_list));
  762. }
  763. };
  764. template <>
  765. struct rot90_impl<2>
  766. {
  767. template <class E>
  768. inline auto operator()(E&& e, const std::array<std::size_t, 2>& axes)
  769. {
  770. return flip(flip(std::forward<E>(e), axes[0]), axes[1]);
  771. }
  772. };
  773. template <>
  774. struct rot90_impl<3>
  775. {
  776. template <class E>
  777. inline auto operator()(E&& e, const std::array<std::size_t, 2>& axes)
  778. {
  779. using std::swap;
  780. dynamic_shape<std::ptrdiff_t> axes_list(e.shape().size());
  781. std::iota(axes_list.begin(), axes_list.end(), 0);
  782. swap(axes_list[axes[0]], axes_list[axes[1]]);
  783. return flip(transpose(std::forward<E>(e), std::move(axes_list)), axes[1]);
  784. }
  785. };
  786. }
  787. /**
  788. * Rotate an array by 90 degrees in the plane specified by axes.
  789. *
  790. * Rotation direction is from the first towards the second axis.
  791. *
  792. * @ingroup xt_xmanipulation
  793. * @param e the input xexpression
  794. * @param axes the array is rotated in the plane defined by the axes. Axes must be different.
  795. * @tparam N number of times the array is rotated by 90 degrees. Default is 1.
  796. * @return returns a view with the result of the rotation
  797. */
  798. template <std::ptrdiff_t N, class E>
  799. inline auto rot90(E&& e, const std::array<std::ptrdiff_t, 2>& axes)
  800. {
  801. auto ndim = static_cast<std::ptrdiff_t>(e.shape().size());
  802. if (axes[0] == axes[1] || std::abs(axes[0] - axes[1]) == ndim)
  803. {
  804. XTENSOR_THROW(std::runtime_error, "Axes must be different");
  805. }
  806. auto norm_axes = forward_normalize<std::array<std::size_t, 2>>(e, axes);
  807. constexpr std::ptrdiff_t n = (4 + (N % 4)) % 4;
  808. return detail::rot90_impl<n>()(std::forward<E>(e), norm_axes);
  809. }
  810. /***********************
  811. * roll implementation *
  812. ***********************/
  813. /**
  814. * Roll an expression.
  815. *
  816. * The expression is flatten before shifting, after which the original
  817. * shape is restore. Elements that roll beyond the last position are
  818. * re-introduced at the first. This function does not change the input
  819. * expression.
  820. *
  821. * @ingroup xt_xmanipulation
  822. * @param e the input xexpression
  823. * @param shift the number of places by which elements are shifted
  824. * @return a roll of the input expression
  825. */
  826. template <class E>
  827. inline auto roll(E&& e, std::ptrdiff_t shift)
  828. {
  829. auto cpy = empty_like(e);
  830. auto flat_size = std::accumulate(
  831. cpy.shape().begin(),
  832. cpy.shape().end(),
  833. 1L,
  834. std::multiplies<std::size_t>()
  835. );
  836. while (shift < 0)
  837. {
  838. shift += flat_size;
  839. }
  840. shift %= flat_size;
  841. std::copy(e.begin(), e.end() - shift, std::copy(e.end() - shift, e.end(), cpy.begin()));
  842. return cpy;
  843. }
  844. namespace detail
  845. {
  846. /**
  847. * Algorithm adapted from pythran/pythonic/numpy/roll.hpp
  848. */
  849. template <class To, class From, class S>
  850. To roll(To to, From from, std::ptrdiff_t shift, std::size_t axis, const S& shape, std::size_t M)
  851. {
  852. std::ptrdiff_t dim = std::ptrdiff_t(shape[M]);
  853. std::ptrdiff_t offset = std::accumulate(
  854. shape.begin() + M + 1,
  855. shape.end(),
  856. std::ptrdiff_t(1),
  857. std::multiplies<std::ptrdiff_t>()
  858. );
  859. if (shape.size() == M + 1)
  860. {
  861. if (axis == M)
  862. {
  863. const auto split = from + (dim - shift) * offset;
  864. for (auto iter = split, end = from + dim * offset; iter != end; iter += offset, ++to)
  865. {
  866. *to = *iter;
  867. }
  868. for (auto iter = from, end = split; iter != end; iter += offset, ++to)
  869. {
  870. *to = *iter;
  871. }
  872. }
  873. else
  874. {
  875. for (auto iter = from, end = from + dim * offset; iter != end; iter += offset, ++to)
  876. {
  877. *to = *iter;
  878. }
  879. }
  880. }
  881. else
  882. {
  883. if (axis == M)
  884. {
  885. const auto split = from + (dim - shift) * offset;
  886. for (auto iter = split, end = from + dim * offset; iter != end; iter += offset)
  887. {
  888. to = roll(to, iter, shift, axis, shape, M + 1);
  889. }
  890. for (auto iter = from, end = split; iter != end; iter += offset)
  891. {
  892. to = roll(to, iter, shift, axis, shape, M + 1);
  893. }
  894. }
  895. else
  896. {
  897. for (auto iter = from, end = from + dim * offset; iter != end; iter += offset)
  898. {
  899. to = roll(to, iter, shift, axis, shape, M + 1);
  900. }
  901. }
  902. }
  903. return to;
  904. }
  905. }
  906. /**
  907. * Roll an expression along a given axis.
  908. *
  909. * Elements that roll beyond the last position are re-introduced at the first.
  910. * This function does not change the input expression.
  911. *
  912. * @ingroup xt_xmanipulation
  913. * @param e the input xexpression
  914. * @param shift the number of places by which elements are shifted
  915. * @param axis the axis along which elements are shifted.
  916. * @return a roll of the input expression
  917. */
  918. template <class E>
  919. inline auto roll(E&& e, std::ptrdiff_t shift, std::ptrdiff_t axis)
  920. {
  921. auto cpy = empty_like(e);
  922. const auto& shape = cpy.shape();
  923. std::size_t saxis = static_cast<std::size_t>(axis);
  924. if (axis < 0)
  925. {
  926. axis += std::ptrdiff_t(cpy.dimension());
  927. }
  928. if (saxis >= cpy.dimension() || axis < 0)
  929. {
  930. XTENSOR_THROW(std::runtime_error, "axis is no within shape dimension.");
  931. }
  932. const auto axis_dim = static_cast<std::ptrdiff_t>(shape[saxis]);
  933. while (shift < 0)
  934. {
  935. shift += axis_dim;
  936. }
  937. detail::roll(cpy.begin(), e.begin(), shift, saxis, shape, 0);
  938. return cpy;
  939. }
  940. /****************************
  941. * repeat implementation *
  942. ****************************/
  943. namespace detail
  944. {
  945. template <class E, class R>
  946. inline auto make_xrepeat(E&& e, R&& r, typename std::decay_t<E>::size_type axis)
  947. {
  948. const auto casted_axis = static_cast<typename std::decay_t<E>::size_type>(axis);
  949. if (r.size() != e.shape(casted_axis))
  950. {
  951. XTENSOR_THROW(std::invalid_argument, "repeats must have the same size as the specified axis");
  952. }
  953. return xrepeat<const_xclosure_t<E>, R>(std::forward<E>(e), std::forward<R>(r), axis);
  954. }
  955. }
  956. /**
  957. * Repeat elements of an expression along a given axis.
  958. *
  959. * @ingroup xt_xmanipulation
  960. * @param e the input xexpression
  961. * @param repeats The number of repetition of each elements.
  962. * @p repeats is broadcasted to fit the shape of the given @p axis.
  963. * @param axis the axis along which to repeat the value
  964. * @return an expression which as the same shape as \ref e, except along the given \ref axis
  965. */
  966. template <class E>
  967. inline auto repeat(E&& e, std::size_t repeats, std::size_t axis)
  968. {
  969. const auto casted_axis = static_cast<typename std::decay_t<E>::size_type>(axis);
  970. std::vector<std::size_t> broadcasted_repeats(e.shape(casted_axis));
  971. std::fill(broadcasted_repeats.begin(), broadcasted_repeats.end(), repeats);
  972. return repeat(std::forward<E>(e), std::move(broadcasted_repeats), axis);
  973. }
  974. /**
  975. * Repeat elements of an expression along a given axis.
  976. *
  977. * @ingroup xt_xmanipulation
  978. * @param e the input xexpression
  979. * @param repeats The number of repetition of each elements.
  980. * The size of @p repeats must match the shape of the given @p axis.
  981. * @param axis the axis along which to repeat the value
  982. *
  983. * @return an expression which as the same shape as \ref e, except along the given \ref axis
  984. */
  985. template <class E>
  986. inline auto repeat(E&& e, const std::vector<std::size_t>& repeats, std::size_t axis)
  987. {
  988. return detail::make_xrepeat(std::forward<E>(e), repeats, axis);
  989. }
  990. /**
  991. * Repeat elements of an expression along a given axis.
  992. *
  993. * @ingroup xt_xmanipulation
  994. * @param e the input xexpression
  995. * @param repeats The number of repetition of each elements.
  996. * The size of @p repeats must match the shape of the given @p axis.
  997. * @param axis the axis along which to repeat the value
  998. * @return an expression which as the same shape as \ref e, except along the given \ref axis
  999. */
  1000. template <class E>
  1001. inline auto repeat(E&& e, std::vector<std::size_t>&& repeats, std::size_t axis)
  1002. {
  1003. return detail::make_xrepeat(std::forward<E>(e), std::move(repeats), axis);
  1004. }
  1005. }
  1006. #endif