xoperation.hpp 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997
  1. /***************************************************************************
  2. * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
  3. * Copyright (c) QuantStack *
  4. * *
  5. * Distributed under the terms of the BSD 3-Clause License. *
  6. * *
  7. * The full license is in the file LICENSE, distributed with this software. *
  8. ****************************************************************************/
  9. #ifndef XTENSOR_OPERATION_HPP
  10. #define XTENSOR_OPERATION_HPP
  11. #include <algorithm>
  12. #include <functional>
  13. #include <type_traits>
  14. #include <xtl/xsequence.hpp>
  15. #include "xfunction.hpp"
  16. #include "xscalar.hpp"
  17. #include "xstrided_view.hpp"
  18. #include "xstrides.hpp"
  19. namespace xt
  20. {
  21. /***********
  22. * helpers *
  23. ***********/
  24. #define UNARY_OPERATOR_FUNCTOR(NAME, OP) \
  25. struct NAME \
  26. { \
  27. template <class A1> \
  28. constexpr auto operator()(const A1& arg) const \
  29. { \
  30. return OP arg; \
  31. } \
  32. template <class B> \
  33. constexpr auto simd_apply(const B& arg) const \
  34. { \
  35. return OP arg; \
  36. } \
  37. }
  38. #define DEFINE_COMPLEX_OVERLOAD(OP) \
  39. template <class T1, class T2, XTL_REQUIRES(xtl::negation<std::is_same<T1, T2>>)> \
  40. constexpr auto operator OP(const std::complex<T1>& arg1, const std::complex<T2>& arg2) \
  41. { \
  42. using result_type = typename xtl::promote_type_t<std::complex<T1>, std::complex<T2>>; \
  43. return (result_type(arg1) OP result_type(arg2)); \
  44. } \
  45. \
  46. template <class T1, class T2, XTL_REQUIRES(xtl::negation<std::is_same<T1, T2>>)> \
  47. constexpr auto operator OP(const T1& arg1, const std::complex<T2>& arg2) \
  48. { \
  49. using result_type = typename xtl::promote_type_t<T1, std::complex<T2>>; \
  50. return (result_type(arg1) OP result_type(arg2)); \
  51. } \
  52. \
  53. template <class T1, class T2, XTL_REQUIRES(xtl::negation<std::is_same<T1, T2>>)> \
  54. constexpr auto operator OP(const std::complex<T1>& arg1, const T2& arg2) \
  55. { \
  56. using result_type = typename xtl::promote_type_t<std::complex<T1>, T2>; \
  57. return (result_type(arg1) OP result_type(arg2)); \
  58. }
  59. #define BINARY_OPERATOR_FUNCTOR(NAME, OP) \
  60. struct NAME \
  61. { \
  62. template <class T1, class T2> \
  63. constexpr auto operator()(T1&& arg1, T2&& arg2) const \
  64. { \
  65. using xt::detail::operator OP; \
  66. return (std::forward<T1>(arg1) OP std::forward<T2>(arg2)); \
  67. } \
  68. template <class B> \
  69. constexpr auto simd_apply(const B& arg1, const B& arg2) const \
  70. { \
  71. return (arg1 OP arg2); \
  72. } \
  73. }
  74. namespace detail
  75. {
  76. DEFINE_COMPLEX_OVERLOAD(+);
  77. DEFINE_COMPLEX_OVERLOAD(-);
  78. DEFINE_COMPLEX_OVERLOAD(*);
  79. DEFINE_COMPLEX_OVERLOAD(/);
  80. DEFINE_COMPLEX_OVERLOAD(%);
  81. DEFINE_COMPLEX_OVERLOAD(||);
  82. DEFINE_COMPLEX_OVERLOAD(&&);
  83. DEFINE_COMPLEX_OVERLOAD(|);
  84. DEFINE_COMPLEX_OVERLOAD(&);
  85. DEFINE_COMPLEX_OVERLOAD(^);
  86. DEFINE_COMPLEX_OVERLOAD(<<);
  87. DEFINE_COMPLEX_OVERLOAD(>>);
  88. DEFINE_COMPLEX_OVERLOAD(<);
  89. DEFINE_COMPLEX_OVERLOAD(<=);
  90. DEFINE_COMPLEX_OVERLOAD(>);
  91. DEFINE_COMPLEX_OVERLOAD(>=);
  92. DEFINE_COMPLEX_OVERLOAD(==);
  93. DEFINE_COMPLEX_OVERLOAD(!=);
  94. UNARY_OPERATOR_FUNCTOR(identity, +);
  95. UNARY_OPERATOR_FUNCTOR(negate, -);
  96. BINARY_OPERATOR_FUNCTOR(plus, +);
  97. BINARY_OPERATOR_FUNCTOR(minus, -);
  98. BINARY_OPERATOR_FUNCTOR(multiplies, *);
  99. BINARY_OPERATOR_FUNCTOR(divides, /);
  100. BINARY_OPERATOR_FUNCTOR(modulus, %);
  101. BINARY_OPERATOR_FUNCTOR(logical_or, ||);
  102. BINARY_OPERATOR_FUNCTOR(logical_and, &&);
  103. UNARY_OPERATOR_FUNCTOR(logical_not, !);
  104. BINARY_OPERATOR_FUNCTOR(bitwise_or, |);
  105. BINARY_OPERATOR_FUNCTOR(bitwise_and, &);
  106. BINARY_OPERATOR_FUNCTOR(bitwise_xor, ^);
  107. UNARY_OPERATOR_FUNCTOR(bitwise_not, ~);
  108. BINARY_OPERATOR_FUNCTOR(left_shift, <<);
  109. BINARY_OPERATOR_FUNCTOR(right_shift, >>);
  110. BINARY_OPERATOR_FUNCTOR(less, <);
  111. BINARY_OPERATOR_FUNCTOR(less_equal, <=);
  112. BINARY_OPERATOR_FUNCTOR(greater, >);
  113. BINARY_OPERATOR_FUNCTOR(greater_equal, >=);
  114. BINARY_OPERATOR_FUNCTOR(equal_to, ==);
  115. BINARY_OPERATOR_FUNCTOR(not_equal_to, !=);
  116. struct conditional_ternary
  117. {
  118. template <class B>
  119. using get_batch_bool = typename xt_simd::simd_traits<typename xt_simd::revert_simd_traits<B>::type>::bool_type;
  120. template <class B, class A1, class A2>
  121. constexpr auto operator()(const B& cond, const A1& v1, const A2& v2) const noexcept
  122. {
  123. return xtl::select(cond, v1, v2);
  124. }
  125. template <class B>
  126. constexpr B simd_apply(const get_batch_bool<B>& t1, const B& t2, const B& t3) const noexcept
  127. {
  128. return xt_simd::select(t1, t2, t3);
  129. }
  130. };
  131. template <class R>
  132. struct cast
  133. {
  134. struct functor
  135. {
  136. using result_type = R;
  137. template <class A1>
  138. constexpr result_type operator()(const A1& arg) const
  139. {
  140. return static_cast<R>(arg);
  141. }
  142. // SIMD conversion disabled for now since it does not make sense
  143. // in most of the cases
  144. /*constexpr simd_result_type simd_apply(const simd_value_type& arg) const
  145. {
  146. return static_cast<R>(arg);
  147. }*/
  148. };
  149. };
  150. template <class Tag, class F, class... E>
  151. struct select_xfunction_expression;
  152. template <class F, class... E>
  153. struct select_xfunction_expression<xtensor_expression_tag, F, E...>
  154. {
  155. using type = xfunction<F, E...>;
  156. };
  157. template <class F, class... E>
  158. struct select_xfunction_expression<xoptional_expression_tag, F, E...>
  159. {
  160. using type = xfunction<F, E...>;
  161. };
  162. template <class Tag, class F, class... E>
  163. using select_xfunction_expression_t = typename select_xfunction_expression<Tag, F, E...>::type;
  164. template <class F, class... E>
  165. struct xfunction_type
  166. {
  167. using expression_tag = xexpression_tag_t<E...>;
  168. using functor_type = F;
  169. using type = select_xfunction_expression_t<expression_tag, functor_type, const_xclosure_t<E>...>;
  170. };
  171. template <class F, class... E>
  172. inline auto make_xfunction(E&&... e) noexcept
  173. {
  174. using function_type = xfunction_type<F, E...>;
  175. using functor_type = typename function_type::functor_type;
  176. using type = typename function_type::type;
  177. return type(functor_type(), std::forward<E>(e)...);
  178. }
  179. // On MSVC, the second argument of enable_if_t is always evaluated, even if the condition is false.
  180. // Wrapping the xfunction type in the xfunction_type metafunction avoids this evaluation when
  181. // the condition is false, since it leads to a tricky bug preventing from using operator+ and
  182. // operator- on vector and arrays iterators.
  183. template <class F, class... E>
  184. using xfunction_type_t = typename std::
  185. enable_if_t<has_xexpression<std::decay_t<E>...>::value, xfunction_type<F, E...>>::type;
  186. }
  187. #undef UNARY_OPERATOR_FUNCTOR
  188. #undef BINARY_OPERATOR_FUNCTOR
  189. /*************
  190. * operators *
  191. *************/
  192. /**
  193. * @defgroup arithmetic_operators Arithmetic operators
  194. */
  195. /**
  196. * @ingroup arithmetic_operators
  197. * @brief Identity
  198. *
  199. * Returns an \ref xfunction for the element-wise identity
  200. * of \a e.
  201. * @param e an \ref xexpression
  202. * @return an \ref xfunction
  203. */
  204. template <class E>
  205. inline auto operator+(E&& e) noexcept -> detail::xfunction_type_t<detail::identity, E>
  206. {
  207. return detail::make_xfunction<detail::identity>(std::forward<E>(e));
  208. }
  209. /**
  210. * @ingroup arithmetic_operators
  211. * @brief Opposite
  212. *
  213. * Returns an \ref xfunction for the element-wise opposite
  214. * of \a e.
  215. * @param e an \ref xexpression
  216. * @return an \ref xfunction
  217. */
  218. template <class E>
  219. inline auto operator-(E&& e) noexcept -> detail::xfunction_type_t<detail::negate, E>
  220. {
  221. return detail::make_xfunction<detail::negate>(std::forward<E>(e));
  222. }
  223. /**
  224. * @ingroup arithmetic_operators
  225. * @brief Addition
  226. *
  227. * Returns an \ref xfunction for the element-wise addition
  228. * of \a e1 and \a e2.
  229. * @param e1 an \ref xexpression or a scalar
  230. * @param e2 an \ref xexpression or a scalar
  231. * @return an \ref xfunction
  232. */
  233. template <class E1, class E2>
  234. inline auto operator+(E1&& e1, E2&& e2) noexcept -> detail::xfunction_type_t<detail::plus, E1, E2>
  235. {
  236. return detail::make_xfunction<detail::plus>(std::forward<E1>(e1), std::forward<E2>(e2));
  237. }
  238. /**
  239. * @ingroup arithmetic_operators
  240. * @brief Substraction
  241. *
  242. * Returns an \ref xfunction for the element-wise substraction
  243. * of \a e2 to \a e1.
  244. * @param e1 an \ref xexpression or a scalar
  245. * @param e2 an \ref xexpression or a scalar
  246. * @return an \ref xfunction
  247. */
  248. template <class E1, class E2>
  249. inline auto operator-(E1&& e1, E2&& e2) noexcept -> detail::xfunction_type_t<detail::minus, E1, E2>
  250. {
  251. return detail::make_xfunction<detail::minus>(std::forward<E1>(e1), std::forward<E2>(e2));
  252. }
  253. /**
  254. * @ingroup arithmetic_operators
  255. * @brief Multiplication
  256. *
  257. * Returns an \ref xfunction for the element-wise multiplication
  258. * of \a e1 by \a e2.
  259. * @param e1 an \ref xexpression or a scalar
  260. * @param e2 an \ref xexpression or a scalar
  261. * @return an \ref xfunction
  262. */
  263. template <class E1, class E2>
  264. inline auto operator*(E1&& e1, E2&& e2) noexcept -> detail::xfunction_type_t<detail::multiplies, E1, E2>
  265. {
  266. return detail::make_xfunction<detail::multiplies>(std::forward<E1>(e1), std::forward<E2>(e2));
  267. }
  268. /**
  269. * @ingroup arithmetic_operators
  270. * @brief Division
  271. *
  272. * Returns an \ref xfunction for the element-wise division
  273. * of \a e1 by \a e2.
  274. * @param e1 an \ref xexpression or a scalar
  275. * @param e2 an \ref xexpression or a scalar
  276. * @return an \ref xfunction
  277. */
  278. template <class E1, class E2>
  279. inline auto operator/(E1&& e1, E2&& e2) noexcept -> detail::xfunction_type_t<detail::divides, E1, E2>
  280. {
  281. return detail::make_xfunction<detail::divides>(std::forward<E1>(e1), std::forward<E2>(e2));
  282. }
  283. /**
  284. * @ingroup arithmetic_operators
  285. * @brief Modulus
  286. *
  287. * Returns an \ref xfunction for the element-wise modulus
  288. * of \a e1 by \a e2.
  289. * @param e1 an \ref xexpression or a scalar
  290. * @param e2 an \ref xexpression or a scalar
  291. * @return an \ref xfunction
  292. */
  293. template <class E1, class E2>
  294. inline auto operator%(E1&& e1, E2&& e2) noexcept -> detail::xfunction_type_t<detail::modulus, E1, E2>
  295. {
  296. return detail::make_xfunction<detail::modulus>(std::forward<E1>(e1), std::forward<E2>(e2));
  297. }
  298. /**
  299. * @defgroup logical_operators Logical operators
  300. */
  301. /**
  302. * @ingroup logical_operators
  303. * @brief Or
  304. *
  305. * Returns an \ref xfunction for the element-wise or
  306. * of \a e1 and \a e2.
  307. * @param e1 an \ref xexpression or a scalar
  308. * @param e2 an \ref xexpression or a scalar
  309. * @return an \ref xfunction
  310. */
  311. template <class E1, class E2>
  312. inline auto operator||(E1&& e1, E2&& e2) noexcept -> detail::xfunction_type_t<detail::logical_or, E1, E2>
  313. {
  314. return detail::make_xfunction<detail::logical_or>(std::forward<E1>(e1), std::forward<E2>(e2));
  315. }
  316. /**
  317. * @ingroup logical_operators
  318. * @brief And
  319. *
  320. * Returns an \ref xfunction for the element-wise and
  321. * of \a e1 and \a e2.
  322. * @param e1 an \ref xexpression or a scalar
  323. * @param e2 an \ref xexpression or a scalar
  324. * @return an \ref xfunction
  325. */
  326. template <class E1, class E2>
  327. inline auto operator&&(E1&& e1, E2&& e2) noexcept -> detail::xfunction_type_t<detail::logical_and, E1, E2>
  328. {
  329. return detail::make_xfunction<detail::logical_and>(std::forward<E1>(e1), std::forward<E2>(e2));
  330. }
  331. /**
  332. * @ingroup logical_operators
  333. * @brief Not
  334. *
  335. * Returns an \ref xfunction for the element-wise not
  336. * of \a e.
  337. * @param e an \ref xexpression
  338. * @return an \ref xfunction
  339. */
  340. template <class E>
  341. inline auto operator!(E&& e) noexcept -> detail::xfunction_type_t<detail::logical_not, E>
  342. {
  343. return detail::make_xfunction<detail::logical_not>(std::forward<E>(e));
  344. }
  345. /**
  346. * @defgroup bitwise_operators Bitwise operators
  347. */
  348. /**
  349. * @ingroup bitwise_operators
  350. * @brief Bitwise and
  351. *
  352. * Returns an \ref xfunction for the element-wise bitwise and
  353. * of \a e1 and \a e2.
  354. * @param e1 an \ref xexpression or a scalar
  355. * @param e2 an \ref xexpression or a scalar
  356. * @return an \ref xfunction
  357. */
  358. template <class E1, class E2>
  359. inline auto operator&(E1&& e1, E2&& e2) noexcept -> detail::xfunction_type_t<detail::bitwise_and, E1, E2>
  360. {
  361. return detail::make_xfunction<detail::bitwise_and>(std::forward<E1>(e1), std::forward<E2>(e2));
  362. }
  363. /**
  364. * @ingroup bitwise_operators
  365. * @brief Bitwise or
  366. *
  367. * Returns an \ref xfunction for the element-wise bitwise or
  368. * of \a e1 and \a e2.
  369. * @param e1 an \ref xexpression or a scalar
  370. * @param e2 an \ref xexpression or a scalar
  371. * @return an \ref xfunction
  372. */
  373. template <class E1, class E2>
  374. inline auto operator|(E1&& e1, E2&& e2) noexcept -> detail::xfunction_type_t<detail::bitwise_or, E1, E2>
  375. {
  376. return detail::make_xfunction<detail::bitwise_or>(std::forward<E1>(e1), std::forward<E2>(e2));
  377. }
  378. /**
  379. * @ingroup bitwise_operators
  380. * @brief Bitwise xor
  381. *
  382. * Returns an \ref xfunction for the element-wise bitwise xor
  383. * of \a e1 and \a e2.
  384. * @param e1 an \ref xexpression or a scalar
  385. * @param e2 an \ref xexpression or a scalar
  386. * @return an \ref xfunction
  387. */
  388. template <class E1, class E2>
  389. inline auto operator^(E1&& e1, E2&& e2) noexcept -> detail::xfunction_type_t<detail::bitwise_xor, E1, E2>
  390. {
  391. return detail::make_xfunction<detail::bitwise_xor>(std::forward<E1>(e1), std::forward<E2>(e2));
  392. }
  393. /**
  394. * @ingroup bitwise_operators
  395. * @brief Bitwise not
  396. *
  397. * Returns an \ref xfunction for the element-wise bitwise not
  398. * of \a e.
  399. * @param e an \ref xexpression
  400. * @return an \ref xfunction
  401. */
  402. template <class E>
  403. inline auto operator~(E&& e) noexcept -> detail::xfunction_type_t<detail::bitwise_not, E>
  404. {
  405. return detail::make_xfunction<detail::bitwise_not>(std::forward<E>(e));
  406. }
  407. /**
  408. * @ingroup bitwise_operators
  409. * @brief Bitwise left shift
  410. *
  411. * Returns an \ref xfunction for the element-wise bitwise left shift of e1
  412. * by e2.
  413. * @param e1 an \ref xexpression
  414. * @param e2 an \ref xexpression
  415. * @return an \ref xfunction
  416. */
  417. template <class E1, class E2>
  418. inline auto left_shift(E1&& e1, E2&& e2) noexcept -> detail::xfunction_type_t<detail::left_shift, E1, E2>
  419. {
  420. return detail::make_xfunction<detail::left_shift>(std::forward<E1>(e1), std::forward<E2>(e2));
  421. }
  422. /**
  423. * @ingroup bitwise_operators
  424. * @brief Bitwise left shift
  425. *
  426. * Returns an \ref xfunction for the element-wise bitwise left shift of e1
  427. * by e2.
  428. * @param e1 an \ref xexpression
  429. * @param e2 an \ref xexpression
  430. * @return an \ref xfunction
  431. */
  432. template <class E1, class E2>
  433. inline auto right_shift(E1&& e1, E2&& e2) noexcept -> detail::xfunction_type_t<detail::right_shift, E1, E2>
  434. {
  435. return detail::make_xfunction<detail::right_shift>(std::forward<E1>(e1), std::forward<E2>(e2));
  436. }
  437. namespace detail
  438. {
  439. // Shift operator is not available for all the types, so the xfunction type instantiation
  440. // has to be delayed, enable_if_t is not sufficient
  441. template <class F, class E1, class E2>
  442. struct shift_function_getter
  443. {
  444. using type = xfunction_type_t<F, E1, E2>;
  445. };
  446. template <bool B, class T>
  447. struct eval_enable_if
  448. {
  449. using type = typename T::type;
  450. };
  451. template <class T>
  452. struct eval_enable_if<false, T>
  453. {
  454. };
  455. template <bool B, class T>
  456. using eval_enable_if_t = typename eval_enable_if<B, T>::type;
  457. template <class F, class E1, class E2>
  458. using shift_return_type_t = eval_enable_if_t<
  459. is_xexpression<std::decay_t<E1>>::value,
  460. shift_function_getter<F, E1, E2>>;
  461. }
  462. /**
  463. * @ingroup bitwise_operators
  464. * @brief Bitwise left shift
  465. *
  466. * Returns an \ref xfunction for the element-wise bitwise left shift of e1
  467. * by e2.
  468. * @param e1 an \ref xexpression
  469. * @param e2 an \ref xexpression
  470. * @return an \ref xfunction
  471. * @sa left_shift
  472. */
  473. template <class E1, class E2>
  474. inline auto operator<<(E1&& e1, E2&& e2) noexcept
  475. -> detail::shift_return_type_t<detail::left_shift, E1, E2>
  476. {
  477. return left_shift(std::forward<E1>(e1), std::forward<E2>(e2));
  478. }
  479. /**
  480. * @ingroup bitwise_operators
  481. * @brief Bitwise right shift
  482. *
  483. * Returns an \ref xfunction for the element-wise bitwise right shift of e1
  484. * by e2.
  485. * @param e1 an \ref xexpression
  486. * @param e2 an \ref xexpression
  487. * @return an \ref xfunction
  488. * @sa right_shift
  489. */
  490. template <class E1, class E2>
  491. inline auto operator>>(E1&& e1, E2&& e2) -> detail::shift_return_type_t<detail::right_shift, E1, E2>
  492. {
  493. return right_shift(std::forward<E1>(e1), std::forward<E2>(e2));
  494. }
  495. /**
  496. * @defgroup comparison_operators Comparison operators
  497. */
  498. /**
  499. * @ingroup comparison_operators
  500. * @brief Lesser than
  501. *
  502. * Returns an \ref xfunction for the element-wise
  503. * lesser than comparison of \a e1 and \a e2.
  504. * @param e1 an \ref xexpression or a scalar
  505. * @param e2 an \ref xexpression or a scalar
  506. * @return an \ref xfunction
  507. */
  508. template <class E1, class E2>
  509. inline auto operator<(E1&& e1, E2&& e2) noexcept -> detail::xfunction_type_t<detail::less, E1, E2>
  510. {
  511. return detail::make_xfunction<detail::less>(std::forward<E1>(e1), std::forward<E2>(e2));
  512. }
  513. /**
  514. * @ingroup comparison_operators
  515. * @brief Lesser or equal
  516. *
  517. * Returns an \ref xfunction for the element-wise
  518. * lesser or equal comparison of \a e1 and \a e2.
  519. * @param e1 an \ref xexpression or a scalar
  520. * @param e2 an \ref xexpression or a scalar
  521. * @return an \ref xfunction
  522. */
  523. template <class E1, class E2>
  524. inline auto operator<=(E1&& e1, E2&& e2) noexcept -> detail::xfunction_type_t<detail::less_equal, E1, E2>
  525. {
  526. return detail::make_xfunction<detail::less_equal>(std::forward<E1>(e1), std::forward<E2>(e2));
  527. }
  528. /**
  529. * @ingroup comparison_operators
  530. * @brief Greater than
  531. *
  532. * Returns an \ref xfunction for the element-wise
  533. * greater than comparison of \a e1 and \a e2.
  534. * @param e1 an \ref xexpression or a scalar
  535. * @param e2 an \ref xexpression or a scalar
  536. * @return an \ref xfunction
  537. */
  538. template <class E1, class E2>
  539. inline auto operator>(E1&& e1, E2&& e2) noexcept -> detail::xfunction_type_t<detail::greater, E1, E2>
  540. {
  541. return detail::make_xfunction<detail::greater>(std::forward<E1>(e1), std::forward<E2>(e2));
  542. }
  543. /**
  544. * @ingroup comparison_operators
  545. * @brief Greater or equal
  546. *
  547. * Returns an \ref xfunction for the element-wise
  548. * greater or equal comparison of \a e1 and \a e2.
  549. * @param e1 an \ref xexpression or a scalar
  550. * @param e2 an \ref xexpression or a scalar
  551. * @return an \ref xfunction
  552. */
  553. template <class E1, class E2>
  554. inline auto operator>=(E1&& e1, E2&& e2) noexcept
  555. -> detail::xfunction_type_t<detail::greater_equal, E1, E2>
  556. {
  557. return detail::make_xfunction<detail::greater_equal>(std::forward<E1>(e1), std::forward<E2>(e2));
  558. }
  559. /**
  560. * @ingroup comparison_operators
  561. * @brief Equality
  562. *
  563. * Returns true if \a e1 and \a e2 have the same shape
  564. * and hold the same values. Unlike other comparison
  565. * operators, this does not return an \ref xfunction.
  566. * @param e1 an \ref xexpression or a scalar
  567. * @param e2 an \ref xexpression or a scalar
  568. * @return a boolean
  569. */
  570. template <class E1, class E2>
  571. inline std::enable_if_t<xoptional_comparable<E1, E2>::value, bool>
  572. operator==(const xexpression<E1>& e1, const xexpression<E2>& e2)
  573. {
  574. const E1& de1 = e1.derived_cast();
  575. const E2& de2 = e2.derived_cast();
  576. bool res = de1.dimension() == de2.dimension()
  577. && std::equal(de1.shape().begin(), de1.shape().end(), de2.shape().begin());
  578. auto iter1 = de1.begin();
  579. auto iter2 = de2.begin();
  580. auto iter_end = de1.end();
  581. while (res && iter1 != iter_end)
  582. {
  583. res = (*iter1++ == *iter2++);
  584. }
  585. return res;
  586. }
  587. /**
  588. * @ingroup comparison_operators
  589. * @brief Inequality
  590. *
  591. * Returns true if \a e1 and \a e2 have different shapes
  592. * or hold the different values. Unlike other comparison
  593. * operators, this does not return an \ref xfunction.
  594. * @param e1 an \ref xexpression or a scalar
  595. * @param e2 an \ref xexpression or a scalar
  596. * @return a boolean
  597. */
  598. template <class E1, class E2>
  599. inline bool operator!=(const xexpression<E1>& e1, const xexpression<E2>& e2)
  600. {
  601. return !(e1 == e2);
  602. }
  603. /**
  604. * @ingroup comparison_operators
  605. * @brief Element-wise equality
  606. *
  607. * Returns an \ref xfunction for the element-wise
  608. * equality of \a e1 and \a e2.
  609. * @param e1 an \ref xexpression or a scalar
  610. * @param e2 an \ref xexpression or a scalar
  611. * @return an \ref xfunction
  612. */
  613. template <class E1, class E2>
  614. inline auto equal(E1&& e1, E2&& e2) noexcept -> detail::xfunction_type_t<detail::equal_to, E1, E2>
  615. {
  616. return detail::make_xfunction<detail::equal_to>(std::forward<E1>(e1), std::forward<E2>(e2));
  617. }
  618. /**
  619. * @ingroup comparison_operators
  620. * @brief Element-wise inequality
  621. *
  622. * Returns an \ref xfunction for the element-wise
  623. * inequality of \a e1 and \a e2.
  624. * @param e1 an \ref xexpression or a scalar
  625. * @param e2 an \ref xexpression or a scalar
  626. * @return an \ref xfunction
  627. */
  628. template <class E1, class E2>
  629. inline auto not_equal(E1&& e1, E2&& e2) noexcept -> detail::xfunction_type_t<detail::not_equal_to, E1, E2>
  630. {
  631. return detail::make_xfunction<detail::not_equal_to>(std::forward<E1>(e1), std::forward<E2>(e2));
  632. }
  633. /**
  634. * @ingroup comparison_operators
  635. * @brief Lesser than
  636. *
  637. * Returns an \ref xfunction for the element-wise
  638. * lesser than comparison of \a e1 and \a e2. This
  639. * function is equivalent to operator<(E1&&, E2&&).
  640. * @param e1 an \ref xexpression or a scalar
  641. * @param e2 an \ref xexpression or a scalar
  642. * @return an \ref xfunction
  643. */
  644. template <class E1, class E2>
  645. inline auto less(E1&& e1, E2&& e2) noexcept -> decltype(std::forward<E1>(e1) < std::forward<E2>(e2))
  646. {
  647. return std::forward<E1>(e1) < std::forward<E2>(e2);
  648. }
  649. /**
  650. * @ingroup comparison_operators
  651. * @brief Lesser or equal
  652. *
  653. * Returns an \ref xfunction for the element-wise
  654. * lesser or equal comparison of \a e1 and \a e2. This
  655. * function is equivalent to operator<=(E1&&, E2&&).
  656. * @param e1 an \ref xexpression or a scalar
  657. * @param e2 an \ref xexpression or a scalar
  658. * @return an \ref xfunction
  659. */
  660. template <class E1, class E2>
  661. inline auto less_equal(E1&& e1, E2&& e2) noexcept -> decltype(std::forward<E1>(e1) <= std::forward<E2>(e2))
  662. {
  663. return std::forward<E1>(e1) <= std::forward<E2>(e2);
  664. }
  665. /**
  666. * @ingroup comparison_operators
  667. * @brief Greater than
  668. *
  669. * Returns an \ref xfunction for the element-wise
  670. * greater than comparison of \a e1 and \a e2. This
  671. * function is equivalent to operator>(E1&&, E2&&).
  672. * @param e1 an \ref xexpression or a scalar
  673. * @param e2 an \ref xexpression or a scalar
  674. * @return an \ref xfunction
  675. */
  676. template <class E1, class E2>
  677. inline auto greater(E1&& e1, E2&& e2) noexcept -> decltype(std::forward<E1>(e1) > std::forward<E2>(e2))
  678. {
  679. return std::forward<E1>(e1) > std::forward<E2>(e2);
  680. }
  681. /**
  682. * @ingroup comparison_operators
  683. * @brief Greater or equal
  684. *
  685. * Returns an \ref xfunction for the element-wise
  686. * greater or equal comparison of \a e1 and \a e2.
  687. * This function is equivalent to operator>=(E1&&, E2&&).
  688. * @param e1 an \ref xexpression or a scalar
  689. * @param e2 an \ref xexpression or a scalar
  690. * @return an \ref xfunction
  691. */
  692. template <class E1, class E2>
  693. inline auto greater_equal(E1&& e1, E2&& e2) noexcept
  694. -> decltype(std::forward<E1>(e1) >= std::forward<E2>(e2))
  695. {
  696. return std::forward<E1>(e1) >= std::forward<E2>(e2);
  697. }
  698. /**
  699. * @ingroup logical_operators
  700. * @brief Ternary selection
  701. *
  702. * Returns an \ref xfunction for the element-wise
  703. * ternary selection (i.e. operator ? :) of \a e1,
  704. * \a e2 and \a e3.
  705. * @param e1 a boolean \ref xexpression
  706. * @param e2 an \ref xexpression or a scalar
  707. * @param e3 an \ref xexpression or a scalar
  708. * @return an \ref xfunction
  709. */
  710. template <class E1, class E2, class E3>
  711. inline auto where(E1&& e1, E2&& e2, E3&& e3) noexcept
  712. -> detail::xfunction_type_t<detail::conditional_ternary, E1, E2, E3>
  713. {
  714. return detail::make_xfunction<detail::conditional_ternary>(
  715. std::forward<E1>(e1),
  716. std::forward<E2>(e2),
  717. std::forward<E3>(e3)
  718. );
  719. }
  720. namespace detail
  721. {
  722. template <layout_type L>
  723. struct next_idx_impl;
  724. template <>
  725. struct next_idx_impl<layout_type::row_major>
  726. {
  727. template <class S, class I>
  728. inline auto operator()(const S& shape, I& idx)
  729. {
  730. for (std::size_t j = shape.size(); j > 0; --j)
  731. {
  732. std::size_t i = j - 1;
  733. if (idx[i] >= shape[i] - 1)
  734. {
  735. idx[i] = 0;
  736. }
  737. else
  738. {
  739. idx[i]++;
  740. return idx;
  741. }
  742. }
  743. // return empty index, happens at last iteration step, but remains unused
  744. return I();
  745. }
  746. };
  747. template <>
  748. struct next_idx_impl<layout_type::column_major>
  749. {
  750. template <class S, class I>
  751. inline auto operator()(const S& shape, I& idx)
  752. {
  753. for (std::size_t i = 0; i < shape.size(); ++i)
  754. {
  755. if (idx[i] >= shape[i] - 1)
  756. {
  757. idx[i] = 0;
  758. }
  759. else
  760. {
  761. idx[i]++;
  762. return idx;
  763. }
  764. }
  765. // return empty index, happens at last iteration step, but remains unused
  766. return I();
  767. }
  768. };
  769. template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class S, class I>
  770. inline auto next_idx(const S& shape, I& idx)
  771. {
  772. next_idx_impl<L> nii;
  773. return nii(shape, idx);
  774. }
  775. }
  776. /**
  777. * @ingroup logical_operators
  778. * @brief return vector of indices where T is not zero
  779. *
  780. * @param arr input array
  781. * @return vector of vectors, one for each dimension of arr, containing
  782. * the indices of the non-zero elements in that dimension
  783. */
  784. template <class T>
  785. inline auto nonzero(const T& arr)
  786. {
  787. auto shape = arr.shape();
  788. using index_type = xindex_type_t<typename T::shape_type>;
  789. using size_type = typename T::size_type;
  790. auto idx = xtl::make_sequence<index_type>(arr.dimension(), 0);
  791. std::vector<std::vector<size_type>> indices(arr.dimension());
  792. size_type total_size = compute_size(shape);
  793. for (size_type i = 0; i < total_size; i++, detail::next_idx(shape, idx))
  794. {
  795. if (arr.element(std::begin(idx), std::end(idx)))
  796. {
  797. for (std::size_t n = 0; n < indices.size(); ++n)
  798. {
  799. indices.at(n).push_back(idx[n]);
  800. }
  801. }
  802. }
  803. return indices;
  804. }
  805. /**
  806. * @ingroup logical_operators
  807. * @brief return vector of indices where condition is true
  808. * (equivalent to \a nonzero(condition))
  809. *
  810. * @param condition input array
  811. * @return vector of \a index_types where condition is not equal to zero
  812. */
  813. template <class T>
  814. inline auto where(const T& condition)
  815. {
  816. return nonzero(condition);
  817. }
  818. /**
  819. * @ingroup logical_operators
  820. * @brief return vector of indices where arr is not zero
  821. *
  822. * @tparam L the traversal order
  823. * @param arr input array
  824. * @return vector of index_types where arr is not equal to zero (use `xt::from_indices` to convert)
  825. *
  826. * @sa xt::from_indices
  827. */
  828. template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class T>
  829. inline auto argwhere(const T& arr)
  830. {
  831. auto shape = arr.shape();
  832. using index_type = xindex_type_t<typename T::shape_type>;
  833. using size_type = typename T::size_type;
  834. auto idx = xtl::make_sequence<index_type>(arr.dimension(), 0);
  835. std::vector<index_type> indices;
  836. size_type total_size = compute_size(shape);
  837. for (size_type i = 0; i < total_size; i++, detail::next_idx<L>(shape, idx))
  838. {
  839. if (arr.element(std::begin(idx), std::end(idx)))
  840. {
  841. indices.push_back(idx);
  842. }
  843. }
  844. return indices;
  845. }
  846. /**
  847. * @ingroup logical_operators
  848. * @brief Any
  849. *
  850. * Returns true if any of the values of \a e is truthy,
  851. * false otherwise.
  852. * @param e an \ref xexpression
  853. * @return a boolean
  854. */
  855. template <class E>
  856. inline bool any(E&& e)
  857. {
  858. using xtype = std::decay_t<E>;
  859. using value_type = typename xtype::value_type;
  860. return std::any_of(
  861. e.cbegin(),
  862. e.cend(),
  863. [](const value_type& el)
  864. {
  865. return el;
  866. }
  867. );
  868. }
  869. /**
  870. * @ingroup logical_operators
  871. * @brief Any
  872. *
  873. * Returns true if all of the values of \a e are truthy,
  874. * false otherwise.
  875. * @param e an \ref xexpression
  876. * @return a boolean
  877. */
  878. template <class E>
  879. inline bool all(E&& e)
  880. {
  881. using xtype = std::decay_t<E>;
  882. using value_type = typename xtype::value_type;
  883. return std::all_of(
  884. e.cbegin(),
  885. e.cend(),
  886. [](const value_type& el)
  887. {
  888. return el;
  889. }
  890. );
  891. }
  892. /**
  893. * @defgroup casting_operators Casting operators
  894. */
  895. /**
  896. * @ingroup casting_operators
  897. * @brief Element-wise ``static_cast``.
  898. *
  899. * Returns an \ref xfunction for the element-wise
  900. * static_cast of \a e to type R.
  901. *
  902. * @param e an \ref xexpression or a scalar
  903. * @return an \ref xfunction
  904. */
  905. template <class R, class E>
  906. inline auto cast(E&& e) noexcept -> detail::xfunction_type_t<typename detail::cast<R>::functor, E>
  907. {
  908. return detail::make_xfunction<typename detail::cast<R>::functor>(std::forward<E>(e));
  909. }
  910. }
  911. #endif