numpy.hpp 34 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037
  1. #pragma once
  2. #include <any>
  3. #include <cstdint>
  4. #include <complex>
  5. #include <chrono>
  6. #include <iostream>
  7. #include <limits>
  8. #include <string>
  9. #include <tuple>
  10. #include <type_traits>
  11. #include <utility>
  12. #include <vector>
  13. // Suppress xtensor warnings if SUPPRESS_XTENSOR_WARNINGS is set
  14. #ifdef SUPPRESS_XTENSOR_WARNINGS
  15. #ifdef _MSC_VER
  16. #pragma warning(push, 0)
  17. #else
  18. #pragma GCC diagnostic push
  19. #pragma GCC diagnostic ignored "-Wall"
  20. #pragma GCC diagnostic ignored "-Wextra"
  21. #pragma GCC system_header
  22. #endif
  23. #endif
  24. #include <xtensor/xarray.hpp>
  25. #include <xtensor/xio.hpp>
  26. #include <xtensor/xmath.hpp>
  27. #include <xtensor/xrandom.hpp>
  28. #include <xtensor/xsort.hpp>
  29. #include <xtensor/xview.hpp>
  30. #ifdef SUPPRESS_XTENSOR_WARNINGS
  31. #ifdef _MSC_VER
  32. #pragma warning(pop)
  33. #else
  34. #pragma GCC diagnostic pop
  35. #endif
  36. #endif
  37. namespace pkpy {
  38. // Type aliases
  39. using int8 = int8_t;
  40. using int16 = int16_t;
  41. using int32 = int32_t;
  42. using int64 = int64_t;
  43. using uint8 = uint8_t;
  44. using uint16 = uint16_t;
  45. using uint32 = uint32_t;
  46. using uint64 = uint64_t;
  47. using int_ = int64;
  48. using float32 = float;
  49. using float64 = double;
  50. using float_ = float64;
  51. using bool_ = bool;
  52. using complex64 = std::complex<float32>;
  53. using complex128 = std::complex<float64>;
  54. using complex_ = complex128;
  55. using string = std::string;
  56. template <typename T>
  57. struct dtype_traits {
  58. constexpr const static char* name = "unknown";
  59. };
  60. #define REGISTER_DTYPE(Type, Name) \
  61. template <> \
  62. struct dtype_traits<Type> { \
  63. static constexpr const char* name = Name; \
  64. };
  65. REGISTER_DTYPE(int8_t, "int8");
  66. REGISTER_DTYPE(int16_t, "int16");
  67. REGISTER_DTYPE(int32_t, "int32");
  68. REGISTER_DTYPE(int64_t, "int64");
  69. REGISTER_DTYPE(uint8_t, "uint8");
  70. REGISTER_DTYPE(uint16_t, "uint16");
  71. REGISTER_DTYPE(uint32_t, "uint32");
  72. REGISTER_DTYPE(uint64_t, "uint64");
  73. REGISTER_DTYPE(float, "float32");
  74. REGISTER_DTYPE(float_, "float64");
  75. REGISTER_DTYPE(bool_, "bool");
  76. REGISTER_DTYPE(std::complex<float32>, "complex64");
  77. REGISTER_DTYPE(std::complex<float64>, "complex128");
  78. using _Dtype = std::string;
  79. using _ShapeLike = std::vector<int>;
  80. namespace numpy {
  81. template <typename T>
  82. class ndarray;
  83. template <typename T>
  84. constexpr inline auto is_ndarray_v = false;
  85. template <typename T>
  86. constexpr inline auto is_ndarray_v<ndarray<T>> = true;
  87. template <typename T>
  88. class ndarray {
  89. public:
  90. // Constructor for xtensor xarray
  91. ndarray() = default;
  92. ndarray(const T scalar) : _array(scalar) {}
  93. ndarray(const xt::xarray<T>& arr) : _array(arr) {}
  94. // Constructor for mutli-dimensional array
  95. ndarray(std::initializer_list<T> init_list) : _array(init_list) {}
  96. ndarray(std::initializer_list<std::initializer_list<T>> init_list) : _array(init_list) {}
  97. ndarray(std::initializer_list<std::initializer_list<std::initializer_list<T>>> init_list) : _array(init_list) {}
  98. ndarray(std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<T>>>> init_list) :
  99. _array(init_list) {}
  100. ndarray(std::initializer_list<
  101. std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<T>>>>> init_list) :
  102. _array(init_list) {}
  103. // Accessor function for _array
  104. const xt::xarray<T>& get_array() const { return _array; }
  105. // Properties
  106. _Dtype dtype() const { return dtype_traits<T>::name; }
  107. int ndim() const { return static_cast<int>(_array.dimension()); }
  108. int size() const { return static_cast<int>(_array.size()); }
  109. _ShapeLike shape() const { return _ShapeLike(_array.shape().begin(), _array.shape().end()); }
  110. // Dunder Methods
  111. template <typename U>
  112. auto operator== (const ndarray<U>& other) const {
  113. return ndarray<bool_>(xt::equal(_array, other.get_array()));
  114. }
  115. template <typename U>
  116. auto operator!= (const ndarray<U>& other) const {
  117. return ndarray<bool_>(xt::not_equal(_array, other.get_array()));
  118. }
  119. template <typename U>
  120. auto operator+ (const ndarray<U>& other) const {
  121. using result_type = std::common_type_t<T, U>;
  122. xt::xarray<result_type> result = xt::cast<result_type>(_array) + xt::cast<result_type>(other.get_array());
  123. return ndarray<result_type>(result);
  124. }
  125. template <typename U, typename = std::enable_if_t<!is_ndarray_v<U>>>
  126. auto operator+ (const U& other) const {
  127. return binary_operator_add_impl<U>(other);
  128. }
  129. template <typename U>
  130. auto binary_operator_add_impl(const U& other) const {
  131. if constexpr(std::is_same_v<U, float_>) {
  132. xt::xarray<float_> result = xt::cast<float_>(_array) + other;
  133. return ndarray<float_>(result);
  134. } else {
  135. using result_type = std::common_type_t<T, U>;
  136. xt::xarray<result_type> result = xt::cast<result_type>(_array) + other;
  137. return ndarray<result_type>(result);
  138. }
  139. }
  140. template <typename U>
  141. auto operator- (const ndarray<U>& other) const {
  142. using result_type = std::common_type_t<T, U>;
  143. xt::xarray<result_type> result = xt::cast<result_type>(_array) - xt::cast<result_type>(other.get_array());
  144. return ndarray<result_type>(result);
  145. }
  146. template <typename U, typename = std::enable_if_t<!is_ndarray_v<U>>>
  147. auto operator- (const U& other) const {
  148. return binary_operator_sub_impl<U>(other);
  149. }
  150. template <typename U>
  151. auto binary_operator_sub_impl(const U& other) const {
  152. if constexpr(std::is_same_v<U, float_>) {
  153. xt::xarray<float_> result = xt::cast<float_>(_array) - other;
  154. return ndarray<float_>(result);
  155. } else {
  156. using result_type = std::common_type_t<T, U>;
  157. xt::xarray<result_type> result = xt::cast<result_type>(_array) - other;
  158. return ndarray<result_type>(result);
  159. }
  160. }
  161. template <typename U>
  162. auto operator* (const ndarray<U>& other) const {
  163. using result_type = std::common_type_t<T, U>;
  164. xt::xarray<result_type> result = xt::cast<result_type>(_array) * xt::cast<result_type>(other.get_array());
  165. return ndarray<result_type>(result);
  166. }
  167. template <typename U, typename = std::enable_if_t<!is_ndarray_v<U>>>
  168. auto operator* (const U& other) const {
  169. return binary_operator_mul_impl<U>(other);
  170. }
  171. template <typename U>
  172. auto binary_operator_mul_impl(const U& other) const {
  173. if constexpr(std::is_same_v<U, float_>) {
  174. xt::xarray<float_> result = xt::cast<float_>(_array) * other;
  175. return ndarray<float_>(result);
  176. } else {
  177. using result_type = std::common_type_t<T, U>;
  178. xt::xarray<result_type> result = xt::cast<result_type>(_array) * other;
  179. return ndarray<result_type>(result);
  180. }
  181. }
  182. template <typename U>
  183. auto operator/ (const ndarray<U>& other) const {
  184. using result_type = std::conditional_t<std::is_same_v<T, bool> || std::is_same_v<U, bool>, float64, std::common_type_t<T, U>>;
  185. xt::xarray<result_type> result = xt::cast<result_type>(_array) / xt::cast<result_type>(other.get_array());
  186. return ndarray<result_type>(result);
  187. }
  188. template <typename U, typename = std::enable_if_t<!is_ndarray_v<U>>>
  189. auto operator/ (const U& other) const {
  190. return binary_operator_truediv_impl<U>(other);
  191. }
  192. template <typename U>
  193. auto binary_operator_truediv_impl(const U& other) const {
  194. xt::xarray<float_> result = xt::cast<float_>(_array) / static_cast<float_>(other);
  195. return ndarray<float_>(result);
  196. }
  197. template <typename U>
  198. auto pow(const ndarray<U>& other) const {
  199. using result_type = std::common_type_t<T, U>;
  200. xt::xarray<result_type> result =
  201. xt::pow(xt::cast<result_type>(_array), xt::cast<result_type>(other.get_array()));
  202. return ndarray<result_type>(result);
  203. }
  204. template <typename U, typename = std::enable_if_t<!is_ndarray_v<U>>>
  205. auto pow(const U& other) const {
  206. return pow_impl<U>(other);
  207. }
  208. template <typename U>
  209. auto pow_impl(const U& other) const {
  210. xt::xarray<float_> result = xt::pow(xt::cast<float_>(_array), other);
  211. return ndarray<float_>(result);
  212. }
  213. template <typename U>
  214. ndarray operator& (const ndarray<U>& other) const {
  215. using result_type = std::common_type_t<T, U>;
  216. xt::xarray<result_type> result = xt::cast<result_type>(_array) & xt::cast<result_type>(other.get_array());
  217. return ndarray(result);
  218. }
  219. template <typename U, typename = std::enable_if_t<!is_ndarray_v<U>>>
  220. ndarray operator& (const U& other) const {
  221. xt::xarray<T> result = _array & static_cast<T>(other);
  222. return ndarray(result);
  223. }
  224. template <typename U>
  225. ndarray operator| (const ndarray<U>& other) const {
  226. using result_type = std::common_type_t<T, U>;
  227. xt::xarray<result_type> result = xt::cast<result_type>(_array) | xt::cast<result_type>(other.get_array());
  228. return ndarray(result);
  229. }
  230. template <typename U, typename = std::enable_if_t<!is_ndarray_v<U>>>
  231. ndarray operator| (const U& other) const {
  232. xt::xarray<T> result = _array | static_cast<T>(other);
  233. return ndarray(result);
  234. }
  235. template <typename U>
  236. ndarray operator^ (const ndarray<U>& other) const {
  237. using result_type = std::common_type_t<T, U>;
  238. xt::xarray<result_type> result = xt::cast<result_type>(_array) ^ xt::cast<result_type>(other.get_array());
  239. return ndarray(result);
  240. }
  241. template <typename U, typename = std::enable_if_t<!is_ndarray_v<U>>>
  242. ndarray operator^ (const U& other) const {
  243. xt::xarray<T> result = _array ^ static_cast<T>(other);
  244. return ndarray(result);
  245. }
  246. ndarray operator~() const { return ndarray(~(_array)); }
  247. ndarray operator!() const { return ndarray(!(_array)); }
  248. T operator() (int index) const { return _array(index); }
  249. ndarray operator[] (int index) const { return ndarray(xt::view(_array, index, xt::all())); }
  250. ndarray operator[] (const std::vector<int>& indices) const { return ndarray(xt::view(_array, xt::keep(indices))); }
  251. ndarray operator[] (const std::tuple<int, int, int>& slice) const {
  252. return ndarray(xt::view(_array, xt::range(std::get<0>(slice), std::get<1>(slice), std::get<2>(slice))));
  253. }
  254. template <typename... Args>
  255. T operator() (Args... args) const {
  256. return _array(args...);
  257. }
  258. void set_item(int index, const ndarray<T>& value) { xt::view(_array, index, xt::all()) = value.get_array(); }
  259. void set_item(int i1, int i2, const ndarray<T>& value) { xt::view(_array, i1, i2, xt::all()) = value.get_array(); }
  260. void set_item(int i1, int i2, int i3, const ndarray<T>& value) { xt::view(_array, i1, i2, i3, xt::all()) = value.get_array(); }
  261. void set_item(int i1, int i2, int i3, int i4, const ndarray<T>& value) { xt::view(_array, i1, i2, i3, i4, xt::all()) = value.get_array(); }
  262. void set_item(int i1, int i2, int i3, int i4, int i5, const ndarray<T>& value) { xt::view(_array, i1, i2, i3, i4, i5, xt::all()) = value.get_array(); }
  263. void set_item(const std::vector<int>& indices, const ndarray<T>& value) {
  264. xt::view(_array, xt::keep(indices)) = value.get_array();
  265. }
  266. void set_item(const std::tuple<int, int, int>& slice, const ndarray<T>& value) {
  267. xt::view(_array, xt::range(std::get<0>(slice), std::get<1>(slice), std::get<2>(slice))) = value.get_array();
  268. }
  269. void set_item(int i1, int i2, T value) { xt::view(_array, i1, i2) = value; }
  270. void set_item(int i1, int i2, int i3, T value) { xt::view(_array, i1, i2, i3) = value; }
  271. void set_item(int i1, int i2, int i3, int i4, T value) { xt::view(_array, i1, i2, i3, i4) = value; }
  272. void set_item(int i1, int i2, int i3, int i4, int i5, T value) { xt::view(_array, i1, i2, i3, i4, i5) = value; }
  273. // Boolean Functions
  274. bool all() const { return xt::all(_array); }
  275. bool any() const { return xt::any(_array); }
  276. // Aggregate Functions
  277. T sum() const { return (xt::sum(_array))[0]; }
  278. ndarray<T> sum(int axis) const {
  279. xt::xarray<T> result = xt::sum(_array, {axis});
  280. return ndarray<T>(result);
  281. }
  282. ndarray<T> sum(const _ShapeLike& axis) const {
  283. xt::xarray<T> result = xt::sum(_array, axis);
  284. return ndarray<T>(result);
  285. }
  286. T prod() const { return (xt::prod(_array))[0]; }
  287. ndarray<T> prod(int axis) const {
  288. xt::xarray<T> result = xt::prod(_array, {axis});
  289. return ndarray<T>(result);
  290. }
  291. ndarray<T> prod(const _ShapeLike& axes) const {
  292. xt::xarray<T> result = xt::prod(_array, axes);
  293. return ndarray<T>(result);
  294. }
  295. T min() const { return (xt::amin(_array))[0]; }
  296. ndarray<T> min(int axis) const {
  297. xt::xarray<T> result = xt::amin(_array, {axis});
  298. return ndarray<T>(result);
  299. }
  300. ndarray<T> min(const _ShapeLike& axes) const {
  301. xt::xarray<T> result = xt::amin(_array, axes);
  302. return ndarray<T>(result);
  303. }
  304. T max() const { return (xt::amax(_array))[0]; }
  305. ndarray<T> max(int axis) const {
  306. xt::xarray<T> result = xt::amax(_array, {axis});
  307. return ndarray<T>(result);
  308. }
  309. ndarray<T> max(const _ShapeLike& axes) const {
  310. xt::xarray<T> result = xt::amax(_array, axes);
  311. return ndarray<T>(result);
  312. }
  313. pkpy::float64 mean() const { return (xt::mean(_array))[0]; }
  314. ndarray<pkpy::float64> mean(int axis) const {
  315. return ndarray<pkpy::float64>(xt::mean(_array, {axis}));
  316. }
  317. ndarray<pkpy::float64> mean(const _ShapeLike& axes) const {
  318. return ndarray<pkpy::float64>(xt::mean(_array, axes));
  319. }
  320. pkpy::float64 std() const { return (xt::stddev(_array))[0]; }
  321. ndarray<pkpy::float64> std(int axis) const {
  322. return ndarray<pkpy::float64>(xt::stddev(_array, {axis}));
  323. }
  324. ndarray<pkpy::float64> std(const _ShapeLike& axes) const {
  325. return ndarray<pkpy::float64>(xt::stddev(_array, axes));
  326. }
  327. pkpy::float64 var() const { return (xt::variance(_array))[0]; }
  328. ndarray<pkpy::float64> var(int axis) const {
  329. return ndarray<pkpy::float64>(xt::variance(_array, {axis}));
  330. }
  331. ndarray<pkpy::float64> var(const _ShapeLike& axes) const {
  332. return ndarray<pkpy::float64>(xt::variance(_array, axes));
  333. }
  334. // Searching and Sorting Functions
  335. pkpy::int64 argmin() const { return (xt::argmin(_array))[0]; }
  336. ndarray<T> argmin(int axis) const {
  337. xt::xarray<T> result = xt::argmin(_array, {axis});
  338. return ndarray<T>(result);
  339. }
  340. pkpy::int64 argmax() const { return (xt::argmax(_array))[0]; }
  341. ndarray<T> argmax(int axis) const {
  342. xt::xarray<T> result = xt::argmax(_array, {axis});
  343. return ndarray<T>(result);
  344. }
  345. ndarray<T> argsort() const { return ndarray<T>(xt::argsort(_array)); }
  346. ndarray<T> argsort(int axis) const {
  347. xt::xarray<T> result = xt::argsort(_array, {axis});
  348. return ndarray<T>(result);
  349. }
  350. ndarray<T> sort() const { return ndarray<T>(xt::sort(_array)); }
  351. ndarray<T> sort(int axis) const {
  352. xt::xarray<T> result = xt::sort(_array, {axis});
  353. return ndarray<T>(result);
  354. }
  355. // Shape Manipulation Functions
  356. ndarray<T> reshape(const _ShapeLike& shape) const {
  357. xt::xarray<T> dummy = _array;
  358. dummy.reshape(shape);
  359. return ndarray<T>(dummy);
  360. }
  361. // Does not preserve elements if expected size is not equal to the current size.
  362. // https://github.com/xtensor-stack/xtensor/issues/1445
  363. ndarray<T> resize(const _ShapeLike& shape) const {
  364. xt::xarray<T> dummy = _array;
  365. dummy.resize(shape);
  366. return ndarray<T>(dummy);
  367. }
  368. ndarray<T> squeeze() const { return ndarray<T>(xt::squeeze(_array)); }
  369. ndarray<T> squeeze(int axis) const {
  370. xt::xarray<T> result = xt::squeeze(_array, {axis});
  371. return ndarray<T>(result);
  372. }
  373. ndarray<T> transpose() const { return ndarray<T>(xt::transpose(_array)); }
  374. ndarray<T> transpose(const _ShapeLike& permutation) const { return ndarray<T>(xt::transpose(_array, permutation)); }
  375. template <typename... Args>
  376. ndarray<T> transpose(Args... args) const {
  377. xt::xarray<T> result = xt::transpose(_array, {args...});
  378. return ndarray<T>(result);
  379. }
  380. ndarray<T> repeat(int repeats, int axis) const { return ndarray<T>(xt::repeat(_array, repeats, axis)); }
  381. ndarray<T> repeat(const std::vector<size_t>& repeats, int axis) const {
  382. return ndarray<T>(xt::repeat(_array, repeats, axis));
  383. }
  384. ndarray<T> flatten() const { return ndarray<T>(xt::flatten(_array)); }
  385. // Miscellaneous Functions
  386. ndarray<T> round() const { return ndarray<T>(xt::round(_array)); }
  387. template <typename U>
  388. ndarray<U> astype() const {
  389. xt::xarray<U> result = xt::cast<U>(_array);
  390. return ndarray<U>(result);
  391. }
  392. ndarray<T> copy() const {
  393. ndarray<T> result = *this;
  394. return result;
  395. }
  396. std::vector<T> to_list() const {
  397. std::vector<T> vec;
  398. for(auto &it : _array) {
  399. vec.push_back(it);
  400. }
  401. return vec;
  402. }
  403. private:
  404. xt::xarray<T> _array;
  405. };
  406. class random {
  407. public:
  408. random() {
  409. auto seed = std::chrono::high_resolution_clock::now().time_since_epoch().count();
  410. xt::random::seed(static_cast<xt::random::seed_type>(seed));
  411. }
  412. template <typename T>
  413. static T rand() {
  414. random random_instance;
  415. return (xt::random::rand<T>(std::vector{1}))[0];
  416. }
  417. template <typename T>
  418. static ndarray<T> rand(const _ShapeLike& shape) {
  419. random random_instance;
  420. return ndarray<T>(xt::random::rand<T>(shape));
  421. }
  422. template <typename T>
  423. static T randn() {
  424. random random_instance;
  425. return (xt::random::randn<T>(std::vector{1}))[0];
  426. }
  427. template <typename T>
  428. static ndarray<T> randn(const _ShapeLike& shape) {
  429. random random_instance;
  430. return ndarray<T>(xt::random::randn<T>(shape));
  431. }
  432. template <typename T>
  433. static int randint(T low, T high) {
  434. random random_instance;
  435. return (xt::random::randint<T>(std::vector{1}, low, high))[0];
  436. }
  437. template <typename T>
  438. static ndarray<T> randint(T low, T high, const _ShapeLike& shape) {
  439. random random_instance;
  440. return ndarray<T>(xt::random::randint<T>(shape, low, high));
  441. }
  442. template <typename T>
  443. static ndarray<T> uniform(T low, T high, const _ShapeLike& shape) {
  444. random random_instance;
  445. return ndarray<T>(xt::random::rand<T>(shape, low, high));
  446. }
  447. };
  448. template<typename T, typename U>
  449. xt::xarray<std::common_type_t<T, U>> matrix_mul(const xt::xarray<T>& a, const xt::xarray<U>& b) {
  450. using result_type = std::common_type_t<T, U>;
  451. using Mat = xt::xarray<result_type>;
  452. bool first_is_1d = false;
  453. bool second_is_1d = false;
  454. xt::xarray<T> a_copy = a;
  455. xt::xarray<U> b_copy = b;
  456. if (a.dimension() == 1) {
  457. first_is_1d = true;
  458. a_copy = xt::reshape_view(a_copy, {1, 3});
  459. }
  460. if(b_copy.dimension() == 1) {
  461. second_is_1d = true;
  462. b_copy = xt::reshape_view(b_copy, {3, 1});
  463. }
  464. if (a_copy.dimension() == 2 && b_copy.dimension() == 2) {
  465. int m = static_cast<int>(a_copy.shape()[0]);
  466. int n = static_cast<int>(a_copy.shape()[1]);
  467. int p = static_cast<int>(b_copy.shape()[1]);
  468. Mat result = xt::zeros<result_type>({m, p});
  469. for (int i = 0; i < m; i++) {
  470. for (int j = 0; j < p; j++) {
  471. for (int k = 0; k < n; k++) {
  472. result(i, j) = result(i, j) + a_copy(i, k) * b_copy(k, j);
  473. }
  474. }
  475. }
  476. if (first_is_1d) {
  477. result = xt::squeeze(result, std::vector<std::size_t>{result.dimension()-2});
  478. }
  479. if (second_is_1d) {
  480. result = xt::squeeze(result, std::vector<std::size_t>{result.dimension()-1});
  481. }
  482. return result;
  483. }
  484. else {
  485. if (a_copy.dimension() == b_copy.dimension()) {
  486. assert(a_copy.shape()[0] == b_copy.shape()[0]);
  487. size_t layers = a_copy.shape()[0];
  488. Mat sub;
  489. {
  490. Mat a0 = xt::view(a_copy, 0);
  491. Mat b0 = xt::view(b_copy, 0);
  492. sub = matrix_mul(a0, b0);
  493. }
  494. auto out_shape = sub.shape();
  495. out_shape.insert(out_shape.begin(), layers);
  496. auto result = Mat::from_shape(out_shape);
  497. xt::view(result, 0) = sub;
  498. for (size_t i = 1; i < layers; i++) {
  499. Mat ai = xt::view(a_copy, i);
  500. Mat bi = xt::view(b_copy, i);
  501. xt::view(result, i) = matrix_mul(ai, bi);
  502. }
  503. if (first_is_1d) {
  504. result = xt::squeeze(result, std::vector<std::size_t>{result.dimension()-2});
  505. }
  506. if (second_is_1d) {
  507. result = xt::squeeze(result, std::vector<std::size_t>{result.dimension()-1});
  508. }
  509. return result;
  510. } else if (a_copy.dimension() > b_copy.dimension()) {
  511. assert(a_copy.dimension() > b_copy.dimension());
  512. size_t layers = a_copy.shape()[0];
  513. Mat sub;
  514. {
  515. Mat a0 = xt::view(a_copy, 0);
  516. sub = matrix_mul(a0, b_copy);
  517. }
  518. auto out_shape = sub.shape();
  519. out_shape.insert(out_shape.begin(), layers);
  520. auto result = Mat::from_shape(out_shape);
  521. xt::view(result, 0) = sub;
  522. for (size_t i = 1; i < layers; i++) {
  523. Mat ai = xt::view(a_copy, i);
  524. xt::view(result, i) = matrix_mul(ai, b_copy);
  525. }
  526. if (first_is_1d) {
  527. result = xt::squeeze(result, std::vector<std::size_t>{result.dimension()-2});
  528. }
  529. if (second_is_1d) {
  530. result = xt::squeeze(result, std::vector<std::size_t>{result.dimension()-1});
  531. }
  532. return result;
  533. } else {
  534. assert(a_copy.dimension() < b_copy.dimension());
  535. size_t layers = b_copy.shape()[0];
  536. Mat sub;
  537. {
  538. Mat b0 = xt::view(b_copy, 0);
  539. sub = matrix_mul(a_copy, b0);
  540. }
  541. auto out_shape = sub.shape();
  542. out_shape.insert(out_shape.begin(), layers);
  543. auto result = Mat::from_shape(out_shape);
  544. xt::view(result, 0) = sub;
  545. for (size_t i = 1; i < layers; i++) {
  546. Mat bi = xt::view(b_copy, i);
  547. xt::view(result, i) = matrix_mul(a_copy, bi);
  548. }
  549. if (first_is_1d) {
  550. result = xt::squeeze(result, std::vector<std::size_t>{result.dimension()-2});
  551. }
  552. if (second_is_1d) {
  553. result = xt::squeeze(result, std::vector<std::size_t>{result.dimension()-1});
  554. }
  555. return result;
  556. }
  557. }
  558. }
  559. template <typename T, typename U>
  560. ndarray<std::common_type_t<T, U>> matmul(const ndarray<T>& a, const ndarray<U>& b) {
  561. return ndarray<std::common_type_t<T, U>>(matrix_mul(a.get_array(), b.get_array()));
  562. }
  563. template <typename T>
  564. ndarray<T> adapt(const std::vector<T>& init_list) {
  565. return ndarray<T>(xt::adapt(init_list));
  566. }
  567. template <typename T>
  568. ndarray<T> adapt(const std::vector<std::vector<T>>& init_list) {
  569. std::vector<T> flat_list;
  570. for(auto row: init_list) {
  571. for(auto elem: row) {
  572. flat_list.push_back(elem);
  573. }
  574. }
  575. std::vector<size_t> sh = {init_list.size(), init_list[0].size()};
  576. return ndarray<T>(xt::adapt(flat_list, sh));
  577. }
  578. template <typename T>
  579. ndarray<T> adapt(const std::vector<std::vector<std::vector<T>>>& init_list) {
  580. std::vector<T> flat_list;
  581. for(auto row: init_list) {
  582. for(auto elem: row) {
  583. for(auto val: elem) {
  584. flat_list.push_back(val);
  585. }
  586. }
  587. }
  588. std::vector<size_t> sh = {init_list.size(), init_list[0].size(), init_list[0][0].size()};
  589. return ndarray<T>(xt::adapt(flat_list, sh));
  590. }
  591. template <typename T>
  592. ndarray<T> adapt(const std::vector<std::vector<std::vector<std::vector<T>>>>& init_list) {
  593. std::vector<T> flat_list;
  594. for(auto row: init_list) {
  595. for(auto elem: row) {
  596. for(auto val: elem) {
  597. for(auto v: val) {
  598. flat_list.push_back(v);
  599. }
  600. }
  601. }
  602. }
  603. std::vector<size_t> sh = {init_list.size(), init_list[0].size(), init_list[0][0].size(), init_list[0][0][0].size()};
  604. return ndarray<T>(xt::adapt(flat_list, sh));
  605. }
  606. template <typename T>
  607. ndarray<T> adapt(const std::vector<std::vector<std::vector<std::vector<std::vector<T>>>>>& init_list) {
  608. std::vector<T> flat_list;
  609. for(auto row: init_list) {
  610. for(auto elem: row) {
  611. for(auto val: elem) {
  612. for(auto v: val) {
  613. for(auto v1: v) {
  614. flat_list.push_back(v1);
  615. }
  616. }
  617. }
  618. }
  619. }
  620. std::vector<size_t> sh = {init_list.size(),
  621. init_list[0].size(),
  622. init_list[0][0].size(),
  623. init_list[0][0][0].size(),
  624. init_list[0][0][0][0].size()};
  625. return ndarray<T>(xt::adapt(flat_list, sh));
  626. }
  627. // Array Creation
  628. template <typename U, typename T>
  629. ndarray<U> array(const std::vector<T>& vec, const _ShapeLike& shape = {}) {
  630. if(shape.empty()) {
  631. return ndarray<U>(xt::cast<U>(xt::adapt(vec)));
  632. } else {
  633. return ndarray<U>(xt::cast<U>(xt::adapt(vec, shape)));
  634. }
  635. }
  636. template <typename T>
  637. ndarray<T> zeros(const _ShapeLike& shape) {
  638. return ndarray<T>(xt::zeros<T>(shape));
  639. }
  640. template <typename T>
  641. ndarray<T> ones(const _ShapeLike& shape) {
  642. return ndarray<T>(xt::ones<T>(shape));
  643. }
  644. template <typename T>
  645. ndarray<T> full(const _ShapeLike& shape, const T& fill_value) {
  646. xt::xarray<T> result = xt::ones<T>(shape);
  647. for(auto it = result.begin(); it != result.end(); ++it) {
  648. *it = fill_value;
  649. }
  650. return ndarray<T>(result);
  651. }
  652. template <typename T>
  653. ndarray<T> identity(int n) {
  654. return ndarray<T>(xt::eye<T>(n));
  655. }
  656. template <typename T>
  657. ndarray<T> arange(const T& stop) {
  658. return ndarray<T>(xt::arange<T>(stop));
  659. }
  660. template <typename T>
  661. ndarray<T> arange(const T& start, const T& stop) {
  662. return ndarray<T>(xt::arange<T>(start, stop));
  663. }
  664. template <typename T>
  665. ndarray<T> arange(const T& start, const T& stop, const T& step) {
  666. return ndarray<T>(xt::arange<T>(start, stop, step));
  667. }
  668. template <typename T>
  669. ndarray<T> linspace(const T& start, const T& stop, int num = 50, bool endpoint = true) {
  670. return ndarray<T>(xt::linspace<T>(start, stop, num, endpoint));
  671. }
  672. // Trigonometry
  673. template <typename T>
  674. ndarray<float_> sin(const ndarray<T>& arr) {
  675. return ndarray<float_>(xt::sin(arr.get_array()));
  676. }
  677. ndarray<complex_> sin(const ndarray<complex64>& arr) { return ndarray<complex_>(xt::sin(arr.get_array())); }
  678. ndarray<complex_> sin(const ndarray<complex128>& arr) { return ndarray<complex_>(xt::sin(arr.get_array())); }
  679. template <typename T>
  680. ndarray<float_> cos(const ndarray<T>& arr) {
  681. return ndarray<float_>(xt::cos(arr.get_array()));
  682. }
  683. ndarray<complex_> cos(const ndarray<complex64>& arr) { return ndarray<complex_>(xt::cos(arr.get_array())); }
  684. ndarray<complex_> cos(const ndarray<complex128>& arr) { return ndarray<complex_>(xt::cos(arr.get_array())); }
  685. template <typename T>
  686. ndarray<float_> tan(const ndarray<T>& arr) {
  687. return ndarray<float_>(xt::tan(arr.get_array()));
  688. }
  689. ndarray<complex_> tan(const ndarray<complex64>& arr) { return ndarray<complex_>(xt::tan(arr.get_array())); }
  690. ndarray<complex_> tan(const ndarray<complex128>& arr) { return ndarray<complex_>(xt::tan(arr.get_array())); }
  691. template <typename T>
  692. ndarray<float_> arcsin(const ndarray<T>& arr) {
  693. return ndarray<float_>(xt::asin(arr.get_array()));
  694. }
  695. ndarray<complex_> arcsin(const ndarray<complex64>& arr) { return ndarray<complex_>(xt::asin(arr.get_array())); }
  696. ndarray<complex_> arcsin(const ndarray<complex128>& arr) { return ndarray<complex_>(xt::asin(arr.get_array())); }
  697. template <typename T>
  698. ndarray<float_> arccos(const ndarray<T>& arr) {
  699. return ndarray<float_>(xt::acos(arr.get_array()));
  700. }
  701. ndarray<complex_> arccos(const ndarray<complex64>& arr) { return ndarray<complex_>(xt::acos(arr.get_array())); }
  702. ndarray<complex_> arccos(const ndarray<complex128>& arr) { return ndarray<complex_>(xt::acos(arr.get_array())); }
  703. template <typename T>
  704. ndarray<float_> arctan(const ndarray<T>& arr) {
  705. return ndarray<float_>(xt::atan(arr.get_array()));
  706. }
  707. ndarray<complex_> arctan(const ndarray<complex64>& arr) { return ndarray<complex_>(xt::atan(arr.get_array())); }
  708. ndarray<complex_> arctan(const ndarray<complex128>& arr) { return ndarray<complex_>(xt::atan(arr.get_array())); }
  709. // Exponents and Logarithms
  710. template <typename T>
  711. ndarray<float_> exp(const ndarray<T>& arr) {
  712. return ndarray<float_>(xt::exp(arr.get_array()));
  713. }
  714. ndarray<complex_> exp(const ndarray<complex64>& arr) { return ndarray<complex_>(xt::exp(arr.get_array())); }
  715. ndarray<complex_> exp(const ndarray<complex128>& arr) { return ndarray<complex_>(xt::exp(arr.get_array())); }
  716. template <typename T>
  717. ndarray<float_> log(const ndarray<T>& arr) {
  718. return ndarray<float_>(xt::log(arr.get_array()));
  719. }
  720. ndarray<complex_> log(const ndarray<complex64>& arr) { return ndarray<complex_>(xt::log(arr.get_array())); }
  721. ndarray<complex_> log(const ndarray<complex128>& arr) { return ndarray<complex_>(xt::log(arr.get_array())); }
  722. template <typename T>
  723. ndarray<float_> log2(const ndarray<T>& arr) {
  724. return ndarray<float_>(xt::log2(arr.get_array()));
  725. }
  726. template <typename T>
  727. ndarray<float_> log10(const ndarray<T>& arr) {
  728. return ndarray<float_>(xt::log10(arr.get_array()));
  729. }
  730. // Miscellanous
  731. template <typename T>
  732. ndarray<T> round(const ndarray<T>& arr) {
  733. return ndarray<T>(xt::round(arr.get_array()));
  734. }
  735. template <typename T>
  736. ndarray<T> floor(const ndarray<T>& arr) {
  737. return ndarray<T>(xt::floor(arr.get_array()));
  738. }
  739. template <typename T>
  740. ndarray<T> ceil(const ndarray<T>& arr) {
  741. return ndarray<T>(xt::ceil(arr.get_array()));
  742. }
  743. template <typename T>
  744. auto abs(const ndarray<T>& arr) {
  745. if constexpr(std::is_same_v<T, complex64> || std::is_same_v<T, complex128>) {
  746. return ndarray<float_>(xt::abs(arr.get_array()));
  747. } else {
  748. return ndarray<T>(xt::abs(arr.get_array()));
  749. }
  750. }
  751. // Xtensor only supports concatenation of initialized objects.
  752. // https://github.com/xtensor-stack/xtensor/issues/1450
  753. template <typename T, typename U>
  754. auto concatenate(const ndarray<T>& arr1, const ndarray<U>& arr2, int axis = 0) {
  755. using result_type = std::common_type_t<T, U>;
  756. xt::xarray<result_type> xarr1 = xt::cast<result_type>(arr1.get_array());
  757. xt::xarray<result_type> xarr2 = xt::cast<result_type>(arr2.get_array());
  758. return ndarray<result_type>(xt::concatenate(xt::xtuple(xarr1, xarr2), axis));
  759. }
  760. // Constants
  761. constexpr float_ pi = xt::numeric_constants<double>::PI;
  762. constexpr double inf = std::numeric_limits<double>::infinity();
  763. // Testing Functions
  764. template <typename T, typename U>
  765. bool allclose(const ndarray<T>& arr1, const ndarray<U>& arr2, float_ rtol = 1e-5, float_ atol = 1e-8) {
  766. return xt::allclose(arr1.get_array(), arr2.get_array(), rtol, atol);
  767. }
  768. // Reverse Dunder Methods
  769. template <typename T, typename U, typename = std::enable_if_t<!is_ndarray_v<U>>>
  770. auto operator+ (const U& scalar, const ndarray<T>& array) {
  771. xt::xarray<T> arr = array.get_array();
  772. if constexpr(std::is_same_v<U, float_>) {
  773. xt::xarray<float_> result = scalar + xt::cast<float_>(arr);
  774. return ndarray<float_>(result);
  775. } else {
  776. using result_type = std::common_type_t<T, U>;
  777. xt::xarray<result_type> result = scalar + xt::cast<result_type>(arr);
  778. return ndarray<result_type>(result);
  779. }
  780. }
  781. template <typename T, typename U, typename = std::enable_if_t<!is_ndarray_v<U>>>
  782. auto operator- (const U& scalar, const ndarray<T>& array) {
  783. xt::xarray<T> arr = array.get_array();
  784. if constexpr(std::is_same_v<U, float_>) {
  785. xt::xarray<float_> result = scalar - xt::cast<float_>(arr);
  786. return ndarray<float_>(result);
  787. } else {
  788. using result_type = std::common_type_t<T, U>;
  789. xt::xarray<result_type> result = scalar - xt::cast<result_type>(arr);
  790. return ndarray<result_type>(result);
  791. }
  792. }
  793. template <typename T, typename U, typename = std::enable_if_t<!is_ndarray_v<U>>>
  794. auto operator* (const U& scalar, const ndarray<T>& array) {
  795. xt::xarray<T> arr = array.get_array();
  796. if constexpr(std::is_same_v<U, float_>) {
  797. xt::xarray<float_> result = scalar * xt::cast<float_>(arr);
  798. return ndarray<float_>(result);
  799. } else {
  800. using result_type = std::common_type_t<T, U>;
  801. xt::xarray<result_type> result = scalar * xt::cast<result_type>(arr);
  802. return ndarray<result_type>(result);
  803. }
  804. }
  805. template <typename T, typename U, typename = std::enable_if_t<!is_ndarray_v<U>>>
  806. auto operator/ (const U& scalar, const ndarray<T>& array) {
  807. xt::xarray<T> arr = array.get_array();
  808. xt::xarray<float_> result = static_cast<float_>(scalar) / xt::cast<float_>(arr);
  809. return ndarray<float_>(result);
  810. }
  811. template <typename T, typename U, typename = std::enable_if_t<!is_ndarray_v<U>>>
  812. auto pow(const U& scalar, const ndarray<T>& array) {
  813. xt::xarray<T> arr = array.get_array();
  814. xt::xarray<float_> result = xt::pow(scalar, xt::cast<float_>(arr));
  815. return ndarray<float_>(result);
  816. }
  817. template <typename T, typename U, typename = std::enable_if_t<!is_ndarray_v<U>>>
  818. auto operator& (const U& scalar, const ndarray<T>& array) {
  819. return array & scalar;
  820. }
  821. template <typename T, typename U, typename = std::enable_if_t<!is_ndarray_v<U>>>
  822. auto operator| (const U& scalar, const ndarray<T>& array) {
  823. return array | scalar;
  824. }
  825. template <typename T, typename U, typename = std::enable_if_t<!is_ndarray_v<U>>>
  826. auto operator^ (const U& scalar, const ndarray<T>& array) {
  827. return array ^ scalar;
  828. }
  829. template <typename T>
  830. std::ostream& operator<< (std::ostream& os, const ndarray<T>& arr) {
  831. os << arr.get_array();
  832. return os;
  833. }
  834. } // namespace numpy
  835. } // namespace pkpy