xreducer.hpp 63 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903
  1. /***************************************************************************
  2. * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
  3. * Copyright (c) QuantStack *
  4. * *
  5. * Distributed under the terms of the BSD 3-Clause License. *
  6. * *
  7. * The full license is in the file LICENSE, distributed with this software. *
  8. ****************************************************************************/
  9. #ifndef XTENSOR_REDUCER_HPP
  10. #define XTENSOR_REDUCER_HPP
  11. #include <algorithm>
  12. #include <cstddef>
  13. #include <initializer_list>
  14. #include <iterator>
  15. #include <stdexcept>
  16. #include <tuple>
  17. #include <type_traits>
  18. #include <utility>
  19. #include <xtl/xfunctional.hpp>
  20. #include <xtl/xsequence.hpp>
  21. #include "xaccessible.hpp"
  22. #include "xbuilder.hpp"
  23. #include "xeval.hpp"
  24. #include "xexpression.hpp"
  25. #include "xgenerator.hpp"
  26. #include "xiterable.hpp"
  27. #include "xtensor_config.hpp"
  28. #include "xutils.hpp"
  29. namespace xt
  30. {
  31. template <template <class...> class A, class... AX, class X, XTL_REQUIRES(is_evaluation_strategy<AX>..., is_evaluation_strategy<X>)>
  32. auto operator|(const A<AX...>& args, const A<X>& rhs)
  33. {
  34. return std::tuple_cat(args, rhs);
  35. }
  36. struct keep_dims_type : xt::detail::option_base
  37. {
  38. };
  39. constexpr auto keep_dims = std::tuple<keep_dims_type>{};
  40. template <class T = double>
  41. struct xinitial : xt::detail::option_base
  42. {
  43. constexpr xinitial(T val)
  44. : m_val(val)
  45. {
  46. }
  47. constexpr T value() const
  48. {
  49. return m_val;
  50. }
  51. T m_val;
  52. };
  53. template <class T>
  54. constexpr auto initial(T val)
  55. {
  56. return std::make_tuple(xinitial<T>(val));
  57. }
  58. template <std::ptrdiff_t I, class T, class Tuple>
  59. struct tuple_idx_of_impl;
  60. template <std::ptrdiff_t I, class T>
  61. struct tuple_idx_of_impl<I, T, std::tuple<>>
  62. {
  63. static constexpr std::ptrdiff_t value = -1;
  64. };
  65. template <std::ptrdiff_t I, class T, class... Types>
  66. struct tuple_idx_of_impl<I, T, std::tuple<T, Types...>>
  67. {
  68. static constexpr std::ptrdiff_t value = I;
  69. };
  70. template <std::ptrdiff_t I, class T, class U, class... Types>
  71. struct tuple_idx_of_impl<I, T, std::tuple<U, Types...>>
  72. {
  73. static constexpr std::ptrdiff_t value = tuple_idx_of_impl<I + 1, T, std::tuple<Types...>>::value;
  74. };
  75. template <class S, class... X>
  76. struct decay_all;
  77. template <template <class...> class S, class... X>
  78. struct decay_all<S<X...>>
  79. {
  80. using type = S<std::decay_t<X>...>;
  81. };
  82. template <class T, class Tuple>
  83. struct tuple_idx_of
  84. {
  85. static constexpr std::ptrdiff_t
  86. value = tuple_idx_of_impl<0, std::decay_t<T>, typename decay_all<Tuple>::type>::value;
  87. };
  88. template <class R, class T>
  89. struct reducer_options
  90. {
  91. template <class X>
  92. struct initial_tester : std::false_type
  93. {
  94. };
  95. template <class X>
  96. struct initial_tester<xinitial<X>> : std::true_type
  97. {
  98. };
  99. // Workaround for Apple because tuple_cat is buggy!
  100. template <class X>
  101. struct initial_tester<const xinitial<X>> : std::true_type
  102. {
  103. };
  104. using d_t = std::decay_t<T>;
  105. static constexpr std::size_t initial_val_idx = xtl::mpl::find_if<initial_tester, d_t>::value;
  106. reducer_options() = default;
  107. reducer_options(const T& tpl)
  108. {
  109. xtl::mpl::static_if<initial_val_idx != std::tuple_size<T>::value>(
  110. [this, &tpl](auto no_compile)
  111. {
  112. // use no_compile to prevent compilation if initial_val_idx is out of bounds!
  113. this->initial_value = no_compile(
  114. std::get < initial_val_idx != std::tuple_size<T>::value
  115. ? initial_val_idx
  116. : 0 > (tpl)
  117. )
  118. .value();
  119. },
  120. [](auto /*np_compile*/) {}
  121. );
  122. }
  123. using evaluation_strategy = std::conditional_t<
  124. tuple_idx_of<xt::evaluation_strategy::immediate_type, d_t>::value != -1,
  125. xt::evaluation_strategy::immediate_type,
  126. xt::evaluation_strategy::lazy_type>;
  127. using keep_dims = std::
  128. conditional_t<tuple_idx_of<xt::keep_dims_type, d_t>::value != -1, std::true_type, std::false_type>;
  129. static constexpr bool has_initial_value = initial_val_idx != std::tuple_size<d_t>::value;
  130. R initial_value;
  131. template <class NR>
  132. using rebind_t = reducer_options<NR, T>;
  133. template <class NR>
  134. auto rebind(NR initial, const reducer_options<R, T>&) const
  135. {
  136. reducer_options<NR, T> res;
  137. res.initial_value = initial;
  138. return res;
  139. }
  140. };
  141. template <class T>
  142. struct is_reducer_options_impl : std::false_type
  143. {
  144. };
  145. template <class... X>
  146. struct is_reducer_options_impl<std::tuple<X...>> : std::true_type
  147. {
  148. };
  149. template <class T>
  150. struct is_reducer_options : is_reducer_options_impl<std::decay_t<T>>
  151. {
  152. };
  153. /**********
  154. * reduce *
  155. **********/
  156. #define DEFAULT_STRATEGY_REDUCERS std::tuple<evaluation_strategy::lazy_type>
  157. template <class ST, class X, class KD = std::false_type>
  158. struct xreducer_shape_type;
  159. template <class S1, class S2>
  160. struct fixed_xreducer_shape_type;
  161. namespace detail
  162. {
  163. template <class O, class RS, class R, class E, class AX>
  164. inline void shape_computation(
  165. RS& result_shape,
  166. R& result,
  167. E& expr,
  168. const AX& axes,
  169. std::enable_if_t<!detail::is_fixed<RS>::value, int> = 0
  170. )
  171. {
  172. if (typename O::keep_dims())
  173. {
  174. resize_container(result_shape, expr.dimension());
  175. for (std::size_t i = 0; i < expr.dimension(); ++i)
  176. {
  177. if (std::find(axes.begin(), axes.end(), i) == axes.end())
  178. {
  179. // i not in axes!
  180. result_shape[i] = expr.shape()[i];
  181. }
  182. else
  183. {
  184. result_shape[i] = 1;
  185. }
  186. }
  187. }
  188. else
  189. {
  190. resize_container(result_shape, expr.dimension() - axes.size());
  191. for (std::size_t i = 0, idx = 0; i < expr.dimension(); ++i)
  192. {
  193. if (std::find(axes.begin(), axes.end(), i) == axes.end())
  194. {
  195. // i not in axes!
  196. result_shape[idx] = expr.shape()[i];
  197. ++idx;
  198. }
  199. }
  200. }
  201. result.resize(result_shape, expr.layout());
  202. }
  203. // skip shape computation if already done at compile time
  204. template <class O, class RS, class R, class S, class AX>
  205. inline void
  206. shape_computation(RS&, R&, const S&, const AX&, std::enable_if_t<detail::is_fixed<RS>::value, int> = 0)
  207. {
  208. }
  209. }
  210. template <class F, class E, class R, XTL_REQUIRES(std::is_convertible<typename E::value_type, typename R::value_type>)>
  211. inline void copy_to_reduced(F&, const E& e, R& result)
  212. {
  213. if (e.layout() == layout_type::row_major)
  214. {
  215. std::copy(
  216. e.template cbegin<layout_type::row_major>(),
  217. e.template cend<layout_type::row_major>(),
  218. result.data()
  219. );
  220. }
  221. else
  222. {
  223. std::copy(
  224. e.template cbegin<layout_type::column_major>(),
  225. e.template cend<layout_type::column_major>(),
  226. result.data()
  227. );
  228. }
  229. }
  230. template <
  231. class F,
  232. class E,
  233. class R,
  234. XTL_REQUIRES(xtl::negation<std::is_convertible<typename E::value_type, typename R::value_type>>)>
  235. inline void copy_to_reduced(F& f, const E& e, R& result)
  236. {
  237. if (e.layout() == layout_type::row_major)
  238. {
  239. std::transform(
  240. e.template cbegin<layout_type::row_major>(),
  241. e.template cend<layout_type::row_major>(),
  242. result.data(),
  243. f
  244. );
  245. }
  246. else
  247. {
  248. std::transform(
  249. e.template cbegin<layout_type::column_major>(),
  250. e.template cend<layout_type::column_major>(),
  251. result.data(),
  252. f
  253. );
  254. }
  255. }
  256. template <class F, class E, class X, class O>
  257. inline auto reduce_immediate(F&& f, E&& e, X&& axes, O&& raw_options)
  258. {
  259. using reduce_functor_type = typename std::decay_t<F>::reduce_functor_type;
  260. using init_functor_type = typename std::decay_t<F>::init_functor_type;
  261. using expr_value_type = typename std::decay_t<E>::value_type;
  262. using result_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
  263. std::declval<init_functor_type>()(),
  264. std::declval<expr_value_type>()
  265. ))>;
  266. using options_t = reducer_options<result_type, std::decay_t<O>>;
  267. options_t options(raw_options);
  268. using shape_type = typename xreducer_shape_type<
  269. typename std::decay_t<E>::shape_type,
  270. std::decay_t<X>,
  271. typename options_t::keep_dims>::type;
  272. using result_container_type = typename detail::xtype_for_shape<
  273. shape_type>::template type<result_type, std::decay_t<E>::static_layout>;
  274. result_container_type result;
  275. // retrieve functors from triple struct
  276. auto reduce_fct = xt::get<0>(f);
  277. auto init_fct = xt::get<1>(f);
  278. auto merge_fct = xt::get<2>(f);
  279. if (axes.size() == 0)
  280. {
  281. result.resize(e.shape(), e.layout());
  282. auto cpf = [&reduce_fct, &init_fct](const auto& v)
  283. {
  284. return reduce_fct(static_cast<result_type>(init_fct()), v);
  285. };
  286. copy_to_reduced(cpf, e, result);
  287. return result;
  288. }
  289. shape_type result_shape{};
  290. dynamic_shape<std::size_t>
  291. iter_shape = xtl::forward_sequence<dynamic_shape<std::size_t>, decltype(e.shape())>(e.shape());
  292. dynamic_shape<std::size_t> iter_strides(e.dimension());
  293. // std::less is used, because as the standard says (24.4.5):
  294. // A sequence is sorted with respect to a comparator comp if for any iterator i pointing to the
  295. // sequence and any non-negative integer n such that i + n is a valid iterator pointing to an element
  296. // of the sequence, comp(*(i + n), *i) == false. Therefore less is required to detect duplicates.
  297. if (!std::is_sorted(axes.cbegin(), axes.cend(), std::less<>()))
  298. {
  299. XTENSOR_THROW(std::runtime_error, "Reducing axes should be sorted.");
  300. }
  301. if (std::adjacent_find(axes.cbegin(), axes.cend()) != axes.cend())
  302. {
  303. XTENSOR_THROW(std::runtime_error, "Reducing axes should not contain duplicates.");
  304. }
  305. if (axes.size() != 0 && axes[axes.size() - 1] > e.dimension() - 1)
  306. {
  307. XTENSOR_THROW(
  308. std::runtime_error,
  309. "Axis " + std::to_string(axes[axes.size() - 1]) + " out of bounds for reduction."
  310. );
  311. }
  312. detail::shape_computation<options_t>(result_shape, result, e, axes);
  313. // Fast track for complete reduction
  314. if (e.dimension() == axes.size())
  315. {
  316. result_type tmp = options_t::has_initial_value ? options.initial_value : init_fct();
  317. result.data()[0] = std::accumulate(e.storage().begin(), e.storage().end(), tmp, reduce_fct);
  318. return result;
  319. }
  320. std::size_t leading_ax = axes[(e.layout() == layout_type::row_major) ? axes.size() - 1 : 0];
  321. auto strides_finder = e.strides().begin() + static_cast<std::ptrdiff_t>(leading_ax);
  322. // The computed strides contain "0" where the shape is 1 -- therefore find the next none-zero number
  323. std::size_t inner_stride = static_cast<std::size_t>(*strides_finder);
  324. auto iter_bound = e.layout() == layout_type::row_major ? e.strides().begin() : (e.strides().end() - 1);
  325. while (inner_stride == 0 && strides_finder != iter_bound)
  326. {
  327. (e.layout() == layout_type::row_major) ? --strides_finder : ++strides_finder;
  328. inner_stride = static_cast<std::size_t>(*strides_finder);
  329. }
  330. if (inner_stride == 0)
  331. {
  332. auto cpf = [&reduce_fct, &init_fct](const auto& v)
  333. {
  334. return reduce_fct(static_cast<result_type>(init_fct()), v);
  335. };
  336. copy_to_reduced(cpf, e, result);
  337. return result;
  338. }
  339. std::size_t inner_loop_size = static_cast<std::size_t>(inner_stride);
  340. std::size_t outer_loop_size = e.shape()[leading_ax];
  341. // The following code merges reduction axes "at the end" (or the beginning for col_major)
  342. // together by increasing the size of the outer loop where appropriate
  343. auto merge_loops = [&outer_loop_size, &e](auto it, auto end)
  344. {
  345. auto last_ax = *it;
  346. ++it;
  347. for (; it != end; ++it)
  348. {
  349. // note that we check is_sorted, so this condition is valid
  350. if (std::abs(std::ptrdiff_t(*it) - std::ptrdiff_t(last_ax)) == 1)
  351. {
  352. last_ax = *it;
  353. outer_loop_size *= e.shape()[last_ax];
  354. }
  355. }
  356. return last_ax;
  357. };
  358. for (std::size_t i = 0, idx = 0; i < e.dimension(); ++i)
  359. {
  360. if (std::find(axes.begin(), axes.end(), i) == axes.end())
  361. {
  362. // i not in axes!
  363. iter_strides[i] = static_cast<std::size_t>(result.strides(
  364. )[typename options_t::keep_dims() ? i : idx]);
  365. ++idx;
  366. }
  367. }
  368. if (e.layout() == layout_type::row_major)
  369. {
  370. std::size_t last_ax = merge_loops(axes.rbegin(), axes.rend());
  371. iter_shape.erase(iter_shape.begin() + std::ptrdiff_t(last_ax), iter_shape.end());
  372. iter_strides.erase(iter_strides.begin() + std::ptrdiff_t(last_ax), iter_strides.end());
  373. }
  374. else if (e.layout() == layout_type::column_major)
  375. {
  376. // we got column_major here
  377. std::size_t last_ax = merge_loops(axes.begin(), axes.end());
  378. // erasing the front vs the back
  379. iter_shape.erase(iter_shape.begin(), iter_shape.begin() + std::ptrdiff_t(last_ax + 1));
  380. iter_strides.erase(iter_strides.begin(), iter_strides.begin() + std::ptrdiff_t(last_ax + 1));
  381. // and reversing, to make it work with the same next_idx function
  382. std::reverse(iter_shape.begin(), iter_shape.end());
  383. std::reverse(iter_strides.begin(), iter_strides.end());
  384. }
  385. else
  386. {
  387. XTENSOR_THROW(std::runtime_error, "Layout not supported in immediate reduction.");
  388. }
  389. xindex temp_idx(iter_shape.size());
  390. auto next_idx = [&iter_shape, &iter_strides, &temp_idx]()
  391. {
  392. std::size_t i = iter_shape.size();
  393. for (; i > 0; --i)
  394. {
  395. if (std::ptrdiff_t(temp_idx[i - 1]) >= std::ptrdiff_t(iter_shape[i - 1]) - 1)
  396. {
  397. temp_idx[i - 1] = 0;
  398. }
  399. else
  400. {
  401. temp_idx[i - 1]++;
  402. break;
  403. }
  404. }
  405. return std::make_pair(
  406. i == 0,
  407. std::inner_product(temp_idx.begin(), temp_idx.end(), iter_strides.begin(), std::ptrdiff_t(0))
  408. );
  409. };
  410. auto begin = e.data();
  411. auto out = result.data();
  412. auto out_begin = result.data();
  413. std::ptrdiff_t next_stride = 0;
  414. std::pair<bool, std::ptrdiff_t> idx_res(false, 0);
  415. // Remark: eventually some modifications here to make conditions faster where merge + accumulate is
  416. // the same function (e.g. check std::is_same<decltype(merge_fct), decltype(reduce_fct)>::value) ...
  417. auto merge_border = out;
  418. bool merge = false;
  419. // TODO there could be some performance gain by removing merge checking
  420. // when axes.size() == 1 and even next_idx could be removed for something simpler (next_stride
  421. // always the same) best way to do this would be to create a function that takes (begin, out,
  422. // outer_loop_size, inner_loop_size, next_idx_lambda)
  423. // Decide if going about it row-wise or col-wise
  424. if (inner_stride == 1)
  425. {
  426. while (idx_res.first != true)
  427. {
  428. // for unknown reasons it's much faster to use a temporary variable and
  429. // std::accumulate here -- probably some cache behavior
  430. result_type tmp = init_fct();
  431. tmp = std::accumulate(begin, begin + outer_loop_size, tmp, reduce_fct);
  432. // use merge function if necessary
  433. *out = merge ? merge_fct(*out, tmp) : tmp;
  434. begin += outer_loop_size;
  435. idx_res = next_idx();
  436. next_stride = idx_res.second;
  437. out = out_begin + next_stride;
  438. if (out > merge_border)
  439. {
  440. // looped over once
  441. merge = false;
  442. merge_border = out;
  443. }
  444. else
  445. {
  446. merge = true;
  447. }
  448. };
  449. }
  450. else
  451. {
  452. while (idx_res.first != true)
  453. {
  454. std::transform(
  455. out,
  456. out + inner_loop_size,
  457. begin,
  458. out,
  459. [merge, &init_fct, &reduce_fct](auto&& v1, auto&& v2)
  460. {
  461. return merge ? reduce_fct(v1, v2) :
  462. // cast because return type of identity function is not upcasted
  463. reduce_fct(static_cast<result_type>(init_fct()), v2);
  464. }
  465. );
  466. begin += inner_stride;
  467. for (std::size_t i = 1; i < outer_loop_size; ++i)
  468. {
  469. std::transform(out, out + inner_loop_size, begin, out, reduce_fct);
  470. begin += inner_stride;
  471. }
  472. idx_res = next_idx();
  473. next_stride = idx_res.second;
  474. out = out_begin + next_stride;
  475. if (out > merge_border)
  476. {
  477. // looped over once
  478. merge = false;
  479. merge_border = out;
  480. }
  481. else
  482. {
  483. merge = true;
  484. }
  485. };
  486. }
  487. if (options_t::has_initial_value)
  488. {
  489. std::transform(
  490. result.data(),
  491. result.data() + result.size(),
  492. result.data(),
  493. [&merge_fct, &options](auto&& v)
  494. {
  495. return merge_fct(v, options.initial_value);
  496. }
  497. );
  498. }
  499. return result;
  500. }
  501. /*********************
  502. * xreducer functors *
  503. *********************/
  504. template <class T>
  505. struct const_value
  506. {
  507. using value_type = T;
  508. constexpr const_value() = default;
  509. constexpr const_value(T t)
  510. : m_value(t)
  511. {
  512. }
  513. constexpr T operator()() const
  514. {
  515. return m_value;
  516. }
  517. template <class NT>
  518. using rebind_t = const_value<NT>;
  519. template <class NT>
  520. const_value<NT> rebind() const;
  521. T m_value;
  522. };
  523. namespace detail
  524. {
  525. template <class T, bool B>
  526. struct evaluated_value_type
  527. {
  528. using type = T;
  529. };
  530. template <class T>
  531. struct evaluated_value_type<T, true>
  532. {
  533. using type = typename std::decay_t<decltype(xt::eval(std::declval<T>()))>;
  534. };
  535. template <class T, bool B>
  536. using evaluated_value_type_t = typename evaluated_value_type<T, B>::type;
  537. }
  538. template <class REDUCE_FUNC, class INIT_FUNC = const_value<long int>, class MERGE_FUNC = REDUCE_FUNC>
  539. struct xreducer_functors : public std::tuple<REDUCE_FUNC, INIT_FUNC, MERGE_FUNC>
  540. {
  541. using self_type = xreducer_functors<REDUCE_FUNC, INIT_FUNC, MERGE_FUNC>;
  542. using base_type = std::tuple<REDUCE_FUNC, INIT_FUNC, MERGE_FUNC>;
  543. using reduce_functor_type = REDUCE_FUNC;
  544. using init_functor_type = INIT_FUNC;
  545. using merge_functor_type = MERGE_FUNC;
  546. using init_value_type = typename init_functor_type::value_type;
  547. xreducer_functors()
  548. : base_type()
  549. {
  550. }
  551. template <class RF>
  552. xreducer_functors(RF&& reduce_func)
  553. : base_type(std::forward<RF>(reduce_func), INIT_FUNC(), reduce_func)
  554. {
  555. }
  556. template <class RF, class IF>
  557. xreducer_functors(RF&& reduce_func, IF&& init_func)
  558. : base_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func), reduce_func)
  559. {
  560. }
  561. template <class RF, class IF, class MF>
  562. xreducer_functors(RF&& reduce_func, IF&& init_func, MF&& merge_func)
  563. : base_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func), std::forward<MF>(merge_func))
  564. {
  565. }
  566. reduce_functor_type get_reduce() const
  567. {
  568. return std::get<0>(upcast());
  569. }
  570. init_functor_type get_init() const
  571. {
  572. return std::get<1>(upcast());
  573. }
  574. merge_functor_type get_merge() const
  575. {
  576. return std::get<2>(upcast());
  577. }
  578. template <class NT>
  579. using rebind_t = xreducer_functors<REDUCE_FUNC, const_value<NT>, MERGE_FUNC>;
  580. template <class NT>
  581. rebind_t<NT> rebind()
  582. {
  583. return make_xreducer_functor(get_reduce(), get_init().template rebind<NT>(), get_merge());
  584. }
  585. private:
  586. // Workaround for clang-cl
  587. const base_type& upcast() const
  588. {
  589. return static_cast<const base_type&>(*this);
  590. }
  591. };
  592. template <class RF>
  593. auto make_xreducer_functor(RF&& reduce_func)
  594. {
  595. using reducer_type = xreducer_functors<std::remove_reference_t<RF>>;
  596. return reducer_type(std::forward<RF>(reduce_func));
  597. }
  598. template <class RF, class IF>
  599. auto make_xreducer_functor(RF&& reduce_func, IF&& init_func)
  600. {
  601. using reducer_type = xreducer_functors<std::remove_reference_t<RF>, std::remove_reference_t<IF>>;
  602. return reducer_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func));
  603. }
  604. template <class RF, class IF, class MF>
  605. auto make_xreducer_functor(RF&& reduce_func, IF&& init_func, MF&& merge_func)
  606. {
  607. using reducer_type = xreducer_functors<
  608. std::remove_reference_t<RF>,
  609. std::remove_reference_t<IF>,
  610. std::remove_reference_t<MF>>;
  611. return reducer_type(
  612. std::forward<RF>(reduce_func),
  613. std::forward<IF>(init_func),
  614. std::forward<MF>(merge_func)
  615. );
  616. }
  617. /**********************
  618. * xreducer extension *
  619. **********************/
  620. namespace extension
  621. {
  622. template <class Tag, class F, class CT, class X, class O>
  623. struct xreducer_base_impl;
  624. template <class F, class CT, class X, class O>
  625. struct xreducer_base_impl<xtensor_expression_tag, F, CT, X, O>
  626. {
  627. using type = xtensor_empty_base;
  628. };
  629. template <class F, class CT, class X, class O>
  630. struct xreducer_base : xreducer_base_impl<xexpression_tag_t<CT>, F, CT, X, O>
  631. {
  632. };
  633. template <class F, class CT, class X, class O>
  634. using xreducer_base_t = typename xreducer_base<F, CT, X, O>::type;
  635. }
  636. /************
  637. * xreducer *
  638. ************/
  639. template <class F, class CT, class X, class O>
  640. class xreducer;
  641. template <class F, class CT, class X, class O>
  642. class xreducer_stepper;
  643. template <class F, class CT, class X, class O>
  644. struct xiterable_inner_types<xreducer<F, CT, X, O>>
  645. {
  646. using xexpression_type = std::decay_t<CT>;
  647. using inner_shape_type = typename xreducer_shape_type<
  648. typename xexpression_type::shape_type,
  649. std::decay_t<X>,
  650. typename O::keep_dims>::type;
  651. using const_stepper = xreducer_stepper<F, CT, X, O>;
  652. using stepper = const_stepper;
  653. };
  654. template <class F, class CT, class X, class O>
  655. struct xcontainer_inner_types<xreducer<F, CT, X, O>>
  656. {
  657. using xexpression_type = std::decay_t<CT>;
  658. using reduce_functor_type = typename std::decay_t<F>::reduce_functor_type;
  659. using init_functor_type = typename std::decay_t<F>::init_functor_type;
  660. using merge_functor_type = typename std::decay_t<F>::merge_functor_type;
  661. using substepper_type = typename xexpression_type::const_stepper;
  662. using raw_value_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
  663. std::declval<init_functor_type>()(),
  664. *std::declval<substepper_type>()
  665. ))>;
  666. using value_type = typename detail::evaluated_value_type_t<raw_value_type, is_xexpression<raw_value_type>::value>;
  667. using reference = value_type;
  668. using const_reference = value_type;
  669. using size_type = typename xexpression_type::size_type;
  670. };
  671. template <class T>
  672. struct select_dim_mapping_type
  673. {
  674. using type = T;
  675. };
  676. template <std::size_t... I>
  677. struct select_dim_mapping_type<fixed_shape<I...>>
  678. {
  679. using type = std::array<std::size_t, sizeof...(I)>;
  680. };
  681. /**
  682. * @class xreducer
  683. * @brief Reducing function operating over specified axes.
  684. *
  685. * The xreducer class implements an \ref xexpression applying
  686. * a reducing function to an \ref xexpression over the specified
  687. * axes.
  688. *
  689. * @tparam F a tuple of functors (class \ref xreducer_functors or compatible)
  690. * @tparam CT the closure type of the \ref xexpression to reduce
  691. * @tparam X the list of axes
  692. *
  693. * The reducer's result_type is deduced from the result type of function
  694. * <tt>F::reduce_functor_type</tt> when called with elements of the expression @tparam CT.
  695. *
  696. * @sa reduce
  697. */
  698. template <class F, class CT, class X, class O>
  699. class xreducer : public xsharable_expression<xreducer<F, CT, X, O>>,
  700. public xconst_iterable<xreducer<F, CT, X, O>>,
  701. public xaccessible<xreducer<F, CT, X, O>>,
  702. public extension::xreducer_base_t<F, CT, X, O>
  703. {
  704. public:
  705. using self_type = xreducer<F, CT, X, O>;
  706. using inner_types = xcontainer_inner_types<self_type>;
  707. using reduce_functor_type = typename inner_types::reduce_functor_type;
  708. using init_functor_type = typename inner_types::init_functor_type;
  709. using merge_functor_type = typename inner_types::merge_functor_type;
  710. using xreducer_functors_type = xreducer_functors<reduce_functor_type, init_functor_type, merge_functor_type>;
  711. using xexpression_type = typename inner_types::xexpression_type;
  712. using axes_type = X;
  713. using extension_base = extension::xreducer_base_t<F, CT, X, O>;
  714. using expression_tag = typename extension_base::expression_tag;
  715. using substepper_type = typename inner_types::substepper_type;
  716. using value_type = typename inner_types::value_type;
  717. using reference = typename inner_types::reference;
  718. using const_reference = typename inner_types::const_reference;
  719. using pointer = value_type*;
  720. using const_pointer = const value_type*;
  721. using size_type = typename inner_types::size_type;
  722. using difference_type = typename xexpression_type::difference_type;
  723. using iterable_base = xconst_iterable<self_type>;
  724. using inner_shape_type = typename iterable_base::inner_shape_type;
  725. using shape_type = inner_shape_type;
  726. using dim_mapping_type = typename select_dim_mapping_type<inner_shape_type>::type;
  727. using stepper = typename iterable_base::stepper;
  728. using const_stepper = typename iterable_base::const_stepper;
  729. using bool_load_type = typename xexpression_type::bool_load_type;
  730. static constexpr layout_type static_layout = layout_type::dynamic;
  731. static constexpr bool contiguous_layout = false;
  732. template <class Func, class CTA, class AX, class OX>
  733. xreducer(Func&& func, CTA&& e, AX&& axes, OX&& options);
  734. const inner_shape_type& shape() const noexcept;
  735. layout_type layout() const noexcept;
  736. bool is_contiguous() const noexcept;
  737. template <class... Args>
  738. const_reference operator()(Args... args) const;
  739. template <class... Args>
  740. const_reference unchecked(Args... args) const;
  741. template <class It>
  742. const_reference element(It first, It last) const;
  743. const xexpression_type& expression() const noexcept;
  744. template <class S>
  745. bool broadcast_shape(S& shape, bool reuse_cache = false) const;
  746. template <class S>
  747. bool has_linear_assign(const S& strides) const noexcept;
  748. template <class S>
  749. const_stepper stepper_begin(const S& shape) const noexcept;
  750. template <class S>
  751. const_stepper stepper_end(const S& shape, layout_type) const noexcept;
  752. template <class E, class Func = F, class Opts = O>
  753. using rebind_t = xreducer<Func, E, X, Opts>;
  754. template <class E>
  755. rebind_t<E> build_reducer(E&& e) const;
  756. template <class E, class Func, class Opts>
  757. rebind_t<E, Func, Opts> build_reducer(E&& e, Func&& func, Opts&& opts) const;
  758. xreducer_functors_type functors() const
  759. {
  760. return xreducer_functors_type(m_reduce, m_init, m_merge); // TODO: understand why
  761. // make_xreducer_functor is throwing an
  762. // error
  763. }
  764. const O& options() const
  765. {
  766. return m_options;
  767. }
  768. private:
  769. CT m_e;
  770. reduce_functor_type m_reduce;
  771. init_functor_type m_init;
  772. merge_functor_type m_merge;
  773. axes_type m_axes;
  774. inner_shape_type m_shape;
  775. dim_mapping_type m_dim_mapping;
  776. O m_options;
  777. friend class xreducer_stepper<F, CT, X, O>;
  778. };
  779. /*************************
  780. * reduce implementation *
  781. *************************/
  782. namespace detail
  783. {
  784. template <class F, class E, class X, class O>
  785. inline auto reduce_impl(F&& f, E&& e, X&& axes, evaluation_strategy::lazy_type, O&& options)
  786. {
  787. decltype(auto) normalized_axes = normalize_axis(e, std::forward<X>(axes));
  788. using reduce_functor_type = typename std::decay_t<F>::reduce_functor_type;
  789. using init_functor_type = typename std::decay_t<F>::init_functor_type;
  790. using value_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
  791. std::declval<init_functor_type>()(),
  792. *std::declval<typename std::decay_t<E>::const_stepper>()
  793. ))>;
  794. using evaluated_value_type = evaluated_value_type_t<value_type, is_xexpression<value_type>::value>;
  795. using reducer_type = xreducer<
  796. F,
  797. const_xclosure_t<E>,
  798. xtl::const_closure_type_t<decltype(normalized_axes)>,
  799. reducer_options<evaluated_value_type, std::decay_t<O>>>;
  800. return reducer_type(
  801. std::forward<F>(f),
  802. std::forward<E>(e),
  803. std::forward<decltype(normalized_axes)>(normalized_axes),
  804. std::forward<O>(options)
  805. );
  806. }
  807. template <class F, class E, class X, class O>
  808. inline auto reduce_impl(F&& f, E&& e, X&& axes, evaluation_strategy::immediate_type, O&& options)
  809. {
  810. decltype(auto) normalized_axes = normalize_axis(e, std::forward<X>(axes));
  811. return reduce_immediate(
  812. std::forward<F>(f),
  813. eval(std::forward<E>(e)),
  814. std::forward<decltype(normalized_axes)>(normalized_axes),
  815. std::forward<O>(options)
  816. );
  817. }
  818. }
  819. #define DEFAULT_STRATEGY_REDUCERS std::tuple<evaluation_strategy::lazy_type>
  820. namespace detail
  821. {
  822. template <class T>
  823. struct is_xreducer_functors_impl : std::false_type
  824. {
  825. };
  826. template <class RF, class IF, class MF>
  827. struct is_xreducer_functors_impl<xreducer_functors<RF, IF, MF>> : std::true_type
  828. {
  829. };
  830. template <class T>
  831. using is_xreducer_functors = is_xreducer_functors_impl<std::decay_t<T>>;
  832. }
  833. /**
  834. * @brief Returns an \ref xexpression applying the specified reducing
  835. * function to an expression over the given axes.
  836. *
  837. * @param f the reducing function to apply.
  838. * @param e the \ref xexpression to reduce.
  839. * @param axes the list of axes.
  840. * @param options evaluation strategy to use (lazy (default), or immediate)
  841. *
  842. * The returned expression either hold a const reference to \p e or a copy
  843. * depending on whether \p e is an lvalue or an rvalue.
  844. */
  845. template <
  846. class F,
  847. class E,
  848. class X,
  849. class EVS = DEFAULT_STRATEGY_REDUCERS,
  850. XTL_REQUIRES(xtl::negation<is_reducer_options<X>>, detail::is_xreducer_functors<F>)>
  851. inline auto reduce(F&& f, E&& e, X&& axes, EVS&& options = EVS())
  852. {
  853. return detail::reduce_impl(
  854. std::forward<F>(f),
  855. std::forward<E>(e),
  856. std::forward<X>(axes),
  857. typename reducer_options<int, EVS>::evaluation_strategy{},
  858. std::forward<EVS>(options)
  859. );
  860. }
  861. template <
  862. class F,
  863. class E,
  864. class X,
  865. class EVS = DEFAULT_STRATEGY_REDUCERS,
  866. XTL_REQUIRES(xtl::negation<is_reducer_options<X>>, xtl::negation<detail::is_xreducer_functors<F>>)>
  867. inline auto reduce(F&& f, E&& e, X&& axes, EVS&& options = EVS())
  868. {
  869. return reduce(
  870. make_xreducer_functor(std::forward<F>(f)),
  871. std::forward<E>(e),
  872. std::forward<X>(axes),
  873. std::forward<EVS>(options)
  874. );
  875. }
  876. template <
  877. class F,
  878. class E,
  879. class EVS = DEFAULT_STRATEGY_REDUCERS,
  880. XTL_REQUIRES(is_reducer_options<EVS>, detail::is_xreducer_functors<F>)>
  881. inline auto reduce(F&& f, E&& e, EVS&& options = EVS())
  882. {
  883. xindex_type_t<typename std::decay_t<E>::shape_type> ar;
  884. resize_container(ar, e.dimension());
  885. std::iota(ar.begin(), ar.end(), 0);
  886. return detail::reduce_impl(
  887. std::forward<F>(f),
  888. std::forward<E>(e),
  889. std::move(ar),
  890. typename reducer_options<int, std::decay_t<EVS>>::evaluation_strategy{},
  891. std::forward<EVS>(options)
  892. );
  893. }
  894. template <
  895. class F,
  896. class E,
  897. class EVS = DEFAULT_STRATEGY_REDUCERS,
  898. XTL_REQUIRES(is_reducer_options<EVS>, xtl::negation<detail::is_xreducer_functors<F>>)>
  899. inline auto reduce(F&& f, E&& e, EVS&& options = EVS())
  900. {
  901. return reduce(make_xreducer_functor(std::forward<F>(f)), std::forward<E>(e), std::forward<EVS>(options));
  902. }
  903. template <
  904. class F,
  905. class E,
  906. class I,
  907. std::size_t N,
  908. class EVS = DEFAULT_STRATEGY_REDUCERS,
  909. XTL_REQUIRES(detail::is_xreducer_functors<F>)>
  910. inline auto reduce(F&& f, E&& e, const I (&axes)[N], EVS options = EVS())
  911. {
  912. using axes_type = std::array<std::size_t, N>;
  913. auto ax = xt::forward_normalize<axes_type>(e, axes);
  914. return detail::reduce_impl(
  915. std::forward<F>(f),
  916. std::forward<E>(e),
  917. std::move(ax),
  918. typename reducer_options<int, EVS>::evaluation_strategy{},
  919. options
  920. );
  921. }
  922. template <
  923. class F,
  924. class E,
  925. class I,
  926. std::size_t N,
  927. class EVS = DEFAULT_STRATEGY_REDUCERS,
  928. XTL_REQUIRES(xtl::negation<detail::is_xreducer_functors<F>>)>
  929. inline auto reduce(F&& f, E&& e, const I (&axes)[N], EVS options = EVS())
  930. {
  931. return reduce(make_xreducer_functor(std::forward<F>(f)), std::forward<E>(e), axes, options);
  932. }
  933. /********************
  934. * xreducer_stepper *
  935. ********************/
  936. template <class F, class CT, class X, class O>
  937. class xreducer_stepper
  938. {
  939. public:
  940. using self_type = xreducer_stepper<F, CT, X, O>;
  941. using xreducer_type = xreducer<F, CT, X, O>;
  942. using value_type = typename xreducer_type::value_type;
  943. using reference = typename xreducer_type::value_type;
  944. using pointer = typename xreducer_type::const_pointer;
  945. using size_type = typename xreducer_type::size_type;
  946. using difference_type = typename xreducer_type::difference_type;
  947. using xexpression_type = typename xreducer_type::xexpression_type;
  948. using substepper_type = typename xexpression_type::const_stepper;
  949. using shape_type = typename xreducer_type::shape_type;
  950. xreducer_stepper(
  951. const xreducer_type& red,
  952. size_type offset,
  953. bool end = false,
  954. layout_type l = default_assignable_layout(xexpression_type::static_layout)
  955. );
  956. reference operator*() const;
  957. void step(size_type dim);
  958. void step_back(size_type dim);
  959. void step(size_type dim, size_type n);
  960. void step_back(size_type dim, size_type n);
  961. void reset(size_type dim);
  962. void reset_back(size_type dim);
  963. void to_begin();
  964. void to_end(layout_type l);
  965. private:
  966. reference initial_value() const;
  967. reference aggregate(size_type dim) const;
  968. reference aggregate_impl(size_type dim, /*keep_dims=*/std::false_type) const;
  969. reference aggregate_impl(size_type dim, /*keep_dims=*/std::true_type) const;
  970. substepper_type get_substepper_begin() const;
  971. size_type get_dim(size_type dim) const noexcept;
  972. size_type shape(size_type i) const noexcept;
  973. size_type axis(size_type i) const noexcept;
  974. const xreducer_type* m_reducer;
  975. size_type m_offset;
  976. mutable substepper_type m_stepper;
  977. };
  978. /******************
  979. * xreducer utils *
  980. ******************/
  981. namespace detail
  982. {
  983. template <std::size_t X, std::size_t... I>
  984. struct in
  985. {
  986. static constexpr bool value = xtl::disjunction<std::integral_constant<bool, X == I>...>::value;
  987. };
  988. template <std::size_t Z, class S1, class S2, class R>
  989. struct fixed_xreducer_shape_type_impl;
  990. template <std::size_t Z, std::size_t... I, std::size_t... J, std::size_t... R>
  991. struct fixed_xreducer_shape_type_impl<Z, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>
  992. {
  993. using type = std::conditional_t<
  994. in<Z, J...>::value,
  995. typename fixed_xreducer_shape_type_impl<Z - 1, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>::type,
  996. typename fixed_xreducer_shape_type_impl<
  997. Z - 1,
  998. fixed_shape<I...>,
  999. fixed_shape<J...>,
  1000. fixed_shape<detail::at<Z, I...>::value, R...>>::type>;
  1001. };
  1002. template <std::size_t... I, std::size_t... J, std::size_t... R>
  1003. struct fixed_xreducer_shape_type_impl<0, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>
  1004. {
  1005. using type = std::
  1006. conditional_t<in<0, J...>::value, fixed_shape<R...>, fixed_shape<detail::at<0, I...>::value, R...>>;
  1007. };
  1008. /***************************
  1009. * helper for return types *
  1010. ***************************/
  1011. template <class T>
  1012. struct xreducer_size_type
  1013. {
  1014. using type = std::size_t;
  1015. };
  1016. template <class T>
  1017. using xreducer_size_type_t = typename xreducer_size_type<T>::type;
  1018. template <class T>
  1019. struct xreducer_temporary_type
  1020. {
  1021. using type = T;
  1022. };
  1023. template <class T>
  1024. using xreducer_temporary_type_t = typename xreducer_temporary_type<T>::type;
  1025. /********************************
  1026. * Default const_value rebinder *
  1027. ********************************/
  1028. template <class T, class U>
  1029. struct const_value_rebinder
  1030. {
  1031. static const_value<U> run(const const_value<T>& t)
  1032. {
  1033. return const_value<U>(t.m_value);
  1034. }
  1035. };
  1036. }
  1037. /*******************************************
  1038. * Init functor const_value implementation *
  1039. *******************************************/
  1040. template <class T>
  1041. template <class NT>
  1042. const_value<NT> const_value<T>::rebind() const
  1043. {
  1044. return detail::const_value_rebinder<T, NT>::run(*this);
  1045. }
  1046. /*****************************
  1047. * fixed_xreducer_shape_type *
  1048. *****************************/
  1049. template <class S1, class S2>
  1050. struct fixed_xreducer_shape_type;
  1051. template <std::size_t... I, std::size_t... J>
  1052. struct fixed_xreducer_shape_type<fixed_shape<I...>, fixed_shape<J...>>
  1053. {
  1054. using type = typename detail::
  1055. fixed_xreducer_shape_type_impl<sizeof...(I) - 1, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<>>::type;
  1056. };
  1057. // meta-function returning the shape type for an xreducer
  1058. template <class ST, class X, class O>
  1059. struct xreducer_shape_type
  1060. {
  1061. using type = promote_shape_t<ST, std::decay_t<X>>;
  1062. };
  1063. template <class I1, std::size_t N1, class I2, std::size_t N2>
  1064. struct xreducer_shape_type<std::array<I1, N1>, std::array<I2, N2>, std::true_type>
  1065. {
  1066. using type = std::array<I2, N1>;
  1067. };
  1068. template <class I1, std::size_t N1, class I2, std::size_t N2>
  1069. struct xreducer_shape_type<std::array<I1, N1>, std::array<I2, N2>, std::false_type>
  1070. {
  1071. using type = std::array<I2, N1 - N2>;
  1072. };
  1073. template <std::size_t... I, class I2, std::size_t N2>
  1074. struct xreducer_shape_type<fixed_shape<I...>, std::array<I2, N2>, std::false_type>
  1075. {
  1076. using type = std::conditional_t<sizeof...(I) == N2, fixed_shape<>, std::array<I2, sizeof...(I) - N2>>;
  1077. };
  1078. namespace detail
  1079. {
  1080. template <class S1, class S2>
  1081. struct ixconcat;
  1082. template <class T, T... I1, T... I2>
  1083. struct ixconcat<std::integer_sequence<T, I1...>, std::integer_sequence<T, I2...>>
  1084. {
  1085. using type = std::integer_sequence<T, I1..., I2...>;
  1086. };
  1087. template <class T, T X, std::size_t N>
  1088. struct repeat_integer_sequence
  1089. {
  1090. using type = typename ixconcat<
  1091. std::integer_sequence<T, X>,
  1092. typename repeat_integer_sequence<T, X, N - 1>::type>::type;
  1093. };
  1094. template <class T, T X>
  1095. struct repeat_integer_sequence<T, X, 0>
  1096. {
  1097. using type = std::integer_sequence<T>;
  1098. };
  1099. template <class T, T X>
  1100. struct repeat_integer_sequence<T, X, 2>
  1101. {
  1102. using type = std::integer_sequence<T, X, X>;
  1103. };
  1104. template <class T, T X>
  1105. struct repeat_integer_sequence<T, X, 1>
  1106. {
  1107. using type = std::integer_sequence<T, X>;
  1108. };
  1109. }
  1110. template <std::size_t... I, class I2, std::size_t N2>
  1111. struct xreducer_shape_type<fixed_shape<I...>, std::array<I2, N2>, std::true_type>
  1112. {
  1113. template <std::size_t... X>
  1114. static constexpr auto get_type(std::index_sequence<X...>)
  1115. {
  1116. return fixed_shape<X...>{};
  1117. }
  1118. // if all axes reduced
  1119. using type = std::conditional_t<
  1120. sizeof...(I) == N2,
  1121. decltype(get_type(typename detail::repeat_integer_sequence<std::size_t, std::size_t(1), N2>::type{})),
  1122. std::array<I2, sizeof...(I)>>;
  1123. };
  1124. // Note adding "A" to prevent compilation in case nothing else matches
  1125. template <std::size_t... I, std::size_t... J, class O>
  1126. struct xreducer_shape_type<fixed_shape<I...>, fixed_shape<J...>, O>
  1127. {
  1128. using type = typename fixed_xreducer_shape_type<fixed_shape<I...>, fixed_shape<J...>>::type;
  1129. };
  1130. namespace detail
  1131. {
  1132. template <class S, class E, class X, class M>
  1133. inline void shape_and_mapping_computation(S& shape, E& e, const X& axes, M& mapping, std::false_type)
  1134. {
  1135. auto first = e.shape().begin();
  1136. auto last = e.shape().end();
  1137. auto exclude_it = axes.begin();
  1138. using value_type = typename S::value_type;
  1139. using difference_type = typename S::difference_type;
  1140. auto d_first = shape.begin();
  1141. auto map_first = mapping.begin();
  1142. auto iter = first;
  1143. while (iter != last && exclude_it != axes.end())
  1144. {
  1145. auto diff = std::distance(first, iter);
  1146. if (diff != difference_type(*exclude_it))
  1147. {
  1148. *d_first++ = *iter++;
  1149. *map_first++ = value_type(diff);
  1150. }
  1151. else
  1152. {
  1153. ++iter;
  1154. ++exclude_it;
  1155. }
  1156. }
  1157. auto diff = std::distance(first, iter);
  1158. auto end = std::distance(iter, last);
  1159. std::iota(map_first, map_first + end, diff);
  1160. std::copy(iter, last, d_first);
  1161. }
  1162. template <class S, class E, class X, class M>
  1163. inline void
  1164. shape_and_mapping_computation_keep_dim(S& shape, E& e, const X& axes, M& mapping, std::false_type)
  1165. {
  1166. for (std::size_t i = 0; i < e.dimension(); ++i)
  1167. {
  1168. if (std::find(axes.cbegin(), axes.cend(), i) == axes.cend())
  1169. {
  1170. // i not in axes!
  1171. shape[i] = e.shape()[i];
  1172. }
  1173. else
  1174. {
  1175. shape[i] = 1;
  1176. }
  1177. }
  1178. std::iota(mapping.begin(), mapping.end(), 0);
  1179. }
  1180. template <class S, class E, class X, class M>
  1181. inline void shape_and_mapping_computation(S&, E&, const X&, M&, std::true_type)
  1182. {
  1183. }
  1184. template <class S, class E, class X, class M>
  1185. inline void shape_and_mapping_computation_keep_dim(S&, E&, const X&, M&, std::true_type)
  1186. {
  1187. }
  1188. }
  1189. /***************************
  1190. * xreducer implementation *
  1191. ***************************/
  1192. /**
  1193. * @name Constructor
  1194. */
  1195. //@{
  1196. /**
  1197. * Constructs an xreducer expression applying the specified
  1198. * function to the given expression over the given axes.
  1199. *
  1200. * @param func the function to apply
  1201. * @param e the expression to reduce
  1202. * @param axes the axes along which the reduction is performed
  1203. */
  1204. template <class F, class CT, class X, class O>
  1205. template <class Func, class CTA, class AX, class OX>
  1206. inline xreducer<F, CT, X, O>::xreducer(Func&& func, CTA&& e, AX&& axes, OX&& options)
  1207. : m_e(std::forward<CTA>(e))
  1208. , m_reduce(xt::get<0>(func))
  1209. , m_init(xt::get<1>(func))
  1210. , m_merge(xt::get<2>(func))
  1211. , m_axes(std::forward<AX>(axes))
  1212. , m_shape(xtl::make_sequence<inner_shape_type>(
  1213. typename O::keep_dims() ? m_e.dimension() : m_e.dimension() - m_axes.size(),
  1214. 0
  1215. ))
  1216. , m_dim_mapping(xtl::make_sequence<dim_mapping_type>(
  1217. typename O::keep_dims() ? m_e.dimension() : m_e.dimension() - m_axes.size(),
  1218. 0
  1219. ))
  1220. , m_options(std::forward<OX>(options))
  1221. {
  1222. // std::less is used, because as the standard says (24.4.5):
  1223. // A sequence is sorted with respect to a comparator comp if for any iterator i pointing to the
  1224. // sequence and any non-negative integer n such that i + n is a valid iterator pointing to an element
  1225. // of the sequence, comp(*(i + n), *i) == false. Therefore less is required to detect duplicates.
  1226. if (!std::is_sorted(m_axes.cbegin(), m_axes.cend(), std::less<>()))
  1227. {
  1228. XTENSOR_THROW(std::runtime_error, "Reducing axes should be sorted.");
  1229. }
  1230. if (std::adjacent_find(m_axes.cbegin(), m_axes.cend()) != m_axes.cend())
  1231. {
  1232. XTENSOR_THROW(std::runtime_error, "Reducing axes should not contain duplicates.");
  1233. }
  1234. if (m_axes.size() != 0 && m_axes[m_axes.size() - 1] > m_e.dimension() - 1)
  1235. {
  1236. XTENSOR_THROW(
  1237. std::runtime_error,
  1238. "Axis " + std::to_string(m_axes[m_axes.size() - 1]) + " out of bounds for reduction."
  1239. );
  1240. }
  1241. if (!typename O::keep_dims())
  1242. {
  1243. detail::shape_and_mapping_computation(
  1244. m_shape,
  1245. m_e,
  1246. m_axes,
  1247. m_dim_mapping,
  1248. detail::is_fixed<shape_type>{}
  1249. );
  1250. }
  1251. else
  1252. {
  1253. detail::shape_and_mapping_computation_keep_dim(
  1254. m_shape,
  1255. m_e,
  1256. m_axes,
  1257. m_dim_mapping,
  1258. detail::is_fixed<shape_type>{}
  1259. );
  1260. }
  1261. }
  1262. //@}
  1263. /**
  1264. * @name Size and shape
  1265. */
  1266. /**
  1267. * Returns the shape of the expression.
  1268. */
  1269. template <class F, class CT, class X, class O>
  1270. inline auto xreducer<F, CT, X, O>::shape() const noexcept -> const inner_shape_type&
  1271. {
  1272. return m_shape;
  1273. }
  1274. /**
  1275. * Returns the shape of the expression.
  1276. */
  1277. template <class F, class CT, class X, class O>
  1278. inline layout_type xreducer<F, CT, X, O>::layout() const noexcept
  1279. {
  1280. return static_layout;
  1281. }
  1282. template <class F, class CT, class X, class O>
  1283. inline bool xreducer<F, CT, X, O>::is_contiguous() const noexcept
  1284. {
  1285. return false;
  1286. }
  1287. //@}
  1288. /**
  1289. * @name Data
  1290. */
  1291. /**
  1292. * Returns a constant reference to the element at the specified position in the reducer.
  1293. * @param args a list of indices specifying the position in the reducer. Indices
  1294. * must be unsigned integers, the number of indices should be equal or greater than
  1295. * the number of dimensions of the reducer.
  1296. */
  1297. template <class F, class CT, class X, class O>
  1298. template <class... Args>
  1299. inline auto xreducer<F, CT, X, O>::operator()(Args... args) const -> const_reference
  1300. {
  1301. XTENSOR_TRY(check_index(shape(), args...));
  1302. XTENSOR_CHECK_DIMENSION(shape(), args...);
  1303. std::array<std::size_t, sizeof...(Args)> arg_array = {{static_cast<std::size_t>(args)...}};
  1304. return element(arg_array.cbegin(), arg_array.cend());
  1305. }
  1306. /**
  1307. * Returns a constant reference to the element at the specified position in the reducer.
  1308. * @param args a list of indices specifying the position in the reducer. Indices
  1309. * must be unsigned integers, the number of indices must be equal to the number of
  1310. * dimensions of the reducer, else the behavior is undefined.
  1311. *
  1312. * @warning This method is meant for performance, for expressions with a dynamic
  1313. * number of dimensions (i.e. not known at compile time). Since it may have
  1314. * undefined behavior (see parameters), operator() should be preferred whenever
  1315. * it is possible.
  1316. * @warning This method is NOT compatible with broadcasting, meaning the following
  1317. * code has undefined behavior:
  1318. * @code{.cpp}
  1319. * xt::xarray<double> a = {{0, 1}, {2, 3}};
  1320. * xt::xarray<double> b = {0, 1};
  1321. * auto fd = a + b;
  1322. * double res = fd.uncheked(0, 1);
  1323. * @endcode
  1324. */
  1325. template <class F, class CT, class X, class O>
  1326. template <class... Args>
  1327. inline auto xreducer<F, CT, X, O>::unchecked(Args... args) const -> const_reference
  1328. {
  1329. std::array<std::size_t, sizeof...(Args)> arg_array = {{static_cast<std::size_t>(args)...}};
  1330. return element(arg_array.cbegin(), arg_array.cend());
  1331. }
  1332. /**
  1333. * Returns a constant reference to the element at the specified position in the reducer.
  1334. * @param first iterator starting the sequence of indices
  1335. * @param last iterator ending the sequence of indices
  1336. * The number of indices in the sequence should be equal to or greater
  1337. * than the number of dimensions of the reducer.
  1338. */
  1339. template <class F, class CT, class X, class O>
  1340. template <class It>
  1341. inline auto xreducer<F, CT, X, O>::element(It first, It last) const -> const_reference
  1342. {
  1343. XTENSOR_TRY(check_element_index(shape(), first, last));
  1344. auto stepper = const_stepper(*this, 0);
  1345. if (first != last)
  1346. {
  1347. size_type dim = 0;
  1348. // drop left most elements
  1349. auto size = std::ptrdiff_t(this->dimension()) - std::distance(first, last);
  1350. auto begin = first - size;
  1351. while (begin != last)
  1352. {
  1353. if (begin < first)
  1354. {
  1355. stepper.step(dim++, std::size_t(0));
  1356. begin++;
  1357. }
  1358. else
  1359. {
  1360. stepper.step(dim++, std::size_t(*begin++));
  1361. }
  1362. }
  1363. }
  1364. return *stepper;
  1365. }
  1366. /**
  1367. * Returns a constant reference to the underlying expression of the reducer.
  1368. */
  1369. template <class F, class CT, class X, class O>
  1370. inline auto xreducer<F, CT, X, O>::expression() const noexcept -> const xexpression_type&
  1371. {
  1372. return m_e;
  1373. }
  1374. //@}
  1375. /**
  1376. * @name Broadcasting
  1377. */
  1378. //@{
  1379. /**
  1380. * Broadcast the shape of the reducer to the specified parameter.
  1381. * @param shape the result shape
  1382. * @param reuse_cache parameter for internal optimization
  1383. * @return a boolean indicating whether the broadcasting is trivial
  1384. */
  1385. template <class F, class CT, class X, class O>
  1386. template <class S>
  1387. inline bool xreducer<F, CT, X, O>::broadcast_shape(S& shape, bool) const
  1388. {
  1389. return xt::broadcast_shape(m_shape, shape);
  1390. }
  1391. /**
  1392. * Checks whether the xreducer can be linearly assigned to an expression
  1393. * with the specified strides.
  1394. * @return a boolean indicating whether a linear assign is possible
  1395. */
  1396. template <class F, class CT, class X, class O>
  1397. template <class S>
  1398. inline bool xreducer<F, CT, X, O>::has_linear_assign(const S& /*strides*/) const noexcept
  1399. {
  1400. return false;
  1401. }
  1402. //@}
  1403. template <class F, class CT, class X, class O>
  1404. template <class S>
  1405. inline auto xreducer<F, CT, X, O>::stepper_begin(const S& shape) const noexcept -> const_stepper
  1406. {
  1407. size_type offset = shape.size() - this->dimension();
  1408. return const_stepper(*this, offset);
  1409. }
  1410. template <class F, class CT, class X, class O>
  1411. template <class S>
  1412. inline auto xreducer<F, CT, X, O>::stepper_end(const S& shape, layout_type l) const noexcept
  1413. -> const_stepper
  1414. {
  1415. size_type offset = shape.size() - this->dimension();
  1416. return const_stepper(*this, offset, true, l);
  1417. }
  1418. template <class F, class CT, class X, class O>
  1419. template <class E>
  1420. inline auto xreducer<F, CT, X, O>::build_reducer(E&& e) const -> rebind_t<E>
  1421. {
  1422. return rebind_t<E>(
  1423. std::make_tuple(m_reduce, m_init, m_merge),
  1424. std::forward<E>(e),
  1425. axes_type(m_axes),
  1426. m_options
  1427. );
  1428. }
  1429. template <class F, class CT, class X, class O>
  1430. template <class E, class Func, class Opts>
  1431. inline auto xreducer<F, CT, X, O>::build_reducer(E&& e, Func&& func, Opts&& opts) const
  1432. -> rebind_t<E, Func, Opts>
  1433. {
  1434. return rebind_t<E, Func, Opts>(
  1435. std::forward<Func>(func),
  1436. std::forward<E>(e),
  1437. axes_type(m_axes),
  1438. std::forward<Opts>(opts)
  1439. );
  1440. }
  1441. /***********************************
  1442. * xreducer_stepper implementation *
  1443. ***********************************/
  1444. template <class F, class CT, class X, class O>
  1445. inline xreducer_stepper<F, CT, X, O>::xreducer_stepper(
  1446. const xreducer_type& red,
  1447. size_type offset,
  1448. bool end,
  1449. layout_type l
  1450. )
  1451. : m_reducer(&red)
  1452. , m_offset(offset)
  1453. , m_stepper(get_substepper_begin())
  1454. {
  1455. if (end)
  1456. {
  1457. to_end(l);
  1458. }
  1459. }
  1460. template <class F, class CT, class X, class O>
  1461. inline auto xreducer_stepper<F, CT, X, O>::operator*() const -> reference
  1462. {
  1463. reference r = aggregate(0);
  1464. return r;
  1465. }
  1466. template <class F, class CT, class X, class O>
  1467. inline void xreducer_stepper<F, CT, X, O>::step(size_type dim)
  1468. {
  1469. if (dim >= m_offset)
  1470. {
  1471. m_stepper.step(get_dim(dim - m_offset));
  1472. }
  1473. }
  1474. template <class F, class CT, class X, class O>
  1475. inline void xreducer_stepper<F, CT, X, O>::step_back(size_type dim)
  1476. {
  1477. if (dim >= m_offset)
  1478. {
  1479. m_stepper.step_back(get_dim(dim - m_offset));
  1480. }
  1481. }
  1482. template <class F, class CT, class X, class O>
  1483. inline void xreducer_stepper<F, CT, X, O>::step(size_type dim, size_type n)
  1484. {
  1485. if (dim >= m_offset)
  1486. {
  1487. m_stepper.step(get_dim(dim - m_offset), n);
  1488. }
  1489. }
  1490. template <class F, class CT, class X, class O>
  1491. inline void xreducer_stepper<F, CT, X, O>::step_back(size_type dim, size_type n)
  1492. {
  1493. if (dim >= m_offset)
  1494. {
  1495. m_stepper.step_back(get_dim(dim - m_offset), n);
  1496. }
  1497. }
  1498. template <class F, class CT, class X, class O>
  1499. inline void xreducer_stepper<F, CT, X, O>::reset(size_type dim)
  1500. {
  1501. if (dim >= m_offset)
  1502. {
  1503. // Because the reducer uses `reset` to reset the non-reducing axes,
  1504. // we need to prevent that here for the KD case where.
  1505. if (typename O::keep_dims()
  1506. && std::binary_search(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim - m_offset))
  1507. {
  1508. // If keep dim activated, and dim is in the axes, do nothing!
  1509. return;
  1510. }
  1511. m_stepper.reset(get_dim(dim - m_offset));
  1512. }
  1513. }
  1514. template <class F, class CT, class X, class O>
  1515. inline void xreducer_stepper<F, CT, X, O>::reset_back(size_type dim)
  1516. {
  1517. if (dim >= m_offset)
  1518. {
  1519. // Note that for *not* KD this is not going to do anything
  1520. if (typename O::keep_dims()
  1521. && std::binary_search(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim - m_offset))
  1522. {
  1523. // If keep dim activated, and dim is in the axes, do nothing!
  1524. return;
  1525. }
  1526. m_stepper.reset_back(get_dim(dim - m_offset));
  1527. }
  1528. }
  1529. template <class F, class CT, class X, class O>
  1530. inline void xreducer_stepper<F, CT, X, O>::to_begin()
  1531. {
  1532. m_stepper.to_begin();
  1533. }
  1534. template <class F, class CT, class X, class O>
  1535. inline void xreducer_stepper<F, CT, X, O>::to_end(layout_type l)
  1536. {
  1537. m_stepper.to_end(l);
  1538. }
  1539. template <class F, class CT, class X, class O>
  1540. inline auto xreducer_stepper<F, CT, X, O>::initial_value() const -> reference
  1541. {
  1542. return O::has_initial_value ? m_reducer->m_options.initial_value
  1543. : static_cast<reference>(m_reducer->m_init());
  1544. }
  1545. template <class F, class CT, class X, class O>
  1546. inline auto xreducer_stepper<F, CT, X, O>::aggregate(size_type dim) const -> reference
  1547. {
  1548. reference res;
  1549. if (m_reducer->m_e.size() == size_type(0))
  1550. {
  1551. res = initial_value();
  1552. }
  1553. else if (m_reducer->m_e.shape().empty() || m_reducer->m_axes.size() == 0)
  1554. {
  1555. res = m_reducer->m_reduce(initial_value(), *m_stepper);
  1556. }
  1557. else
  1558. {
  1559. res = aggregate_impl(dim, typename O::keep_dims());
  1560. if (O::has_initial_value && dim == 0)
  1561. {
  1562. res = m_reducer->m_merge(m_reducer->m_options.initial_value, res);
  1563. }
  1564. }
  1565. return res;
  1566. }
  1567. template <class F, class CT, class X, class O>
  1568. inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl(size_type dim, std::false_type) const -> reference
  1569. {
  1570. // reference can be std::array, hence the {} initializer
  1571. reference res = {};
  1572. size_type index = axis(dim);
  1573. size_type size = shape(index);
  1574. if (dim != m_reducer->m_axes.size() - 1)
  1575. {
  1576. res = aggregate_impl(dim + 1, typename O::keep_dims());
  1577. for (size_type i = 1; i != size; ++i)
  1578. {
  1579. m_stepper.step(index);
  1580. res = m_reducer->m_merge(res, aggregate_impl(dim + 1, typename O::keep_dims()));
  1581. }
  1582. }
  1583. else
  1584. {
  1585. res = m_reducer->m_reduce(static_cast<reference>(m_reducer->m_init()), *m_stepper);
  1586. for (size_type i = 1; i != size; ++i)
  1587. {
  1588. m_stepper.step(index);
  1589. res = m_reducer->m_reduce(res, *m_stepper);
  1590. }
  1591. }
  1592. m_stepper.reset(index);
  1593. return res;
  1594. }
  1595. template <class F, class CT, class X, class O>
  1596. inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl(size_type dim, std::true_type) const -> reference
  1597. {
  1598. // reference can be std::array, hence the {} initializer
  1599. reference res = {};
  1600. auto ax_it = std::find(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim);
  1601. if (ax_it != m_reducer->m_axes.end())
  1602. {
  1603. size_type index = dim;
  1604. size_type size = m_reducer->m_e.shape()[index];
  1605. if (ax_it != m_reducer->m_axes.end() - 1 && size != 0)
  1606. {
  1607. res = aggregate_impl(dim + 1, typename O::keep_dims());
  1608. for (size_type i = 1; i != size; ++i)
  1609. {
  1610. m_stepper.step(index);
  1611. res = m_reducer->m_merge(res, aggregate_impl(dim + 1, typename O::keep_dims()));
  1612. }
  1613. }
  1614. else
  1615. {
  1616. res = m_reducer->m_reduce(static_cast<reference>(m_reducer->m_init()), *m_stepper);
  1617. for (size_type i = 1; i != size; ++i)
  1618. {
  1619. m_stepper.step(index);
  1620. res = m_reducer->m_reduce(res, *m_stepper);
  1621. }
  1622. }
  1623. m_stepper.reset(index);
  1624. }
  1625. else
  1626. {
  1627. if (dim < m_reducer->m_e.dimension())
  1628. {
  1629. res = aggregate_impl(dim + 1, typename O::keep_dims());
  1630. }
  1631. }
  1632. return res;
  1633. }
  1634. template <class F, class CT, class X, class O>
  1635. inline auto xreducer_stepper<F, CT, X, O>::get_substepper_begin() const -> substepper_type
  1636. {
  1637. return m_reducer->m_e.stepper_begin(m_reducer->m_e.shape());
  1638. }
  1639. template <class F, class CT, class X, class O>
  1640. inline auto xreducer_stepper<F, CT, X, O>::get_dim(size_type dim) const noexcept -> size_type
  1641. {
  1642. return m_reducer->m_dim_mapping[dim];
  1643. }
  1644. template <class F, class CT, class X, class O>
  1645. inline auto xreducer_stepper<F, CT, X, O>::shape(size_type i) const noexcept -> size_type
  1646. {
  1647. return m_reducer->m_e.shape()[i];
  1648. }
  1649. template <class F, class CT, class X, class O>
  1650. inline auto xreducer_stepper<F, CT, X, O>::axis(size_type i) const noexcept -> size_type
  1651. {
  1652. return m_reducer->m_axes[i];
  1653. }
  1654. }
  1655. #endif