xassign.hpp 48 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367
  1. /***************************************************************************
  2. * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
  3. * Copyright (c) QuantStack *
  4. * *
  5. * Distributed under the terms of the BSD 3-Clause License. *
  6. * *
  7. * The full license is in the file LICENSE, distributed with this software. *
  8. ****************************************************************************/
  9. #ifndef XTENSOR_ASSIGN_HPP
  10. #define XTENSOR_ASSIGN_HPP
  11. #include <algorithm>
  12. #include <functional>
  13. #include <type_traits>
  14. #include <utility>
  15. #include <xtl/xcomplex.hpp>
  16. #include <xtl/xsequence.hpp>
  17. #include "xexpression.hpp"
  18. #include "xfunction.hpp"
  19. #include "xiterator.hpp"
  20. #include "xstrides.hpp"
  21. #include "xtensor_config.hpp"
  22. #include "xtensor_forward.hpp"
  23. #include "xutils.hpp"
  24. #if defined(XTENSOR_USE_TBB)
  25. #include <tbb/tbb.h>
  26. #endif
  27. namespace xt
  28. {
  29. /********************
  30. * Assign functions *
  31. ********************/
  32. template <class E1, class E2>
  33. void assign_data(xexpression<E1>& e1, const xexpression<E2>& e2, bool trivial);
  34. template <class E1, class E2>
  35. void assign_xexpression(xexpression<E1>& e1, const xexpression<E2>& e2);
  36. template <class E1, class E2>
  37. void computed_assign(xexpression<E1>& e1, const xexpression<E2>& e2);
  38. template <class E1, class E2, class F>
  39. void scalar_computed_assign(xexpression<E1>& e1, const E2& e2, F&& f);
  40. template <class E1, class E2>
  41. void assert_compatible_shape(const xexpression<E1>& e1, const xexpression<E2>& e2);
  42. template <class E1, class E2>
  43. void strided_assign(E1& e1, const E2& e2, std::false_type /*disable*/);
  44. template <class E1, class E2>
  45. void strided_assign(E1& e1, const E2& e2, std::true_type /*enable*/);
  46. /************************
  47. * xexpression_assigner *
  48. ************************/
  49. template <class Tag>
  50. class xexpression_assigner_base;
  51. template <>
  52. class xexpression_assigner_base<xtensor_expression_tag>
  53. {
  54. public:
  55. template <class E1, class E2>
  56. static void assign_data(xexpression<E1>& e1, const xexpression<E2>& e2, bool trivial);
  57. };
  58. template <class Tag>
  59. class xexpression_assigner : public xexpression_assigner_base<Tag>
  60. {
  61. public:
  62. using base_type = xexpression_assigner_base<Tag>;
  63. template <class E1, class E2>
  64. static void assign_xexpression(E1& e1, const E2& e2);
  65. template <class E1, class E2>
  66. static void computed_assign(xexpression<E1>& e1, const xexpression<E2>& e2);
  67. template <class E1, class E2, class F>
  68. static void scalar_computed_assign(xexpression<E1>& e1, const E2& e2, F&& f);
  69. template <class E1, class E2>
  70. static void assert_compatible_shape(const xexpression<E1>& e1, const xexpression<E2>& e2);
  71. private:
  72. template <class E1, class E2>
  73. static bool resize(E1& e1, const E2& e2);
  74. template <class E1, class F, class... CT>
  75. static bool resize(E1& e1, const xfunction<F, CT...>& e2);
  76. };
  77. /********************
  78. * stepper_assigner *
  79. ********************/
  80. template <class E1, class E2, layout_type L>
  81. class stepper_assigner
  82. {
  83. public:
  84. using lhs_iterator = typename E1::stepper;
  85. using rhs_iterator = typename E2::const_stepper;
  86. using shape_type = typename E1::shape_type;
  87. using index_type = xindex_type_t<shape_type>;
  88. using size_type = typename lhs_iterator::size_type;
  89. using difference_type = typename lhs_iterator::difference_type;
  90. stepper_assigner(E1& e1, const E2& e2);
  91. void run();
  92. void step(size_type i);
  93. void step(size_type i, size_type n);
  94. void reset(size_type i);
  95. void to_end(layout_type);
  96. private:
  97. E1& m_e1;
  98. lhs_iterator m_lhs;
  99. rhs_iterator m_rhs;
  100. index_type m_index;
  101. };
  102. /*******************
  103. * linear_assigner *
  104. *******************/
  105. template <bool simd_assign>
  106. class linear_assigner
  107. {
  108. public:
  109. template <class E1, class E2>
  110. static void run(E1& e1, const E2& e2);
  111. };
  112. template <>
  113. class linear_assigner<false>
  114. {
  115. public:
  116. template <class E1, class E2>
  117. static void run(E1& e1, const E2& e2);
  118. private:
  119. template <class E1, class E2>
  120. static void run_impl(E1& e1, const E2& e2, std::true_type);
  121. template <class E1, class E2>
  122. static void run_impl(E1& e1, const E2& e2, std::false_type);
  123. };
  124. /*************************
  125. * strided_loop_assigner *
  126. *************************/
  127. namespace strided_assign_detail
  128. {
  129. struct loop_sizes_t
  130. {
  131. bool can_do_strided_assign;
  132. bool is_row_major;
  133. std::size_t inner_loop_size;
  134. std::size_t outer_loop_size;
  135. std::size_t cut;
  136. std::size_t dimension;
  137. };
  138. }
  139. template <bool simd>
  140. class strided_loop_assigner
  141. {
  142. public:
  143. using loop_sizes_t = strided_assign_detail::loop_sizes_t;
  144. // is_row_major, inner_loop_size, outer_loop_size, cut
  145. template <class E1, class E2>
  146. static void run(E1& e1, const E2& e2, const loop_sizes_t& loop_sizes);
  147. template <class E1, class E2>
  148. static loop_sizes_t get_loop_sizes(E1& e1, const E2& e2);
  149. template <class E1, class E2>
  150. static void run(E1& e1, const E2& e2);
  151. };
  152. /***********************************
  153. * Assign functions implementation *
  154. ***********************************/
  155. template <class E1, class E2>
  156. inline void assign_data(xexpression<E1>& e1, const xexpression<E2>& e2, bool trivial)
  157. {
  158. using tag = xexpression_tag_t<E1, E2>;
  159. xexpression_assigner<tag>::assign_data(e1, e2, trivial);
  160. }
  161. template <class E1, class E2>
  162. inline void assign_xexpression(xexpression<E1>& e1, const xexpression<E2>& e2)
  163. {
  164. xtl::mpl::static_if<has_assign_to<E1, E2>::value>(
  165. [&](auto self)
  166. {
  167. self(e2).derived_cast().assign_to(e1);
  168. },
  169. /*else*/
  170. [&](auto /*self*/)
  171. {
  172. using tag = xexpression_tag_t<E1, E2>;
  173. xexpression_assigner<tag>::assign_xexpression(e1, e2);
  174. }
  175. );
  176. }
  177. template <class E1, class E2>
  178. inline void computed_assign(xexpression<E1>& e1, const xexpression<E2>& e2)
  179. {
  180. using tag = xexpression_tag_t<E1, E2>;
  181. xexpression_assigner<tag>::computed_assign(e1, e2);
  182. }
  183. template <class E1, class E2, class F>
  184. inline void scalar_computed_assign(xexpression<E1>& e1, const E2& e2, F&& f)
  185. {
  186. using tag = xexpression_tag_t<E1, E2>;
  187. xexpression_assigner<tag>::scalar_computed_assign(e1, e2, std::forward<F>(f));
  188. }
  189. template <class E1, class E2>
  190. inline void assert_compatible_shape(const xexpression<E1>& e1, const xexpression<E2>& e2)
  191. {
  192. using tag = xexpression_tag_t<E1, E2>;
  193. xexpression_assigner<tag>::assert_compatible_shape(e1, e2);
  194. }
  195. /***************************************
  196. * xexpression_assigner implementation *
  197. ***************************************/
  198. namespace detail
  199. {
  200. template <class E1, class E2>
  201. constexpr bool linear_static_layout()
  202. {
  203. // A row_major or column_major container with a dimension <= 1 is computed as
  204. // layout any, leading to some performance improvements, for example when
  205. // assigning a col-major vector to a row-major vector etc
  206. return compute_layout(
  207. select_layout<E1::static_layout, typename E1::shape_type>::value,
  208. select_layout<E2::static_layout, typename E2::shape_type>::value
  209. )
  210. != layout_type::dynamic;
  211. }
  212. template <class E1, class E2>
  213. inline auto is_linear_assign(const E1& e1, const E2& e2)
  214. -> std::enable_if_t<has_strides<E1>::value, bool>
  215. {
  216. return (E1::contiguous_layout && E2::contiguous_layout && linear_static_layout<E1, E2>())
  217. || (e1.is_contiguous() && e2.has_linear_assign(e1.strides()));
  218. }
  219. template <class E1, class E2>
  220. inline auto is_linear_assign(const E1&, const E2&) -> std::enable_if_t<!has_strides<E1>::value, bool>
  221. {
  222. return false;
  223. }
  224. template <class E1, class E2>
  225. inline bool linear_dynamic_layout(const E1& e1, const E2& e2)
  226. {
  227. return e1.is_contiguous() && e2.is_contiguous()
  228. && compute_layout(e1.layout(), e2.layout()) != layout_type::dynamic;
  229. }
  230. template <class E, class = void>
  231. struct has_step_leading : std::false_type
  232. {
  233. };
  234. template <class E>
  235. struct has_step_leading<E, void_t<decltype(std::declval<E>().step_leading())>> : std::true_type
  236. {
  237. };
  238. template <class T>
  239. struct use_strided_loop
  240. {
  241. static constexpr bool stepper_deref()
  242. {
  243. return std::is_reference<typename T::stepper::reference>::value;
  244. }
  245. static constexpr bool value = has_strides<T>::value
  246. && has_step_leading<typename T::stepper>::value && stepper_deref();
  247. };
  248. template <class T>
  249. struct use_strided_loop<xscalar<T>>
  250. {
  251. static constexpr bool value = true;
  252. };
  253. template <class F, class... CT>
  254. struct use_strided_loop<xfunction<F, CT...>>
  255. {
  256. static constexpr bool value = xtl::conjunction<use_strided_loop<std::decay_t<CT>>...>::value;
  257. };
  258. /**
  259. * Considering the assignment LHS = RHS, if the requested value type used for
  260. * loading simd from RHS is not complex while LHS value_type is complex,
  261. * the assignment fails. The reason is that SIMD batches of complex values cannot
  262. * be implicitly instantiated from batches of scalar values.
  263. * Making the constructor implicit does not fix the issue since in the end,
  264. * the assignment is done with vec.store(buffer) where vec is a batch of scalars
  265. * and buffer an array of complex. SIMD batches of scalars do not provide overloads
  266. * of store that accept buffer of complex values and that SHOULD NOT CHANGE.
  267. * Load and store overloads must accept SCALAR BUFFERS ONLY.
  268. * Therefore, the solution is to explicitly force the instantiation of complex
  269. * batches in the assignment mechanism. A common situation that triggers this
  270. * issue is:
  271. * xt::xarray<double> rhs = { 1, 2, 3 };
  272. * xt::xarray<std::complex<double>> lhs = rhs;
  273. */
  274. template <class T1, class T2>
  275. struct conditional_promote_to_complex
  276. {
  277. static constexpr bool cond = xtl::is_gen_complex<T1>::value && !xtl::is_gen_complex<T2>::value;
  278. // Alternative: use std::complex<T2> or xcomplex<T2, T2, bool> depending on T1
  279. using type = std::conditional_t<cond, T1, T2>;
  280. };
  281. template <class T1, class T2>
  282. using conditional_promote_to_complex_t = typename conditional_promote_to_complex<T1, T2>::type;
  283. }
  284. template <class E1, class E2>
  285. class xassign_traits
  286. {
  287. private:
  288. using e1_value_type = typename E1::value_type;
  289. using e2_value_type = typename E2::value_type;
  290. template <class T>
  291. using is_bool = std::is_same<T, bool>;
  292. static constexpr bool is_bool_conversion()
  293. {
  294. return is_bool<e2_value_type>::value && !is_bool<e1_value_type>::value;
  295. }
  296. static constexpr bool contiguous_layout()
  297. {
  298. return E1::contiguous_layout && E2::contiguous_layout;
  299. }
  300. static constexpr bool convertible_types()
  301. {
  302. return std::is_convertible<e2_value_type, e1_value_type>::value && !is_bool_conversion();
  303. }
  304. static constexpr bool use_xsimd()
  305. {
  306. return xt_simd::simd_traits<int8_t>::size > 1;
  307. }
  308. template <class T>
  309. static constexpr bool simd_size_impl()
  310. {
  311. return xt_simd::simd_traits<T>::size > 1 || (is_bool<T>::value && use_xsimd());
  312. }
  313. static constexpr bool simd_size()
  314. {
  315. return simd_size_impl<e1_value_type>() && simd_size_impl<e2_value_type>();
  316. }
  317. static constexpr bool simd_interface()
  318. {
  319. return has_simd_interface<E1, requested_value_type>()
  320. && has_simd_interface<E2, requested_value_type>();
  321. }
  322. public:
  323. // constexpr methods instead of constexpr data members avoid the need of definitions at namespace
  324. // scope of these data members (since they are odr-used).
  325. static constexpr bool simd_assign()
  326. {
  327. return convertible_types() && simd_size() && simd_interface();
  328. }
  329. static constexpr bool linear_assign(const E1& e1, const E2& e2, bool trivial)
  330. {
  331. return trivial && detail::is_linear_assign(e1, e2);
  332. }
  333. static constexpr bool strided_assign()
  334. {
  335. return detail::use_strided_loop<E1>::value && detail::use_strided_loop<E2>::value;
  336. }
  337. static constexpr bool simd_linear_assign()
  338. {
  339. return contiguous_layout() && simd_assign();
  340. }
  341. static constexpr bool simd_strided_assign()
  342. {
  343. return strided_assign() && simd_assign();
  344. }
  345. static constexpr bool simd_linear_assign(const E1& e1, const E2& e2)
  346. {
  347. return simd_assign() && detail::linear_dynamic_layout(e1, e2);
  348. }
  349. using e2_requested_value_type = std::
  350. conditional_t<is_bool<e2_value_type>::value, typename E2::bool_load_type, e2_value_type>;
  351. using requested_value_type = detail::conditional_promote_to_complex_t<e1_value_type, e2_requested_value_type>;
  352. };
  353. template <class E1, class E2>
  354. inline void xexpression_assigner_base<xtensor_expression_tag>::assign_data(
  355. xexpression<E1>& e1,
  356. const xexpression<E2>& e2,
  357. bool trivial
  358. )
  359. {
  360. E1& de1 = e1.derived_cast();
  361. const E2& de2 = e2.derived_cast();
  362. using traits = xassign_traits<E1, E2>;
  363. bool linear_assign = traits::linear_assign(de1, de2, trivial);
  364. constexpr bool simd_assign = traits::simd_assign();
  365. constexpr bool simd_linear_assign = traits::simd_linear_assign();
  366. constexpr bool simd_strided_assign = traits::simd_strided_assign();
  367. if (linear_assign)
  368. {
  369. if (simd_linear_assign || traits::simd_linear_assign(de1, de2))
  370. {
  371. // Do not use linear_assigner<true> here since it will make the compiler
  372. // instantiate this branch even if the runtime condition is false, resulting
  373. // in compilation error for expressions that do not provide a SIMD interface.
  374. // simd_assign is true if simd_linear_assign() or simd_linear_assign(de1, de2)
  375. // is true.
  376. linear_assigner<simd_assign>::run(de1, de2);
  377. }
  378. else
  379. {
  380. linear_assigner<false>::run(de1, de2);
  381. }
  382. }
  383. else if (simd_strided_assign)
  384. {
  385. strided_loop_assigner<simd_strided_assign>::run(de1, de2);
  386. }
  387. else
  388. {
  389. stepper_assigner<E1, E2, default_assignable_layout(E1::static_layout)>(de1, de2).run();
  390. }
  391. }
  392. template <class Tag>
  393. template <class E1, class E2>
  394. inline void xexpression_assigner<Tag>::assign_xexpression(E1& e1, const E2& e2)
  395. {
  396. bool trivial_broadcast = resize(e1.derived_cast(), e2.derived_cast());
  397. base_type::assign_data(e1, e2, trivial_broadcast);
  398. }
  399. template <class Tag>
  400. template <class E1, class E2>
  401. inline void xexpression_assigner<Tag>::computed_assign(xexpression<E1>& e1, const xexpression<E2>& e2)
  402. {
  403. using shape_type = typename E1::shape_type;
  404. using comperator_type = std::greater<typename shape_type::value_type>;
  405. using size_type = typename E1::size_type;
  406. E1& de1 = e1.derived_cast();
  407. const E2& de2 = e2.derived_cast();
  408. size_type dim2 = de2.dimension();
  409. shape_type shape = uninitialized_shape<shape_type>(dim2);
  410. bool trivial_broadcast = de2.broadcast_shape(shape, true);
  411. auto&& de1_shape = de1.shape();
  412. if (dim2 > de1.dimension()
  413. || std::lexicographical_compare(
  414. shape.begin(),
  415. shape.end(),
  416. de1_shape.begin(),
  417. de1_shape.end(),
  418. comperator_type()
  419. ))
  420. {
  421. typename E1::temporary_type tmp(shape);
  422. base_type::assign_data(tmp, e2, trivial_broadcast);
  423. de1.assign_temporary(std::move(tmp));
  424. }
  425. else
  426. {
  427. base_type::assign_data(e1, e2, trivial_broadcast);
  428. }
  429. }
  430. template <class Tag>
  431. template <class E1, class E2, class F>
  432. inline void xexpression_assigner<Tag>::scalar_computed_assign(xexpression<E1>& e1, const E2& e2, F&& f)
  433. {
  434. E1& d = e1.derived_cast();
  435. using size_type = typename E1::size_type;
  436. auto dst = d.storage().begin();
  437. for (size_type i = d.size(); i > 0; --i)
  438. {
  439. *dst = f(*dst, e2);
  440. ++dst;
  441. }
  442. }
  443. template <class Tag>
  444. template <class E1, class E2>
  445. inline void
  446. xexpression_assigner<Tag>::assert_compatible_shape(const xexpression<E1>& e1, const xexpression<E2>& e2)
  447. {
  448. const E1& de1 = e1.derived_cast();
  449. const E2& de2 = e2.derived_cast();
  450. if (!broadcastable(de2.shape(), de1.shape()))
  451. {
  452. throw_broadcast_error(de2.shape(), de1.shape());
  453. }
  454. }
  455. namespace detail
  456. {
  457. template <bool B, class... CT>
  458. struct static_trivial_broadcast;
  459. template <class... CT>
  460. struct static_trivial_broadcast<true, CT...>
  461. {
  462. static constexpr bool value = detail::promote_index<typename std::decay_t<CT>::shape_type...>::value;
  463. };
  464. template <class... CT>
  465. struct static_trivial_broadcast<false, CT...>
  466. {
  467. static constexpr bool value = false;
  468. };
  469. }
  470. template <class Tag>
  471. template <class E1, class E2>
  472. inline bool xexpression_assigner<Tag>::resize(E1& e1, const E2& e2)
  473. {
  474. // If our RHS is not a xfunction, we know that the RHS is at least potentially trivial
  475. // We check the strides of the RHS in detail::is_trivial_broadcast to see if they match up!
  476. // So we can skip a shape copy and a call to broadcast_shape(...)
  477. e1.resize(e2.shape());
  478. return true;
  479. }
  480. template <class Tag>
  481. template <class E1, class F, class... CT>
  482. inline bool xexpression_assigner<Tag>::resize(E1& e1, const xfunction<F, CT...>& e2)
  483. {
  484. return xtl::mpl::static_if<detail::is_fixed<typename xfunction<F, CT...>::shape_type>::value>(
  485. [&](auto /*self*/)
  486. {
  487. /*
  488. * If the shape of the xfunction is statically known, we can compute the broadcast triviality
  489. * at compile time plus we can resize right away.
  490. */
  491. // resize in case LHS is not a fixed size container. If it is, this is a NOP
  492. e1.resize(typename xfunction<F, CT...>::shape_type{});
  493. return detail::static_trivial_broadcast<
  494. detail::is_fixed<typename xfunction<F, CT...>::shape_type>::value,
  495. CT...>::value;
  496. },
  497. /* else */
  498. [&](auto /*self*/)
  499. {
  500. using index_type = xindex_type_t<typename E1::shape_type>;
  501. using size_type = typename E1::size_type;
  502. size_type size = e2.dimension();
  503. index_type shape = uninitialized_shape<index_type>(size);
  504. bool trivial_broadcast = e2.broadcast_shape(shape, true);
  505. e1.resize(std::move(shape));
  506. return trivial_broadcast;
  507. }
  508. );
  509. }
  510. /***********************************
  511. * stepper_assigner implementation *
  512. ***********************************/
  513. template <class FROM, class TO>
  514. struct is_narrowing_conversion
  515. {
  516. using argument_type = std::decay_t<FROM>;
  517. using result_type = std::decay_t<TO>;
  518. static const bool value = xtl::is_arithmetic<result_type>::value
  519. && (sizeof(result_type) < sizeof(argument_type)
  520. || (xtl::is_integral<result_type>::value
  521. && std::is_floating_point<argument_type>::value));
  522. };
  523. template <class FROM, class TO>
  524. struct has_sign_conversion
  525. {
  526. using argument_type = std::decay_t<FROM>;
  527. using result_type = std::decay_t<TO>;
  528. static const bool value = xtl::is_signed<argument_type>::value != xtl::is_signed<result_type>::value;
  529. };
  530. template <class FROM, class TO>
  531. struct has_assign_conversion
  532. {
  533. using argument_type = std::decay_t<FROM>;
  534. using result_type = std::decay_t<TO>;
  535. static const bool value = is_narrowing_conversion<argument_type, result_type>::value
  536. || has_sign_conversion<argument_type, result_type>::value;
  537. };
  538. template <class E1, class E2, layout_type L>
  539. inline stepper_assigner<E1, E2, L>::stepper_assigner(E1& e1, const E2& e2)
  540. : m_e1(e1)
  541. , m_lhs(e1.stepper_begin(e1.shape()))
  542. , m_rhs(e2.stepper_begin(e1.shape()))
  543. , m_index(xtl::make_sequence<index_type>(e1.shape().size(), size_type(0)))
  544. {
  545. }
  546. template <class E1, class E2, layout_type L>
  547. inline void stepper_assigner<E1, E2, L>::run()
  548. {
  549. using tmp_size_type = typename E1::size_type;
  550. using argument_type = std::decay_t<decltype(*m_rhs)>;
  551. using result_type = std::decay_t<decltype(*m_lhs)>;
  552. constexpr bool needs_cast = has_assign_conversion<argument_type, result_type>::value;
  553. tmp_size_type s = m_e1.size();
  554. for (tmp_size_type i = 0; i < s; ++i)
  555. {
  556. *m_lhs = conditional_cast<needs_cast, result_type>(*m_rhs);
  557. stepper_tools<L>::increment_stepper(*this, m_index, m_e1.shape());
  558. }
  559. }
  560. template <class E1, class E2, layout_type L>
  561. inline void stepper_assigner<E1, E2, L>::step(size_type i)
  562. {
  563. m_lhs.step(i);
  564. m_rhs.step(i);
  565. }
  566. template <class E1, class E2, layout_type L>
  567. inline void stepper_assigner<E1, E2, L>::step(size_type i, size_type n)
  568. {
  569. m_lhs.step(i, n);
  570. m_rhs.step(i, n);
  571. }
  572. template <class E1, class E2, layout_type L>
  573. inline void stepper_assigner<E1, E2, L>::reset(size_type i)
  574. {
  575. m_lhs.reset(i);
  576. m_rhs.reset(i);
  577. }
  578. template <class E1, class E2, layout_type L>
  579. inline void stepper_assigner<E1, E2, L>::to_end(layout_type l)
  580. {
  581. m_lhs.to_end(l);
  582. m_rhs.to_end(l);
  583. }
  584. /**********************************
  585. * linear_assigner implementation *
  586. **********************************/
  587. template <bool simd_assign>
  588. template <class E1, class E2>
  589. inline void linear_assigner<simd_assign>::run(E1& e1, const E2& e2)
  590. {
  591. using lhs_align_mode = xt_simd::container_alignment_t<E1>;
  592. constexpr bool is_aligned = std::is_same<lhs_align_mode, aligned_mode>::value;
  593. using rhs_align_mode = std::conditional_t<is_aligned, inner_aligned_mode, unaligned_mode>;
  594. using e1_value_type = typename E1::value_type;
  595. using e2_value_type = typename E2::value_type;
  596. using value_type = typename xassign_traits<E1, E2>::requested_value_type;
  597. using simd_type = xt_simd::simd_type<value_type>;
  598. using size_type = typename E1::size_type;
  599. size_type size = e1.size();
  600. constexpr size_type simd_size = simd_type::size;
  601. constexpr bool needs_cast = has_assign_conversion<e1_value_type, e2_value_type>::value;
  602. size_type align_begin = is_aligned ? 0 : xt_simd::get_alignment_offset(e1.data(), size, simd_size);
  603. size_type align_end = align_begin + ((size - align_begin) & ~(simd_size - 1));
  604. for (size_type i = 0; i < align_begin; ++i)
  605. {
  606. e1.data_element(i) = conditional_cast<needs_cast, e1_value_type>(e2.data_element(i));
  607. }
  608. #if defined(XTENSOR_USE_TBB)
  609. if (size >= XTENSOR_TBB_THRESHOLD)
  610. {
  611. tbb::static_partitioner ap;
  612. tbb::parallel_for(
  613. align_begin,
  614. align_end,
  615. simd_size,
  616. [&e1, &e2](size_t i)
  617. {
  618. e1.template store_simd<lhs_align_mode>(
  619. i,
  620. e2.template load_simd<rhs_align_mode, value_type>(i)
  621. );
  622. },
  623. ap
  624. );
  625. }
  626. else
  627. {
  628. for (size_type i = align_begin; i < align_end; i += simd_size)
  629. {
  630. e1.template store_simd<lhs_align_mode>(i, e2.template load_simd<rhs_align_mode, value_type>(i));
  631. }
  632. }
  633. #elif defined(XTENSOR_USE_OPENMP)
  634. if (size >= size_type(XTENSOR_OPENMP_TRESHOLD))
  635. {
  636. #pragma omp parallel for default(none) shared(align_begin, align_end, e1, e2)
  637. #ifndef _WIN32
  638. for (size_type i = align_begin; i < align_end; i += simd_size)
  639. {
  640. e1.template store_simd<lhs_align_mode>(i, e2.template load_simd<rhs_align_mode, value_type>(i));
  641. }
  642. #else
  643. for (auto i = static_cast<std::ptrdiff_t>(align_begin); i < static_cast<std::ptrdiff_t>(align_end);
  644. i += static_cast<std::ptrdiff_t>(simd_size))
  645. {
  646. size_type ui = static_cast<size_type>(i);
  647. e1.template store_simd<lhs_align_mode>(ui, e2.template load_simd<rhs_align_mode, value_type>(ui));
  648. }
  649. #endif
  650. }
  651. else
  652. {
  653. for (size_type i = align_begin; i < align_end; i += simd_size)
  654. {
  655. e1.template store_simd<lhs_align_mode>(i, e2.template load_simd<rhs_align_mode, value_type>(i));
  656. }
  657. }
  658. #else
  659. for (size_type i = align_begin; i < align_end; i += simd_size)
  660. {
  661. e1.template store_simd<lhs_align_mode>(i, e2.template load_simd<rhs_align_mode, value_type>(i));
  662. }
  663. #endif
  664. for (size_type i = align_end; i < size; ++i)
  665. {
  666. e1.data_element(i) = conditional_cast<needs_cast, e1_value_type>(e2.data_element(i));
  667. }
  668. }
  669. template <class E1, class E2>
  670. inline void linear_assigner<false>::run(E1& e1, const E2& e2)
  671. {
  672. using is_convertible = std::
  673. is_convertible<typename std::decay_t<E2>::value_type, typename std::decay_t<E1>::value_type>;
  674. // If the types are not compatible, this function is still instantiated but never called.
  675. // To avoid compilation problems in effectively unused code trivial_assigner_run_impl is
  676. // empty in this case.
  677. run_impl(e1, e2, is_convertible());
  678. }
  679. template <class E1, class E2>
  680. inline void linear_assigner<false>::run_impl(E1& e1, const E2& e2, std::true_type /*is_convertible*/)
  681. {
  682. using value_type = typename E1::value_type;
  683. using size_type = typename E1::size_type;
  684. auto src = linear_begin(e2);
  685. auto dst = linear_begin(e1);
  686. size_type n = e1.size();
  687. #if defined(XTENSOR_USE_TBB)
  688. tbb::static_partitioner sp;
  689. tbb::parallel_for(
  690. std::ptrdiff_t(0),
  691. static_cast<std::ptrdiff_t>(n),
  692. [&](std::ptrdiff_t i)
  693. {
  694. *(dst + i) = static_cast<value_type>(*(src + i));
  695. },
  696. sp
  697. );
  698. #elif defined(XTENSOR_USE_OPENMP)
  699. if (n >= XTENSOR_OPENMP_TRESHOLD)
  700. {
  701. #pragma omp parallel for default(none) shared(src, dst, n)
  702. for (std::ptrdiff_t i = std::ptrdiff_t(0); i < static_cast<std::ptrdiff_t>(n); i++)
  703. {
  704. *(dst + i) = static_cast<value_type>(*(src + i));
  705. }
  706. }
  707. else
  708. {
  709. for (; n > size_type(0); --n)
  710. {
  711. *dst = static_cast<value_type>(*src);
  712. ++src;
  713. ++dst;
  714. }
  715. }
  716. #else
  717. for (; n > size_type(0); --n)
  718. {
  719. *dst = static_cast<value_type>(*src);
  720. ++src;
  721. ++dst;
  722. }
  723. #endif
  724. }
  725. template <class E1, class E2>
  726. inline void linear_assigner<false>::run_impl(E1&, const E2&, std::false_type /*is_convertible*/)
  727. {
  728. XTENSOR_PRECONDITION(false, "Internal error: linear_assigner called with unrelated types.");
  729. }
  730. /****************************************
  731. * strided_loop_assigner implementation *
  732. ****************************************/
  733. namespace strided_assign_detail
  734. {
  735. template <layout_type layout>
  736. struct idx_tools;
  737. template <>
  738. struct idx_tools<layout_type::row_major>
  739. {
  740. template <class T>
  741. static void next_idx(T& outer_index, T& outer_shape)
  742. {
  743. auto i = outer_index.size();
  744. for (; i > 0; --i)
  745. {
  746. if (outer_index[i - 1] + 1 >= outer_shape[i - 1])
  747. {
  748. outer_index[i - 1] = 0;
  749. }
  750. else
  751. {
  752. outer_index[i - 1]++;
  753. break;
  754. }
  755. }
  756. }
  757. template <class T>
  758. static void nth_idx(size_t n, T& outer_index, const T& outer_shape)
  759. {
  760. dynamic_shape<std::size_t> stride_sizes;
  761. xt::resize_container(stride_sizes, outer_shape.size());
  762. // compute strides
  763. using size_type = typename T::size_type;
  764. for (size_type i = outer_shape.size(); i > 0; i--)
  765. {
  766. stride_sizes[i - 1] = (i == outer_shape.size()) ? 1 : stride_sizes[i] * outer_shape[i];
  767. }
  768. // compute index
  769. for (size_type i = 0; i < outer_shape.size(); i++)
  770. {
  771. auto d_idx = n / stride_sizes[i];
  772. outer_index[i] = d_idx;
  773. n -= d_idx * stride_sizes[i];
  774. }
  775. }
  776. };
  777. template <>
  778. struct idx_tools<layout_type::column_major>
  779. {
  780. template <class T>
  781. static void next_idx(T& outer_index, T& outer_shape)
  782. {
  783. using size_type = typename T::size_type;
  784. size_type i = 0;
  785. auto sz = outer_index.size();
  786. for (; i < sz; ++i)
  787. {
  788. if (outer_index[i] + 1 >= outer_shape[i])
  789. {
  790. outer_index[i] = 0;
  791. }
  792. else
  793. {
  794. outer_index[i]++;
  795. break;
  796. }
  797. }
  798. }
  799. template <class T>
  800. static void nth_idx(size_t n, T& outer_index, const T& outer_shape)
  801. {
  802. dynamic_shape<std::size_t> stride_sizes;
  803. xt::resize_container(stride_sizes, outer_shape.size());
  804. using size_type = typename T::size_type;
  805. // compute required strides
  806. for (size_type i = 0; i < outer_shape.size(); i++)
  807. {
  808. stride_sizes[i] = (i == 0) ? 1 : stride_sizes[i - 1] * outer_shape[i - 1];
  809. }
  810. // compute index
  811. for (size_type i = outer_shape.size(); i > 0;)
  812. {
  813. i--;
  814. auto d_idx = n / stride_sizes[i];
  815. outer_index[i] = d_idx;
  816. n -= d_idx * stride_sizes[i];
  817. }
  818. }
  819. };
  820. template <layout_type L, class S>
  821. struct check_strides_functor
  822. {
  823. using strides_type = S;
  824. check_strides_functor(const S& strides)
  825. : m_cut(L == layout_type::row_major ? 0 : strides.size())
  826. , m_strides(strides)
  827. {
  828. }
  829. template <class T, layout_type LE = L>
  830. std::enable_if_t<LE == layout_type::row_major, std::size_t> operator()(const T& el)
  831. {
  832. // All dimenions less than var have differing strides
  833. auto var = check_strides_overlap<layout_type::row_major>::get(m_strides, el.strides());
  834. if (var > m_cut)
  835. {
  836. m_cut = var;
  837. }
  838. return m_cut;
  839. }
  840. template <class T, layout_type LE = L>
  841. std::enable_if_t<LE == layout_type::column_major, std::size_t> operator()(const T& el)
  842. {
  843. auto var = check_strides_overlap<layout_type::column_major>::get(m_strides, el.strides());
  844. // All dimensions >= var have differing strides
  845. if (var < m_cut)
  846. {
  847. m_cut = var;
  848. }
  849. return m_cut;
  850. }
  851. template <class T>
  852. std::size_t operator()(const xt::xscalar<T>& /*el*/)
  853. {
  854. return m_cut;
  855. }
  856. template <class F, class... CT>
  857. std::size_t operator()(const xt::xfunction<F, CT...>& xf)
  858. {
  859. xt::for_each(*this, xf.arguments());
  860. return m_cut;
  861. }
  862. private:
  863. std::size_t m_cut;
  864. const strides_type& m_strides;
  865. };
  866. template <bool possible = true, class E1, class E2, std::enable_if_t<!has_strides<E1>::value || !possible, bool> = true>
  867. loop_sizes_t get_loop_sizes(const E1& e1, const E2&)
  868. {
  869. return {false, true, 1, e1.size(), e1.dimension(), e1.dimension()};
  870. }
  871. template <bool possible = true, class E1, class E2, std::enable_if_t<has_strides<E1>::value && possible, bool> = true>
  872. loop_sizes_t get_loop_sizes(const E1& e1, const E2& e2)
  873. {
  874. using shape_value_type = typename E1::shape_type::value_type;
  875. bool is_row_major = true;
  876. // Try to find a row-major scheme first, where the outer loop is on the first N = `cut`
  877. // dimensions, and the inner loop is
  878. is_row_major = true;
  879. auto is_zero = [](auto i)
  880. {
  881. return i == 0;
  882. };
  883. auto&& strides = e1.strides();
  884. auto it_bwd = std::find_if_not(strides.rbegin(), strides.rend(), is_zero);
  885. bool de1_row_contiguous = it_bwd != strides.rend() && *it_bwd == 1;
  886. auto it_fwd = std::find_if_not(strides.begin(), strides.end(), is_zero);
  887. bool de1_col_contiguous = it_fwd != strides.end() && *it_fwd == 1;
  888. if (de1_row_contiguous)
  889. {
  890. is_row_major = true;
  891. }
  892. else if (de1_col_contiguous)
  893. {
  894. is_row_major = false;
  895. }
  896. else
  897. {
  898. // No strided loop possible.
  899. return {false, true, 1, e1.size(), e1.dimension(), e1.dimension()};
  900. }
  901. // Cut is the number of dimensions in the outer loop
  902. std::size_t cut = 0;
  903. if (is_row_major)
  904. {
  905. auto csf = check_strides_functor<layout_type::row_major, decltype(e1.strides())>(e1.strides());
  906. cut = csf(e2);
  907. // This makes that only one dimension will be treated in the inner loop.
  908. if (cut < e1.strides().size() - 1)
  909. {
  910. // Only make the inner loop go over one dimension by default for now
  911. cut = e1.strides().size() - 1;
  912. }
  913. }
  914. else if (!is_row_major)
  915. {
  916. auto csf = check_strides_functor<layout_type::column_major, decltype(e1.strides())>(e1.strides()
  917. );
  918. cut = csf(e2);
  919. if (cut > 1)
  920. {
  921. // Only make the inner loop go over one dimension by default for now
  922. cut = 1;
  923. }
  924. } // can't reach here because this would have already triggered the fallback
  925. std::size_t outer_loop_size = static_cast<std::size_t>(std::accumulate(
  926. e1.shape().begin(),
  927. e1.shape().begin() + static_cast<std::ptrdiff_t>(cut),
  928. shape_value_type(1),
  929. std::multiplies<shape_value_type>{}
  930. ));
  931. std::size_t inner_loop_size = static_cast<std::size_t>(std::accumulate(
  932. e1.shape().begin() + static_cast<std::ptrdiff_t>(cut),
  933. e1.shape().end(),
  934. shape_value_type(1),
  935. std::multiplies<shape_value_type>{}
  936. ));
  937. if (!is_row_major)
  938. {
  939. std::swap(outer_loop_size, inner_loop_size);
  940. }
  941. return {inner_loop_size > 1, is_row_major, inner_loop_size, outer_loop_size, cut, e1.dimension()};
  942. }
  943. }
  944. template <bool simd>
  945. template <class E1, class E2>
  946. inline strided_assign_detail::loop_sizes_t strided_loop_assigner<simd>::get_loop_sizes(E1& e1, const E2& e2)
  947. {
  948. return strided_assign_detail::get_loop_sizes<simd>(e1, e2);
  949. }
  950. #define strided_parallel_assign
  951. template <bool simd>
  952. template <class E1, class E2>
  953. inline void strided_loop_assigner<simd>::run(E1& e1, const E2& e2, const loop_sizes_t& loop_sizes)
  954. {
  955. bool is_row_major = loop_sizes.is_row_major;
  956. std::size_t inner_loop_size = loop_sizes.inner_loop_size;
  957. std::size_t outer_loop_size = loop_sizes.outer_loop_size;
  958. std::size_t cut = loop_sizes.cut;
  959. // TODO can we get rid of this and use `shape_type`?
  960. dynamic_shape<std::size_t> idx, max_shape;
  961. if (is_row_major)
  962. {
  963. xt::resize_container(idx, cut);
  964. max_shape.assign(e1.shape().begin(), e1.shape().begin() + static_cast<std::ptrdiff_t>(cut));
  965. }
  966. else
  967. {
  968. xt::resize_container(idx, e1.shape().size() - cut);
  969. max_shape.assign(e1.shape().begin() + static_cast<std::ptrdiff_t>(cut), e1.shape().end());
  970. }
  971. // add this when we have std::array index!
  972. // std::fill(idx.begin(), idx.end(), 0);
  973. using e1_value_type = typename E1::value_type;
  974. using e2_value_type = typename E2::value_type;
  975. constexpr bool needs_cast = has_assign_conversion<e1_value_type, e2_value_type>::value;
  976. using value_type = typename xassign_traits<E1, E2>::requested_value_type;
  977. using simd_type = std::conditional_t<
  978. std::is_same<e1_value_type, bool>::value,
  979. xt_simd::simd_bool_type<value_type>,
  980. xt_simd::simd_type<value_type>>;
  981. std::size_t simd_size = inner_loop_size / simd_type::size;
  982. std::size_t simd_rest = inner_loop_size % simd_type::size;
  983. auto fct_stepper = e2.stepper_begin(e1.shape());
  984. auto res_stepper = e1.stepper_begin(e1.shape());
  985. // TODO in 1D case this is ambiguous -- could be RM or CM.
  986. // Use default layout to make decision
  987. std::size_t step_dim = 0;
  988. if (!is_row_major) // row major case
  989. {
  990. step_dim = cut;
  991. }
  992. #if defined(XTENSOR_USE_OPENMP) && defined(strided_parallel_assign)
  993. if (outer_loop_size >= XTENSOR_OPENMP_TRESHOLD / inner_loop_size)
  994. {
  995. std::size_t first_step = true;
  996. #pragma omp parallel for schedule(static) firstprivate(first_step, fct_stepper, res_stepper, idx)
  997. for (std::size_t ox = 0; ox < outer_loop_size; ++ox)
  998. {
  999. if (first_step)
  1000. {
  1001. is_row_major
  1002. ? strided_assign_detail::idx_tools<layout_type::row_major>::nth_idx(ox, idx, max_shape)
  1003. : strided_assign_detail::idx_tools<layout_type::column_major>::nth_idx(ox, idx, max_shape);
  1004. for (std::size_t i = 0; i < idx.size(); ++i)
  1005. {
  1006. fct_stepper.step(i + step_dim, idx[i]);
  1007. res_stepper.step(i + step_dim, idx[i]);
  1008. }
  1009. first_step = false;
  1010. }
  1011. for (std::size_t i = 0; i < simd_size; ++i)
  1012. {
  1013. res_stepper.template store_simd(fct_stepper.template step_simd<value_type>());
  1014. }
  1015. for (std::size_t i = 0; i < simd_rest; ++i)
  1016. {
  1017. *(res_stepper) = conditional_cast<needs_cast, e1_value_type>(*(fct_stepper));
  1018. res_stepper.step_leading();
  1019. fct_stepper.step_leading();
  1020. }
  1021. // next unaligned index
  1022. is_row_major
  1023. ? strided_assign_detail::idx_tools<layout_type::row_major>::next_idx(idx, max_shape)
  1024. : strided_assign_detail::idx_tools<layout_type::column_major>::next_idx(idx, max_shape);
  1025. fct_stepper.to_begin();
  1026. // need to step E1 as well if not contigous assign (e.g. view)
  1027. if (!E1::contiguous_layout)
  1028. {
  1029. res_stepper.to_begin();
  1030. for (std::size_t i = 0; i < idx.size(); ++i)
  1031. {
  1032. fct_stepper.step(i + step_dim, idx[i]);
  1033. res_stepper.step(i + step_dim, idx[i]);
  1034. }
  1035. }
  1036. else
  1037. {
  1038. for (std::size_t i = 0; i < idx.size(); ++i)
  1039. {
  1040. fct_stepper.step(i + step_dim, idx[i]);
  1041. }
  1042. }
  1043. }
  1044. }
  1045. else
  1046. {
  1047. #elif defined(strided_parallel_assign) && defined(XTENSOR_USE_TBB)
  1048. if (outer_loop_size > XTENSOR_TBB_THRESHOLD / inner_loop_size)
  1049. {
  1050. tbb::static_partitioner sp;
  1051. tbb::parallel_for(
  1052. tbb::blocked_range<size_t>(0ul, outer_loop_size),
  1053. [&e1, &e2, is_row_major, step_dim, simd_size, simd_rest, &max_shape, &idx_ = idx](
  1054. const tbb::blocked_range<size_t>& r
  1055. )
  1056. {
  1057. auto idx = idx_;
  1058. auto fct_stepper = e2.stepper_begin(e1.shape());
  1059. auto res_stepper = e1.stepper_begin(e1.shape());
  1060. std::size_t first_step = true;
  1061. // #pragma omp parallel for schedule(static) firstprivate(first_step, fct_stepper,
  1062. // res_stepper, idx)
  1063. for (std::size_t ox = r.begin(); ox < r.end(); ++ox)
  1064. {
  1065. if (first_step)
  1066. {
  1067. is_row_major
  1068. ? strided_assign_detail::idx_tools<layout_type::row_major>::nth_idx(ox, idx, max_shape)
  1069. : strided_assign_detail::idx_tools<layout_type::column_major>::nth_idx(
  1070. ox,
  1071. idx,
  1072. max_shape
  1073. );
  1074. for (std::size_t i = 0; i < idx.size(); ++i)
  1075. {
  1076. fct_stepper.step(i + step_dim, idx[i]);
  1077. res_stepper.step(i + step_dim, idx[i]);
  1078. }
  1079. first_step = false;
  1080. }
  1081. for (std::size_t i = 0; i < simd_size; ++i)
  1082. {
  1083. res_stepper.template store_simd(fct_stepper.template step_simd<value_type>());
  1084. }
  1085. for (std::size_t i = 0; i < simd_rest; ++i)
  1086. {
  1087. *(res_stepper) = conditional_cast<needs_cast, e1_value_type>(*(fct_stepper));
  1088. res_stepper.step_leading();
  1089. fct_stepper.step_leading();
  1090. }
  1091. // next unaligned index
  1092. is_row_major
  1093. ? strided_assign_detail::idx_tools<layout_type::row_major>::next_idx(idx, max_shape)
  1094. : strided_assign_detail::idx_tools<layout_type::column_major>::next_idx(idx, max_shape);
  1095. fct_stepper.to_begin();
  1096. // need to step E1 as well if not contigous assign (e.g. view)
  1097. if (!E1::contiguous_layout)
  1098. {
  1099. res_stepper.to_begin();
  1100. for (std::size_t i = 0; i < idx.size(); ++i)
  1101. {
  1102. fct_stepper.step(i + step_dim, idx[i]);
  1103. res_stepper.step(i + step_dim, idx[i]);
  1104. }
  1105. }
  1106. else
  1107. {
  1108. for (std::size_t i = 0; i < idx.size(); ++i)
  1109. {
  1110. fct_stepper.step(i + step_dim, idx[i]);
  1111. }
  1112. }
  1113. }
  1114. },
  1115. sp
  1116. );
  1117. }
  1118. else
  1119. {
  1120. #endif
  1121. for (std::size_t ox = 0; ox < outer_loop_size; ++ox)
  1122. {
  1123. for (std::size_t i = 0; i < simd_size; ++i)
  1124. {
  1125. res_stepper.store_simd(fct_stepper.template step_simd<value_type>());
  1126. }
  1127. for (std::size_t i = 0; i < simd_rest; ++i)
  1128. {
  1129. *(res_stepper) = conditional_cast<needs_cast, e1_value_type>(*(fct_stepper));
  1130. res_stepper.step_leading();
  1131. fct_stepper.step_leading();
  1132. }
  1133. is_row_major
  1134. ? strided_assign_detail::idx_tools<layout_type::row_major>::next_idx(idx, max_shape)
  1135. : strided_assign_detail::idx_tools<layout_type::column_major>::next_idx(idx, max_shape);
  1136. fct_stepper.to_begin();
  1137. // need to step E1 as well if not contigous assign (e.g. view)
  1138. if (!E1::contiguous_layout)
  1139. {
  1140. res_stepper.to_begin();
  1141. for (std::size_t i = 0; i < idx.size(); ++i)
  1142. {
  1143. fct_stepper.step(i + step_dim, idx[i]);
  1144. res_stepper.step(i + step_dim, idx[i]);
  1145. }
  1146. }
  1147. else
  1148. {
  1149. for (std::size_t i = 0; i < idx.size(); ++i)
  1150. {
  1151. fct_stepper.step(i + step_dim, idx[i]);
  1152. }
  1153. }
  1154. }
  1155. #if (defined(XTENSOR_USE_OPENMP) || defined(XTENSOR_USE_TBB)) && defined(strided_parallel_assign)
  1156. }
  1157. #endif
  1158. }
  1159. template <>
  1160. template <class E1, class E2>
  1161. inline void strided_loop_assigner<true>::run(E1& e1, const E2& e2)
  1162. {
  1163. strided_assign_detail::loop_sizes_t loop_sizes = strided_loop_assigner<true>::get_loop_sizes(e1, e2);
  1164. if (loop_sizes.can_do_strided_assign)
  1165. {
  1166. run(e1, e2, loop_sizes);
  1167. }
  1168. else
  1169. {
  1170. // trigger the fallback assigner
  1171. stepper_assigner<E1, E2, default_assignable_layout(E1::static_layout)>(e1, e2).run();
  1172. }
  1173. }
  1174. template <>
  1175. template <class E1, class E2>
  1176. inline void strided_loop_assigner<false>::run(E1& /*e1*/, const E2& /*e2*/, const loop_sizes_t&)
  1177. {
  1178. }
  1179. template <>
  1180. template <class E1, class E2>
  1181. inline void strided_loop_assigner<false>::run(E1& e1, const E2& e2)
  1182. {
  1183. // trigger the fallback assigner
  1184. stepper_assigner<E1, E2, default_assignable_layout(E1::static_layout)>(e1, e2).run();
  1185. }
  1186. }
  1187. #endif