xbroadcast.hpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482
  1. /***************************************************************************
  2. * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
  3. * Copyright (c) QuantStack *
  4. * *
  5. * Distributed under the terms of the BSD 3-Clause License. *
  6. * *
  7. * The full license is in the file LICENSE, distributed with this software. *
  8. ****************************************************************************/
  9. #ifndef XTENSOR_BROADCAST_HPP
  10. #define XTENSOR_BROADCAST_HPP
  11. #include <algorithm>
  12. #include <array>
  13. #include <cstddef>
  14. #include <iterator>
  15. #include <numeric>
  16. #include <type_traits>
  17. #include <utility>
  18. #include <xtl/xsequence.hpp>
  19. #include "xaccessible.hpp"
  20. #include "xexpression.hpp"
  21. #include "xiterable.hpp"
  22. #include "xscalar.hpp"
  23. #include "xstrides.hpp"
  24. #include "xtensor_config.hpp"
  25. #include "xutils.hpp"
  26. namespace xt
  27. {
  28. /*************
  29. * broadcast *
  30. *************/
  31. template <class E, class S>
  32. auto broadcast(E&& e, const S& s);
  33. template <class E, class I, std::size_t L>
  34. auto broadcast(E&& e, const I (&s)[L]);
  35. /*************************
  36. * xbroadcast extensions *
  37. *************************/
  38. namespace extension
  39. {
  40. template <class Tag, class CT, class X>
  41. struct xbroadcast_base_impl;
  42. template <class CT, class X>
  43. struct xbroadcast_base_impl<xtensor_expression_tag, CT, X>
  44. {
  45. using type = xtensor_empty_base;
  46. };
  47. template <class CT, class X>
  48. struct xbroadcast_base : xbroadcast_base_impl<xexpression_tag_t<CT>, CT, X>
  49. {
  50. };
  51. template <class CT, class X>
  52. using xbroadcast_base_t = typename xbroadcast_base<CT, X>::type;
  53. }
  54. /**************
  55. * xbroadcast *
  56. **************/
  57. template <class CT, class X>
  58. class xbroadcast;
  59. template <class CT, class X>
  60. struct xiterable_inner_types<xbroadcast<CT, X>>
  61. {
  62. using xexpression_type = std::decay_t<CT>;
  63. using inner_shape_type = promote_shape_t<typename xexpression_type::shape_type, X>;
  64. using const_stepper = typename xexpression_type::const_stepper;
  65. using stepper = const_stepper;
  66. };
  67. template <class CT, class X>
  68. struct xcontainer_inner_types<xbroadcast<CT, X>>
  69. {
  70. using xexpression_type = std::decay_t<CT>;
  71. using reference = typename xexpression_type::const_reference;
  72. using const_reference = typename xexpression_type::const_reference;
  73. using size_type = typename xexpression_type::size_type;
  74. };
  75. /*****************************
  76. * linear_begin / linear_end *
  77. *****************************/
  78. template <class CT, class X>
  79. XTENSOR_CONSTEXPR_RETURN auto linear_begin(xbroadcast<CT, X>& c) noexcept
  80. {
  81. return linear_begin(c.expression());
  82. }
  83. template <class CT, class X>
  84. XTENSOR_CONSTEXPR_RETURN auto linear_end(xbroadcast<CT, X>& c) noexcept
  85. {
  86. return linear_end(c.expression());
  87. }
  88. template <class CT, class X>
  89. XTENSOR_CONSTEXPR_RETURN auto linear_begin(const xbroadcast<CT, X>& c) noexcept
  90. {
  91. return linear_begin(c.expression());
  92. }
  93. template <class CT, class X>
  94. XTENSOR_CONSTEXPR_RETURN auto linear_end(const xbroadcast<CT, X>& c) noexcept
  95. {
  96. return linear_end(c.expression());
  97. }
  98. /*************************************
  99. * overlapping_memory_checker_traits *
  100. *************************************/
  101. template <class E>
  102. struct overlapping_memory_checker_traits<
  103. E,
  104. std::enable_if_t<!has_memory_address<E>::value && is_specialization_of<xbroadcast, E>::value>>
  105. {
  106. static bool check_overlap(const E& expr, const memory_range& dst_range)
  107. {
  108. if (expr.size() == 0)
  109. {
  110. return false;
  111. }
  112. else
  113. {
  114. using ChildE = std::decay_t<decltype(expr.expression())>;
  115. return overlapping_memory_checker_traits<ChildE>::check_overlap(expr.expression(), dst_range);
  116. }
  117. }
  118. };
  119. /**
  120. * @class xbroadcast
  121. * @brief Broadcasted xexpression to a specified shape.
  122. *
  123. * The xbroadcast class implements the broadcasting of an \ref xexpression
  124. * to a specified shape. xbroadcast is not meant to be used directly, but
  125. * only with the \ref broadcast helper functions.
  126. *
  127. * @tparam CT the closure type of the \ref xexpression to broadcast
  128. * @tparam X the type of the specified shape.
  129. *
  130. * @sa broadcast
  131. */
  132. template <class CT, class X>
  133. class xbroadcast : public xsharable_expression<xbroadcast<CT, X>>,
  134. public xconst_iterable<xbroadcast<CT, X>>,
  135. public xconst_accessible<xbroadcast<CT, X>>,
  136. public extension::xbroadcast_base_t<CT, X>
  137. {
  138. public:
  139. using self_type = xbroadcast<CT, X>;
  140. using xexpression_type = std::decay_t<CT>;
  141. using accessible_base = xconst_accessible<self_type>;
  142. using extension_base = extension::xbroadcast_base_t<CT, X>;
  143. using expression_tag = typename extension_base::expression_tag;
  144. using inner_types = xcontainer_inner_types<self_type>;
  145. using value_type = typename xexpression_type::value_type;
  146. using reference = typename inner_types::reference;
  147. using const_reference = typename inner_types::const_reference;
  148. using pointer = typename xexpression_type::const_pointer;
  149. using const_pointer = typename xexpression_type::const_pointer;
  150. using size_type = typename inner_types::size_type;
  151. using difference_type = typename xexpression_type::difference_type;
  152. using iterable_base = xconst_iterable<self_type>;
  153. using inner_shape_type = typename iterable_base::inner_shape_type;
  154. using shape_type = inner_shape_type;
  155. using stepper = typename iterable_base::stepper;
  156. using const_stepper = typename iterable_base::const_stepper;
  157. using bool_load_type = typename xexpression_type::bool_load_type;
  158. static constexpr layout_type static_layout = layout_type::dynamic;
  159. static constexpr bool contiguous_layout = false;
  160. template <class CTA, class S>
  161. xbroadcast(CTA&& e, const S& s);
  162. template <class CTA>
  163. xbroadcast(CTA&& e, shape_type&& s);
  164. using accessible_base::size;
  165. const inner_shape_type& shape() const noexcept;
  166. layout_type layout() const noexcept;
  167. bool is_contiguous() const noexcept;
  168. using accessible_base::shape;
  169. template <class... Args>
  170. const_reference operator()(Args... args) const;
  171. template <class... Args>
  172. const_reference unchecked(Args... args) const;
  173. template <class It>
  174. const_reference element(It first, It last) const;
  175. const xexpression_type& expression() const noexcept;
  176. template <class S>
  177. bool broadcast_shape(S& shape, bool reuse_cache = false) const;
  178. template <class S>
  179. bool has_linear_assign(const S& strides) const noexcept;
  180. template <class S>
  181. const_stepper stepper_begin(const S& shape) const noexcept;
  182. template <class S>
  183. const_stepper stepper_end(const S& shape, layout_type l) const noexcept;
  184. template <class E, class XCT = CT, class = std::enable_if_t<xt::is_xscalar<XCT>::value>>
  185. void assign_to(xexpression<E>& e) const;
  186. template <class E>
  187. using rebind_t = xbroadcast<E, X>;
  188. template <class E>
  189. rebind_t<E> build_broadcast(E&& e) const;
  190. private:
  191. CT m_e;
  192. inner_shape_type m_shape;
  193. };
  194. /****************************
  195. * broadcast implementation *
  196. ****************************/
  197. /**
  198. * @brief Returns an \ref xexpression broadcasting the given expression to
  199. * a specified shape.
  200. *
  201. * @tparam e the \ref xexpression to broadcast
  202. * @tparam s the specified shape to broadcast.
  203. *
  204. * The returned expression either hold a const reference to \p e or a copy
  205. * depending on whether \p e is an lvalue or an rvalue.
  206. */
  207. template <class E, class S>
  208. inline auto broadcast(E&& e, const S& s)
  209. {
  210. using shape_type = filter_fixed_shape_t<std::decay_t<S>>;
  211. using broadcast_type = xbroadcast<const_xclosure_t<E>, shape_type>;
  212. return broadcast_type(std::forward<E>(e), xtl::forward_sequence<shape_type, decltype(s)>(s));
  213. }
  214. template <class E, class I, std::size_t L>
  215. inline auto broadcast(E&& e, const I (&s)[L])
  216. {
  217. using broadcast_type = xbroadcast<const_xclosure_t<E>, std::array<std::size_t, L>>;
  218. using shape_type = typename broadcast_type::shape_type;
  219. return broadcast_type(std::forward<E>(e), xtl::forward_sequence<shape_type, decltype(s)>(s));
  220. }
  221. /*****************************
  222. * xbroadcast implementation *
  223. *****************************/
  224. /**
  225. * @name Constructor
  226. */
  227. //@{
  228. /**
  229. * Constructs an xbroadcast expression broadcasting the specified
  230. * \ref xexpression to the given shape
  231. *
  232. * @param e the expression to broadcast
  233. * @param s the shape to apply
  234. */
  235. template <class CT, class X>
  236. template <class CTA, class S>
  237. inline xbroadcast<CT, X>::xbroadcast(CTA&& e, const S& s)
  238. : m_e(std::forward<CTA>(e))
  239. {
  240. if (s.size() < m_e.dimension())
  241. {
  242. XTENSOR_THROW(xt::broadcast_error, "Broadcast shape has fewer elements than original expression.");
  243. }
  244. xt::resize_container(m_shape, s.size());
  245. std::copy(s.begin(), s.end(), m_shape.begin());
  246. xt::broadcast_shape(m_e.shape(), m_shape);
  247. }
  248. /**
  249. * Constructs an xbroadcast expression broadcasting the specified
  250. * \ref xexpression to the given shape
  251. *
  252. * @param e the expression to broadcast
  253. * @param s the shape to apply
  254. */
  255. template <class CT, class X>
  256. template <class CTA>
  257. inline xbroadcast<CT, X>::xbroadcast(CTA&& e, shape_type&& s)
  258. : m_e(std::forward<CTA>(e))
  259. , m_shape(std::move(s))
  260. {
  261. xt::broadcast_shape(m_e.shape(), m_shape);
  262. }
  263. //@}
  264. /**
  265. * @name Size and shape
  266. */
  267. //@{
  268. /**
  269. * Returns the shape of the expression.
  270. */
  271. template <class CT, class X>
  272. inline auto xbroadcast<CT, X>::shape() const noexcept -> const inner_shape_type&
  273. {
  274. return m_shape;
  275. }
  276. /**
  277. * Returns the layout_type of the expression.
  278. */
  279. template <class CT, class X>
  280. inline layout_type xbroadcast<CT, X>::layout() const noexcept
  281. {
  282. return m_e.layout();
  283. }
  284. template <class CT, class X>
  285. inline bool xbroadcast<CT, X>::is_contiguous() const noexcept
  286. {
  287. return false;
  288. }
  289. //@}
  290. /**
  291. * @name Data
  292. */
  293. //@{
  294. /**
  295. * Returns a constant reference to the element at the specified position in the expression.
  296. * @param args a list of indices specifying the position in the function. Indices
  297. * must be unsigned integers, the number of indices should be equal or greater than
  298. * the number of dimensions of the expression.
  299. */
  300. template <class CT, class X>
  301. template <class... Args>
  302. inline auto xbroadcast<CT, X>::operator()(Args... args) const -> const_reference
  303. {
  304. return m_e(args...);
  305. }
  306. /**
  307. * Returns a constant reference to the element at the specified position in the expression.
  308. * @param args a list of indices specifying the position in the expression. Indices
  309. * must be unsigned integers, the number of indices must be equal to the number of
  310. * dimensions of the expression, else the behavior is undefined.
  311. *
  312. * @warning This method is meant for performance, for expressions with a dynamic
  313. * number of dimensions (i.e. not known at compile time). Since it may have
  314. * undefined behavior (see parameters), operator() should be preferred whenever
  315. * it is possible.
  316. * @warning This method is NOT compatible with broadcasting, meaning the following
  317. * code has undefined behavior:
  318. * @code{.cpp}
  319. * xt::xarray<double> a = {{0, 1}, {2, 3}};
  320. * xt::xarray<double> b = {0, 1};
  321. * auto fd = a + b;
  322. * double res = fd.uncheked(0, 1);
  323. * @endcode
  324. */
  325. template <class CT, class X>
  326. template <class... Args>
  327. inline auto xbroadcast<CT, X>::unchecked(Args... args) const -> const_reference
  328. {
  329. return this->operator()(args...);
  330. }
  331. /**
  332. * Returns a constant reference to the element at the specified position in the expression.
  333. * @param first iterator starting the sequence of indices
  334. * @param last iterator ending the sequence of indices
  335. * The number of indices in the sequence should be equal to or greater
  336. * than the number of dimensions of the function.
  337. */
  338. template <class CT, class X>
  339. template <class It>
  340. inline auto xbroadcast<CT, X>::element(It, It last) const -> const_reference
  341. {
  342. return m_e.element(last - this->dimension(), last);
  343. }
  344. /**
  345. * Returns a constant reference to the underlying expression of the broadcast expression.
  346. */
  347. template <class CT, class X>
  348. inline auto xbroadcast<CT, X>::expression() const noexcept -> const xexpression_type&
  349. {
  350. return m_e;
  351. }
  352. //@}
  353. /**
  354. * @name Broadcasting
  355. */
  356. //@{
  357. /**
  358. * Broadcast the shape of the function to the specified parameter.
  359. * @param shape the result shape
  360. * @param reuse_cache parameter for internal optimization
  361. * @return a boolean indicating whether the broadcasting is trivial
  362. */
  363. template <class CT, class X>
  364. template <class S>
  365. inline bool xbroadcast<CT, X>::broadcast_shape(S& shape, bool) const
  366. {
  367. return xt::broadcast_shape(m_shape, shape);
  368. }
  369. /**
  370. * Checks whether the xbroadcast can be linearly assigned to an expression
  371. * with the specified strides.
  372. * @return a boolean indicating whether a linear assign is possible
  373. */
  374. template <class CT, class X>
  375. template <class S>
  376. inline bool xbroadcast<CT, X>::has_linear_assign(const S& strides) const noexcept
  377. {
  378. return this->dimension() == m_e.dimension()
  379. && std::equal(m_shape.cbegin(), m_shape.cend(), m_e.shape().cbegin())
  380. && m_e.has_linear_assign(strides);
  381. }
  382. //@}
  383. template <class CT, class X>
  384. template <class S>
  385. inline auto xbroadcast<CT, X>::stepper_begin(const S& shape) const noexcept -> const_stepper
  386. {
  387. // Could check if (broadcastable(shape, m_shape)
  388. return m_e.stepper_begin(shape);
  389. }
  390. template <class CT, class X>
  391. template <class S>
  392. inline auto xbroadcast<CT, X>::stepper_end(const S& shape, layout_type l) const noexcept -> const_stepper
  393. {
  394. // Could check if (broadcastable(shape, m_shape)
  395. return m_e.stepper_end(shape, l);
  396. }
  397. template <class CT, class X>
  398. template <class E, class XCT, class>
  399. inline void xbroadcast<CT, X>::assign_to(xexpression<E>& e) const
  400. {
  401. auto& ed = e.derived_cast();
  402. ed.resize(m_shape);
  403. std::fill(ed.begin(), ed.end(), m_e());
  404. }
  405. template <class CT, class X>
  406. template <class E>
  407. inline auto xbroadcast<CT, X>::build_broadcast(E&& e) const -> rebind_t<E>
  408. {
  409. return rebind_t<E>(std::forward<E>(e), inner_shape_type(m_shape));
  410. }
  411. }
  412. #endif