1
0

xcomplex.hpp 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. /***************************************************************************
  2. * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
  3. * Copyright (c) QuantStack *
  4. * *
  5. * Distributed under the terms of the BSD 3-Clause License. *
  6. * *
  7. * The full license is in the file LICENSE, distributed with this software. *
  8. ****************************************************************************/
  9. #ifndef XTENSOR_COMPLEX_HPP
  10. #define XTENSOR_COMPLEX_HPP
  11. #include <type_traits>
  12. #include <utility>
  13. #include <xtl/xcomplex.hpp>
  14. #include "xtensor/xbuilder.hpp"
  15. #include "xtensor/xexpression.hpp"
  16. #include "xtensor/xoffset_view.hpp"
  17. namespace xt
  18. {
  19. /**
  20. * @defgroup xt_xcomplex
  21. *
  22. * Defined in ``xtensor/xcomplex.hpp``
  23. */
  24. /******************************
  25. * real and imag declarations *
  26. ******************************/
  27. template <class E>
  28. decltype(auto) real(E&& e) noexcept;
  29. template <class E>
  30. decltype(auto) imag(E&& e) noexcept;
  31. /********************************
  32. * real and imag implementation *
  33. ********************************/
  34. namespace detail
  35. {
  36. template <bool iscomplex = true>
  37. struct complex_helper
  38. {
  39. template <class E>
  40. inline static auto real(E&& e) noexcept
  41. {
  42. using real_type = typename std::decay_t<E>::value_type::value_type;
  43. return xoffset_view<xclosure_t<E>, real_type, 0>(std::forward<E>(e));
  44. }
  45. template <class E>
  46. inline static auto imag(E&& e) noexcept
  47. {
  48. using real_type = typename std::decay_t<E>::value_type::value_type;
  49. return xoffset_view<xclosure_t<E>, real_type, sizeof(real_type)>(std::forward<E>(e));
  50. }
  51. };
  52. template <>
  53. struct complex_helper<false>
  54. {
  55. template <class E>
  56. inline static decltype(auto) real(E&& e) noexcept
  57. {
  58. return std::forward<E>(e);
  59. }
  60. template <class E>
  61. inline static auto imag(E&& e) noexcept
  62. {
  63. return zeros<typename std::decay_t<E>::value_type>(e.shape());
  64. }
  65. };
  66. template <bool isexpression = true>
  67. struct complex_expression_helper
  68. {
  69. template <class E>
  70. inline static decltype(auto) real(E&& e) noexcept
  71. {
  72. return detail::complex_helper<xtl::is_complex<typename std::decay_t<E>::value_type>::value>::real(
  73. std::forward<E>(e)
  74. );
  75. }
  76. template <class E>
  77. inline static decltype(auto) imag(E&& e) noexcept
  78. {
  79. return detail::complex_helper<xtl::is_complex<typename std::decay_t<E>::value_type>::value>::imag(
  80. std::forward<E>(e)
  81. );
  82. }
  83. };
  84. template <>
  85. struct complex_expression_helper<false>
  86. {
  87. template <class E>
  88. inline static decltype(auto) real(E&& e) noexcept
  89. {
  90. return xtl::forward_real(std::forward<E>(e));
  91. }
  92. template <class E>
  93. inline static decltype(auto) imag(E&& e) noexcept
  94. {
  95. return xtl::forward_imag(std::forward<E>(e));
  96. }
  97. };
  98. }
  99. /**
  100. * Return an xt::xexpression representing the real part of the given expression.
  101. *
  102. * The returned expression either hold a const reference to @p e or a copy
  103. * depending on whether @p e is an lvalue or an rvalue.
  104. *
  105. * @ingroup xt_xcomplex
  106. * @tparam e The xt::xexpression
  107. */
  108. template <class E>
  109. inline decltype(auto) real(E&& e) noexcept
  110. {
  111. return detail::complex_expression_helper<is_xexpression<std::decay_t<E>>::value>::real(std::forward<E>(e
  112. ));
  113. }
  114. /**
  115. * Return an xt::xexpression representing the imaginary part of the given expression.
  116. *
  117. * The returned expression either hold a const reference to @p e or a copy
  118. * depending on whether @p e is an lvalue or an rvalue.
  119. *
  120. * @ingroup xt_xcomplex
  121. * @tparam e The xt::xexpression
  122. */
  123. template <class E>
  124. inline decltype(auto) imag(E&& e) noexcept
  125. {
  126. return detail::complex_expression_helper<is_xexpression<std::decay_t<E>>::value>::imag(std::forward<E>(e
  127. ));
  128. }
  129. #define UNARY_COMPLEX_FUNCTOR(NS, NAME) \
  130. struct NAME##_fun \
  131. { \
  132. template <class T> \
  133. constexpr auto operator()(const T& t) const \
  134. { \
  135. using NS::NAME; \
  136. return NAME(t); \
  137. } \
  138. \
  139. template <class B> \
  140. constexpr auto simd_apply(const B& t) const \
  141. { \
  142. using NS::NAME; \
  143. return NAME(t); \
  144. } \
  145. }
  146. namespace math
  147. {
  148. namespace detail
  149. {
  150. template <class T>
  151. constexpr std::complex<T> conj_impl(const std::complex<T>& c)
  152. {
  153. return std::complex<T>(c.real(), -c.imag());
  154. }
  155. template <class T>
  156. constexpr std::complex<T> conj_impl(const T& real)
  157. {
  158. return std::complex<T>(real, 0);
  159. }
  160. #ifdef XTENSOR_USE_XSIMD
  161. template <class T, class A>
  162. xsimd::complex_batch_type_t<xsimd::batch<T, A>> conj_impl(const xsimd::batch<T, A>& z)
  163. {
  164. return xsimd::conj(z);
  165. }
  166. #endif
  167. }
  168. UNARY_COMPLEX_FUNCTOR(std, norm);
  169. UNARY_COMPLEX_FUNCTOR(std, arg);
  170. UNARY_COMPLEX_FUNCTOR(detail, conj_impl);
  171. }
  172. #undef UNARY_COMPLEX_FUNCTOR
  173. /**
  174. * Return an xt::xfunction evaluating to the complex conjugate of the given expression.
  175. *
  176. * @ingroup xt_xcomplex
  177. * @param e the xt::xexpression
  178. */
  179. template <class E>
  180. inline auto conj(E&& e) noexcept
  181. {
  182. using functor = math::conj_impl_fun;
  183. using type = xfunction<functor, const_xclosure_t<E>>;
  184. return type(functor(), std::forward<E>(e));
  185. }
  186. /**
  187. * Calculates the phase angle (in radians) elementwise for the complex numbers in @p e.
  188. *
  189. * @ingroup xt_xcomplex
  190. * @param e the xt::xexpression
  191. */
  192. template <class E>
  193. inline auto arg(E&& e) noexcept
  194. {
  195. using functor = math::arg_fun;
  196. using type = xfunction<functor, const_xclosure_t<E>>;
  197. return type(functor(), std::forward<E>(e));
  198. }
  199. /**
  200. * Calculates the phase angle elementwise for the complex numbers in @p e.
  201. *
  202. * Note that this function might be slightly less performant than xt::arg.
  203. *
  204. * @ingroup xt_xcomplex
  205. * @param e the xt::xexpression
  206. * @param deg calculate angle in degrees instead of radians
  207. */
  208. template <class E>
  209. inline auto angle(E&& e, bool deg = false) noexcept
  210. {
  211. using value_type = xtl::complex_value_type_t<typename std::decay_t<E>::value_type>;
  212. value_type multiplier = 1.0;
  213. if (deg)
  214. {
  215. multiplier = value_type(180) / numeric_constants<value_type>::PI;
  216. }
  217. return arg(std::forward<E>(e)) * std::move(multiplier);
  218. }
  219. /**
  220. * Calculates the squared magnitude elementwise for the complex numbers in @p e.
  221. *
  222. * Equivalent to ``xt::pow(xt::real(e), 2) + xt::pow(xt::imag(e), 2)``.
  223. * @ingroup xt_xcomplex
  224. * @param e the xt::xexpression
  225. */
  226. template <class E>
  227. inline auto norm(E&& e) noexcept
  228. {
  229. using functor = math::norm_fun;
  230. using type = xfunction<functor, const_xclosure_t<E>>;
  231. return type(functor(), std::forward<E>(e));
  232. }
  233. }
  234. #endif