xbuilder.hpp 43 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261
  1. /***************************************************************************
  2. * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
  3. * Copyright (c) QuantStack *
  4. * *
  5. * Distributed under the terms of the BSD 3-Clause License. *
  6. * *
  7. * The full license is in the file LICENSE, distributed with this software. *
  8. ****************************************************************************/
  9. /**
  10. * @brief standard mathematical functions for xexpressions
  11. */
  12. #ifndef XTENSOR_BUILDER_HPP
  13. #define XTENSOR_BUILDER_HPP
  14. #include <array>
  15. #include <chrono>
  16. #include <cmath>
  17. #include <cstddef>
  18. #include <functional>
  19. #include <utility>
  20. #include <vector>
  21. #include <xtl/xclosure.hpp>
  22. #include <xtl/xsequence.hpp>
  23. #include <xtl/xtype_traits.hpp>
  24. #include "xbroadcast.hpp"
  25. #include "xfunction.hpp"
  26. #include "xgenerator.hpp"
  27. #include "xoperation.hpp"
  28. namespace xt
  29. {
  30. /********
  31. * ones *
  32. ********/
  33. /**
  34. * Returns an \ref xexpression containing ones of the specified shape.
  35. * @tparam shape the shape of the returned expression.
  36. */
  37. template <class T, class S>
  38. inline auto ones(S shape) noexcept
  39. {
  40. return broadcast(T(1), std::forward<S>(shape));
  41. }
  42. template <class T, class I, std::size_t L>
  43. inline auto ones(const I (&shape)[L]) noexcept
  44. {
  45. return broadcast(T(1), shape);
  46. }
  47. /*********
  48. * zeros *
  49. *********/
  50. /**
  51. * Returns an \ref xexpression containing zeros of the specified shape.
  52. * @tparam shape the shape of the returned expression.
  53. */
  54. template <class T, class S>
  55. inline auto zeros(S shape) noexcept
  56. {
  57. return broadcast(T(0), std::forward<S>(shape));
  58. }
  59. template <class T, class I, std::size_t L>
  60. inline auto zeros(const I (&shape)[L]) noexcept
  61. {
  62. return broadcast(T(0), shape);
  63. }
  64. /**
  65. * Create a xcontainer (xarray, xtensor or xtensor_fixed) with uninitialized values of
  66. * with value_type T and shape. Selects the best container match automatically
  67. * from the supplied shape.
  68. *
  69. * - ``std::vector`` → ``xarray<T>``
  70. * - ``std::array`` or ``initializer_list`` → ``xtensor<T, N>``
  71. * - ``xshape<N...>`` → ``xtensor_fixed<T, xshape<N...>>``
  72. *
  73. * @param shape shape of the new xcontainer
  74. */
  75. template <class T, layout_type L = XTENSOR_DEFAULT_LAYOUT, class S>
  76. inline xarray<T, L> empty(const S& shape)
  77. {
  78. return xarray<T, L>::from_shape(shape);
  79. }
  80. template <class T, layout_type L = XTENSOR_DEFAULT_LAYOUT, class ST, std::size_t N>
  81. inline xtensor<T, N, L> empty(const std::array<ST, N>& shape)
  82. {
  83. using shape_type = typename xtensor<T, N>::shape_type;
  84. return xtensor<T, N, L>(xtl::forward_sequence<shape_type, decltype(shape)>(shape));
  85. }
  86. template <class T, layout_type L = XTENSOR_DEFAULT_LAYOUT, class I, std::size_t N>
  87. inline xtensor<T, N, L> empty(const I (&shape)[N])
  88. {
  89. using shape_type = typename xtensor<T, N>::shape_type;
  90. return xtensor<T, N, L>(xtl::forward_sequence<shape_type, decltype(shape)>(shape));
  91. }
  92. template <class T, layout_type L = XTENSOR_DEFAULT_LAYOUT, std::size_t... N>
  93. inline xtensor_fixed<T, fixed_shape<N...>, L> empty(const fixed_shape<N...>& /*shape*/)
  94. {
  95. return xtensor_fixed<T, fixed_shape<N...>, L>();
  96. }
  97. /**
  98. * Create a xcontainer (xarray, xtensor or xtensor_fixed) with uninitialized values of
  99. * the same shape, value type and layout as the input xexpression *e*.
  100. *
  101. * @param e the xexpression from which to extract shape, value type and layout.
  102. */
  103. template <class E>
  104. inline auto empty_like(const xexpression<E>& e)
  105. {
  106. using xtype = temporary_type_t<E>;
  107. auto res = xtype::from_shape(e.derived_cast().shape());
  108. return res;
  109. }
  110. /**
  111. * Create a xcontainer (xarray, xtensor or xtensor_fixed), filled with *fill_value* and of
  112. * the same shape, value type and layout as the input xexpression *e*.
  113. *
  114. * @param e the xexpression from which to extract shape, value type and layout.
  115. * @param fill_value the value used to set each element of the returned xcontainer.
  116. */
  117. template <class E>
  118. inline auto full_like(const xexpression<E>& e, typename E::value_type fill_value)
  119. {
  120. using xtype = temporary_type_t<E>;
  121. auto res = xtype::from_shape(e.derived_cast().shape());
  122. res.fill(fill_value);
  123. return res;
  124. }
  125. /**
  126. * Create a xcontainer (xarray, xtensor or xtensor_fixed), filled with zeros and of
  127. * the same shape, value type and layout as the input xexpression *e*.
  128. *
  129. * Note: contrary to zeros(shape), this function returns a non-lazy, allocated container!
  130. * Use ``xt::zeros<double>(e.shape());` for a lazy version.
  131. *
  132. * @param e the xexpression from which to extract shape, value type and layout.
  133. */
  134. template <class E>
  135. inline auto zeros_like(const xexpression<E>& e)
  136. {
  137. return full_like(e, typename E::value_type(0));
  138. }
  139. /**
  140. * Create a xcontainer (xarray, xtensor or xtensor_fixed), filled with ones and of
  141. * the same shape, value type and layout as the input xexpression *e*.
  142. *
  143. * Note: contrary to ones(shape), this function returns a non-lazy, evaluated container!
  144. * Use ``xt::ones<double>(e.shape());`` for a lazy version.
  145. *
  146. * @param e the xexpression from which to extract shape, value type and layout.
  147. */
  148. template <class E>
  149. inline auto ones_like(const xexpression<E>& e)
  150. {
  151. return full_like(e, typename E::value_type(1));
  152. }
  153. namespace detail
  154. {
  155. template <class T, class S>
  156. struct get_mult_type_impl
  157. {
  158. using type = T;
  159. };
  160. template <class T, class R, class P>
  161. struct get_mult_type_impl<T, std::chrono::duration<R, P>>
  162. {
  163. using type = R;
  164. };
  165. template <class T, class S>
  166. using get_mult_type = typename get_mult_type_impl<T, S>::type;
  167. // These methods should be private methods of arange_generator, however thi leads
  168. // to ICE on VS2015
  169. template <class R, class E, class U, class X, XTL_REQUIRES(xtl::is_integral<X>)>
  170. inline void arange_assign_to(xexpression<E>& e, U start, U, X step, bool) noexcept
  171. {
  172. auto& de = e.derived_cast();
  173. U value = start;
  174. for (auto&& el : de.storage())
  175. {
  176. el = static_cast<R>(value);
  177. value += step;
  178. }
  179. }
  180. template <class R, class E, class U, class X, XTL_REQUIRES(xtl::negation<xtl::is_integral<X>>)>
  181. inline void arange_assign_to(xexpression<E>& e, U start, U stop, X step, bool endpoint) noexcept
  182. {
  183. auto& buf = e.derived_cast().storage();
  184. using size_type = decltype(buf.size());
  185. using mult_type = get_mult_type<U, X>;
  186. size_type num = buf.size();
  187. for (size_type i = 0; i < num; ++i)
  188. {
  189. buf[i] = static_cast<R>(start + step * mult_type(i));
  190. }
  191. if (endpoint && num > 1)
  192. {
  193. buf[num - 1] = static_cast<R>(stop);
  194. }
  195. }
  196. template <class T, class R = T, class S = T>
  197. class arange_generator
  198. {
  199. public:
  200. using value_type = R;
  201. using step_type = S;
  202. arange_generator(T start, T stop, S step, size_t num_steps, bool endpoint = false)
  203. : m_start(start)
  204. , m_stop(stop)
  205. , m_step(step)
  206. , m_num_steps(num_steps)
  207. , m_endpoint(endpoint)
  208. {
  209. }
  210. template <class... Args>
  211. inline R operator()(Args... args) const
  212. {
  213. return access_impl(args...);
  214. }
  215. template <class It>
  216. inline R element(It first, It) const
  217. {
  218. return access_impl(*first);
  219. }
  220. template <class E>
  221. inline void assign_to(xexpression<E>& e) const noexcept
  222. {
  223. arange_assign_to<R>(e, m_start, m_stop, m_step, m_endpoint);
  224. }
  225. private:
  226. T m_start;
  227. T m_stop;
  228. step_type m_step;
  229. size_t m_num_steps;
  230. bool m_endpoint; // true for setting the last element to m_stop
  231. template <class T1, class... Args>
  232. inline R access_impl(T1 t, Args...) const
  233. {
  234. if (m_endpoint && m_num_steps > 1 && size_t(t) == m_num_steps - 1)
  235. {
  236. return static_cast<R>(m_stop);
  237. }
  238. // Avoids warning when T = char (because char + char => int!)
  239. using mult_type = get_mult_type<T, S>;
  240. return static_cast<R>(m_start + m_step * mult_type(t));
  241. }
  242. inline R access_impl() const
  243. {
  244. return static_cast<R>(m_start);
  245. }
  246. };
  247. template <class T, class S>
  248. using both_integer = xtl::conjunction<xtl::is_integral<T>, xtl::is_integral<S>>;
  249. template <class T, class S>
  250. using integer_with_signed_integer = xtl::conjunction<both_integer<T, S>, xtl::is_signed<S>>;
  251. template <class T, class S>
  252. using integer_with_unsigned_integer = xtl::conjunction<both_integer<T, S>, std::is_unsigned<S>>;
  253. template <class T, class S = T, XTL_REQUIRES(xtl::negation<both_integer<T, S>>)>
  254. inline auto arange_impl(T start, T stop, S step = 1) noexcept
  255. {
  256. std::size_t shape = static_cast<std::size_t>(std::ceil((stop - start) / step));
  257. return detail::make_xgenerator(detail::arange_generator<T, T, S>(start, stop, step, shape), {shape});
  258. }
  259. template <class T, class S = T, XTL_REQUIRES(integer_with_signed_integer<T, S>)>
  260. inline auto arange_impl(T start, T stop, S step = 1) noexcept
  261. {
  262. bool empty_cond = (stop - start) / step <= 0;
  263. std::size_t shape = 0;
  264. if (!empty_cond)
  265. {
  266. shape = stop > start ? static_cast<std::size_t>((stop - start + step - S(1)) / step)
  267. : static_cast<std::size_t>((start - stop - step - S(1)) / -step);
  268. }
  269. return detail::make_xgenerator(detail::arange_generator<T, T, S>(start, stop, step, shape), {shape});
  270. }
  271. template <class T, class S = T, XTL_REQUIRES(integer_with_unsigned_integer<T, S>)>
  272. inline auto arange_impl(T start, T stop, S step = 1) noexcept
  273. {
  274. bool empty_cond = stop <= start;
  275. std::size_t shape = 0;
  276. if (!empty_cond)
  277. {
  278. shape = static_cast<std::size_t>((stop - start + step - S(1)) / step);
  279. }
  280. return detail::make_xgenerator(detail::arange_generator<T, T, S>(start, stop, step, shape), {shape});
  281. }
  282. template <class F>
  283. class fn_impl
  284. {
  285. public:
  286. using value_type = typename F::value_type;
  287. using size_type = std::size_t;
  288. fn_impl(F&& f)
  289. : m_ft(f)
  290. {
  291. }
  292. inline value_type operator()() const
  293. {
  294. size_type idx[1] = {0ul};
  295. return access_impl(std::begin(idx), std::end(idx));
  296. }
  297. template <class... Args>
  298. inline value_type operator()(Args... args) const
  299. {
  300. size_type idx[sizeof...(Args)] = {static_cast<size_type>(args)...};
  301. return access_impl(std::begin(idx), std::end(idx));
  302. }
  303. template <class It>
  304. inline value_type element(It first, It last) const
  305. {
  306. return access_impl(first, last);
  307. }
  308. private:
  309. F m_ft;
  310. template <class It>
  311. inline value_type access_impl(const It& begin, const It& end) const
  312. {
  313. return m_ft(begin, end);
  314. }
  315. };
  316. template <class T>
  317. class eye_fn
  318. {
  319. public:
  320. using value_type = T;
  321. eye_fn(int k)
  322. : m_k(k)
  323. {
  324. }
  325. template <class It>
  326. inline T operator()(const It& /*begin*/, const It& end) const
  327. {
  328. using lvalue_type = typename std::iterator_traits<It>::value_type;
  329. return *(end - 1) == *(end - 2) + static_cast<lvalue_type>(m_k) ? T(1) : T(0);
  330. }
  331. private:
  332. std::ptrdiff_t m_k;
  333. };
  334. }
  335. /**
  336. * Generates an array with ones on the diagonal.
  337. * @param shape shape of the resulting expression
  338. * @param k index of the diagonal. 0 (default) refers to the main diagonal,
  339. * a positive value refers to an upper diagonal, and a negative
  340. * value to a lower diagonal.
  341. * @tparam T value_type of xexpression
  342. * @return xgenerator that generates the values on access
  343. */
  344. template <class T = bool>
  345. inline auto eye(const std::vector<std::size_t>& shape, int k = 0)
  346. {
  347. return detail::make_xgenerator(detail::fn_impl<detail::eye_fn<T>>(detail::eye_fn<T>(k)), shape);
  348. }
  349. /**
  350. * Generates a (n x n) array with ones on the diagonal.
  351. * @param n length of the diagonal.
  352. * @param k index of the diagonal. 0 (default) refers to the main diagonal,
  353. * a positive value refers to an upper diagonal, and a negative
  354. * value to a lower diagonal.
  355. * @tparam T value_type of xexpression
  356. * @return xgenerator that generates the values on access
  357. */
  358. template <class T = bool>
  359. inline auto eye(std::size_t n, int k = 0)
  360. {
  361. return eye<T>({n, n}, k);
  362. }
  363. /**
  364. * Generates numbers evenly spaced within given half-open interval [start, stop).
  365. * @param start start of the interval
  366. * @param stop stop of the interval
  367. * @param step stepsize
  368. * @tparam T value_type of xexpression
  369. * @return xgenerator that generates the values on access
  370. */
  371. template <class T, class S = T>
  372. inline auto arange(T start, T stop, S step = 1) noexcept
  373. {
  374. return detail::arange_impl(start, stop, step);
  375. }
  376. /**
  377. * Generate numbers evenly spaced within given half-open interval [0, stop)
  378. * with a step size of 1.
  379. * @param stop stop of the interval
  380. * @tparam T value_type of xexpression
  381. * @return xgenerator that generates the values on access
  382. */
  383. template <class T>
  384. inline auto arange(T stop) noexcept
  385. {
  386. return arange<T>(T(0), stop, T(1));
  387. }
  388. /**
  389. * Generates @a num_samples evenly spaced numbers over given interval
  390. * @param start start of interval
  391. * @param stop stop of interval
  392. * @param num_samples number of samples (defaults to 50)
  393. * @param endpoint if true, include endpoint (defaults to true)
  394. * @tparam T value_type of xexpression
  395. * @return xgenerator that generates the values on access
  396. */
  397. template <class T>
  398. inline auto linspace(T start, T stop, std::size_t num_samples = 50, bool endpoint = true) noexcept
  399. {
  400. using fp_type = std::common_type_t<T, double>;
  401. fp_type step = fp_type(stop - start) / std::fmax(fp_type(1), fp_type(num_samples - (endpoint ? 1 : 0)));
  402. return detail::make_xgenerator(
  403. detail::arange_generator<fp_type, T>(fp_type(start), fp_type(stop), step, num_samples, endpoint),
  404. {num_samples}
  405. );
  406. }
  407. /**
  408. * Generates @a num_samples numbers evenly spaced on a log scale over given interval
  409. * @param start start of interval (pow(base, start) is the first value).
  410. * @param stop stop of interval (pow(base, stop) is the final value, except if endpoint = false)
  411. * @param num_samples number of samples (defaults to 50)
  412. * @param base the base of the log space.
  413. * @param endpoint if true, include endpoint (defaults to true)
  414. * @tparam T value_type of xexpression
  415. * @return xgenerator that generates the values on access
  416. */
  417. template <class T>
  418. inline auto logspace(T start, T stop, std::size_t num_samples, T base = 10, bool endpoint = true) noexcept
  419. {
  420. return pow(std::move(base), linspace(start, stop, num_samples, endpoint));
  421. }
  422. namespace detail
  423. {
  424. template <class... CT>
  425. class concatenate_access
  426. {
  427. public:
  428. using tuple_type = std::tuple<CT...>;
  429. using size_type = std::size_t;
  430. using value_type = xtl::promote_type_t<typename std::decay_t<CT>::value_type...>;
  431. template <class It>
  432. inline value_type access(const tuple_type& t, size_type axis, It first, It last) const
  433. {
  434. // trim off extra indices if provided to match behavior of containers
  435. auto dim_offset = std::distance(first, last) - std::get<0>(t).dimension();
  436. size_t axis_dim = *(first + axis + dim_offset);
  437. auto match = [&](auto& arr)
  438. {
  439. if (axis_dim >= arr.shape()[axis])
  440. {
  441. axis_dim -= arr.shape()[axis];
  442. return false;
  443. }
  444. return true;
  445. };
  446. auto get = [&](auto& arr)
  447. {
  448. size_t offset = 0;
  449. const size_t end = arr.dimension();
  450. for (size_t i = 0; i < end; i++)
  451. {
  452. const auto& shape = arr.shape();
  453. const size_t stride = std::accumulate(
  454. shape.begin() + i + 1,
  455. shape.end(),
  456. 1,
  457. std::multiplies<size_t>()
  458. );
  459. if (i == axis)
  460. {
  461. offset += axis_dim * stride;
  462. }
  463. else
  464. {
  465. const auto len = (*(first + i + dim_offset));
  466. offset += len * stride;
  467. }
  468. }
  469. const auto element = arr.begin() + offset;
  470. return *element;
  471. };
  472. size_type i = 0;
  473. for (; i < sizeof...(CT); ++i)
  474. {
  475. if (apply<bool>(i, match, t))
  476. {
  477. break;
  478. }
  479. }
  480. return apply<value_type>(i, get, t);
  481. }
  482. };
  483. template <class... CT>
  484. class stack_access
  485. {
  486. public:
  487. using tuple_type = std::tuple<CT...>;
  488. using size_type = std::size_t;
  489. using value_type = xtl::promote_type_t<typename std::decay_t<CT>::value_type...>;
  490. template <class It>
  491. inline value_type access(const tuple_type& t, size_type axis, It first, It) const
  492. {
  493. auto get_item = [&](auto& arr)
  494. {
  495. size_t offset = 0;
  496. const size_t end = arr.dimension();
  497. size_t after_axis = 0;
  498. for (size_t i = 0; i < end; i++)
  499. {
  500. if (i == axis)
  501. {
  502. after_axis = 1;
  503. }
  504. const auto& shape = arr.shape();
  505. const size_t stride = std::accumulate(
  506. shape.begin() + i + 1,
  507. shape.end(),
  508. 1,
  509. std::multiplies<size_t>()
  510. );
  511. const auto len = (*(first + i + after_axis));
  512. offset += len * stride;
  513. }
  514. const auto element = arr.begin() + offset;
  515. return *element;
  516. };
  517. size_type i = *(first + axis);
  518. return apply<value_type>(i, get_item, t);
  519. }
  520. };
  521. template <class... CT>
  522. class vstack_access
  523. {
  524. public:
  525. using tuple_type = std::tuple<CT...>;
  526. using size_type = std::size_t;
  527. using value_type = xtl::promote_type_t<typename std::decay_t<CT>::value_type...>;
  528. template <class It>
  529. inline value_type access(const tuple_type& t, size_type axis, It first, It last) const
  530. {
  531. if (std::get<0>(t).dimension() == 1)
  532. {
  533. return stack.access(t, axis, first, last);
  534. }
  535. else
  536. {
  537. return concatonate.access(t, axis, first, last);
  538. }
  539. }
  540. private:
  541. concatenate_access<CT...> concatonate;
  542. stack_access<CT...> stack;
  543. };
  544. template <template <class...> class F, class... CT>
  545. class concatenate_invoker
  546. {
  547. public:
  548. using tuple_type = std::tuple<CT...>;
  549. using size_type = std::size_t;
  550. using value_type = xtl::promote_type_t<typename std::decay_t<CT>::value_type...>;
  551. inline concatenate_invoker(tuple_type&& t, size_type axis)
  552. : m_t(std::move(t))
  553. , m_axis(axis)
  554. {
  555. }
  556. template <class... Args>
  557. inline value_type operator()(Args... args) const
  558. {
  559. // TODO: avoid memory allocation
  560. xindex index({static_cast<size_type>(args)...});
  561. return access_method.access(m_t, m_axis, index.begin(), index.end());
  562. }
  563. template <class It>
  564. inline value_type element(It first, It last) const
  565. {
  566. return access_method.access(m_t, m_axis, first, last);
  567. }
  568. private:
  569. F<CT...> access_method;
  570. tuple_type m_t;
  571. size_type m_axis;
  572. };
  573. template <class... CT>
  574. using concatenate_impl = concatenate_invoker<concatenate_access, CT...>;
  575. template <class... CT>
  576. using stack_impl = concatenate_invoker<stack_access, CT...>;
  577. template <class... CT>
  578. using vstack_impl = concatenate_invoker<vstack_access, CT...>;
  579. template <class CT>
  580. class repeat_impl
  581. {
  582. public:
  583. using xexpression_type = std::decay_t<CT>;
  584. using size_type = typename xexpression_type::size_type;
  585. using value_type = typename xexpression_type::value_type;
  586. template <class CTA>
  587. repeat_impl(CTA&& source, size_type axis)
  588. : m_source(std::forward<CTA>(source))
  589. , m_axis(axis)
  590. {
  591. }
  592. template <class... Args>
  593. value_type operator()(Args... args) const
  594. {
  595. std::array<size_type, sizeof...(Args)> args_arr = {static_cast<size_type>(args)...};
  596. return m_source(args_arr[m_axis]);
  597. }
  598. template <class It>
  599. inline value_type element(It first, It) const
  600. {
  601. return m_source(*(first + static_cast<std::ptrdiff_t>(m_axis)));
  602. }
  603. private:
  604. CT m_source;
  605. size_type m_axis;
  606. };
  607. }
  608. /**
  609. * @brief Creates tuples from arguments for \ref concatenate and \ref stack.
  610. * Very similar to std::make_tuple.
  611. */
  612. template <class... Types>
  613. inline auto xtuple(Types&&... args)
  614. {
  615. return std::tuple<xtl::const_closure_type_t<Types>...>(std::forward<Types>(args)...);
  616. }
  617. namespace detail
  618. {
  619. template <bool... values>
  620. using all_true = xtl::conjunction<std::integral_constant<bool, values>...>;
  621. template <class X, class Y, std::size_t axis, class AxesSequence>
  622. struct concat_fixed_shape_impl;
  623. template <class X, class Y, std::size_t axis, std::size_t... Is>
  624. struct concat_fixed_shape_impl<X, Y, axis, std::index_sequence<Is...>>
  625. {
  626. static_assert(X::size() == Y::size(), "Concatenation requires equisized shapes");
  627. static_assert(axis < X::size(), "Concatenation requires a valid axis");
  628. static_assert(
  629. all_true<(axis == Is || X::template get<Is>() == Y::template get<Is>())...>::value,
  630. "Concatenation requires compatible shapes and axis"
  631. );
  632. using type = fixed_shape<
  633. (axis == Is ? X::template get<Is>() + Y::template get<Is>() : X::template get<Is>())...>;
  634. };
  635. template <std::size_t axis, class X, class Y, class... Rest>
  636. struct concat_fixed_shape;
  637. template <std::size_t axis, class X, class Y>
  638. struct concat_fixed_shape<axis, X, Y>
  639. {
  640. using type = typename concat_fixed_shape_impl<X, Y, axis, std::make_index_sequence<X::size()>>::type;
  641. };
  642. template <std::size_t axis, class X, class Y, class... Rest>
  643. struct concat_fixed_shape
  644. {
  645. using type = typename concat_fixed_shape<axis, X, typename concat_fixed_shape<axis, Y, Rest...>::type>::type;
  646. };
  647. template <std::size_t axis, class... Args>
  648. using concat_fixed_shape_t = typename concat_fixed_shape<axis, Args...>::type;
  649. template <class... CT>
  650. using all_fixed_shapes = detail::all_fixed<typename std::decay_t<CT>::shape_type...>;
  651. struct concat_shape_builder_t
  652. {
  653. template <class Shape, bool = detail::is_fixed<Shape>::value>
  654. struct concat_shape;
  655. template <class Shape>
  656. struct concat_shape<Shape, true>
  657. {
  658. // Convert `fixed_shape` to `static_shape` to allow runtime dimension calculation.
  659. using type = static_shape<typename Shape::value_type, Shape::size()>;
  660. };
  661. template <class Shape>
  662. struct concat_shape<Shape, false>
  663. {
  664. using type = Shape;
  665. };
  666. template <class... Args>
  667. static auto build(const std::tuple<Args...>& t, std::size_t axis)
  668. {
  669. using shape_type = promote_shape_t<
  670. typename concat_shape<typename std::decay_t<Args>::shape_type>::type...>;
  671. using source_shape_type = decltype(std::get<0>(t).shape());
  672. shape_type new_shape = xtl::forward_sequence<shape_type, source_shape_type>(
  673. std::get<0>(t).shape()
  674. );
  675. auto check_shape = [&axis, &new_shape](auto& arr)
  676. {
  677. std::size_t s = new_shape.size();
  678. bool res = s == arr.dimension();
  679. for (std::size_t i = 0; i < s; ++i)
  680. {
  681. res = res && (i == axis || new_shape[i] == arr.shape(i));
  682. }
  683. if (!res)
  684. {
  685. throw_concatenate_error(new_shape, arr.shape());
  686. }
  687. };
  688. for_each(check_shape, t);
  689. auto shape_at_axis = [&axis](std::size_t prev, auto& arr) -> std::size_t
  690. {
  691. return prev + arr.shape()[axis];
  692. };
  693. new_shape[axis] += accumulate(shape_at_axis, std::size_t(0), t) - new_shape[axis];
  694. return new_shape;
  695. }
  696. };
  697. } // namespace detail
  698. /***************
  699. * concatenate *
  700. ***************/
  701. /**
  702. * @brief Concatenates xexpressions along \em axis.
  703. *
  704. * @param t \ref xtuple of xexpressions to concatenate
  705. * @param axis axis along which elements are concatenated
  706. * @returns xgenerator evaluating to concatenated elements
  707. *
  708. * @code{.cpp}
  709. * xt::xarray<double> a = {{1, 2, 3}};
  710. * xt::xarray<double> b = {{2, 3, 4}};
  711. * xt::xarray<double> c = xt::concatenate(xt::xtuple(a, b)); // => {{1, 2, 3},
  712. * // {2, 3, 4}}
  713. * xt::xarray<double> d = xt::concatenate(xt::xtuple(a, b), 1); // => {{1, 2, 3, 2, 3, 4}}
  714. * @endcode
  715. */
  716. template <class... CT>
  717. inline auto concatenate(std::tuple<CT...>&& t, std::size_t axis = 0)
  718. {
  719. const auto shape = detail::concat_shape_builder_t::build(t, axis);
  720. return detail::make_xgenerator(detail::concatenate_impl<CT...>(std::move(t), axis), shape);
  721. }
  722. template <std::size_t axis, class... CT, typename = std::enable_if_t<detail::all_fixed_shapes<CT...>::value>>
  723. inline auto concatenate(std::tuple<CT...>&& t)
  724. {
  725. using shape_type = detail::concat_fixed_shape_t<axis, typename std::decay_t<CT>::shape_type...>;
  726. return detail::make_xgenerator(detail::concatenate_impl<CT...>(std::move(t), axis), shape_type{});
  727. }
  728. namespace detail
  729. {
  730. template <class T, std::size_t N>
  731. inline std::array<T, N + 1> add_axis(std::array<T, N> arr, std::size_t axis, std::size_t value)
  732. {
  733. std::array<T, N + 1> temp;
  734. std::copy(arr.begin(), arr.begin() + axis, temp.begin());
  735. temp[axis] = value;
  736. std::copy(arr.begin() + axis, arr.end(), temp.begin() + axis + 1);
  737. return temp;
  738. }
  739. template <class T>
  740. inline T add_axis(T arr, std::size_t axis, std::size_t value)
  741. {
  742. T temp(arr);
  743. temp.insert(temp.begin() + std::ptrdiff_t(axis), value);
  744. return temp;
  745. }
  746. }
  747. /**
  748. * @brief Stack xexpressions along \em axis.
  749. * Stacking always creates a new dimension along which elements are stacked.
  750. *
  751. * @param t \ref xtuple of xexpressions to concatenate
  752. * @param axis axis along which elements are stacked
  753. * @returns xgenerator evaluating to stacked elements
  754. *
  755. * @code{.cpp}
  756. * xt::xarray<double> a = {1, 2, 3};
  757. * xt::xarray<double> b = {5, 6, 7};
  758. * xt::xarray<double> s = xt::stack(xt::xtuple(a, b)); // => {{1, 2, 3},
  759. * // {5, 6, 7}}
  760. * xt::xarray<double> t = xt::stack(xt::xtuple(a, b), 1); // => {{1, 5},
  761. * // {2, 6},
  762. * // {3, 7}}
  763. * @endcode
  764. */
  765. template <class... CT>
  766. inline auto stack(std::tuple<CT...>&& t, std::size_t axis = 0)
  767. {
  768. using shape_type = promote_shape_t<typename std::decay_t<CT>::shape_type...>;
  769. using source_shape_type = decltype(std::get<0>(t).shape());
  770. auto new_shape = detail::add_axis(
  771. xtl::forward_sequence<shape_type, source_shape_type>(std::get<0>(t).shape()),
  772. axis,
  773. sizeof...(CT)
  774. );
  775. return detail::make_xgenerator(detail::stack_impl<CT...>(std::move(t), axis), new_shape);
  776. }
  777. /**
  778. * @brief Stack xexpressions in sequence horizontally (column wise).
  779. * This is equivalent to concatenation along the second axis, except for 1-D
  780. * xexpressions where it concatenate along the first axis.
  781. *
  782. * @param t \ref xtuple of xexpressions to stack
  783. * @return xgenerator evaluating to stacked elements
  784. */
  785. template <class... CT>
  786. inline auto hstack(std::tuple<CT...>&& t)
  787. {
  788. auto dim = std::get<0>(t).dimension();
  789. std::size_t axis = dim > std::size_t(1) ? 1 : 0;
  790. return concatenate(std::move(t), axis);
  791. }
  792. namespace detail
  793. {
  794. template <class S, class... CT>
  795. inline auto vstack_shape(std::tuple<CT...>& t, const S& shape)
  796. {
  797. using size_type = typename S::value_type;
  798. auto res = shape.size() == size_type(1)
  799. ? S({sizeof...(CT), shape[0]})
  800. : concat_shape_builder_t::build(std::move(t), size_type(0));
  801. return res;
  802. }
  803. template <class T, class... CT>
  804. inline auto vstack_shape(const std::tuple<CT...>&, std::array<T, 1> shape)
  805. {
  806. std::array<T, 2> res = {sizeof...(CT), shape[0]};
  807. return res;
  808. }
  809. }
  810. /**
  811. * @brief Stack xexpressions in sequence vertically (row wise).
  812. * This is equivalent to concatenation along the first axis after
  813. * 1-D arrays of shape (N) have been reshape to (1, N).
  814. *
  815. * @param t \ref xtuple of xexpressions to stack
  816. * @return xgenerator evaluating to stacked elements
  817. */
  818. template <class... CT>
  819. inline auto vstack(std::tuple<CT...>&& t)
  820. {
  821. using shape_type = promote_shape_t<typename std::decay_t<CT>::shape_type...>;
  822. using source_shape_type = decltype(std::get<0>(t).shape());
  823. auto new_shape = detail::vstack_shape(
  824. t,
  825. xtl::forward_sequence<shape_type, source_shape_type>(std::get<0>(t).shape())
  826. );
  827. return detail::make_xgenerator(detail::vstack_impl<CT...>(std::move(t), size_t(0)), new_shape);
  828. }
  829. namespace detail
  830. {
  831. template <std::size_t... I, class... E>
  832. inline auto meshgrid_impl(std::index_sequence<I...>, E&&... e) noexcept
  833. {
  834. #if defined _MSC_VER
  835. const std::array<std::size_t, sizeof...(E)> shape = {e.shape()[0]...};
  836. return std::make_tuple(
  837. detail::make_xgenerator(detail::repeat_impl<xclosure_t<E>>(std::forward<E>(e), I), shape)...
  838. );
  839. #else
  840. return std::make_tuple(detail::make_xgenerator(
  841. detail::repeat_impl<xclosure_t<E>>(std::forward<E>(e), I),
  842. {e.shape()[0]...}
  843. )...);
  844. #endif
  845. }
  846. }
  847. /**
  848. * @brief Return coordinate tensors from coordinate vectors.
  849. * Make N-D coordinate tensor expressions for vectorized evaluations of N-D scalar/vector
  850. * fields over N-D grids, given one-dimensional coordinate arrays x1, x2,..., xn.
  851. *
  852. * @param e xexpressions to concatenate
  853. * @returns tuple of xgenerator expressions.
  854. */
  855. template <class... E>
  856. inline auto meshgrid(E&&... e) noexcept
  857. {
  858. return detail::meshgrid_impl(std::make_index_sequence<sizeof...(E)>(), std::forward<E>(e)...);
  859. }
  860. namespace detail
  861. {
  862. template <class CT>
  863. class diagonal_fn
  864. {
  865. public:
  866. using xexpression_type = std::decay_t<CT>;
  867. using value_type = typename xexpression_type::value_type;
  868. template <class CTA>
  869. diagonal_fn(CTA&& source, int offset, std::size_t axis_1, std::size_t axis_2)
  870. : m_source(std::forward<CTA>(source))
  871. , m_offset(offset)
  872. , m_axis_1(axis_1)
  873. , m_axis_2(axis_2)
  874. {
  875. }
  876. template <class It>
  877. inline value_type operator()(It begin, It) const
  878. {
  879. xindex idx(m_source.shape().size());
  880. for (std::size_t i = 0; i < idx.size(); i++)
  881. {
  882. if (i != m_axis_1 && i != m_axis_2)
  883. {
  884. idx[i] = static_cast<std::size_t>(*begin++);
  885. }
  886. }
  887. using it_vtype = typename std::iterator_traits<It>::value_type;
  888. it_vtype uoffset = static_cast<it_vtype>(m_offset);
  889. if (m_offset >= 0)
  890. {
  891. idx[m_axis_1] = static_cast<std::size_t>(*(begin));
  892. idx[m_axis_2] = static_cast<std::size_t>(*(begin) + uoffset);
  893. }
  894. else
  895. {
  896. idx[m_axis_1] = static_cast<std::size_t>(*(begin) -uoffset);
  897. idx[m_axis_2] = static_cast<std::size_t>(*(begin));
  898. }
  899. return m_source[idx];
  900. }
  901. private:
  902. CT m_source;
  903. const int m_offset;
  904. const std::size_t m_axis_1;
  905. const std::size_t m_axis_2;
  906. };
  907. template <class CT>
  908. class diag_fn
  909. {
  910. public:
  911. using xexpression_type = std::decay_t<CT>;
  912. using value_type = typename xexpression_type::value_type;
  913. template <class CTA>
  914. diag_fn(CTA&& source, int k)
  915. : m_source(std::forward<CTA>(source))
  916. , m_k(k)
  917. {
  918. }
  919. template <class It>
  920. inline value_type operator()(It begin, It) const
  921. {
  922. using it_vtype = typename std::iterator_traits<It>::value_type;
  923. it_vtype umk = static_cast<it_vtype>(m_k);
  924. if (m_k > 0)
  925. {
  926. return *begin + umk == *(begin + 1) ? m_source(*begin) : value_type(0);
  927. }
  928. else
  929. {
  930. return *begin + umk == *(begin + 1) ? m_source(*begin + umk) : value_type(0);
  931. }
  932. }
  933. private:
  934. CT m_source;
  935. const int m_k;
  936. };
  937. template <class CT, class Comp>
  938. class trilu_fn
  939. {
  940. public:
  941. using xexpression_type = std::decay_t<CT>;
  942. using value_type = typename xexpression_type::value_type;
  943. using signed_idx_type = long int;
  944. template <class CTA>
  945. trilu_fn(CTA&& source, int k, Comp comp)
  946. : m_source(std::forward<CTA>(source))
  947. , m_k(k)
  948. , m_comp(comp)
  949. {
  950. }
  951. template <class It>
  952. inline value_type operator()(It begin, It end) const
  953. {
  954. // have to cast to signed int otherwise -1 can lead to overflow
  955. return m_comp(signed_idx_type(*begin) + m_k, signed_idx_type(*(begin + 1)))
  956. ? m_source.element(begin, end)
  957. : value_type(0);
  958. }
  959. private:
  960. CT m_source;
  961. const signed_idx_type m_k;
  962. const Comp m_comp;
  963. };
  964. }
  965. namespace detail
  966. {
  967. // meta-function returning the shape type for a diagonal
  968. template <class ST, class... S>
  969. struct diagonal_shape_type
  970. {
  971. using type = ST;
  972. };
  973. template <class I, std::size_t L>
  974. struct diagonal_shape_type<std::array<I, L>>
  975. {
  976. using type = std::array<I, L - 1>;
  977. };
  978. }
  979. /**
  980. * @brief Returns the elements on the diagonal of arr
  981. * If arr has more than two dimensions, then the axes specified by
  982. * axis_1 and axis_2 are used to determine the 2-D sub-array whose
  983. * diagonal is returned. The shape of the resulting array can be
  984. * determined by removing axis1 and axis2 and appending an index
  985. * to the right equal to the size of the resulting diagonals.
  986. *
  987. * @param arr the input array
  988. * @param offset offset of the diagonal from the main diagonal. Can
  989. * be positive or negative.
  990. * @param axis_1 Axis to be used as the first axis of the 2-D sub-arrays
  991. * from which the diagonals should be taken.
  992. * @param axis_2 Axis to be used as the second axis of the 2-D sub-arrays
  993. * from which the diagonals should be taken.
  994. * @returns xexpression with values of the diagonal
  995. *
  996. * @code{.cpp}
  997. * xt::xarray<double> a = {{1, 2, 3},
  998. * {4, 5, 6}
  999. * {7, 8, 9}};
  1000. * auto b = xt::diagonal(a); // => {1, 5, 9}
  1001. * @endcode
  1002. */
  1003. template <class E>
  1004. inline auto diagonal(E&& arr, int offset = 0, std::size_t axis_1 = 0, std::size_t axis_2 = 1)
  1005. {
  1006. using CT = xclosure_t<E>;
  1007. using shape_type = typename detail::diagonal_shape_type<typename std::decay_t<E>::shape_type>::type;
  1008. auto shape = arr.shape();
  1009. auto dimension = arr.dimension();
  1010. // The following shape calculation code is an almost verbatim adaptation of NumPy:
  1011. // https://github.com/numpy/numpy/blob/2aabeafb97bea4e1bfa29d946fbf31e1104e7ae0/numpy/core/src/multiarray/item_selection.c#L1799
  1012. auto ret_shape = xtl::make_sequence<shape_type>(dimension - 1, 0);
  1013. int dim_1 = static_cast<int>(shape[axis_1]);
  1014. int dim_2 = static_cast<int>(shape[axis_2]);
  1015. offset >= 0 ? dim_2 -= offset : dim_1 += offset;
  1016. auto diag_size = std::size_t(dim_2 < dim_1 ? dim_2 : dim_1);
  1017. std::size_t i = 0;
  1018. for (std::size_t idim = 0; idim < dimension; ++idim)
  1019. {
  1020. if (idim != axis_1 && idim != axis_2)
  1021. {
  1022. ret_shape[i++] = shape[idim];
  1023. }
  1024. }
  1025. ret_shape.back() = diag_size;
  1026. return detail::make_xgenerator(
  1027. detail::fn_impl<detail::diagonal_fn<CT>>(
  1028. detail::diagonal_fn<CT>(std::forward<E>(arr), offset, axis_1, axis_2)
  1029. ),
  1030. ret_shape
  1031. );
  1032. }
  1033. /**
  1034. * @brief xexpression with values of arr on the diagonal, zeroes otherwise
  1035. *
  1036. * @param arr the 1D input array of length n
  1037. * @param k the offset of the considered diagonal
  1038. * @returns xexpression function with shape n x n and arr on the diagonal
  1039. *
  1040. * @code{.cpp}
  1041. * xt::xarray<double> a = {1, 5, 9};
  1042. * auto b = xt::diag(a); // => {{1, 0, 0},
  1043. * // {0, 5, 0},
  1044. * // {0, 0, 9}}
  1045. * @endcode
  1046. */
  1047. template <class E>
  1048. inline auto diag(E&& arr, int k = 0)
  1049. {
  1050. using CT = xclosure_t<E>;
  1051. std::size_t sk = std::size_t(std::abs(k));
  1052. std::size_t s = arr.shape()[0] + sk;
  1053. return detail::make_xgenerator(
  1054. detail::fn_impl<detail::diag_fn<CT>>(detail::diag_fn<CT>(std::forward<E>(arr), k)),
  1055. {s, s}
  1056. );
  1057. }
  1058. /**
  1059. * @brief Extract lower triangular matrix from xexpression. The parameter k selects the
  1060. * offset of the diagonal.
  1061. *
  1062. * @param arr the input array
  1063. * @param k the diagonal above which to zero elements. 0 (default) selects the main diagonal,
  1064. * k < 0 is below the main diagonal, k > 0 above.
  1065. * @returns xexpression containing lower triangle from arr, 0 otherwise
  1066. */
  1067. template <class E>
  1068. inline auto tril(E&& arr, int k = 0)
  1069. {
  1070. using CT = xclosure_t<E>;
  1071. auto shape = arr.shape();
  1072. return detail::make_xgenerator(
  1073. detail::fn_impl<detail::trilu_fn<CT, std::greater_equal<long int>>>(
  1074. detail::trilu_fn<CT, std::greater_equal<long int>>(
  1075. std::forward<E>(arr),
  1076. k,
  1077. std::greater_equal<long int>()
  1078. )
  1079. ),
  1080. shape
  1081. );
  1082. }
  1083. /**
  1084. * @brief Extract upper triangular matrix from xexpression. The parameter k selects the
  1085. * offset of the diagonal.
  1086. *
  1087. * @param arr the input array
  1088. * @param k the diagonal below which to zero elements. 0 (default) selects the main diagonal,
  1089. * k < 0 is below the main diagonal, k > 0 above.
  1090. * @returns xexpression containing lower triangle from arr, 0 otherwise
  1091. */
  1092. template <class E>
  1093. inline auto triu(E&& arr, int k = 0)
  1094. {
  1095. using CT = xclosure_t<E>;
  1096. auto shape = arr.shape();
  1097. return detail::make_xgenerator(
  1098. detail::fn_impl<detail::trilu_fn<CT, std::less_equal<long int>>>(
  1099. detail::trilu_fn<CT, std::less_equal<long int>>(std::forward<E>(arr), k, std::less_equal<long int>())
  1100. ),
  1101. shape
  1102. );
  1103. }
  1104. }
  1105. #endif