xpad.hpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. /***************************************************************************
  2. * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
  3. * Copyright (c) QuantStack *
  4. * *
  5. * Distributed under the terms of the BSD 3-Clause License. *
  6. * *
  7. * The full license is in the file LICENSE, distributed with this software. *
  8. ****************************************************************************/
  9. #ifndef XTENSOR_PAD_HPP
  10. #define XTENSOR_PAD_HPP
  11. #include "xarray.hpp"
  12. #include "xstrided_view.hpp"
  13. #include "xtensor.hpp"
  14. #include "xview.hpp"
  15. using namespace xt::placeholders; // to enable _ syntax
  16. namespace xt
  17. {
  18. /**
  19. * @brief Defines different algorithms to be used in ``xt::pad``:
  20. * - ``constant``: Pads with a constant value.
  21. * - ``symmetric``: Pads with the reflection of the vector mirrored along the edge of the array.
  22. * - ``reflect``: Pads with the reflection of the vector mirrored on the first and last values
  23. * of the vector along each axis.
  24. * - ``wrap``: Pads with the wrap of the vector along the axis. The first values are used to pad
  25. * the end and the end values are used to pad the beginning.
  26. * - ``periodic`` : ``== wrap`` (pads with periodic repetitions of the vector).
  27. *
  28. * OpenCV to xtensor:
  29. * - ``BORDER_CONSTANT == constant``
  30. * - ``BORDER_REFLECT == symmetric``
  31. * - ``BORDER_REFLECT_101 == reflect``
  32. * - ``BORDER_WRAP == wrap``
  33. */
  34. enum class pad_mode
  35. {
  36. constant,
  37. symmetric,
  38. reflect,
  39. wrap,
  40. periodic,
  41. edge
  42. };
  43. namespace detail
  44. {
  45. template <class S, class T>
  46. inline bool check_pad_width(const std::vector<std::vector<S>>& pad_width, const T& shape)
  47. {
  48. if (pad_width.size() != shape.size())
  49. {
  50. return false;
  51. }
  52. return true;
  53. }
  54. }
  55. /**
  56. * @brief Pad an array.
  57. *
  58. * @param e The array.
  59. * @param pad_width Number of values padded to the edges of each axis:
  60. * `{{before_1, after_1}, ..., {before_N, after_N}}`.
  61. * @param mode The type of algorithm to use. [default: `xt::pad_mode::constant`].
  62. * @param constant_value The value to set the padded values for each axis
  63. * (used in `xt::pad_mode::constant`).
  64. * @return The padded array.
  65. */
  66. template <class E, class S = typename std::decay_t<E>::size_type, class V = typename std::decay_t<E>::value_type>
  67. inline auto
  68. pad(E&& e,
  69. const std::vector<std::vector<S>>& pad_width,
  70. pad_mode mode = pad_mode::constant,
  71. V constant_value = 0)
  72. {
  73. XTENSOR_ASSERT(detail::check_pad_width(pad_width, e.shape()));
  74. using size_type = typename std::decay_t<E>::size_type;
  75. using return_type = temporary_type_t<E>;
  76. // place the original array in the center
  77. auto new_shape = e.shape();
  78. xt::xstrided_slice_vector sv;
  79. sv.reserve(e.shape().size());
  80. for (size_type axis = 0; axis < e.shape().size(); ++axis)
  81. {
  82. size_type nb = static_cast<size_type>(pad_width[axis][0]);
  83. size_type ne = static_cast<size_type>(pad_width[axis][1]);
  84. size_type ns = nb + e.shape(axis) + ne;
  85. new_shape[axis] = ns;
  86. sv.push_back(xt::range(nb, nb + e.shape(axis)));
  87. }
  88. if (mode == pad_mode::constant)
  89. {
  90. return_type out(new_shape, constant_value);
  91. xt::strided_view(out, sv) = e;
  92. return out;
  93. }
  94. return_type out(new_shape);
  95. xt::strided_view(out, sv) = e;
  96. // construct padded regions based on original image
  97. xt::xstrided_slice_vector svs(e.shape().size(), xt::all());
  98. xt::xstrided_slice_vector svt(e.shape().size(), xt::all());
  99. for (size_type axis = 0; axis < e.shape().size(); ++axis)
  100. {
  101. size_type nb = static_cast<size_type>(pad_width[axis][0]);
  102. size_type ne = static_cast<size_type>(pad_width[axis][1]);
  103. if (nb > static_cast<size_type>(0))
  104. {
  105. svt[axis] = xt::range(0, nb);
  106. if (mode == pad_mode::wrap || mode == pad_mode::periodic)
  107. {
  108. XTENSOR_ASSERT(nb <= e.shape(axis));
  109. svs[axis] = xt::range(e.shape(axis), nb + e.shape(axis));
  110. xt::strided_view(out, svt) = xt::strided_view(out, svs);
  111. }
  112. else if (mode == pad_mode::symmetric)
  113. {
  114. XTENSOR_ASSERT(nb <= e.shape(axis));
  115. svs[axis] = xt::range(2 * nb - 1, nb - 1, -1);
  116. xt::strided_view(out, svt) = xt::strided_view(out, svs);
  117. }
  118. else if (mode == pad_mode::reflect)
  119. {
  120. XTENSOR_ASSERT(nb <= e.shape(axis) - 1);
  121. svs[axis] = xt::range(2 * nb, nb, -1);
  122. xt::strided_view(out, svt) = xt::strided_view(out, svs);
  123. }
  124. else if (mode == pad_mode::edge)
  125. {
  126. svs[axis] = xt::range(nb, nb + 1);
  127. xt::strided_view(out, svt) = xt::broadcast(
  128. xt::strided_view(out, svs),
  129. xt::strided_view(out, svt).shape()
  130. );
  131. }
  132. }
  133. if (ne > static_cast<size_type>(0))
  134. {
  135. svt[axis] = xt::range(out.shape(axis) - ne, out.shape(axis));
  136. if (mode == pad_mode::wrap || mode == pad_mode::periodic)
  137. {
  138. XTENSOR_ASSERT(ne <= e.shape(axis));
  139. svs[axis] = xt::range(nb, nb + ne);
  140. xt::strided_view(out, svt) = xt::strided_view(out, svs);
  141. }
  142. else if (mode == pad_mode::symmetric)
  143. {
  144. XTENSOR_ASSERT(ne <= e.shape(axis));
  145. if (ne == nb + e.shape(axis))
  146. {
  147. svs[axis] = xt::range(nb + e.shape(axis) - 1, _, -1);
  148. }
  149. else
  150. {
  151. svs[axis] = xt::range(nb + e.shape(axis) - 1, nb + e.shape(axis) - ne - 1, -1);
  152. }
  153. xt::strided_view(out, svt) = xt::strided_view(out, svs);
  154. }
  155. else if (mode == pad_mode::reflect)
  156. {
  157. XTENSOR_ASSERT(ne <= e.shape(axis) - 1);
  158. if (ne == nb + e.shape(axis) - 1)
  159. {
  160. svs[axis] = xt::range(nb + e.shape(axis) - 2, _, -1);
  161. }
  162. else
  163. {
  164. svs[axis] = xt::range(nb + e.shape(axis) - 2, nb + e.shape(axis) - ne - 2, -1);
  165. }
  166. xt::strided_view(out, svt) = xt::strided_view(out, svs);
  167. }
  168. else if (mode == pad_mode::edge)
  169. {
  170. svs[axis] = xt::range(out.shape(axis) - ne - 1, out.shape(axis) - ne);
  171. xt::strided_view(out, svt) = xt::broadcast(
  172. xt::strided_view(out, svs),
  173. xt::strided_view(out, svt).shape()
  174. );
  175. }
  176. }
  177. svs[axis] = xt::all();
  178. svt[axis] = xt::all();
  179. }
  180. return out;
  181. }
  182. /**
  183. * @brief Pad an array.
  184. *
  185. * @param e The array.
  186. * @param pad_width Number of values padded to the edges of each axis:
  187. * `{before, after}`.
  188. * @param mode The type of algorithm to use. [default: `xt::pad_mode::constant`].
  189. * @param constant_value The value to set the padded values for each axis
  190. * (used in `xt::pad_mode::constant`).
  191. * @return The padded array.
  192. */
  193. template <class E, class S = typename std::decay_t<E>::size_type, class V = typename std::decay_t<E>::value_type>
  194. inline auto
  195. pad(E&& e, const std::vector<S>& pad_width, pad_mode mode = pad_mode::constant, V constant_value = 0)
  196. {
  197. std::vector<std::vector<S>> pw(e.shape().size(), pad_width);
  198. return pad(e, pw, mode, constant_value);
  199. }
  200. /**
  201. * @brief Pad an array.
  202. *
  203. * @param e The array.
  204. * @param pad_width Number of values padded to the edges of each axis.
  205. * @param mode The type of algorithm to use. [default: `xt::pad_mode::constant`].
  206. * @param constant_value The value to set the padded values for each axis
  207. * (used in `xt::pad_mode::constant`).
  208. * @return The padded array.
  209. */
  210. template <class E, class S = typename std::decay_t<E>::size_type, class V = typename std::decay_t<E>::value_type>
  211. inline auto pad(E&& e, S pad_width, pad_mode mode = pad_mode::constant, V constant_value = 0)
  212. {
  213. std::vector<std::vector<S>> pw(e.shape().size(), {pad_width, pad_width});
  214. return pad(e, pw, mode, constant_value);
  215. }
  216. namespace detail
  217. {
  218. template <class E, class S>
  219. inline auto tile(E&& e, const S& reps)
  220. {
  221. using size_type = typename std::decay_t<E>::size_type;
  222. using return_type = temporary_type_t<E>;
  223. XTENSOR_ASSERT(e.shape().size() == reps.size());
  224. using new_shape_type = typename return_type::shape_type;
  225. auto new_shape = xtl::make_sequence<new_shape_type>(e.shape().size());
  226. xt::xstrided_slice_vector sv(reps.size());
  227. for (size_type axis = 0; axis < reps.size(); ++axis)
  228. {
  229. new_shape[axis] = e.shape(axis) * reps[axis];
  230. sv[axis] = xt::range(0, e.shape(axis));
  231. }
  232. return_type out(new_shape);
  233. xt::strided_view(out, sv) = e;
  234. xt::xstrided_slice_vector svs(e.shape().size(), xt::all());
  235. xt::xstrided_slice_vector svt(e.shape().size(), xt::all());
  236. for (size_type axis = 0; axis < e.shape().size(); ++axis)
  237. {
  238. for (size_type i = 1; i < static_cast<size_type>(reps[axis]); ++i)
  239. {
  240. svs[axis] = xt::range(0, e.shape(axis));
  241. svt[axis] = xt::range(i * e.shape(axis), (i + 1) * e.shape(axis));
  242. xt::strided_view(out, svt) = xt::strided_view(out, svs);
  243. svs[axis] = xt::all();
  244. svt[axis] = xt::all();
  245. }
  246. }
  247. return out;
  248. }
  249. }
  250. /**
  251. * @brief Tile an array.
  252. *
  253. * @param e The array.
  254. * @param reps The number of repetitions of A along each axis.
  255. * @return The tiled array.
  256. */
  257. template <class E, class S = typename std::decay_t<E>::size_type>
  258. inline auto tile(E&& e, std::initializer_list<S> reps)
  259. {
  260. return detail::tile(std::forward<E>(e), std::vector<S>{reps});
  261. }
  262. template <class E, class C, XTL_REQUIRES(xtl::negation<xtl::is_integral<C>>)>
  263. inline auto tile(E&& e, const C& reps)
  264. {
  265. return detail::tile(std::forward<E>(e), reps);
  266. }
  267. /**
  268. * @brief Tile an array.
  269. *
  270. * @param e The array.
  271. * @param reps The number of repetitions of A along the first axis.
  272. * @return The tiled array.
  273. */
  274. template <class E, class S = typename std::decay_t<E>::size_type, XTL_REQUIRES(xtl::is_integral<S>)>
  275. inline auto tile(E&& e, S reps)
  276. {
  277. std::vector<S> tw(e.shape().size(), static_cast<S>(1));
  278. tw[0] = reps;
  279. return detail::tile(std::forward<E>(e), tw);
  280. }
  281. }
  282. #endif