xblockwise_reducer_functors.hpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. #ifndef XTENSOR_XBLOCKWISE_REDUCER_FUNCTORS_HPP
  2. #define XTENSOR_XBLOCKWISE_REDUCER_FUNCTORS_HPP
  3. #include <sstream>
  4. #include <string>
  5. #include <tuple>
  6. #include <typeinfo>
  7. #include "xarray.hpp"
  8. #include "xbuilder.hpp"
  9. #include "xchunked_array.hpp"
  10. #include "xchunked_assign.hpp"
  11. #include "xchunked_view.hpp"
  12. #include "xexpression.hpp"
  13. #include "xmath.hpp"
  14. #include "xnorm.hpp"
  15. #include "xreducer.hpp"
  16. #include "xtl/xclosure.hpp"
  17. #include "xtl/xsequence.hpp"
  18. #include "xutils.hpp"
  19. namespace xt
  20. {
  21. namespace detail
  22. {
  23. namespace blockwise
  24. {
  25. struct empty_reduction_variable
  26. {
  27. };
  28. struct simple_functor_base
  29. {
  30. template <class E>
  31. auto reduction_variable(const E&) const
  32. {
  33. return empty_reduction_variable();
  34. }
  35. template <class MR, class E, class R>
  36. void finalize(const MR&, E&, const R&) const
  37. {
  38. }
  39. };
  40. template <class T_E, class T_I = void>
  41. struct sum_functor : public simple_functor_base
  42. {
  43. using value_type = typename std::decay_t<decltype(xt::sum<T_I>(std::declval<xarray<T_E>>()))>::value_type;
  44. template <class E, class A, class O>
  45. auto compute(const E& input, const A& axes, const O& options) const
  46. {
  47. return xt::sum<value_type>(input, axes, options);
  48. }
  49. template <class BR, class E, class MR>
  50. auto merge(const BR& block_result, bool first, E& result, MR&) const
  51. {
  52. if (first)
  53. {
  54. xt::noalias(result) = block_result;
  55. }
  56. else
  57. {
  58. xt::noalias(result) += block_result;
  59. }
  60. }
  61. };
  62. template <class T_E, class T_I = void>
  63. struct prod_functor : public simple_functor_base
  64. {
  65. using value_type = typename std::decay_t<decltype(xt::sum<T_I>(std::declval<xarray<T_E>>()))>::value_type;
  66. template <class E, class A, class O>
  67. auto compute(const E& input, const A& axes, const O& options) const
  68. {
  69. return xt::prod<value_type>(input, axes, options);
  70. }
  71. template <class BR, class E, class MR>
  72. auto merge(const BR& block_result, bool first, E& result, MR&) const
  73. {
  74. if (first)
  75. {
  76. xt::noalias(result) = block_result;
  77. }
  78. else
  79. {
  80. xt::noalias(result) *= block_result;
  81. }
  82. }
  83. };
  84. template <class T_E, class T_I = void>
  85. struct amin_functor : public simple_functor_base
  86. {
  87. using value_type = typename std::decay_t<decltype(xt::amin<T_I>(std::declval<xarray<T_E>>()))>::value_type;
  88. template <class E, class A, class O>
  89. auto compute(const E& input, const A& axes, const O& options) const
  90. {
  91. return xt::amin(input, axes, options);
  92. }
  93. template <class BR, class E, class MR>
  94. auto merge(const BR& block_result, bool first, E& result, MR&) const
  95. {
  96. if (first)
  97. {
  98. xt::noalias(result) = block_result;
  99. }
  100. else
  101. {
  102. xt::noalias(result) = xt::minimum(block_result, result);
  103. }
  104. }
  105. };
  106. template <class T_E, class T_I = void>
  107. struct amax_functor : public simple_functor_base
  108. {
  109. using value_type = typename std::decay_t<decltype(xt::amax<T_I>(std::declval<xarray<T_E>>()))>::value_type;
  110. template <class E, class A, class O>
  111. auto compute(const E& input, const A& axes, const O& options) const
  112. {
  113. return xt::amax(input, axes, options);
  114. }
  115. template <class BR, class E, class MR>
  116. auto merge(const BR& block_result, bool first, E& result, MR&) const
  117. {
  118. if (first)
  119. {
  120. xt::noalias(result) = block_result;
  121. }
  122. else
  123. {
  124. xt::noalias(result) = xt::maximum(block_result, result);
  125. }
  126. }
  127. };
  128. template <class T_E, class T_I = void>
  129. struct mean_functor
  130. {
  131. using value_type = typename std::decay_t<decltype(xt::mean<T_I>(std::declval<xarray<T_E>>()))>::value_type;
  132. template <class E, class A, class O>
  133. auto compute(const E& input, const A& axes, const O& options) const
  134. {
  135. return xt::sum<value_type>(input, axes, options);
  136. }
  137. template <class E>
  138. auto reduction_variable(const E&) const
  139. {
  140. return empty_reduction_variable();
  141. }
  142. template <class BR, class E>
  143. auto merge(const BR& block_result, bool first, E& result, empty_reduction_variable&) const
  144. {
  145. if (first)
  146. {
  147. xt::noalias(result) = block_result;
  148. }
  149. else
  150. {
  151. xt::noalias(result) += block_result;
  152. }
  153. }
  154. template <class E, class R>
  155. void finalize(const empty_reduction_variable&, E& results, const R& reducer) const
  156. {
  157. const auto& axes = reducer.axes();
  158. std::decay_t<decltype(reducer.input_shape()[0])> factor = 1;
  159. for (auto a : axes)
  160. {
  161. factor *= reducer.input_shape()[a];
  162. }
  163. xt::noalias(results) /= static_cast<typename E::value_type>(factor);
  164. }
  165. };
  166. template <class T_E, class T_I = void>
  167. struct variance_functor
  168. {
  169. using value_type = typename std::decay_t<decltype(xt::variance<T_I>(std::declval<xarray<T_E>>())
  170. )>::value_type;
  171. template <class E, class A, class O>
  172. auto compute(const E& input, const A& axes, const O& options) const
  173. {
  174. double weight = 1.0;
  175. for (auto a : axes)
  176. {
  177. weight *= static_cast<double>(input.shape()[a]);
  178. }
  179. return std::make_tuple(
  180. xt::variance<value_type>(input, axes, options),
  181. xt::mean<value_type>(input, axes, options),
  182. weight
  183. );
  184. }
  185. template <class E>
  186. auto reduction_variable(const E&) const
  187. {
  188. return std::make_tuple(xarray<value_type>(), 0.0);
  189. }
  190. template <class BR, class E, class MR>
  191. auto merge(const BR& block_result, bool first, E& variance_a, MR& mr) const
  192. {
  193. auto& mean_a = std::get<0>(mr);
  194. auto& n_a = std::get<1>(mr);
  195. const auto& variance_b = std::get<0>(block_result);
  196. const auto& mean_b = std::get<1>(block_result);
  197. const auto& n_b = std::get<2>(block_result);
  198. if (first)
  199. {
  200. xt::noalias(variance_a) = variance_b;
  201. xt::noalias(mean_a) = mean_b;
  202. n_a += n_b;
  203. }
  204. else
  205. {
  206. auto new_mean = (n_a * mean_a + n_b * mean_b) / (n_a + n_b);
  207. auto new_variance = (n_a * variance_a + n_b * variance_b
  208. + n_a * xt::pow(mean_a - new_mean, 2)
  209. + n_b * xt::pow(mean_b - new_mean, 2))
  210. / (n_a + n_b);
  211. xt::noalias(variance_a) = new_variance;
  212. xt::noalias(mean_a) = new_mean;
  213. n_a += n_b;
  214. }
  215. }
  216. template <class MR, class E, class R>
  217. void finalize(const MR&, E&, const R&) const
  218. {
  219. }
  220. };
  221. template <class T_E, class T_I = void>
  222. struct stddev_functor : public variance_functor<T_E, T_I>
  223. {
  224. template <class MR, class E, class R>
  225. void finalize(const MR&, E& results, const R&) const
  226. {
  227. xt::noalias(results) = xt::sqrt(results);
  228. }
  229. };
  230. template <class T_E>
  231. struct norm_l0_functor : public simple_functor_base
  232. {
  233. using value_type = typename std::decay_t<decltype(xt::norm_l0(std::declval<xarray<T_E>>()))>::value_type;
  234. template <class E, class A, class O>
  235. auto compute(const E& input, const A& axes, const O& options) const
  236. {
  237. return xt::sum<value_type>(xt::not_equal(input, xt::zeros<T_E>(input.shape())), axes, options);
  238. }
  239. template <class BR, class E, class MR>
  240. auto merge(const BR& block_result, bool first, E& result, MR&) const
  241. {
  242. if (first)
  243. {
  244. xt::noalias(result) = block_result;
  245. }
  246. else
  247. {
  248. xt::noalias(result) += block_result;
  249. }
  250. }
  251. };
  252. template <class T_E>
  253. struct norm_l1_functor : public simple_functor_base
  254. {
  255. using value_type = typename std::decay_t<decltype(xt::norm_l1(std::declval<xarray<T_E>>()))>::value_type;
  256. template <class E, class A, class O>
  257. auto compute(const E& input, const A& axes, const O& options) const
  258. {
  259. return xt::sum<value_type>(xt::abs(input), axes, options);
  260. }
  261. template <class BR, class E, class MR>
  262. auto merge(const BR& block_result, bool first, E& result, MR&) const
  263. {
  264. if (first)
  265. {
  266. xt::noalias(result) = block_result;
  267. }
  268. else
  269. {
  270. xt::noalias(result) += block_result;
  271. }
  272. }
  273. };
  274. template <class T_E>
  275. struct norm_l2_functor
  276. {
  277. using value_type = typename std::decay_t<decltype(xt::norm_l2(std::declval<xarray<T_E>>()))>::value_type;
  278. template <class E, class A, class O>
  279. auto compute(const E& input, const A& axes, const O& options) const
  280. {
  281. return xt::sum<value_type>(xt::square(input), axes, options);
  282. }
  283. template <class E>
  284. auto reduction_variable(const E&) const
  285. {
  286. return empty_reduction_variable();
  287. }
  288. template <class BR, class E>
  289. auto merge(const BR& block_result, bool first, E& result, empty_reduction_variable&) const
  290. {
  291. if (first)
  292. {
  293. xt::noalias(result) = block_result;
  294. }
  295. else
  296. {
  297. xt::noalias(result) += block_result;
  298. }
  299. }
  300. template <class E, class R>
  301. void finalize(const empty_reduction_variable&, E& results, const R&) const
  302. {
  303. xt::noalias(results) = xt::sqrt(results);
  304. }
  305. };
  306. template <class T_E>
  307. struct norm_sq_functor : public simple_functor_base
  308. {
  309. using value_type = typename std::decay_t<decltype(xt::norm_sq(std::declval<xarray<T_E>>()))>::value_type;
  310. template <class E, class A, class O>
  311. auto compute(const E& input, const A& axes, const O& options) const
  312. {
  313. return xt::sum<value_type>(xt::square(input), axes, options);
  314. }
  315. template <class BR, class E, class MR>
  316. auto merge(const BR& block_result, bool first, E& result, MR&) const
  317. {
  318. if (first)
  319. {
  320. xt::noalias(result) = block_result;
  321. }
  322. else
  323. {
  324. xt::noalias(result) += block_result;
  325. }
  326. }
  327. };
  328. template <class T_E>
  329. struct norm_linf_functor : public simple_functor_base
  330. {
  331. using value_type = typename std::decay_t<decltype(xt::norm_linf(std::declval<xarray<T_E>>()))>::value_type;
  332. template <class E, class A, class O>
  333. auto compute(const E& input, const A& axes, const O& options) const
  334. {
  335. return xt::amax<value_type>(xt::abs(input), axes, options);
  336. }
  337. template <class BR, class E, class MR>
  338. auto merge(const BR& block_result, bool first, E& result, MR&) const
  339. {
  340. if (first)
  341. {
  342. xt::noalias(result) = block_result;
  343. }
  344. else
  345. {
  346. xt::noalias(result) = xt::maximum(block_result, result);
  347. }
  348. }
  349. };
  350. template <class T_E>
  351. class norm_lp_to_p_functor
  352. {
  353. public:
  354. using value_type = typename std::decay_t<
  355. decltype(xt::norm_lp_to_p(std::declval<xarray<T_E>>(), 1.0))>::value_type;
  356. norm_lp_to_p_functor(double p)
  357. : m_p(p)
  358. {
  359. }
  360. template <class E, class A, class O>
  361. auto compute(const E& input, const A& axes, const O& options) const
  362. {
  363. return xt::sum<value_type>(xt::pow(input, m_p), axes, options);
  364. }
  365. template <class E>
  366. auto reduction_variable(const E&) const
  367. {
  368. return empty_reduction_variable();
  369. }
  370. template <class BR, class E>
  371. auto merge(const BR& block_result, bool first, E& result, empty_reduction_variable&) const
  372. {
  373. if (first)
  374. {
  375. xt::noalias(result) = block_result;
  376. }
  377. else
  378. {
  379. xt::noalias(result) += block_result;
  380. }
  381. }
  382. template <class E, class R>
  383. void finalize(const empty_reduction_variable&, E&, const R&) const
  384. {
  385. }
  386. private:
  387. double m_p;
  388. };
  389. template <class T_E>
  390. class norm_lp_functor
  391. {
  392. public:
  393. norm_lp_functor(double p)
  394. : m_p(p)
  395. {
  396. }
  397. using value_type = typename std::decay_t<decltype(xt::norm_lp(std::declval<xarray<T_E>>(), 1.0)
  398. )>::value_type;
  399. template <class E, class A, class O>
  400. auto compute(const E& input, const A& axes, const O& options) const
  401. {
  402. return xt::sum<value_type>(xt::pow(input, m_p), axes, options);
  403. }
  404. template <class E>
  405. auto reduction_variable(const E&) const
  406. {
  407. return empty_reduction_variable();
  408. }
  409. template <class BR, class E>
  410. auto merge(const BR& block_result, bool first, E& result, empty_reduction_variable&) const
  411. {
  412. if (first)
  413. {
  414. xt::noalias(result) = block_result;
  415. }
  416. else
  417. {
  418. xt::noalias(result) += block_result;
  419. }
  420. }
  421. template <class E, class R>
  422. void finalize(const empty_reduction_variable&, E& results, const R&) const
  423. {
  424. results = xt::pow(results, 1.0 / m_p);
  425. }
  426. private:
  427. double m_p;
  428. };
  429. }
  430. }
  431. }
  432. #endif