xsort.hpp 49 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353
  1. /***************************************************************************
  2. * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
  3. * Copyright (c) QuantStack *
  4. * *
  5. * Distributed under the terms of the BSD 3-Clause License. *
  6. * *
  7. * The full license is in the file LICENSE, distributed with this software. *
  8. ****************************************************************************/
  9. #ifndef XTENSOR_SORT_HPP
  10. #define XTENSOR_SORT_HPP
  11. #include <algorithm>
  12. #include <cmath>
  13. #include <iterator>
  14. #include <utility>
  15. #include <xtl/xcompare.hpp>
  16. #include "xadapt.hpp"
  17. #include "xarray.hpp"
  18. #include "xeval.hpp"
  19. #include "xindex_view.hpp"
  20. #include "xmanipulation.hpp"
  21. #include "xmath.hpp"
  22. #include "xslice.hpp" // for xnone
  23. #include "xtensor.hpp"
  24. #include "xtensor_config.hpp"
  25. #include "xtensor_forward.hpp"
  26. #include "xview.hpp"
  27. namespace xt
  28. {
  29. /**
  30. * @defgroup xt_xsort Sorting functions.
  31. *
  32. * Because sorting functions need to access the tensor data repeatedly, they evaluate their
  33. * input and may allocate temporaries.
  34. */
  35. namespace detail
  36. {
  37. template <class T>
  38. std::ptrdiff_t adjust_secondary_stride(std::ptrdiff_t stride, T shape)
  39. {
  40. return stride != 0 ? stride : static_cast<std::ptrdiff_t>(shape);
  41. }
  42. template <class E>
  43. inline std::ptrdiff_t get_secondary_stride(const E& ev)
  44. {
  45. if (ev.layout() == layout_type::row_major)
  46. {
  47. return adjust_secondary_stride(ev.strides()[ev.dimension() - 2], *(ev.shape().end() - 1));
  48. }
  49. return adjust_secondary_stride(ev.strides()[1], *(ev.shape().begin()));
  50. }
  51. template <class E>
  52. inline std::size_t leading_axis_n_iters(const E& ev)
  53. {
  54. if (ev.layout() == layout_type::row_major)
  55. {
  56. return std::accumulate(
  57. ev.shape().begin(),
  58. ev.shape().end() - 1,
  59. std::size_t(1),
  60. std::multiplies<>()
  61. );
  62. }
  63. return std::accumulate(ev.shape().begin() + 1, ev.shape().end(), std::size_t(1), std::multiplies<>());
  64. }
  65. template <class E, class F>
  66. inline void call_over_leading_axis(E& ev, F&& fct)
  67. {
  68. XTENSOR_ASSERT(ev.dimension() >= 2);
  69. const std::size_t n_iters = leading_axis_n_iters(ev);
  70. const std::ptrdiff_t secondary_stride = get_secondary_stride(ev);
  71. const auto begin = ev.data();
  72. const auto end = begin + n_iters * secondary_stride;
  73. for (auto iter = begin; iter != end; iter += secondary_stride)
  74. {
  75. fct(iter, iter + secondary_stride);
  76. }
  77. }
  78. template <class E1, class E2, class F>
  79. inline void call_over_leading_axis(E1& e1, E2& e2, F&& fct)
  80. {
  81. XTENSOR_ASSERT(e1.dimension() >= 2);
  82. XTENSOR_ASSERT(e1.dimension() == e2.dimension());
  83. const std::size_t n_iters = leading_axis_n_iters(e1);
  84. const std::ptrdiff_t secondary_stride1 = get_secondary_stride(e1);
  85. const std::ptrdiff_t secondary_stride2 = get_secondary_stride(e2);
  86. XTENSOR_ASSERT(secondary_stride1 == secondary_stride2);
  87. const auto begin1 = e1.data();
  88. const auto end1 = begin1 + n_iters * secondary_stride1;
  89. const auto begin2 = e2.data();
  90. const auto end2 = begin2 + n_iters * secondary_stride2;
  91. auto iter1 = begin1;
  92. auto iter2 = begin2;
  93. for (; (iter1 != end1) && (iter2 != end2); iter1 += secondary_stride1, iter2 += secondary_stride2)
  94. {
  95. fct(iter1, iter1 + secondary_stride1, iter2, iter2 + secondary_stride2);
  96. }
  97. }
  98. template <class E>
  99. inline std::size_t leading_axis(const E& e)
  100. {
  101. if (e.layout() == layout_type::row_major)
  102. {
  103. return e.dimension() - 1;
  104. }
  105. else if (e.layout() == layout_type::column_major)
  106. {
  107. return 0;
  108. }
  109. XTENSOR_THROW(std::runtime_error, "Layout not supported.");
  110. }
  111. // get permutations to transpose and reverse-transpose array
  112. inline std::pair<dynamic_shape<std::size_t>, dynamic_shape<std::size_t>>
  113. get_permutations(std::size_t dim, std::size_t ax, layout_type layout)
  114. {
  115. dynamic_shape<std::size_t> permutation(dim);
  116. std::iota(permutation.begin(), permutation.end(), std::size_t(0));
  117. permutation.erase(permutation.begin() + std::ptrdiff_t(ax));
  118. if (layout == layout_type::row_major)
  119. {
  120. permutation.push_back(ax);
  121. }
  122. else
  123. {
  124. permutation.insert(permutation.begin(), ax);
  125. }
  126. // TODO find a more clever way to get reverse permutation?
  127. dynamic_shape<std::size_t> reverse_permutation;
  128. for (std::size_t i = 0; i < dim; ++i)
  129. {
  130. auto it = std::find(permutation.begin(), permutation.end(), i);
  131. reverse_permutation.push_back(std::size_t(std::distance(permutation.begin(), it)));
  132. }
  133. return std::make_pair(std::move(permutation), std::move(reverse_permutation));
  134. }
  135. template <class R, class E, class F>
  136. inline R map_axis(const E& e, std::ptrdiff_t axis, F&& lambda)
  137. {
  138. if (e.dimension() == 1)
  139. {
  140. R res = e;
  141. lambda(res.begin(), res.end());
  142. return res;
  143. }
  144. const std::size_t ax = normalize_axis(e.dimension(), axis);
  145. if (ax == detail::leading_axis(e))
  146. {
  147. R res = e;
  148. detail::call_over_leading_axis(res, std::forward<F>(lambda));
  149. return res;
  150. }
  151. dynamic_shape<std::size_t> permutation, reverse_permutation;
  152. std::tie(permutation, reverse_permutation) = get_permutations(e.dimension(), ax, e.layout());
  153. R res = transpose(e, permutation);
  154. detail::call_over_leading_axis(res, std::forward<F>(lambda));
  155. res = transpose(res, reverse_permutation);
  156. return res;
  157. }
  158. template <class VT>
  159. struct flatten_sort_result_type_impl
  160. {
  161. using type = VT;
  162. };
  163. template <class VT, std::size_t N, layout_type L>
  164. struct flatten_sort_result_type_impl<xtensor<VT, N, L>>
  165. {
  166. using type = xtensor<VT, 1, L>;
  167. };
  168. template <class VT, class S, layout_type L>
  169. struct flatten_sort_result_type_impl<xtensor_fixed<VT, S, L>>
  170. {
  171. using type = xtensor_fixed<VT, xshape<fixed_compute_size<S>::value>, L>;
  172. };
  173. template <class VT>
  174. struct flatten_sort_result_type : flatten_sort_result_type_impl<common_tensor_type_t<VT>>
  175. {
  176. };
  177. template <class VT>
  178. using flatten_sort_result_type_t = typename flatten_sort_result_type<VT>::type;
  179. template <class E, class R = flatten_sort_result_type_t<E>>
  180. inline auto flat_sort_impl(const xexpression<E>& e)
  181. {
  182. const auto& de = e.derived_cast();
  183. R ev;
  184. ev.resize({static_cast<typename R::shape_type::value_type>(de.size())});
  185. std::copy(de.cbegin(), de.cend(), ev.begin());
  186. std::sort(ev.begin(), ev.end());
  187. return ev;
  188. }
  189. }
  190. template <class E>
  191. inline auto sort(const xexpression<E>& e, placeholders::xtuph /*t*/)
  192. {
  193. return detail::flat_sort_impl(e);
  194. }
  195. namespace detail
  196. {
  197. template <class T>
  198. struct sort_eval_type
  199. {
  200. using type = typename T::temporary_type;
  201. };
  202. template <class T, std::size_t... I, layout_type L>
  203. struct sort_eval_type<xtensor_fixed<T, fixed_shape<I...>, L>>
  204. {
  205. using type = xtensor<T, sizeof...(I), L>;
  206. };
  207. }
  208. /**
  209. * Sort xexpression (optionally along axis)
  210. * The sort is performed using the ``std::sort`` functions.
  211. * A copy of the xexpression is created and returned.
  212. *
  213. * @ingroup xt_xsort
  214. * @param e xexpression to sort
  215. * @param axis axis along which sort is performed
  216. *
  217. * @return sorted array (copy)
  218. */
  219. template <class E>
  220. inline auto sort(const xexpression<E>& e, std::ptrdiff_t axis = -1)
  221. {
  222. using eval_type = typename detail::sort_eval_type<E>::type;
  223. return detail::map_axis<eval_type>(
  224. e.derived_cast(),
  225. axis,
  226. [](auto begin, auto end)
  227. {
  228. std::sort(begin, end);
  229. }
  230. );
  231. }
  232. /*****************************
  233. * Implementation of argsort *
  234. *****************************/
  235. /**
  236. * Sorting method.
  237. * Predefined methods for performing indirect sorting.
  238. * @see argsort(const xexpression<E>&, std::ptrdiff_t, sorting_method)
  239. */
  240. enum class sorting_method
  241. {
  242. /**
  243. * Faster method but with no guarantee on preservation of order of equal elements
  244. * https://en.cppreference.com/w/cpp/algorithm/sort.
  245. */
  246. quick,
  247. /**
  248. * Slower method but with guarantee on preservation of order of equal elements
  249. * https://en.cppreference.com/w/cpp/algorithm/stable_sort.
  250. */
  251. stable,
  252. };
  253. namespace detail
  254. {
  255. template <class ConstRandomIt, class RandomIt, class Compare, class Method>
  256. inline void argsort_iter(
  257. ConstRandomIt data_begin,
  258. ConstRandomIt data_end,
  259. RandomIt idx_begin,
  260. RandomIt idx_end,
  261. Compare comp,
  262. Method method
  263. )
  264. {
  265. XTENSOR_ASSERT(std::distance(data_begin, data_end) >= 0);
  266. XTENSOR_ASSERT(std::distance(idx_begin, idx_end) == std::distance(data_begin, data_end));
  267. (void) idx_end; // TODO(C++17) [[maybe_unused]] only used in assertion.
  268. std::iota(idx_begin, idx_end, 0);
  269. switch (method)
  270. {
  271. case (sorting_method::quick):
  272. {
  273. std::sort(
  274. idx_begin,
  275. idx_end,
  276. [&](const auto i, const auto j)
  277. {
  278. return comp(*(data_begin + i), *(data_begin + j));
  279. }
  280. );
  281. }
  282. case (sorting_method::stable):
  283. {
  284. std::stable_sort(
  285. idx_begin,
  286. idx_end,
  287. [&](const auto i, const auto j)
  288. {
  289. return comp(*(data_begin + i), *(data_begin + j));
  290. }
  291. );
  292. }
  293. }
  294. }
  295. template <class ConstRandomIt, class RandomIt, class Method>
  296. inline void
  297. argsort_iter(ConstRandomIt data_begin, ConstRandomIt data_end, RandomIt idx_begin, RandomIt idx_end, Method method)
  298. {
  299. return argsort_iter(
  300. std::move(data_begin),
  301. std::move(data_end),
  302. std::move(idx_begin),
  303. std::move(idx_end),
  304. [](const auto& x, const auto& y) -> bool
  305. {
  306. return x < y;
  307. },
  308. method
  309. );
  310. }
  311. template <class VT, class T>
  312. struct rebind_value_type
  313. {
  314. using type = xarray<VT, xt::layout_type::dynamic>;
  315. };
  316. template <class VT, class EC, layout_type L>
  317. struct rebind_value_type<VT, xarray<EC, L>>
  318. {
  319. using type = xarray<VT, L>;
  320. };
  321. template <class VT, class EC, std::size_t N, layout_type L>
  322. struct rebind_value_type<VT, xtensor<EC, N, L>>
  323. {
  324. using type = xtensor<VT, N, L>;
  325. };
  326. template <class VT, class ET, class S, layout_type L>
  327. struct rebind_value_type<VT, xtensor_fixed<ET, S, L>>
  328. {
  329. using type = xtensor_fixed<VT, S, L>;
  330. };
  331. template <class VT, class T>
  332. struct flatten_rebind_value_type
  333. {
  334. using type = typename rebind_value_type<VT, T>::type;
  335. };
  336. template <class VT, class EC, std::size_t N, layout_type L>
  337. struct flatten_rebind_value_type<VT, xtensor<EC, N, L>>
  338. {
  339. using type = xtensor<VT, 1, L>;
  340. };
  341. template <class VT, class ET, class S, layout_type L>
  342. struct flatten_rebind_value_type<VT, xtensor_fixed<ET, S, L>>
  343. {
  344. using type = xtensor_fixed<VT, xshape<fixed_compute_size<S>::value>, L>;
  345. };
  346. template <class T>
  347. struct argsort_result_type
  348. {
  349. using type = typename rebind_value_type<typename T::temporary_type::size_type, typename T::temporary_type>::type;
  350. };
  351. template <class T>
  352. struct linear_argsort_result_type
  353. {
  354. using type = typename flatten_rebind_value_type<
  355. typename T::temporary_type::size_type,
  356. typename T::temporary_type>::type;
  357. };
  358. template <class E, class R = typename detail::linear_argsort_result_type<E>::type, class Method>
  359. inline auto flatten_argsort_impl(const xexpression<E>& e, Method method)
  360. {
  361. const auto& de = e.derived_cast();
  362. auto cit = de.template begin<layout_type::row_major>();
  363. using const_iterator = decltype(cit);
  364. auto ad = xiterator_adaptor<const_iterator, const_iterator>(cit, cit, de.size());
  365. using result_type = R;
  366. result_type result;
  367. result.resize({de.size()});
  368. detail::argsort_iter(de.cbegin(), de.cend(), result.begin(), result.end(), method);
  369. return result;
  370. }
  371. }
  372. template <class E>
  373. inline auto
  374. argsort(const xexpression<E>& e, placeholders::xtuph /*t*/, sorting_method method = sorting_method::quick)
  375. {
  376. return detail::flatten_argsort_impl(e, method);
  377. }
  378. /**
  379. * Argsort xexpression (optionally along axis)
  380. * Performs an indirect sort along the given axis. Returns an xarray
  381. * of indices of the same shape as e that index data along the given axis in
  382. * sorted order.
  383. *
  384. * @ingroup xt_xsort
  385. * @param e xexpression to argsort
  386. * @param axis axis along which argsort is performed
  387. * @param method sorting algorithm to use
  388. *
  389. * @return argsorted index array
  390. *
  391. * @see xt::sorting_method
  392. */
  393. template <class E>
  394. inline auto
  395. argsort(const xexpression<E>& e, std::ptrdiff_t axis = -1, sorting_method method = sorting_method::quick)
  396. {
  397. using eval_type = typename detail::sort_eval_type<E>::type;
  398. using result_type = typename detail::argsort_result_type<eval_type>::type;
  399. const auto& de = e.derived_cast();
  400. std::size_t ax = normalize_axis(de.dimension(), axis);
  401. if (de.dimension() == 1)
  402. {
  403. return detail::flatten_argsort_impl<E, result_type>(e, method);
  404. }
  405. const auto argsort = [&method](auto res_begin, auto res_end, auto ev_begin, auto ev_end)
  406. {
  407. detail::argsort_iter(ev_begin, ev_end, res_begin, res_end, method);
  408. };
  409. if (ax == detail::leading_axis(de))
  410. {
  411. result_type res = result_type::from_shape(de.shape());
  412. detail::call_over_leading_axis(res, de, argsort);
  413. return res;
  414. }
  415. dynamic_shape<std::size_t> permutation, reverse_permutation;
  416. std::tie(permutation, reverse_permutation) = detail::get_permutations(de.dimension(), ax, de.layout());
  417. eval_type ev = transpose(de, permutation);
  418. result_type res = result_type::from_shape(ev.shape());
  419. detail::call_over_leading_axis(res, ev, argsort);
  420. res = transpose(res, reverse_permutation);
  421. return res;
  422. }
  423. /************************************************
  424. * Implementation of partition and argpartition *
  425. ************************************************/
  426. namespace detail
  427. {
  428. /**
  429. * Partition a given random iterator.
  430. *
  431. * @param data_begin Start of the data to partition.
  432. * @param data_end Past end of the data to partition.
  433. * @param kth_start Start of the indices to partition.
  434. * Indices must be sorted in decreasing order.
  435. * @param kth_end Past end of the indices to partition.
  436. * Indices must be sorted in decreasing order.
  437. * @param comp Comparison function for `x < y`.
  438. */
  439. template <class RandomIt, class Iter, class Compare>
  440. inline void
  441. partition_iter(RandomIt data_begin, RandomIt data_end, Iter kth_begin, Iter kth_end, Compare comp)
  442. {
  443. XTENSOR_ASSERT(std::distance(data_begin, data_end) >= 0);
  444. XTENSOR_ASSERT(std::distance(kth_begin, kth_end) >= 0);
  445. using idx_type = typename std::iterator_traits<Iter>::value_type;
  446. idx_type k_last = static_cast<idx_type>(std::distance(data_begin, data_end));
  447. for (; kth_begin != kth_end; ++kth_begin)
  448. {
  449. std::nth_element(data_begin, data_begin + *kth_begin, data_begin + k_last, std::move(comp));
  450. k_last = *kth_begin;
  451. }
  452. }
  453. template <class RandomIt, class Iter>
  454. inline void partition_iter(RandomIt data_begin, RandomIt data_end, Iter kth_begin, Iter kth_end)
  455. {
  456. return partition_iter(
  457. std::move(data_begin),
  458. std::move(data_end),
  459. std::move(kth_begin),
  460. std::move(kth_end),
  461. [](const auto& x, const auto& y) -> bool
  462. {
  463. return x < y;
  464. }
  465. );
  466. }
  467. }
  468. /**
  469. * Partially sort xexpression
  470. *
  471. * Partition shuffles the xexpression in a way so that the kth element
  472. * in the returned xexpression is in the place it would appear in a sorted
  473. * array and all elements smaller than this entry are placed (unsorted) before.
  474. *
  475. * The optional third parameter can either be an axis or ``xnone()`` in which case
  476. * the xexpression will be flattened.
  477. *
  478. * This function uses ``std::nth_element`` internally.
  479. *
  480. * @code{cpp}
  481. * xt::xarray<float> a = {1, 10, -10, 123};
  482. * std::cout << xt::partition(a, 0) << std::endl; // {-10, 1, 123, 10} the correct entry at index 0
  483. * std::cout << xt::partition(a, 3) << std::endl; // {1, 10, -10, 123} the correct entry at index 3
  484. * std::cout << xt::partition(a, {0, 3}) << std::endl; // {-10, 1, 10, 123} the correct entries at index 0
  485. * and 3 \endcode
  486. *
  487. * @ingroup xt_xsort
  488. * @param e input xexpression
  489. * @param kth_container a container of ``indices`` that should contain the correctly sorted value
  490. * @param axis either integer (default = -1) to sort along last axis or ``xnone()`` to flatten before
  491. * sorting
  492. *
  493. * @return partially sorted xcontainer
  494. */
  495. template <
  496. class E,
  497. class C,
  498. class R = detail::flatten_sort_result_type_t<E>,
  499. class = std::enable_if_t<!xtl::is_integral<C>::value, int>>
  500. inline R partition(const xexpression<E>& e, C kth_container, placeholders::xtuph /*ax*/)
  501. {
  502. const auto& de = e.derived_cast();
  503. R ev = R::from_shape({de.size()});
  504. std::sort(kth_container.begin(), kth_container.end());
  505. std::copy(de.linear_cbegin(), de.linear_cend(), ev.linear_begin()); // flatten
  506. detail::partition_iter(ev.linear_begin(), ev.linear_end(), kth_container.rbegin(), kth_container.rend());
  507. return ev;
  508. }
  509. template <class E, class I, std::size_t N, class R = detail::flatten_sort_result_type_t<E>>
  510. inline R partition(const xexpression<E>& e, const I (&kth_container)[N], placeholders::xtuph tag)
  511. {
  512. return partition(
  513. e,
  514. xtl::forward_sequence<std::array<std::size_t, N>, decltype(kth_container)>(kth_container),
  515. tag
  516. );
  517. }
  518. template <class E, class R = detail::flatten_sort_result_type_t<E>>
  519. inline R partition(const xexpression<E>& e, std::size_t kth, placeholders::xtuph tag)
  520. {
  521. return partition(e, std::array<std::size_t, 1>({kth}), tag);
  522. }
  523. template <class E, class C, class = std::enable_if_t<!xtl::is_integral<C>::value, int>>
  524. inline auto partition(const xexpression<E>& e, C kth_container, std::ptrdiff_t axis = -1)
  525. {
  526. using eval_type = typename detail::sort_eval_type<E>::type;
  527. std::sort(kth_container.begin(), kth_container.end());
  528. return detail::map_axis<eval_type>(
  529. e.derived_cast(),
  530. axis,
  531. [&kth_container](auto begin, auto end)
  532. {
  533. detail::partition_iter(begin, end, kth_container.rbegin(), kth_container.rend());
  534. }
  535. );
  536. }
  537. template <class E, class T, std::size_t N>
  538. inline auto partition(const xexpression<E>& e, const T (&kth_container)[N], std::ptrdiff_t axis = -1)
  539. {
  540. return partition(
  541. e,
  542. xtl::forward_sequence<std::array<std::size_t, N>, decltype(kth_container)>(kth_container),
  543. axis
  544. );
  545. }
  546. template <class E>
  547. inline auto partition(const xexpression<E>& e, std::size_t kth, std::ptrdiff_t axis = -1)
  548. {
  549. return partition(e, std::array<std::size_t, 1>({kth}), axis);
  550. }
  551. /**
  552. * Partially sort arguments
  553. *
  554. * Argpartition shuffles the indices to a xexpression in a way so that the index for the
  555. * kth element in the returned xexpression is in the place it would appear in a sorted
  556. * array and all elements smaller than this entry are placed (unsorted) before.
  557. *
  558. * The optional third parameter can either be an axis or ``xnone()`` in which case
  559. * the xexpression will be flattened.
  560. *
  561. * This function uses ``std::nth_element`` internally.
  562. *
  563. * @code{cpp}
  564. * xt::xarray<float> a = {1, 10, -10, 123};
  565. * std::cout << xt::argpartition(a, 0) << std::endl; // {2, 0, 3, 1} the correct entry at index 0
  566. * std::cout << xt::argpartition(a, 3) << std::endl; // {0, 1, 2, 3} the correct entry at index 3
  567. * std::cout << xt::argpartition(a, {0, 3}) << std::endl; // {2, 0, 1, 3} the correct entries at index 0
  568. * and 3 \endcode
  569. *
  570. * @ingroup xt_xsort
  571. * @param e input xexpression
  572. * @param kth_container a container of ``indices`` that should contain the correctly sorted value
  573. * @param axis either integer (default = -1) to sort along last axis or ``xnone()`` to flatten before
  574. * sorting
  575. *
  576. * @return xcontainer with indices of partial sort of input
  577. */
  578. template <
  579. class E,
  580. class C,
  581. class R = typename detail::linear_argsort_result_type<typename detail::sort_eval_type<E>::type>::type,
  582. class = std::enable_if_t<!xtl::is_integral<C>::value, int>>
  583. inline R argpartition(const xexpression<E>& e, C kth_container, placeholders::xtuph)
  584. {
  585. using eval_type = typename detail::sort_eval_type<E>::type;
  586. using result_type = typename detail::linear_argsort_result_type<eval_type>::type;
  587. const auto& de = e.derived_cast();
  588. result_type res = result_type::from_shape({de.size()});
  589. std::sort(kth_container.begin(), kth_container.end());
  590. std::iota(res.linear_begin(), res.linear_end(), 0);
  591. detail::partition_iter(
  592. res.linear_begin(),
  593. res.linear_end(),
  594. kth_container.rbegin(),
  595. kth_container.rend(),
  596. [&de](std::size_t a, std::size_t b)
  597. {
  598. return de[a] < de[b];
  599. }
  600. );
  601. return res;
  602. }
  603. template <class E, class I, std::size_t N>
  604. inline auto argpartition(const xexpression<E>& e, const I (&kth_container)[N], placeholders::xtuph tag)
  605. {
  606. return argpartition(
  607. e,
  608. xtl::forward_sequence<std::array<std::size_t, N>, decltype(kth_container)>(kth_container),
  609. tag
  610. );
  611. }
  612. template <class E>
  613. inline auto argpartition(const xexpression<E>& e, std::size_t kth, placeholders::xtuph tag)
  614. {
  615. return argpartition(e, std::array<std::size_t, 1>({kth}), tag);
  616. }
  617. template <class E, class C, class = std::enable_if_t<!xtl::is_integral<C>::value, int>>
  618. inline auto argpartition(const xexpression<E>& e, C kth_container, std::ptrdiff_t axis = -1)
  619. {
  620. using eval_type = typename detail::sort_eval_type<E>::type;
  621. using result_type = typename detail::argsort_result_type<eval_type>::type;
  622. const auto& de = e.derived_cast();
  623. if (de.dimension() == 1)
  624. {
  625. return argpartition<E, C, result_type>(e, std::forward<C>(kth_container), xnone());
  626. }
  627. std::sort(kth_container.begin(), kth_container.end());
  628. const auto argpartition_w_kth =
  629. [&kth_container](auto res_begin, auto res_end, auto ev_begin, auto /*ev_end*/)
  630. {
  631. std::iota(res_begin, res_end, 0);
  632. detail::partition_iter(
  633. res_begin,
  634. res_end,
  635. kth_container.rbegin(),
  636. kth_container.rend(),
  637. [&ev_begin](auto const& i, auto const& j)
  638. {
  639. return *(ev_begin + i) < *(ev_begin + j);
  640. }
  641. );
  642. };
  643. const std::size_t ax = normalize_axis(de.dimension(), axis);
  644. if (ax == detail::leading_axis(de))
  645. {
  646. result_type res = result_type::from_shape(de.shape());
  647. detail::call_over_leading_axis(res, de, argpartition_w_kth);
  648. return res;
  649. }
  650. dynamic_shape<std::size_t> permutation, reverse_permutation;
  651. std::tie(permutation, reverse_permutation) = detail::get_permutations(de.dimension(), ax, de.layout());
  652. eval_type ev = transpose(de, permutation);
  653. result_type res = result_type::from_shape(ev.shape());
  654. detail::call_over_leading_axis(res, ev, argpartition_w_kth);
  655. res = transpose(res, reverse_permutation);
  656. return res;
  657. }
  658. template <class E, class I, std::size_t N>
  659. inline auto argpartition(const xexpression<E>& e, const I (&kth_container)[N], std::ptrdiff_t axis = -1)
  660. {
  661. return argpartition(
  662. e,
  663. xtl::forward_sequence<std::array<std::size_t, N>, decltype(kth_container)>(kth_container),
  664. axis
  665. );
  666. }
  667. template <class E>
  668. inline auto argpartition(const xexpression<E>& e, std::size_t kth, std::ptrdiff_t axis = -1)
  669. {
  670. return argpartition(e, std::array<std::size_t, 1>({kth}), axis);
  671. }
  672. /******************
  673. * xt::quantile *
  674. ******************/
  675. namespace detail
  676. {
  677. template <class S, class I, class K, class O>
  678. inline void select_indices_impl(
  679. const S& shape,
  680. const I& indices,
  681. std::size_t axis,
  682. std::size_t current_dim,
  683. const K& current_index,
  684. O& out
  685. )
  686. {
  687. using id_t = typename K::value_type;
  688. if ((current_dim < shape.size() - 1) && (current_dim == axis))
  689. {
  690. for (auto i : indices)
  691. {
  692. auto idx = current_index;
  693. idx[current_dim] = i;
  694. select_indices_impl(shape, indices, axis, current_dim + 1, idx, out);
  695. }
  696. }
  697. else if ((current_dim < shape.size() - 1) && (current_dim != axis))
  698. {
  699. for (id_t i = 0; xtl::cmp_less(i, shape[current_dim]); ++i)
  700. {
  701. auto idx = current_index;
  702. idx[current_dim] = i;
  703. select_indices_impl(shape, indices, axis, current_dim + 1, idx, out);
  704. }
  705. }
  706. else if ((current_dim == shape.size() - 1) && (current_dim == axis))
  707. {
  708. for (auto i : indices)
  709. {
  710. auto idx = current_index;
  711. idx[current_dim] = i;
  712. out.push_back(std::move(idx));
  713. }
  714. }
  715. else if ((current_dim == shape.size() - 1) && (current_dim != axis))
  716. {
  717. for (id_t i = 0; xtl::cmp_less(i, shape[current_dim]); ++i)
  718. {
  719. auto idx = current_index;
  720. idx[current_dim] = i;
  721. out.push_back(std::move(idx));
  722. }
  723. }
  724. }
  725. template <class S, class I>
  726. inline auto select_indices(const S& shape, const I& indices, std::size_t axis)
  727. {
  728. using index_type = get_strides_t<S>;
  729. auto out = std::vector<index_type>();
  730. select_indices_impl(shape, indices, axis, 0, xtl::make_sequence<index_type>(shape.size()), out);
  731. return out;
  732. }
  733. // TODO remove when fancy index views are implemented
  734. // Poor man's indexing along a single axis as in NumPy a[:, [1, 3, 4]]
  735. template <class E, class I>
  736. inline auto fancy_indexing(E&& e, const I& indices, std::ptrdiff_t axis)
  737. {
  738. const std::size_t ax = normalize_axis(e.dimension(), axis);
  739. using shape_t = get_strides_t<typename std::decay_t<E>::shape_type>;
  740. auto shape = xtl::forward_sequence<shape_t, decltype(e.shape())>(e.shape());
  741. shape[ax] = indices.size();
  742. return reshape_view(
  743. index_view(std::forward<E>(e), select_indices(e.shape(), indices, ax)),
  744. std::move(shape)
  745. );
  746. }
  747. template <class T, class I, class P>
  748. inline auto quantile_kth_gamma(std::size_t n, const P& probas, T alpha, T beta)
  749. {
  750. const auto m = alpha + probas * (T(1) - alpha - beta);
  751. // Evaluting since reused a lot
  752. const auto p_n_m = eval(probas * static_cast<T>(n) + m - 1);
  753. // Previous (virtual) index, may be out of bounds
  754. const auto j = floor(p_n_m);
  755. const auto j_jp1 = concatenate(xtuple(j, j + 1));
  756. // Both interpolation indices, k and k+1
  757. const auto k_kp1 = xt::cast<std::size_t>(clip(j_jp1, T(0), T(n - 1)));
  758. // Both interpolation coefficients, 1-gamma and gamma
  759. const auto omg_g = concatenate(xtuple(T(1) - (p_n_m - j), p_n_m - j));
  760. return std::make_pair(eval(k_kp1), eval(omg_g));
  761. }
  762. // TODO should implement unsqueeze rather
  763. template <class S>
  764. inline auto unsqueeze_shape(const S& shape, std::size_t axis)
  765. {
  766. XTENSOR_ASSERT(axis <= shape.size());
  767. auto new_shape = xtl::forward_sequence<xt::svector<std::size_t>, decltype(shape)>(shape);
  768. new_shape.insert(new_shape.begin() + axis, 1);
  769. return new_shape;
  770. }
  771. }
  772. /**
  773. * Compute quantiles over the given axis.
  774. *
  775. * In a sorted array represneting a distribution of numbers, the quantile of a probability ``p``
  776. * is the the cut value ``q`` such that a fraction ``p`` of the distribution is lesser or equal
  777. * to ``q``.
  778. * When the cutpoint falls between two elemnts of the sample distribution, a interpolation is
  779. * computed using the @p alpha and @p beta coefficients, as descripted in
  780. * (Hyndman and Fan, 1996).
  781. *
  782. * The algorithm partially sorts entries in a copy along the @p axis axis.
  783. *
  784. * @ingroup xt_xsort
  785. * @param e Expression containing the distribution over which the quantiles are computed.
  786. * @param probas An list of probability associated with each desired quantiles.
  787. * All elements must be in the range ``[0, 1]``.
  788. * @param axis The dimension in which to compute the quantiles, *i.e* the axis representing the
  789. * distribution.
  790. * @param alpha Interpolation parameter. Must be in the range ``[0, 1]]``.
  791. * @param beta Interpolation parameter. Must be in the range ``[0, 1]]``.
  792. * @tparam T The type in which the quantile are computed.
  793. * @return An expression with as many dimensions as the input @p e.
  794. * The first axis correspond to the quantiles.
  795. * The other axes are the axes that remain after the reduction of @p e.
  796. * @see (Hyndman and Fan, 1996) R. J. Hyndman and Y. Fan,
  797. * "Sample quantiles in statistical packages", The American Statistician,
  798. * 50(4), pp. 361-365, 1996
  799. * @see https://en.wikipedia.org/wiki/Quantile
  800. */
  801. template <class T = double, class E, class P>
  802. inline auto quantile(E&& e, const P& probas, std::ptrdiff_t axis, T alpha, T beta)
  803. {
  804. XTENSOR_ASSERT(all(0. <= probas));
  805. XTENSOR_ASSERT(all(probas <= 1.));
  806. XTENSOR_ASSERT(0. <= alpha);
  807. XTENSOR_ASSERT(alpha <= 1.);
  808. XTENSOR_ASSERT(0. <= beta);
  809. XTENSOR_ASSERT(beta <= 1.);
  810. using tmp_shape_t = get_strides_t<typename std::decay_t<E>::shape_type>;
  811. using id_t = typename tmp_shape_t::value_type;
  812. const std::size_t ax = normalize_axis(e.dimension(), axis);
  813. const std::size_t n = e.shape()[ax];
  814. auto kth_gamma = detail::quantile_kth_gamma<T, id_t, P>(n, probas, alpha, beta);
  815. // Select relevant values for computing interpolating quantiles
  816. auto e_partition = xt::partition(std::forward<E>(e), kth_gamma.first, ax);
  817. auto e_kth = detail::fancy_indexing(std::move(e_partition), std::move(kth_gamma.first), ax);
  818. // Reshape interpolation coefficients
  819. auto gm1_g_shape = xtl::make_sequence<tmp_shape_t>(e.dimension(), 1);
  820. gm1_g_shape[ax] = kth_gamma.second.size();
  821. auto gm1_g_reshaped = reshape_view(std::move(kth_gamma.second), std::move(gm1_g_shape));
  822. // Compute interpolation
  823. // TODO(C++20) use (and create) xt::lerp in C++
  824. auto e_kth_g = std::move(e_kth) * std::move(gm1_g_reshaped);
  825. // Reshape pairwise interpolate for suming along new axis
  826. auto e_kth_g_shape = detail::unsqueeze_shape(e_kth_g.shape(), ax);
  827. e_kth_g_shape[ax] = 2;
  828. e_kth_g_shape[ax + 1] /= 2;
  829. auto quantiles = xt::sum(reshape_view(std::move(e_kth_g), std::move(e_kth_g_shape)), ax);
  830. // Cannot do a transpose on a non-strided expression so we have to eval
  831. return moveaxis(eval(std::move(quantiles)), ax, 0);
  832. }
  833. // Static proba array overload
  834. template <class T = double, class E, std::size_t N>
  835. inline auto quantile(E&& e, const T (&probas)[N], std::ptrdiff_t axis, T alpha, T beta)
  836. {
  837. return quantile(std::forward<E>(e), adapt(probas, {N}), axis, alpha, beta);
  838. }
  839. /**
  840. * Compute quantiles of the whole expression.
  841. *
  842. * The quantiles are computed over the whole expression, as if flatten in a one-dimensional
  843. * expression.
  844. *
  845. * @ingroup xt_xsort
  846. * @see xt::quantile(E&& e, P const& probas, std::ptrdiff_t axis, T alpha, T beta)
  847. */
  848. template <class T = double, class E, class P>
  849. inline auto quantile(E&& e, const P& probas, T alpha, T beta)
  850. {
  851. return quantile(xt::ravel(std::forward<E>(e)), probas, 0, alpha, beta);
  852. }
  853. // Static proba array overload
  854. template <class T = double, class E, std::size_t N>
  855. inline auto quantile(E&& e, const T (&probas)[N], T alpha, T beta)
  856. {
  857. return quantile(std::forward<E>(e), adapt(probas, {N}), alpha, beta);
  858. }
  859. /**
  860. * Quantile interpolation method.
  861. *
  862. * Predefined methods for interpolating quantiles, as defined in (Hyndman and Fan, 1996).
  863. *
  864. * @ingroup xt_xsort
  865. * @see (Hyndman and Fan, 1996) R. J. Hyndman and Y. Fan,
  866. * "Sample quantiles in statistical packages", The American Statistician,
  867. * 50(4), pp. 361-365, 1996
  868. * @see xt::quantile(E&& e, P const& probas, std::ptrdiff_t axis, xt::quantile_method method)
  869. */
  870. enum class quantile_method
  871. {
  872. /** Method 4 of (Hyndman and Fan, 1996) with ``alpha=0`` and ``beta=1``. */
  873. interpolated_inverted_cdf = 4,
  874. /** Method 5 of (Hyndman and Fan, 1996) with ``alpha=1/2`` and ``beta=1/2``. */
  875. hazen,
  876. /** Method 6 of (Hyndman and Fan, 1996) with ``alpha=0`` and ``beta=0``. */
  877. weibull,
  878. /** Method 7 of (Hyndman and Fan, 1996) with ``alpha=1`` and ``beta=1``. */
  879. linear,
  880. /** Method 8 of (Hyndman and Fan, 1996) with ``alpha=1/3`` and ``beta=1/3``. */
  881. median_unbiased,
  882. /** Method 9 of (Hyndman and Fan, 1996) with ``alpha=3/8`` and ``beta=3/8``. */
  883. normal_unbiased,
  884. };
  885. /**
  886. * Compute quantiles over the given axis.
  887. *
  888. * The function takes the name of a predefined method to compute to interpolate between values.
  889. *
  890. * @ingroup xt_xsort
  891. * @see xt::quantile_method
  892. * @see xt::quantile(E&& e, P const& probas, std::ptrdiff_t axis, T alpha, T beta)
  893. */
  894. template <class T = double, class E, class P>
  895. inline auto
  896. quantile(E&& e, const P& probas, std::ptrdiff_t axis, quantile_method method = quantile_method::linear)
  897. {
  898. T alpha = 0.;
  899. T beta = 0.;
  900. switch (method)
  901. {
  902. case (quantile_method::interpolated_inverted_cdf):
  903. {
  904. alpha = 0.;
  905. beta = 1.;
  906. break;
  907. }
  908. case (quantile_method::hazen):
  909. {
  910. alpha = 0.5;
  911. beta = 0.5;
  912. break;
  913. }
  914. case (quantile_method::weibull):
  915. {
  916. alpha = 0.;
  917. beta = 0.;
  918. break;
  919. }
  920. case (quantile_method::linear):
  921. {
  922. alpha = 1.;
  923. beta = 1.;
  924. break;
  925. }
  926. case (quantile_method::median_unbiased):
  927. {
  928. alpha = 1. / 3.;
  929. beta = 1. / 3.;
  930. break;
  931. }
  932. case (quantile_method::normal_unbiased):
  933. {
  934. alpha = 3. / 8.;
  935. beta = 3. / 8.;
  936. break;
  937. }
  938. }
  939. return quantile(std::forward<E>(e), probas, axis, alpha, beta);
  940. }
  941. // Static proba array overload
  942. template <class T = double, class E, std::size_t N>
  943. inline auto
  944. quantile(E&& e, const T (&probas)[N], std::ptrdiff_t axis, quantile_method method = quantile_method::linear)
  945. {
  946. return quantile(std::forward<E>(e), adapt(probas, {N}), axis, method);
  947. }
  948. /**
  949. * Compute quantiles of the whole expression.
  950. *
  951. * The quantiles are computed over the whole expression, as if flatten in a one-dimensional
  952. * expression.
  953. * The function takes the name of a predefined method to compute to interpolate between values.
  954. *
  955. * @ingroup xt_xsort
  956. * @see xt::quantile_method
  957. * @see xt::quantile(E&& e, P const& probas, std::ptrdiff_t axis, xt::quantile_method method)
  958. */
  959. template <class T = double, class E, class P>
  960. inline auto quantile(E&& e, const P& probas, quantile_method method = quantile_method::linear)
  961. {
  962. return quantile(xt::ravel(std::forward<E>(e)), probas, 0, method);
  963. }
  964. // Static proba array overload
  965. template <class T = double, class E, std::size_t N>
  966. inline auto quantile(E&& e, const T (&probas)[N], quantile_method method = quantile_method::linear)
  967. {
  968. return quantile(std::forward<E>(e), adapt(probas, {N}), method);
  969. }
  970. /****************
  971. * xt::median *
  972. ****************/
  973. template <class E>
  974. inline typename std::decay_t<E>::value_type median(E&& e)
  975. {
  976. using value_type = typename std::decay_t<E>::value_type;
  977. auto sz = e.size();
  978. if (sz % 2 == 0)
  979. {
  980. std::size_t szh = sz / 2; // integer floor div
  981. std::array<std::size_t, 2> kth = {szh - 1, szh};
  982. auto values = xt::partition(xt::flatten(e), kth);
  983. return (values[kth[0]] + values[kth[1]]) / value_type(2);
  984. }
  985. else
  986. {
  987. std::array<std::size_t, 1> kth = {(sz - 1) / 2};
  988. auto values = xt::partition(xt::flatten(e), kth);
  989. return values[kth[0]];
  990. }
  991. }
  992. /**
  993. * Find the median along the specified axis
  994. *
  995. * Given a vector V of length N, the median of V is the middle value of a
  996. * sorted copy of V, V_sorted - i e., V_sorted[(N-1)/2], when N is odd,
  997. * and the average of the two middle values of V_sorted when N is even.
  998. *
  999. * @ingroup xt_xsort
  1000. * @param axis axis along which the medians are computed.
  1001. * If not set, computes the median along a flattened version of the input.
  1002. * @param e input xexpression
  1003. * @return median value
  1004. */
  1005. template <class E>
  1006. inline auto median(E&& e, std::ptrdiff_t axis)
  1007. {
  1008. std::size_t ax = normalize_axis(e.dimension(), axis);
  1009. std::size_t sz = e.shape()[ax];
  1010. xstrided_slice_vector sv(e.dimension(), xt::all());
  1011. if (sz % 2 == 0)
  1012. {
  1013. std::size_t szh = sz / 2; // integer floor div
  1014. std::array<std::size_t, 2> kth = {szh - 1, szh};
  1015. auto values = xt::partition(std::forward<E>(e), kth, static_cast<ptrdiff_t>(ax));
  1016. sv[ax] = xt::range(szh - 1, szh + 1);
  1017. return xt::mean(xt::strided_view(std::move(values), std::move(sv)), {ax});
  1018. }
  1019. else
  1020. {
  1021. std::size_t szh = (sz - 1) / 2;
  1022. std::array<std::size_t, 1> kth = {(sz - 1) / 2};
  1023. auto values = xt::partition(std::forward<E>(e), kth, static_cast<ptrdiff_t>(ax));
  1024. sv[ax] = xt::range(szh, szh + 1);
  1025. return xt::mean(xt::strided_view(std::move(values), std::move(sv)), {ax});
  1026. }
  1027. }
  1028. namespace detail
  1029. {
  1030. template <class T>
  1031. struct argfunc_result_type
  1032. {
  1033. using type = xarray<std::size_t>;
  1034. };
  1035. template <class T, std::size_t N>
  1036. struct argfunc_result_type<xtensor<T, N>>
  1037. {
  1038. using type = xtensor<std::size_t, N - 1>;
  1039. };
  1040. template <layout_type L, class E, class F>
  1041. inline typename argfunc_result_type<E>::type arg_func_impl(const E& e, std::size_t axis, F&& cmp)
  1042. {
  1043. using eval_type = typename detail::sort_eval_type<E>::type;
  1044. using value_type = typename E::value_type;
  1045. using result_type = typename argfunc_result_type<E>::type;
  1046. using result_shape_type = typename result_type::shape_type;
  1047. if (e.dimension() == 1)
  1048. {
  1049. auto begin = e.template begin<L>();
  1050. auto end = e.template end<L>();
  1051. // todo C++17 : constexpr
  1052. if (std::is_same<F, std::less<value_type>>::value)
  1053. {
  1054. std::size_t i = static_cast<std::size_t>(std::distance(begin, std::min_element(begin, end)));
  1055. return xtensor<size_t, 0>{i};
  1056. }
  1057. else
  1058. {
  1059. std::size_t i = static_cast<std::size_t>(std::distance(begin, std::max_element(begin, end)));
  1060. return xtensor<size_t, 0>{i};
  1061. }
  1062. }
  1063. result_shape_type alt_shape;
  1064. xt::resize_container(alt_shape, e.dimension() - 1);
  1065. // Excluding copy, copy all of shape except for axis
  1066. std::copy(e.shape().cbegin(), e.shape().cbegin() + std::ptrdiff_t(axis), alt_shape.begin());
  1067. std::copy(
  1068. e.shape().cbegin() + std::ptrdiff_t(axis) + 1,
  1069. e.shape().cend(),
  1070. alt_shape.begin() + std::ptrdiff_t(axis)
  1071. );
  1072. result_type result = result_type::from_shape(std::move(alt_shape));
  1073. auto result_iter = result.template begin<L>();
  1074. auto arg_func_lambda = [&result_iter, &cmp](auto begin, auto end)
  1075. {
  1076. std::size_t idx = 0;
  1077. value_type val = *begin;
  1078. ++begin;
  1079. for (std::size_t i = 1; begin != end; ++begin, ++i)
  1080. {
  1081. if (cmp(*begin, val))
  1082. {
  1083. val = *begin;
  1084. idx = i;
  1085. }
  1086. }
  1087. *result_iter = idx;
  1088. ++result_iter;
  1089. };
  1090. if (axis != detail::leading_axis(e))
  1091. {
  1092. dynamic_shape<std::size_t> permutation, reverse_permutation;
  1093. std::tie(
  1094. permutation,
  1095. reverse_permutation
  1096. ) = detail::get_permutations(e.dimension(), axis, e.layout());
  1097. // note: creating copy
  1098. eval_type input = transpose(e, permutation);
  1099. detail::call_over_leading_axis(input, arg_func_lambda);
  1100. return result;
  1101. }
  1102. else
  1103. {
  1104. auto&& input = eval(e);
  1105. detail::call_over_leading_axis(input, arg_func_lambda);
  1106. return result;
  1107. }
  1108. }
  1109. }
  1110. template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E>
  1111. inline auto argmin(const xexpression<E>& e)
  1112. {
  1113. using value_type = typename E::value_type;
  1114. auto&& ed = eval(e.derived_cast());
  1115. auto begin = ed.template begin<L>();
  1116. auto end = ed.template end<L>();
  1117. std::size_t i = static_cast<std::size_t>(std::distance(begin, std::min_element(begin, end)));
  1118. return xtensor<size_t, 0>{i};
  1119. }
  1120. /**
  1121. * Find position of minimal value in xexpression.
  1122. * By default, the returned index is into the flattened array.
  1123. * If `axis` is specified, the indices are along the specified axis.
  1124. *
  1125. * @param e input xexpression
  1126. * @param axis select axis (optional)
  1127. *
  1128. * @return returns xarray with positions of minimal value
  1129. */
  1130. template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E>
  1131. inline auto argmin(const xexpression<E>& e, std::ptrdiff_t axis)
  1132. {
  1133. using value_type = typename E::value_type;
  1134. auto&& ed = eval(e.derived_cast());
  1135. std::size_t ax = normalize_axis(ed.dimension(), axis);
  1136. return detail::arg_func_impl<L>(ed, ax, std::less<value_type>());
  1137. }
  1138. template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E>
  1139. inline auto argmax(const xexpression<E>& e)
  1140. {
  1141. using value_type = typename E::value_type;
  1142. auto&& ed = eval(e.derived_cast());
  1143. auto begin = ed.template begin<L>();
  1144. auto end = ed.template end<L>();
  1145. std::size_t i = static_cast<std::size_t>(std::distance(begin, std::max_element(begin, end)));
  1146. return xtensor<size_t, 0>{i};
  1147. }
  1148. /**
  1149. * Find position of maximal value in xexpression
  1150. * By default, the returned index is into the flattened array.
  1151. * If `axis` is specified, the indices are along the specified axis.
  1152. *
  1153. * @ingroup xt_xsort
  1154. * @param e input xexpression
  1155. * @param axis select axis (optional)
  1156. *
  1157. * @return returns xarray with positions of maximal value
  1158. */
  1159. template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E>
  1160. inline auto argmax(const xexpression<E>& e, std::ptrdiff_t axis)
  1161. {
  1162. using value_type = typename E::value_type;
  1163. auto&& ed = eval(e.derived_cast());
  1164. std::size_t ax = normalize_axis(ed.dimension(), axis);
  1165. return detail::arg_func_impl<L>(ed, ax, std::greater<value_type>());
  1166. }
  1167. /**
  1168. * Find unique elements of a xexpression. This returns a flattened xtensor with
  1169. * sorted, unique elements from the original expression.
  1170. *
  1171. * @ingroup xt_xsort
  1172. * @param e input xexpression (will be flattened)
  1173. */
  1174. template <class E>
  1175. inline auto unique(const xexpression<E>& e)
  1176. {
  1177. auto sorted = sort(e, xnone());
  1178. auto end = std::unique(sorted.begin(), sorted.end());
  1179. std::size_t sz = static_cast<std::size_t>(std::distance(sorted.begin(), end));
  1180. // TODO check if we can shrink the vector without reallocation
  1181. using value_type = typename E::value_type;
  1182. auto result = xtensor<value_type, 1>::from_shape({sz});
  1183. std::copy(sorted.begin(), end, result.begin());
  1184. return result;
  1185. }
  1186. /**
  1187. * Find the set difference of two xexpressions. This returns a flattened xtensor with
  1188. * the sorted, unique values in ar1 that are not in ar2.
  1189. *
  1190. * @ingroup xt_xsort
  1191. * @param ar1 input xexpression (will be flattened)
  1192. * @param ar2 input xexpression
  1193. */
  1194. template <class E1, class E2>
  1195. inline auto setdiff1d(const xexpression<E1>& ar1, const xexpression<E2>& ar2)
  1196. {
  1197. using value_type = typename E1::value_type;
  1198. auto unique1 = unique(ar1);
  1199. auto unique2 = unique(ar2);
  1200. auto tmp = xtensor<value_type, 1>::from_shape({unique1.size()});
  1201. auto end = std::set_difference(unique1.begin(), unique1.end(), unique2.begin(), unique2.end(), tmp.begin());
  1202. std::size_t sz = static_cast<std::size_t>(std::distance(tmp.begin(), end));
  1203. auto result = xtensor<value_type, 1>::from_shape({sz});
  1204. std::copy(tmp.begin(), end, result.begin());
  1205. return result;
  1206. }
  1207. }
  1208. #endif