| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261 |
- /***************************************************************************
- * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
- * Copyright (c) QuantStack *
- * *
- * Distributed under the terms of the BSD 3-Clause License. *
- * *
- * The full license is in the file LICENSE, distributed with this software. *
- ****************************************************************************/
- /**
- * @brief standard mathematical functions for xexpressions
- */
- #ifndef XTENSOR_BUILDER_HPP
- #define XTENSOR_BUILDER_HPP
- #include <array>
- #include <chrono>
- #include <cmath>
- #include <cstddef>
- #include <functional>
- #include <utility>
- #include <vector>
- #include <xtl/xclosure.hpp>
- #include <xtl/xsequence.hpp>
- #include <xtl/xtype_traits.hpp>
- #include "xbroadcast.hpp"
- #include "xfunction.hpp"
- #include "xgenerator.hpp"
- #include "xoperation.hpp"
- namespace xt
- {
- /********
- * ones *
- ********/
- /**
- * Returns an \ref xexpression containing ones of the specified shape.
- * @tparam shape the shape of the returned expression.
- */
- template <class T, class S>
- inline auto ones(S shape) noexcept
- {
- return broadcast(T(1), std::forward<S>(shape));
- }
- template <class T, class I, std::size_t L>
- inline auto ones(const I (&shape)[L]) noexcept
- {
- return broadcast(T(1), shape);
- }
- /*********
- * zeros *
- *********/
- /**
- * Returns an \ref xexpression containing zeros of the specified shape.
- * @tparam shape the shape of the returned expression.
- */
- template <class T, class S>
- inline auto zeros(S shape) noexcept
- {
- return broadcast(T(0), std::forward<S>(shape));
- }
- template <class T, class I, std::size_t L>
- inline auto zeros(const I (&shape)[L]) noexcept
- {
- return broadcast(T(0), shape);
- }
- /**
- * Create a xcontainer (xarray, xtensor or xtensor_fixed) with uninitialized values of
- * with value_type T and shape. Selects the best container match automatically
- * from the supplied shape.
- *
- * - ``std::vector`` → ``xarray<T>``
- * - ``std::array`` or ``initializer_list`` → ``xtensor<T, N>``
- * - ``xshape<N...>`` → ``xtensor_fixed<T, xshape<N...>>``
- *
- * @param shape shape of the new xcontainer
- */
- template <class T, layout_type L = XTENSOR_DEFAULT_LAYOUT, class S>
- inline xarray<T, L> empty(const S& shape)
- {
- return xarray<T, L>::from_shape(shape);
- }
- template <class T, layout_type L = XTENSOR_DEFAULT_LAYOUT, class ST, std::size_t N>
- inline xtensor<T, N, L> empty(const std::array<ST, N>& shape)
- {
- using shape_type = typename xtensor<T, N>::shape_type;
- return xtensor<T, N, L>(xtl::forward_sequence<shape_type, decltype(shape)>(shape));
- }
- template <class T, layout_type L = XTENSOR_DEFAULT_LAYOUT, class I, std::size_t N>
- inline xtensor<T, N, L> empty(const I (&shape)[N])
- {
- using shape_type = typename xtensor<T, N>::shape_type;
- return xtensor<T, N, L>(xtl::forward_sequence<shape_type, decltype(shape)>(shape));
- }
- template <class T, layout_type L = XTENSOR_DEFAULT_LAYOUT, std::size_t... N>
- inline xtensor_fixed<T, fixed_shape<N...>, L> empty(const fixed_shape<N...>& /*shape*/)
- {
- return xtensor_fixed<T, fixed_shape<N...>, L>();
- }
- /**
- * Create a xcontainer (xarray, xtensor or xtensor_fixed) with uninitialized values of
- * the same shape, value type and layout as the input xexpression *e*.
- *
- * @param e the xexpression from which to extract shape, value type and layout.
- */
- template <class E>
- inline auto empty_like(const xexpression<E>& e)
- {
- using xtype = temporary_type_t<E>;
- auto res = xtype::from_shape(e.derived_cast().shape());
- return res;
- }
- /**
- * Create a xcontainer (xarray, xtensor or xtensor_fixed), filled with *fill_value* and of
- * the same shape, value type and layout as the input xexpression *e*.
- *
- * @param e the xexpression from which to extract shape, value type and layout.
- * @param fill_value the value used to set each element of the returned xcontainer.
- */
- template <class E>
- inline auto full_like(const xexpression<E>& e, typename E::value_type fill_value)
- {
- using xtype = temporary_type_t<E>;
- auto res = xtype::from_shape(e.derived_cast().shape());
- res.fill(fill_value);
- return res;
- }
- /**
- * Create a xcontainer (xarray, xtensor or xtensor_fixed), filled with zeros and of
- * the same shape, value type and layout as the input xexpression *e*.
- *
- * Note: contrary to zeros(shape), this function returns a non-lazy, allocated container!
- * Use ``xt::zeros<double>(e.shape());` for a lazy version.
- *
- * @param e the xexpression from which to extract shape, value type and layout.
- */
- template <class E>
- inline auto zeros_like(const xexpression<E>& e)
- {
- return full_like(e, typename E::value_type(0));
- }
- /**
- * Create a xcontainer (xarray, xtensor or xtensor_fixed), filled with ones and of
- * the same shape, value type and layout as the input xexpression *e*.
- *
- * Note: contrary to ones(shape), this function returns a non-lazy, evaluated container!
- * Use ``xt::ones<double>(e.shape());`` for a lazy version.
- *
- * @param e the xexpression from which to extract shape, value type and layout.
- */
- template <class E>
- inline auto ones_like(const xexpression<E>& e)
- {
- return full_like(e, typename E::value_type(1));
- }
- namespace detail
- {
- template <class T, class S>
- struct get_mult_type_impl
- {
- using type = T;
- };
- template <class T, class R, class P>
- struct get_mult_type_impl<T, std::chrono::duration<R, P>>
- {
- using type = R;
- };
- template <class T, class S>
- using get_mult_type = typename get_mult_type_impl<T, S>::type;
- // These methods should be private methods of arange_generator, however thi leads
- // to ICE on VS2015
- template <class R, class E, class U, class X, XTL_REQUIRES(xtl::is_integral<X>)>
- inline void arange_assign_to(xexpression<E>& e, U start, U, X step, bool) noexcept
- {
- auto& de = e.derived_cast();
- U value = start;
- for (auto&& el : de.storage())
- {
- el = static_cast<R>(value);
- value += step;
- }
- }
- template <class R, class E, class U, class X, XTL_REQUIRES(xtl::negation<xtl::is_integral<X>>)>
- inline void arange_assign_to(xexpression<E>& e, U start, U stop, X step, bool endpoint) noexcept
- {
- auto& buf = e.derived_cast().storage();
- using size_type = decltype(buf.size());
- using mult_type = get_mult_type<U, X>;
- size_type num = buf.size();
- for (size_type i = 0; i < num; ++i)
- {
- buf[i] = static_cast<R>(start + step * mult_type(i));
- }
- if (endpoint && num > 1)
- {
- buf[num - 1] = static_cast<R>(stop);
- }
- }
- template <class T, class R = T, class S = T>
- class arange_generator
- {
- public:
- using value_type = R;
- using step_type = S;
- arange_generator(T start, T stop, S step, size_t num_steps, bool endpoint = false)
- : m_start(start)
- , m_stop(stop)
- , m_step(step)
- , m_num_steps(num_steps)
- , m_endpoint(endpoint)
- {
- }
- template <class... Args>
- inline R operator()(Args... args) const
- {
- return access_impl(args...);
- }
- template <class It>
- inline R element(It first, It) const
- {
- return access_impl(*first);
- }
- template <class E>
- inline void assign_to(xexpression<E>& e) const noexcept
- {
- arange_assign_to<R>(e, m_start, m_stop, m_step, m_endpoint);
- }
- private:
- T m_start;
- T m_stop;
- step_type m_step;
- size_t m_num_steps;
- bool m_endpoint; // true for setting the last element to m_stop
- template <class T1, class... Args>
- inline R access_impl(T1 t, Args...) const
- {
- if (m_endpoint && m_num_steps > 1 && size_t(t) == m_num_steps - 1)
- {
- return static_cast<R>(m_stop);
- }
- // Avoids warning when T = char (because char + char => int!)
- using mult_type = get_mult_type<T, S>;
- return static_cast<R>(m_start + m_step * mult_type(t));
- }
- inline R access_impl() const
- {
- return static_cast<R>(m_start);
- }
- };
- template <class T, class S>
- using both_integer = xtl::conjunction<xtl::is_integral<T>, xtl::is_integral<S>>;
- template <class T, class S>
- using integer_with_signed_integer = xtl::conjunction<both_integer<T, S>, xtl::is_signed<S>>;
- template <class T, class S>
- using integer_with_unsigned_integer = xtl::conjunction<both_integer<T, S>, std::is_unsigned<S>>;
- template <class T, class S = T, XTL_REQUIRES(xtl::negation<both_integer<T, S>>)>
- inline auto arange_impl(T start, T stop, S step = 1) noexcept
- {
- std::size_t shape = static_cast<std::size_t>(std::ceil((stop - start) / step));
- return detail::make_xgenerator(detail::arange_generator<T, T, S>(start, stop, step, shape), {shape});
- }
- template <class T, class S = T, XTL_REQUIRES(integer_with_signed_integer<T, S>)>
- inline auto arange_impl(T start, T stop, S step = 1) noexcept
- {
- bool empty_cond = (stop - start) / step <= 0;
- std::size_t shape = 0;
- if (!empty_cond)
- {
- shape = stop > start ? static_cast<std::size_t>((stop - start + step - S(1)) / step)
- : static_cast<std::size_t>((start - stop - step - S(1)) / -step);
- }
- return detail::make_xgenerator(detail::arange_generator<T, T, S>(start, stop, step, shape), {shape});
- }
- template <class T, class S = T, XTL_REQUIRES(integer_with_unsigned_integer<T, S>)>
- inline auto arange_impl(T start, T stop, S step = 1) noexcept
- {
- bool empty_cond = stop <= start;
- std::size_t shape = 0;
- if (!empty_cond)
- {
- shape = static_cast<std::size_t>((stop - start + step - S(1)) / step);
- }
- return detail::make_xgenerator(detail::arange_generator<T, T, S>(start, stop, step, shape), {shape});
- }
- template <class F>
- class fn_impl
- {
- public:
- using value_type = typename F::value_type;
- using size_type = std::size_t;
- fn_impl(F&& f)
- : m_ft(f)
- {
- }
- inline value_type operator()() const
- {
- size_type idx[1] = {0ul};
- return access_impl(std::begin(idx), std::end(idx));
- }
- template <class... Args>
- inline value_type operator()(Args... args) const
- {
- size_type idx[sizeof...(Args)] = {static_cast<size_type>(args)...};
- return access_impl(std::begin(idx), std::end(idx));
- }
- template <class It>
- inline value_type element(It first, It last) const
- {
- return access_impl(first, last);
- }
- private:
- F m_ft;
- template <class It>
- inline value_type access_impl(const It& begin, const It& end) const
- {
- return m_ft(begin, end);
- }
- };
- template <class T>
- class eye_fn
- {
- public:
- using value_type = T;
- eye_fn(int k)
- : m_k(k)
- {
- }
- template <class It>
- inline T operator()(const It& /*begin*/, const It& end) const
- {
- using lvalue_type = typename std::iterator_traits<It>::value_type;
- return *(end - 1) == *(end - 2) + static_cast<lvalue_type>(m_k) ? T(1) : T(0);
- }
- private:
- std::ptrdiff_t m_k;
- };
- }
- /**
- * Generates an array with ones on the diagonal.
- * @param shape shape of the resulting expression
- * @param k index of the diagonal. 0 (default) refers to the main diagonal,
- * a positive value refers to an upper diagonal, and a negative
- * value to a lower diagonal.
- * @tparam T value_type of xexpression
- * @return xgenerator that generates the values on access
- */
- template <class T = bool>
- inline auto eye(const std::vector<std::size_t>& shape, int k = 0)
- {
- return detail::make_xgenerator(detail::fn_impl<detail::eye_fn<T>>(detail::eye_fn<T>(k)), shape);
- }
- /**
- * Generates a (n x n) array with ones on the diagonal.
- * @param n length of the diagonal.
- * @param k index of the diagonal. 0 (default) refers to the main diagonal,
- * a positive value refers to an upper diagonal, and a negative
- * value to a lower diagonal.
- * @tparam T value_type of xexpression
- * @return xgenerator that generates the values on access
- */
- template <class T = bool>
- inline auto eye(std::size_t n, int k = 0)
- {
- return eye<T>({n, n}, k);
- }
- /**
- * Generates numbers evenly spaced within given half-open interval [start, stop).
- * @param start start of the interval
- * @param stop stop of the interval
- * @param step stepsize
- * @tparam T value_type of xexpression
- * @return xgenerator that generates the values on access
- */
- template <class T, class S = T>
- inline auto arange(T start, T stop, S step = 1) noexcept
- {
- return detail::arange_impl(start, stop, step);
- }
- /**
- * Generate numbers evenly spaced within given half-open interval [0, stop)
- * with a step size of 1.
- * @param stop stop of the interval
- * @tparam T value_type of xexpression
- * @return xgenerator that generates the values on access
- */
- template <class T>
- inline auto arange(T stop) noexcept
- {
- return arange<T>(T(0), stop, T(1));
- }
- /**
- * Generates @a num_samples evenly spaced numbers over given interval
- * @param start start of interval
- * @param stop stop of interval
- * @param num_samples number of samples (defaults to 50)
- * @param endpoint if true, include endpoint (defaults to true)
- * @tparam T value_type of xexpression
- * @return xgenerator that generates the values on access
- */
- template <class T>
- inline auto linspace(T start, T stop, std::size_t num_samples = 50, bool endpoint = true) noexcept
- {
- using fp_type = std::common_type_t<T, double>;
- fp_type step = fp_type(stop - start) / std::fmax(fp_type(1), fp_type(num_samples - (endpoint ? 1 : 0)));
- return detail::make_xgenerator(
- detail::arange_generator<fp_type, T>(fp_type(start), fp_type(stop), step, num_samples, endpoint),
- {num_samples}
- );
- }
- /**
- * Generates @a num_samples numbers evenly spaced on a log scale over given interval
- * @param start start of interval (pow(base, start) is the first value).
- * @param stop stop of interval (pow(base, stop) is the final value, except if endpoint = false)
- * @param num_samples number of samples (defaults to 50)
- * @param base the base of the log space.
- * @param endpoint if true, include endpoint (defaults to true)
- * @tparam T value_type of xexpression
- * @return xgenerator that generates the values on access
- */
- template <class T>
- inline auto logspace(T start, T stop, std::size_t num_samples, T base = 10, bool endpoint = true) noexcept
- {
- return pow(std::move(base), linspace(start, stop, num_samples, endpoint));
- }
- namespace detail
- {
- template <class... CT>
- class concatenate_access
- {
- public:
- using tuple_type = std::tuple<CT...>;
- using size_type = std::size_t;
- using value_type = xtl::promote_type_t<typename std::decay_t<CT>::value_type...>;
- template <class It>
- inline value_type access(const tuple_type& t, size_type axis, It first, It last) const
- {
- // trim off extra indices if provided to match behavior of containers
- auto dim_offset = std::distance(first, last) - std::get<0>(t).dimension();
- size_t axis_dim = *(first + axis + dim_offset);
- auto match = [&](auto& arr)
- {
- if (axis_dim >= arr.shape()[axis])
- {
- axis_dim -= arr.shape()[axis];
- return false;
- }
- return true;
- };
- auto get = [&](auto& arr)
- {
- size_t offset = 0;
- const size_t end = arr.dimension();
- for (size_t i = 0; i < end; i++)
- {
- const auto& shape = arr.shape();
- const size_t stride = std::accumulate(
- shape.begin() + i + 1,
- shape.end(),
- 1,
- std::multiplies<size_t>()
- );
- if (i == axis)
- {
- offset += axis_dim * stride;
- }
- else
- {
- const auto len = (*(first + i + dim_offset));
- offset += len * stride;
- }
- }
- const auto element = arr.begin() + offset;
- return *element;
- };
- size_type i = 0;
- for (; i < sizeof...(CT); ++i)
- {
- if (apply<bool>(i, match, t))
- {
- break;
- }
- }
- return apply<value_type>(i, get, t);
- }
- };
- template <class... CT>
- class stack_access
- {
- public:
- using tuple_type = std::tuple<CT...>;
- using size_type = std::size_t;
- using value_type = xtl::promote_type_t<typename std::decay_t<CT>::value_type...>;
- template <class It>
- inline value_type access(const tuple_type& t, size_type axis, It first, It) const
- {
- auto get_item = [&](auto& arr)
- {
- size_t offset = 0;
- const size_t end = arr.dimension();
- size_t after_axis = 0;
- for (size_t i = 0; i < end; i++)
- {
- if (i == axis)
- {
- after_axis = 1;
- }
- const auto& shape = arr.shape();
- const size_t stride = std::accumulate(
- shape.begin() + i + 1,
- shape.end(),
- 1,
- std::multiplies<size_t>()
- );
- const auto len = (*(first + i + after_axis));
- offset += len * stride;
- }
- const auto element = arr.begin() + offset;
- return *element;
- };
- size_type i = *(first + axis);
- return apply<value_type>(i, get_item, t);
- }
- };
- template <class... CT>
- class vstack_access
- {
- public:
- using tuple_type = std::tuple<CT...>;
- using size_type = std::size_t;
- using value_type = xtl::promote_type_t<typename std::decay_t<CT>::value_type...>;
- template <class It>
- inline value_type access(const tuple_type& t, size_type axis, It first, It last) const
- {
- if (std::get<0>(t).dimension() == 1)
- {
- return stack.access(t, axis, first, last);
- }
- else
- {
- return concatonate.access(t, axis, first, last);
- }
- }
- private:
- concatenate_access<CT...> concatonate;
- stack_access<CT...> stack;
- };
- template <template <class...> class F, class... CT>
- class concatenate_invoker
- {
- public:
- using tuple_type = std::tuple<CT...>;
- using size_type = std::size_t;
- using value_type = xtl::promote_type_t<typename std::decay_t<CT>::value_type...>;
- inline concatenate_invoker(tuple_type&& t, size_type axis)
- : m_t(std::move(t))
- , m_axis(axis)
- {
- }
- template <class... Args>
- inline value_type operator()(Args... args) const
- {
- // TODO: avoid memory allocation
- xindex index({static_cast<size_type>(args)...});
- return access_method.access(m_t, m_axis, index.begin(), index.end());
- }
- template <class It>
- inline value_type element(It first, It last) const
- {
- return access_method.access(m_t, m_axis, first, last);
- }
- private:
- F<CT...> access_method;
- tuple_type m_t;
- size_type m_axis;
- };
- template <class... CT>
- using concatenate_impl = concatenate_invoker<concatenate_access, CT...>;
- template <class... CT>
- using stack_impl = concatenate_invoker<stack_access, CT...>;
- template <class... CT>
- using vstack_impl = concatenate_invoker<vstack_access, CT...>;
- template <class CT>
- class repeat_impl
- {
- public:
- using xexpression_type = std::decay_t<CT>;
- using size_type = typename xexpression_type::size_type;
- using value_type = typename xexpression_type::value_type;
- template <class CTA>
- repeat_impl(CTA&& source, size_type axis)
- : m_source(std::forward<CTA>(source))
- , m_axis(axis)
- {
- }
- template <class... Args>
- value_type operator()(Args... args) const
- {
- std::array<size_type, sizeof...(Args)> args_arr = {static_cast<size_type>(args)...};
- return m_source(args_arr[m_axis]);
- }
- template <class It>
- inline value_type element(It first, It) const
- {
- return m_source(*(first + static_cast<std::ptrdiff_t>(m_axis)));
- }
- private:
- CT m_source;
- size_type m_axis;
- };
- }
- /**
- * @brief Creates tuples from arguments for \ref concatenate and \ref stack.
- * Very similar to std::make_tuple.
- */
- template <class... Types>
- inline auto xtuple(Types&&... args)
- {
- return std::tuple<xtl::const_closure_type_t<Types>...>(std::forward<Types>(args)...);
- }
- namespace detail
- {
- template <bool... values>
- using all_true = xtl::conjunction<std::integral_constant<bool, values>...>;
- template <class X, class Y, std::size_t axis, class AxesSequence>
- struct concat_fixed_shape_impl;
- template <class X, class Y, std::size_t axis, std::size_t... Is>
- struct concat_fixed_shape_impl<X, Y, axis, std::index_sequence<Is...>>
- {
- static_assert(X::size() == Y::size(), "Concatenation requires equisized shapes");
- static_assert(axis < X::size(), "Concatenation requires a valid axis");
- static_assert(
- all_true<(axis == Is || X::template get<Is>() == Y::template get<Is>())...>::value,
- "Concatenation requires compatible shapes and axis"
- );
- using type = fixed_shape<
- (axis == Is ? X::template get<Is>() + Y::template get<Is>() : X::template get<Is>())...>;
- };
- template <std::size_t axis, class X, class Y, class... Rest>
- struct concat_fixed_shape;
- template <std::size_t axis, class X, class Y>
- struct concat_fixed_shape<axis, X, Y>
- {
- using type = typename concat_fixed_shape_impl<X, Y, axis, std::make_index_sequence<X::size()>>::type;
- };
- template <std::size_t axis, class X, class Y, class... Rest>
- struct concat_fixed_shape
- {
- using type = typename concat_fixed_shape<axis, X, typename concat_fixed_shape<axis, Y, Rest...>::type>::type;
- };
- template <std::size_t axis, class... Args>
- using concat_fixed_shape_t = typename concat_fixed_shape<axis, Args...>::type;
- template <class... CT>
- using all_fixed_shapes = detail::all_fixed<typename std::decay_t<CT>::shape_type...>;
- struct concat_shape_builder_t
- {
- template <class Shape, bool = detail::is_fixed<Shape>::value>
- struct concat_shape;
- template <class Shape>
- struct concat_shape<Shape, true>
- {
- // Convert `fixed_shape` to `static_shape` to allow runtime dimension calculation.
- using type = static_shape<typename Shape::value_type, Shape::size()>;
- };
- template <class Shape>
- struct concat_shape<Shape, false>
- {
- using type = Shape;
- };
- template <class... Args>
- static auto build(const std::tuple<Args...>& t, std::size_t axis)
- {
- using shape_type = promote_shape_t<
- typename concat_shape<typename std::decay_t<Args>::shape_type>::type...>;
- using source_shape_type = decltype(std::get<0>(t).shape());
- shape_type new_shape = xtl::forward_sequence<shape_type, source_shape_type>(
- std::get<0>(t).shape()
- );
- auto check_shape = [&axis, &new_shape](auto& arr)
- {
- std::size_t s = new_shape.size();
- bool res = s == arr.dimension();
- for (std::size_t i = 0; i < s; ++i)
- {
- res = res && (i == axis || new_shape[i] == arr.shape(i));
- }
- if (!res)
- {
- throw_concatenate_error(new_shape, arr.shape());
- }
- };
- for_each(check_shape, t);
- auto shape_at_axis = [&axis](std::size_t prev, auto& arr) -> std::size_t
- {
- return prev + arr.shape()[axis];
- };
- new_shape[axis] += accumulate(shape_at_axis, std::size_t(0), t) - new_shape[axis];
- return new_shape;
- }
- };
- } // namespace detail
- /***************
- * concatenate *
- ***************/
- /**
- * @brief Concatenates xexpressions along \em axis.
- *
- * @param t \ref xtuple of xexpressions to concatenate
- * @param axis axis along which elements are concatenated
- * @returns xgenerator evaluating to concatenated elements
- *
- * @code{.cpp}
- * xt::xarray<double> a = {{1, 2, 3}};
- * xt::xarray<double> b = {{2, 3, 4}};
- * xt::xarray<double> c = xt::concatenate(xt::xtuple(a, b)); // => {{1, 2, 3},
- * // {2, 3, 4}}
- * xt::xarray<double> d = xt::concatenate(xt::xtuple(a, b), 1); // => {{1, 2, 3, 2, 3, 4}}
- * @endcode
- */
- template <class... CT>
- inline auto concatenate(std::tuple<CT...>&& t, std::size_t axis = 0)
- {
- const auto shape = detail::concat_shape_builder_t::build(t, axis);
- return detail::make_xgenerator(detail::concatenate_impl<CT...>(std::move(t), axis), shape);
- }
- template <std::size_t axis, class... CT, typename = std::enable_if_t<detail::all_fixed_shapes<CT...>::value>>
- inline auto concatenate(std::tuple<CT...>&& t)
- {
- using shape_type = detail::concat_fixed_shape_t<axis, typename std::decay_t<CT>::shape_type...>;
- return detail::make_xgenerator(detail::concatenate_impl<CT...>(std::move(t), axis), shape_type{});
- }
- namespace detail
- {
- template <class T, std::size_t N>
- inline std::array<T, N + 1> add_axis(std::array<T, N> arr, std::size_t axis, std::size_t value)
- {
- std::array<T, N + 1> temp;
- std::copy(arr.begin(), arr.begin() + axis, temp.begin());
- temp[axis] = value;
- std::copy(arr.begin() + axis, arr.end(), temp.begin() + axis + 1);
- return temp;
- }
- template <class T>
- inline T add_axis(T arr, std::size_t axis, std::size_t value)
- {
- T temp(arr);
- temp.insert(temp.begin() + std::ptrdiff_t(axis), value);
- return temp;
- }
- }
- /**
- * @brief Stack xexpressions along \em axis.
- * Stacking always creates a new dimension along which elements are stacked.
- *
- * @param t \ref xtuple of xexpressions to concatenate
- * @param axis axis along which elements are stacked
- * @returns xgenerator evaluating to stacked elements
- *
- * @code{.cpp}
- * xt::xarray<double> a = {1, 2, 3};
- * xt::xarray<double> b = {5, 6, 7};
- * xt::xarray<double> s = xt::stack(xt::xtuple(a, b)); // => {{1, 2, 3},
- * // {5, 6, 7}}
- * xt::xarray<double> t = xt::stack(xt::xtuple(a, b), 1); // => {{1, 5},
- * // {2, 6},
- * // {3, 7}}
- * @endcode
- */
- template <class... CT>
- inline auto stack(std::tuple<CT...>&& t, std::size_t axis = 0)
- {
- using shape_type = promote_shape_t<typename std::decay_t<CT>::shape_type...>;
- using source_shape_type = decltype(std::get<0>(t).shape());
- auto new_shape = detail::add_axis(
- xtl::forward_sequence<shape_type, source_shape_type>(std::get<0>(t).shape()),
- axis,
- sizeof...(CT)
- );
- return detail::make_xgenerator(detail::stack_impl<CT...>(std::move(t), axis), new_shape);
- }
- /**
- * @brief Stack xexpressions in sequence horizontally (column wise).
- * This is equivalent to concatenation along the second axis, except for 1-D
- * xexpressions where it concatenate along the first axis.
- *
- * @param t \ref xtuple of xexpressions to stack
- * @return xgenerator evaluating to stacked elements
- */
- template <class... CT>
- inline auto hstack(std::tuple<CT...>&& t)
- {
- auto dim = std::get<0>(t).dimension();
- std::size_t axis = dim > std::size_t(1) ? 1 : 0;
- return concatenate(std::move(t), axis);
- }
- namespace detail
- {
- template <class S, class... CT>
- inline auto vstack_shape(std::tuple<CT...>& t, const S& shape)
- {
- using size_type = typename S::value_type;
- auto res = shape.size() == size_type(1)
- ? S({sizeof...(CT), shape[0]})
- : concat_shape_builder_t::build(std::move(t), size_type(0));
- return res;
- }
- template <class T, class... CT>
- inline auto vstack_shape(const std::tuple<CT...>&, std::array<T, 1> shape)
- {
- std::array<T, 2> res = {sizeof...(CT), shape[0]};
- return res;
- }
- }
- /**
- * @brief Stack xexpressions in sequence vertically (row wise).
- * This is equivalent to concatenation along the first axis after
- * 1-D arrays of shape (N) have been reshape to (1, N).
- *
- * @param t \ref xtuple of xexpressions to stack
- * @return xgenerator evaluating to stacked elements
- */
- template <class... CT>
- inline auto vstack(std::tuple<CT...>&& t)
- {
- using shape_type = promote_shape_t<typename std::decay_t<CT>::shape_type...>;
- using source_shape_type = decltype(std::get<0>(t).shape());
- auto new_shape = detail::vstack_shape(
- t,
- xtl::forward_sequence<shape_type, source_shape_type>(std::get<0>(t).shape())
- );
- return detail::make_xgenerator(detail::vstack_impl<CT...>(std::move(t), size_t(0)), new_shape);
- }
- namespace detail
- {
- template <std::size_t... I, class... E>
- inline auto meshgrid_impl(std::index_sequence<I...>, E&&... e) noexcept
- {
- #if defined _MSC_VER
- const std::array<std::size_t, sizeof...(E)> shape = {e.shape()[0]...};
- return std::make_tuple(
- detail::make_xgenerator(detail::repeat_impl<xclosure_t<E>>(std::forward<E>(e), I), shape)...
- );
- #else
- return std::make_tuple(detail::make_xgenerator(
- detail::repeat_impl<xclosure_t<E>>(std::forward<E>(e), I),
- {e.shape()[0]...}
- )...);
- #endif
- }
- }
- /**
- * @brief Return coordinate tensors from coordinate vectors.
- * Make N-D coordinate tensor expressions for vectorized evaluations of N-D scalar/vector
- * fields over N-D grids, given one-dimensional coordinate arrays x1, x2,..., xn.
- *
- * @param e xexpressions to concatenate
- * @returns tuple of xgenerator expressions.
- */
- template <class... E>
- inline auto meshgrid(E&&... e) noexcept
- {
- return detail::meshgrid_impl(std::make_index_sequence<sizeof...(E)>(), std::forward<E>(e)...);
- }
- namespace detail
- {
- template <class CT>
- class diagonal_fn
- {
- public:
- using xexpression_type = std::decay_t<CT>;
- using value_type = typename xexpression_type::value_type;
- template <class CTA>
- diagonal_fn(CTA&& source, int offset, std::size_t axis_1, std::size_t axis_2)
- : m_source(std::forward<CTA>(source))
- , m_offset(offset)
- , m_axis_1(axis_1)
- , m_axis_2(axis_2)
- {
- }
- template <class It>
- inline value_type operator()(It begin, It) const
- {
- xindex idx(m_source.shape().size());
- for (std::size_t i = 0; i < idx.size(); i++)
- {
- if (i != m_axis_1 && i != m_axis_2)
- {
- idx[i] = static_cast<std::size_t>(*begin++);
- }
- }
- using it_vtype = typename std::iterator_traits<It>::value_type;
- it_vtype uoffset = static_cast<it_vtype>(m_offset);
- if (m_offset >= 0)
- {
- idx[m_axis_1] = static_cast<std::size_t>(*(begin));
- idx[m_axis_2] = static_cast<std::size_t>(*(begin) + uoffset);
- }
- else
- {
- idx[m_axis_1] = static_cast<std::size_t>(*(begin) -uoffset);
- idx[m_axis_2] = static_cast<std::size_t>(*(begin));
- }
- return m_source[idx];
- }
- private:
- CT m_source;
- const int m_offset;
- const std::size_t m_axis_1;
- const std::size_t m_axis_2;
- };
- template <class CT>
- class diag_fn
- {
- public:
- using xexpression_type = std::decay_t<CT>;
- using value_type = typename xexpression_type::value_type;
- template <class CTA>
- diag_fn(CTA&& source, int k)
- : m_source(std::forward<CTA>(source))
- , m_k(k)
- {
- }
- template <class It>
- inline value_type operator()(It begin, It) const
- {
- using it_vtype = typename std::iterator_traits<It>::value_type;
- it_vtype umk = static_cast<it_vtype>(m_k);
- if (m_k > 0)
- {
- return *begin + umk == *(begin + 1) ? m_source(*begin) : value_type(0);
- }
- else
- {
- return *begin + umk == *(begin + 1) ? m_source(*begin + umk) : value_type(0);
- }
- }
- private:
- CT m_source;
- const int m_k;
- };
- template <class CT, class Comp>
- class trilu_fn
- {
- public:
- using xexpression_type = std::decay_t<CT>;
- using value_type = typename xexpression_type::value_type;
- using signed_idx_type = long int;
- template <class CTA>
- trilu_fn(CTA&& source, int k, Comp comp)
- : m_source(std::forward<CTA>(source))
- , m_k(k)
- , m_comp(comp)
- {
- }
- template <class It>
- inline value_type operator()(It begin, It end) const
- {
- // have to cast to signed int otherwise -1 can lead to overflow
- return m_comp(signed_idx_type(*begin) + m_k, signed_idx_type(*(begin + 1)))
- ? m_source.element(begin, end)
- : value_type(0);
- }
- private:
- CT m_source;
- const signed_idx_type m_k;
- const Comp m_comp;
- };
- }
- namespace detail
- {
- // meta-function returning the shape type for a diagonal
- template <class ST, class... S>
- struct diagonal_shape_type
- {
- using type = ST;
- };
- template <class I, std::size_t L>
- struct diagonal_shape_type<std::array<I, L>>
- {
- using type = std::array<I, L - 1>;
- };
- }
- /**
- * @brief Returns the elements on the diagonal of arr
- * If arr has more than two dimensions, then the axes specified by
- * axis_1 and axis_2 are used to determine the 2-D sub-array whose
- * diagonal is returned. The shape of the resulting array can be
- * determined by removing axis1 and axis2 and appending an index
- * to the right equal to the size of the resulting diagonals.
- *
- * @param arr the input array
- * @param offset offset of the diagonal from the main diagonal. Can
- * be positive or negative.
- * @param axis_1 Axis to be used as the first axis of the 2-D sub-arrays
- * from which the diagonals should be taken.
- * @param axis_2 Axis to be used as the second axis of the 2-D sub-arrays
- * from which the diagonals should be taken.
- * @returns xexpression with values of the diagonal
- *
- * @code{.cpp}
- * xt::xarray<double> a = {{1, 2, 3},
- * {4, 5, 6}
- * {7, 8, 9}};
- * auto b = xt::diagonal(a); // => {1, 5, 9}
- * @endcode
- */
- template <class E>
- inline auto diagonal(E&& arr, int offset = 0, std::size_t axis_1 = 0, std::size_t axis_2 = 1)
- {
- using CT = xclosure_t<E>;
- using shape_type = typename detail::diagonal_shape_type<typename std::decay_t<E>::shape_type>::type;
- auto shape = arr.shape();
- auto dimension = arr.dimension();
- // The following shape calculation code is an almost verbatim adaptation of NumPy:
- // https://github.com/numpy/numpy/blob/2aabeafb97bea4e1bfa29d946fbf31e1104e7ae0/numpy/core/src/multiarray/item_selection.c#L1799
- auto ret_shape = xtl::make_sequence<shape_type>(dimension - 1, 0);
- int dim_1 = static_cast<int>(shape[axis_1]);
- int dim_2 = static_cast<int>(shape[axis_2]);
- offset >= 0 ? dim_2 -= offset : dim_1 += offset;
- auto diag_size = std::size_t(dim_2 < dim_1 ? dim_2 : dim_1);
- std::size_t i = 0;
- for (std::size_t idim = 0; idim < dimension; ++idim)
- {
- if (idim != axis_1 && idim != axis_2)
- {
- ret_shape[i++] = shape[idim];
- }
- }
- ret_shape.back() = diag_size;
- return detail::make_xgenerator(
- detail::fn_impl<detail::diagonal_fn<CT>>(
- detail::diagonal_fn<CT>(std::forward<E>(arr), offset, axis_1, axis_2)
- ),
- ret_shape
- );
- }
- /**
- * @brief xexpression with values of arr on the diagonal, zeroes otherwise
- *
- * @param arr the 1D input array of length n
- * @param k the offset of the considered diagonal
- * @returns xexpression function with shape n x n and arr on the diagonal
- *
- * @code{.cpp}
- * xt::xarray<double> a = {1, 5, 9};
- * auto b = xt::diag(a); // => {{1, 0, 0},
- * // {0, 5, 0},
- * // {0, 0, 9}}
- * @endcode
- */
- template <class E>
- inline auto diag(E&& arr, int k = 0)
- {
- using CT = xclosure_t<E>;
- std::size_t sk = std::size_t(std::abs(k));
- std::size_t s = arr.shape()[0] + sk;
- return detail::make_xgenerator(
- detail::fn_impl<detail::diag_fn<CT>>(detail::diag_fn<CT>(std::forward<E>(arr), k)),
- {s, s}
- );
- }
- /**
- * @brief Extract lower triangular matrix from xexpression. The parameter k selects the
- * offset of the diagonal.
- *
- * @param arr the input array
- * @param k the diagonal above which to zero elements. 0 (default) selects the main diagonal,
- * k < 0 is below the main diagonal, k > 0 above.
- * @returns xexpression containing lower triangle from arr, 0 otherwise
- */
- template <class E>
- inline auto tril(E&& arr, int k = 0)
- {
- using CT = xclosure_t<E>;
- auto shape = arr.shape();
- return detail::make_xgenerator(
- detail::fn_impl<detail::trilu_fn<CT, std::greater_equal<long int>>>(
- detail::trilu_fn<CT, std::greater_equal<long int>>(
- std::forward<E>(arr),
- k,
- std::greater_equal<long int>()
- )
- ),
- shape
- );
- }
- /**
- * @brief Extract upper triangular matrix from xexpression. The parameter k selects the
- * offset of the diagonal.
- *
- * @param arr the input array
- * @param k the diagonal below which to zero elements. 0 (default) selects the main diagonal,
- * k < 0 is below the main diagonal, k > 0 above.
- * @returns xexpression containing lower triangle from arr, 0 otherwise
- */
- template <class E>
- inline auto triu(E&& arr, int k = 0)
- {
- using CT = xclosure_t<E>;
- auto shape = arr.shape();
- return detail::make_xgenerator(
- detail::fn_impl<detail::trilu_fn<CT, std::less_equal<long int>>>(
- detail::trilu_fn<CT, std::less_equal<long int>>(std::forward<E>(arr), k, std::less_equal<long int>())
- ),
- shape
- );
- }
- }
- #endif
|