xeval.hpp 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. /***************************************************************************
  2. * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
  3. * Copyright (c) QuantStack *
  4. * *
  5. * Distributed under the terms of the BSD 3-Clause License. *
  6. * *
  7. * The full license is in the file LICENSE, distributed with this software. *
  8. ****************************************************************************/
  9. #ifndef XTENSOR_EVAL_HPP
  10. #define XTENSOR_EVAL_HPP
  11. #include "xexpression_traits.hpp"
  12. #include "xshape.hpp"
  13. #include "xtensor_forward.hpp"
  14. namespace xt
  15. {
  16. /**
  17. * @defgroup xt_xeval
  18. *
  19. * Evaluation functions.
  20. * Defined in ``xtensor/xeval.hpp``
  21. */
  22. namespace detail
  23. {
  24. template <class T>
  25. using is_container = std::is_base_of<xcontainer<std::remove_const_t<T>>, T>;
  26. }
  27. /**
  28. * Force evaluation of xexpression.
  29. *
  30. * @code{.cpp}
  31. * xt::xarray<double> a = {1, 2, 3, 4};
  32. * auto&& b = xt::eval(a); // b is a reference to a, no copy!
  33. * auto&& c = xt::eval(a + b); // c is xarray<double>, not an xexpression
  34. * @endcode
  35. *
  36. * @ingroup xt_xeval
  37. * @return xt::xarray or xt::xtensor depending on shape type
  38. */
  39. template <class T>
  40. inline auto eval(T&& t) -> std::enable_if_t<detail::is_container<std::decay_t<T>>::value, T&&>
  41. {
  42. return std::forward<T>(t);
  43. }
  44. /// @cond DOXYGEN_INCLUDE_SFINAE
  45. template <class T>
  46. inline auto eval(T&& t)
  47. -> std::enable_if_t<!detail::is_container<std::decay_t<T>>::value, temporary_type_t<T>>
  48. {
  49. return std::forward<T>(t);
  50. }
  51. /// @endcond
  52. namespace detail
  53. {
  54. /**********************************
  55. * has_same_layout implementation *
  56. **********************************/
  57. template <layout_type L = layout_type::any, class E>
  58. constexpr bool has_same_layout()
  59. {
  60. return (std::decay_t<E>::static_layout == L) || (L == layout_type::any);
  61. }
  62. template <layout_type L = layout_type::any, class E>
  63. constexpr bool has_same_layout(E&&)
  64. {
  65. return has_same_layout<L, E>();
  66. }
  67. template <class E1, class E2>
  68. constexpr bool has_same_layout(E1&&, E2&&)
  69. {
  70. return has_same_layout<std::decay_t<E1>::static_layout, E2>();
  71. }
  72. /*********************************
  73. * has_fixed_dims implementation *
  74. *********************************/
  75. template <class E>
  76. constexpr bool has_fixed_dims()
  77. {
  78. return detail::is_array<typename std::decay_t<E>::shape_type>::value;
  79. }
  80. template <class E>
  81. constexpr bool has_fixed_dims(E&&)
  82. {
  83. return has_fixed_dims<E>();
  84. }
  85. /****************************************
  86. * as_xarray_container_t implementation *
  87. ****************************************/
  88. template <class E, layout_type L>
  89. using as_xarray_container_t = xarray<typename std::decay_t<E>::value_type, layout_remove_any(L)>;
  90. /*****************************************
  91. * as_xtensor_container_t implementation *
  92. *****************************************/
  93. template <class E, layout_type L>
  94. using as_xtensor_container_t = xtensor<
  95. typename std::decay_t<E>::value_type,
  96. std::tuple_size<typename std::decay_t<E>::shape_type>::value,
  97. layout_remove_any(L)>;
  98. }
  99. /**
  100. * Force evaluation of xexpression not providing a data interface
  101. * and convert to the required layout.
  102. *
  103. * @code{.cpp}
  104. * xt::xarray<double, xt::layout_type::row_major> a = {1, 2, 3, 4};
  105. *
  106. * // take reference to a (no copy!)
  107. * auto&& b = xt::as_strided(a);
  108. *
  109. * // xarray<double> with the required layout
  110. * auto&& c = xt::as_strided<xt::layout_type::column_major>(a);
  111. *
  112. * // xexpression
  113. * auto&& a_cast = xt::cast<int>(a);
  114. *
  115. * // xarray<int>, not an xexpression
  116. * auto&& d = xt::as_strided(a_cast);
  117. *
  118. * // xarray<int> with the required layout
  119. * auto&& e = xt::as_strided<xt::layout_type::column_major>(a_cast);
  120. * @endcode
  121. *
  122. * @warning This function should be used in a local context only.
  123. * Returning the value returned by this function could lead to a dangling reference.
  124. * @ingroup xt_xeval
  125. * @return The expression when it already provides a data interface with the correct layout,
  126. * an evaluated xt::xarray or xt::xtensor depending on shape type otherwise.
  127. */
  128. template <layout_type L = layout_type::any, class E>
  129. inline auto as_strided(E&& e)
  130. -> std::enable_if_t<has_data_interface<std::decay_t<E>>::value && detail::has_same_layout<L, E>(), E&&>
  131. {
  132. return std::forward<E>(e);
  133. }
  134. /// @cond DOXYGEN_INCLUDE_SFINAE
  135. template <layout_type L = layout_type::any, class E>
  136. inline auto as_strided(E&& e) -> std::enable_if_t<
  137. (!(has_data_interface<std::decay_t<E>>::value && detail::has_same_layout<L, E>()))
  138. && detail::has_fixed_dims<E>(),
  139. detail::as_xtensor_container_t<E, L>>
  140. {
  141. return e;
  142. }
  143. template <layout_type L = layout_type::any, class E>
  144. inline auto as_strided(E&& e) -> std::enable_if_t<
  145. (!(has_data_interface<std::decay_t<E>>::value && detail::has_same_layout<L, E>()))
  146. && (!detail::has_fixed_dims<E>()),
  147. detail::as_xarray_container_t<E, L>>
  148. {
  149. return e;
  150. }
  151. /// @endcond
  152. }
  153. #endif