xaccessible.hpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. /***************************************************************************
  2. * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
  3. * Copyright (c) QuantStack *
  4. * *
  5. * Distributed under the terms of the BSD 3-Clause License. *
  6. * *
  7. * The full license is in the file LICENSE, distributed with this software. *
  8. ****************************************************************************/
  9. #ifndef XTENSOR_ACCESSIBLE_HPP
  10. #define XTENSOR_ACCESSIBLE_HPP
  11. #include "xexception.hpp"
  12. #include "xstrides.hpp"
  13. #include "xtensor_forward.hpp"
  14. namespace xt
  15. {
  16. /**
  17. * @class xconst_accessible
  18. * @brief Base class for implementation of common expression constant access methods.
  19. *
  20. * The xaccessible class implements constant access methods common to all expressions.
  21. *
  22. * @tparam D The derived type, i.e. the inheriting class for which xconst_accessible
  23. * provides the interface.
  24. */
  25. template <class D>
  26. class xconst_accessible
  27. {
  28. public:
  29. using derived_type = D;
  30. using inner_types = xcontainer_inner_types<D>;
  31. using reference = typename inner_types::reference;
  32. using const_reference = typename inner_types::const_reference;
  33. using size_type = typename inner_types::size_type;
  34. size_type size() const noexcept;
  35. size_type dimension() const noexcept;
  36. size_type shape(size_type index) const;
  37. template <class... Args>
  38. const_reference at(Args... args) const;
  39. template <class S>
  40. disable_integral_t<S, const_reference> operator[](const S& index) const;
  41. template <class I>
  42. const_reference operator[](std::initializer_list<I> index) const;
  43. const_reference operator[](size_type i) const;
  44. template <class... Args>
  45. const_reference periodic(Args... args) const;
  46. template <class... Args>
  47. bool in_bounds(Args... args) const;
  48. const_reference front() const;
  49. const_reference back() const;
  50. protected:
  51. xconst_accessible() = default;
  52. ~xconst_accessible() = default;
  53. xconst_accessible(const xconst_accessible&) = default;
  54. xconst_accessible& operator=(const xconst_accessible&) = default;
  55. xconst_accessible(xconst_accessible&&) = default;
  56. xconst_accessible& operator=(xconst_accessible&&) = default;
  57. private:
  58. const derived_type& derived_cast() const noexcept;
  59. };
  60. /**
  61. * @class xaccessible
  62. * @brief Base class for implementation of common expression access methods.
  63. *
  64. * The xaccessible class implements access methods common to all expressions.
  65. *
  66. * @tparam D The derived type, i.e. the inheriting class for which xaccessible
  67. * provides the interface.
  68. */
  69. template <class D>
  70. class xaccessible : public xconst_accessible<D>
  71. {
  72. public:
  73. using base_type = xconst_accessible<D>;
  74. using derived_type = typename base_type::derived_type;
  75. using reference = typename base_type::reference;
  76. using size_type = typename base_type::size_type;
  77. template <class... Args>
  78. reference at(Args... args);
  79. template <class S>
  80. disable_integral_t<S, reference> operator[](const S& index);
  81. template <class I>
  82. reference operator[](std::initializer_list<I> index);
  83. reference operator[](size_type i);
  84. template <class... Args>
  85. reference periodic(Args... args);
  86. reference front();
  87. reference back();
  88. using base_type::at;
  89. using base_type::operator[];
  90. using base_type::back;
  91. using base_type::front;
  92. using base_type::periodic;
  93. protected:
  94. xaccessible() = default;
  95. ~xaccessible() = default;
  96. xaccessible(const xaccessible&) = default;
  97. xaccessible& operator=(const xaccessible&) = default;
  98. xaccessible(xaccessible&&) = default;
  99. xaccessible& operator=(xaccessible&&) = default;
  100. private:
  101. derived_type& derived_cast() noexcept;
  102. };
  103. /************************************
  104. * xconst_accessible implementation *
  105. ************************************/
  106. /**
  107. * Returns the size of the expression.
  108. */
  109. template <class D>
  110. inline auto xconst_accessible<D>::size() const noexcept -> size_type
  111. {
  112. return compute_size(derived_cast().shape());
  113. }
  114. /**
  115. * Returns the number of dimensions of the expression.
  116. */
  117. template <class D>
  118. inline auto xconst_accessible<D>::dimension() const noexcept -> size_type
  119. {
  120. return derived_cast().shape().size();
  121. }
  122. /**
  123. * Returns the i-th dimension of the expression.
  124. */
  125. template <class D>
  126. inline auto xconst_accessible<D>::shape(size_type index) const -> size_type
  127. {
  128. return derived_cast().shape()[index];
  129. }
  130. /**
  131. * Returns a constant reference to the element at the specified position in the expression,
  132. * after dimension and bounds checking.
  133. * @param args a list of indices specifying the position in the expression. Indices
  134. * must be unsigned integers, the number of indices should be equal to the number of dimensions
  135. * of the expression.
  136. * @exception std::out_of_range if the number of argument is greater than the number of dimensions
  137. * or if indices are out of bounds.
  138. */
  139. template <class D>
  140. template <class... Args>
  141. inline auto xconst_accessible<D>::at(Args... args) const -> const_reference
  142. {
  143. check_access(derived_cast().shape(), args...);
  144. return derived_cast().operator()(args...);
  145. }
  146. /**
  147. * Returns a constant reference to the element at the specified position in the expression.
  148. * @param index a sequence of indices specifying the position in the expression. Indices
  149. * must be unsigned integers, the number of indices in the list should be equal or greater
  150. * than the number of dimensions of the expression.
  151. */
  152. template <class D>
  153. template <class S>
  154. inline auto xconst_accessible<D>::operator[](const S& index) const
  155. -> disable_integral_t<S, const_reference>
  156. {
  157. return derived_cast().element(index.cbegin(), index.cend());
  158. }
  159. template <class D>
  160. template <class I>
  161. inline auto xconst_accessible<D>::operator[](std::initializer_list<I> index) const -> const_reference
  162. {
  163. return derived_cast().element(index.begin(), index.end());
  164. }
  165. template <class D>
  166. inline auto xconst_accessible<D>::operator[](size_type i) const -> const_reference
  167. {
  168. return derived_cast().operator()(i);
  169. }
  170. /**
  171. * Returns a constant reference to the element at the specified position in the expression,
  172. * after applying periodicity to the indices (negative and 'overflowing' indices are changed).
  173. * @param args a list of indices specifying the position in the expression. Indices
  174. * must be integers, the number of indices should be equal to the number of dimensions
  175. * of the expression.
  176. */
  177. template <class D>
  178. template <class... Args>
  179. inline auto xconst_accessible<D>::periodic(Args... args) const -> const_reference
  180. {
  181. normalize_periodic(derived_cast().shape(), args...);
  182. return derived_cast()(static_cast<size_type>(args)...);
  183. }
  184. /**
  185. * Returns a constant reference to first the element of the expression
  186. */
  187. template <class D>
  188. inline auto xconst_accessible<D>::front() const -> const_reference
  189. {
  190. return *derived_cast().begin();
  191. }
  192. /**
  193. * Returns a constant reference to last the element of the expression
  194. */
  195. template <class D>
  196. inline auto xconst_accessible<D>::back() const -> const_reference
  197. {
  198. return *std::prev(derived_cast().end());
  199. }
  200. /**
  201. * Returns ``true`` only if the the specified position is a valid entry in the expression.
  202. * @param args a list of indices specifying the position in the expression.
  203. * @return bool
  204. */
  205. template <class D>
  206. template <class... Args>
  207. inline bool xconst_accessible<D>::in_bounds(Args... args) const
  208. {
  209. return check_in_bounds(derived_cast().shape(), args...);
  210. }
  211. template <class D>
  212. inline auto xconst_accessible<D>::derived_cast() const noexcept -> const derived_type&
  213. {
  214. return *static_cast<const derived_type*>(this);
  215. }
  216. /******************************
  217. * xaccessible implementation *
  218. ******************************/
  219. /**
  220. * Returns a reference to the element at the specified position in the expression,
  221. * after dimension and bounds checking.
  222. * @param args a list of indices specifying the position in the expression. Indices
  223. * must be unsigned integers, the number of indices should be equal to the number of dimensions
  224. * of the expression.
  225. * @exception std::out_of_range if the number of argument is greater than the number of dimensions
  226. * or if indices are out of bounds.
  227. */
  228. template <class D>
  229. template <class... Args>
  230. inline auto xaccessible<D>::at(Args... args) -> reference
  231. {
  232. check_access(derived_cast().shape(), args...);
  233. return derived_cast().operator()(args...);
  234. }
  235. /**
  236. * Returns a reference to the element at the specified position in the expression.
  237. * @param index a sequence of indices specifying the position in the expression. Indices
  238. * must be unsigned integers, the number of indices in the list should be equal or greater
  239. * than the number of dimensions of the expression.
  240. */
  241. template <class D>
  242. template <class S>
  243. inline auto xaccessible<D>::operator[](const S& index) -> disable_integral_t<S, reference>
  244. {
  245. return derived_cast().element(index.cbegin(), index.cend());
  246. }
  247. template <class D>
  248. template <class I>
  249. inline auto xaccessible<D>::operator[](std::initializer_list<I> index) -> reference
  250. {
  251. return derived_cast().element(index.begin(), index.end());
  252. }
  253. template <class D>
  254. inline auto xaccessible<D>::operator[](size_type i) -> reference
  255. {
  256. return derived_cast().operator()(i);
  257. }
  258. /**
  259. * Returns a reference to the element at the specified position in the expression,
  260. * after applying periodicity to the indices (negative and 'overflowing' indices are changed).
  261. * @param args a list of indices specifying the position in the expression. Indices
  262. * must be integers, the number of indices should be equal to the number of dimensions
  263. * of the expression.
  264. */
  265. template <class D>
  266. template <class... Args>
  267. inline auto xaccessible<D>::periodic(Args... args) -> reference
  268. {
  269. normalize_periodic(derived_cast().shape(), args...);
  270. return derived_cast()(args...);
  271. }
  272. /**
  273. * Returns a reference to the first element of the expression.
  274. */
  275. template <class D>
  276. inline auto xaccessible<D>::front() -> reference
  277. {
  278. return *derived_cast().begin();
  279. }
  280. /**
  281. * Returns a reference to the last element of the expression.
  282. */
  283. template <class D>
  284. inline auto xaccessible<D>::back() -> reference
  285. {
  286. return *std::prev(derived_cast().end());
  287. }
  288. template <class D>
  289. inline auto xaccessible<D>::derived_cast() noexcept -> derived_type&
  290. {
  291. return *static_cast<derived_type*>(this);
  292. }
  293. }
  294. #endif