xnpy.hpp 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803
  1. /***************************************************************************
  2. * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
  3. * Copyright Leon Merten Lohse
  4. * Copyright (c) QuantStack *
  5. * *
  6. * Distributed under the terms of the BSD 3-Clause License. *
  7. * *
  8. * The full license is in the file LICENSE, distributed with this software. *
  9. ****************************************************************************/
  10. #ifndef XTENSOR_NPY_HPP
  11. #define XTENSOR_NPY_HPP
  12. // Derived from https://github.com/llohse/libnpy by Leon Merten Lohse,
  13. // relicensed from MIT License with permission
  14. #include <algorithm>
  15. #include <complex>
  16. #include <cstdint>
  17. #include <cstring>
  18. #include <fstream>
  19. #include <iostream>
  20. #include <memory>
  21. #include <regex>
  22. #include <sstream>
  23. #include <stdexcept>
  24. #include <string>
  25. #include <typeinfo>
  26. #include <vector>
  27. #include <xtl/xplatform.hpp>
  28. #include <xtl/xsequence.hpp>
  29. #include "xtensor/xadapt.hpp"
  30. #include "xtensor/xarray.hpp"
  31. #include "xtensor/xeval.hpp"
  32. #include "xtensor/xstrides.hpp"
  33. #include "xtensor_config.hpp"
  34. namespace xt
  35. {
  36. using namespace std::string_literals;
  37. namespace detail
  38. {
  39. const char magic_string[] = "\x93NUMPY";
  40. const std::size_t magic_string_length = sizeof(magic_string) - 1;
  41. template <class O>
  42. inline void write_magic(O& ostream, unsigned char v_major = 1, unsigned char v_minor = 0)
  43. {
  44. ostream.write(magic_string, magic_string_length);
  45. ostream.put(char(v_major));
  46. ostream.put(char(v_minor));
  47. }
  48. inline void read_magic(std::istream& istream, unsigned char* v_major, unsigned char* v_minor)
  49. {
  50. std::unique_ptr<char[]> buf(new char[magic_string_length + 2]);
  51. istream.read(buf.get(), magic_string_length + 2);
  52. if (!istream)
  53. {
  54. XTENSOR_THROW(std::runtime_error, "io error: failed reading file");
  55. }
  56. for (std::size_t i = 0; i < magic_string_length; i++)
  57. {
  58. if (buf[i] != magic_string[i])
  59. {
  60. XTENSOR_THROW(std::runtime_error, "this file do not have a valid npy format.");
  61. }
  62. }
  63. *v_major = static_cast<unsigned char>(buf[magic_string_length]);
  64. *v_minor = static_cast<unsigned char>(buf[magic_string_length + 1]);
  65. }
  66. template <class T>
  67. inline char map_type()
  68. {
  69. if (std::is_same<T, float>::value)
  70. {
  71. return 'f';
  72. }
  73. if (std::is_same<T, double>::value)
  74. {
  75. return 'f';
  76. }
  77. if (std::is_same<T, long double>::value)
  78. {
  79. return 'f';
  80. }
  81. if (std::is_same<T, char>::value)
  82. {
  83. return 'i';
  84. }
  85. if (std::is_same<T, signed char>::value)
  86. {
  87. return 'i';
  88. }
  89. if (std::is_same<T, short>::value)
  90. {
  91. return 'i';
  92. }
  93. if (std::is_same<T, int>::value)
  94. {
  95. return 'i';
  96. }
  97. if (std::is_same<T, long>::value)
  98. {
  99. return 'i';
  100. }
  101. if (std::is_same<T, long long>::value)
  102. {
  103. return 'i';
  104. }
  105. if (std::is_same<T, unsigned char>::value)
  106. {
  107. return 'u';
  108. }
  109. if (std::is_same<T, unsigned short>::value)
  110. {
  111. return 'u';
  112. }
  113. if (std::is_same<T, unsigned int>::value)
  114. {
  115. return 'u';
  116. }
  117. if (std::is_same<T, unsigned long>::value)
  118. {
  119. return 'u';
  120. }
  121. if (std::is_same<T, unsigned long long>::value)
  122. {
  123. return 'u';
  124. }
  125. if (std::is_same<T, bool>::value)
  126. {
  127. return 'b';
  128. }
  129. if (std::is_same<T, std::complex<float>>::value)
  130. {
  131. return 'c';
  132. }
  133. if (std::is_same<T, std::complex<double>>::value)
  134. {
  135. return 'c';
  136. }
  137. if (std::is_same<T, std::complex<long double>>::value)
  138. {
  139. return 'c';
  140. }
  141. XTENSOR_THROW(std::runtime_error, "Type not known.");
  142. }
  143. template <class T>
  144. inline char get_endianess()
  145. {
  146. constexpr char little_endian_char = '<';
  147. constexpr char big_endian_char = '>';
  148. constexpr char no_endian_char = '|';
  149. if (sizeof(T) <= sizeof(char))
  150. {
  151. return no_endian_char;
  152. }
  153. switch (xtl::endianness())
  154. {
  155. case xtl::endian::little_endian:
  156. return little_endian_char;
  157. case xtl::endian::big_endian:
  158. return big_endian_char;
  159. default:
  160. return no_endian_char;
  161. }
  162. }
  163. template <class T>
  164. inline std::string build_typestring()
  165. {
  166. std::stringstream ss;
  167. ss << get_endianess<T>() << map_type<T>() << sizeof(T);
  168. return ss.str();
  169. }
  170. // Safety check function
  171. inline void parse_typestring(std::string typestring)
  172. {
  173. std::regex re("'([<>|])([ifucb])(\\d+)'");
  174. std::smatch sm;
  175. std::regex_match(typestring, sm, re);
  176. if (sm.size() != 4)
  177. {
  178. XTENSOR_THROW(std::runtime_error, "invalid typestring");
  179. }
  180. }
  181. // Helpers for the improvised parser
  182. inline std::string unwrap_s(std::string s, char delim_front, char delim_back)
  183. {
  184. if ((s.back() == delim_back) && (s.front() == delim_front))
  185. {
  186. return s.substr(1, s.length() - 2);
  187. }
  188. else
  189. {
  190. XTENSOR_THROW(std::runtime_error, "unable to unwrap");
  191. }
  192. }
  193. inline std::string get_value_from_map(std::string mapstr)
  194. {
  195. std::size_t sep_pos = mapstr.find_first_of(":");
  196. if (sep_pos == std::string::npos)
  197. {
  198. return "";
  199. }
  200. return mapstr.substr(sep_pos + 1);
  201. }
  202. inline void pop_char(std::string& s, char c)
  203. {
  204. if (s.back() == c)
  205. {
  206. s.pop_back();
  207. }
  208. }
  209. inline void
  210. parse_header(std::string header, std::string& descr, bool* fortran_order, std::vector<std::size_t>& shape)
  211. {
  212. // The first 6 bytes are a magic string: exactly "x93NUMPY".
  213. //
  214. // The next 1 byte is an unsigned byte: the major version number of the file
  215. // format, e.g. x01.
  216. //
  217. // The next 1 byte is an unsigned byte: the minor version number of the file
  218. // format, e.g. x00. Note: the version of the file format is not tied to the
  219. // version of the NumPy package.
  220. //
  221. // The next 2 bytes form a little-endian unsigned short int: the length of the
  222. // header data HEADER_LEN.
  223. //
  224. // The next HEADER_LEN bytes form the header data describing the array's
  225. // format. It is an ASCII string which contains a Python literal expression of
  226. // a dictionary. It is terminated by a newline ('n') and padded with spaces
  227. // ('x20') to make the total length of the magic string + 4 + HEADER_LEN be
  228. // evenly divisible by 16 for alignment purposes.
  229. //
  230. // The dictionary contains three keys:
  231. //
  232. // "descr" : dtype.descr
  233. // An object that can be passed as an argument to the numpy.dtype()
  234. // constructor to create the array's dtype.
  235. // "fortran_order" : bool
  236. // Whether the array data is Fortran-contiguous or not. Since
  237. // Fortran-contiguous arrays are a common form of non-C-contiguity, we allow
  238. // them to be written directly to disk for efficiency.
  239. // "shape" : tuple of int
  240. // The shape of the array.
  241. // For repeatability and readability, this dictionary is formatted using
  242. // pprint.pformat() so the keys are in alphabetic order.
  243. // remove trailing newline
  244. if (header.back() != '\n')
  245. {
  246. XTENSOR_THROW(std::runtime_error, "invalid header");
  247. }
  248. header.pop_back();
  249. // remove all whitespaces
  250. header.erase(std::remove(header.begin(), header.end(), ' '), header.end());
  251. // unwrap dictionary
  252. header = unwrap_s(header, '{', '}');
  253. // find the positions of the 3 dictionary keys
  254. std::size_t keypos_descr = header.find("'descr'");
  255. std::size_t keypos_fortran = header.find("'fortran_order'");
  256. std::size_t keypos_shape = header.find("'shape'");
  257. // make sure all the keys are present
  258. if (keypos_descr == std::string::npos)
  259. {
  260. XTENSOR_THROW(std::runtime_error, "missing 'descr' key");
  261. }
  262. if (keypos_fortran == std::string::npos)
  263. {
  264. XTENSOR_THROW(std::runtime_error, "missing 'fortran_order' key");
  265. }
  266. if (keypos_shape == std::string::npos)
  267. {
  268. XTENSOR_THROW(std::runtime_error, "missing 'shape' key");
  269. }
  270. // Make sure the keys are in order.
  271. // Note that this violates the standard, which states that readers *must* not
  272. // depend on the correct order here.
  273. // TODO: fix
  274. if (keypos_descr >= keypos_fortran || keypos_fortran >= keypos_shape)
  275. {
  276. XTENSOR_THROW(std::runtime_error, "header keys in wrong order");
  277. }
  278. // get the 3 key-value pairs
  279. std::string keyvalue_descr;
  280. keyvalue_descr = header.substr(keypos_descr, keypos_fortran - keypos_descr);
  281. pop_char(keyvalue_descr, ',');
  282. std::string keyvalue_fortran;
  283. keyvalue_fortran = header.substr(keypos_fortran, keypos_shape - keypos_fortran);
  284. pop_char(keyvalue_fortran, ',');
  285. std::string keyvalue_shape;
  286. keyvalue_shape = header.substr(keypos_shape, std::string::npos);
  287. pop_char(keyvalue_shape, ',');
  288. // get the values (right side of `:')
  289. std::string descr_s = get_value_from_map(keyvalue_descr);
  290. std::string fortran_s = get_value_from_map(keyvalue_fortran);
  291. std::string shape_s = get_value_from_map(keyvalue_shape);
  292. parse_typestring(descr_s);
  293. descr = unwrap_s(descr_s, '\'', '\'');
  294. // convert literal Python bool to C++ bool
  295. if (fortran_s == "True")
  296. {
  297. *fortran_order = true;
  298. }
  299. else if (fortran_s == "False")
  300. {
  301. *fortran_order = false;
  302. }
  303. else
  304. {
  305. XTENSOR_THROW(std::runtime_error, "invalid fortran_order value");
  306. }
  307. // parse the shape Python tuple ( x, y, z,)
  308. // first clear the vector
  309. shape.clear();
  310. shape_s = unwrap_s(shape_s, '(', ')');
  311. // a tokenizer would be nice...
  312. std::size_t pos = 0;
  313. for (;;)
  314. {
  315. std::size_t pos_next = shape_s.find_first_of(',', pos);
  316. std::string dim_s;
  317. if (pos_next != std::string::npos)
  318. {
  319. dim_s = shape_s.substr(pos, pos_next - pos);
  320. }
  321. else
  322. {
  323. dim_s = shape_s.substr(pos);
  324. }
  325. if (dim_s.length() == 0)
  326. {
  327. if (pos_next != std::string::npos)
  328. {
  329. XTENSOR_THROW(std::runtime_error, "invalid shape");
  330. }
  331. }
  332. else
  333. {
  334. std::stringstream ss;
  335. ss << dim_s;
  336. std::size_t tmp;
  337. ss >> tmp;
  338. shape.push_back(tmp);
  339. }
  340. if (pos_next != std::string::npos)
  341. {
  342. pos = ++pos_next;
  343. }
  344. else
  345. {
  346. break;
  347. }
  348. }
  349. }
  350. template <class O, class S>
  351. inline void write_header(O& out, const std::string& descr, bool fortran_order, const S& shape)
  352. {
  353. std::ostringstream ss_header;
  354. std::string s_fortran_order;
  355. if (fortran_order)
  356. {
  357. s_fortran_order = "True";
  358. }
  359. else
  360. {
  361. s_fortran_order = "False";
  362. }
  363. std::string s_shape;
  364. std::ostringstream ss_shape;
  365. ss_shape << "(";
  366. for (auto shape_it = std::begin(shape); shape_it != std::end(shape); ++shape_it)
  367. {
  368. ss_shape << *shape_it << ", ";
  369. }
  370. s_shape = ss_shape.str();
  371. if (xtl::sequence_size(shape) > 1)
  372. {
  373. s_shape = s_shape.erase(s_shape.size() - 2);
  374. }
  375. else if (xtl::sequence_size(shape) == 1)
  376. {
  377. s_shape = s_shape.erase(s_shape.size() - 1);
  378. }
  379. s_shape += ")";
  380. ss_header << "{'descr': '" << descr << "', 'fortran_order': " << s_fortran_order
  381. << ", 'shape': " << s_shape << ", }";
  382. std::size_t header_len_pre = ss_header.str().length() + 1;
  383. std::size_t metadata_len = magic_string_length + 2 + 2 + header_len_pre;
  384. unsigned char version[2] = {1, 0};
  385. if (metadata_len >= 255 * 255)
  386. {
  387. metadata_len = magic_string_length + 2 + 4 + header_len_pre;
  388. version[0] = 2;
  389. version[1] = 0;
  390. }
  391. std::size_t padding_len = 64 - (metadata_len % 64);
  392. std::string padding(padding_len, ' ');
  393. ss_header << padding;
  394. ss_header << std::endl;
  395. std::string header = ss_header.str();
  396. // write magic
  397. write_magic(out, version[0], version[1]);
  398. // write header length
  399. if (version[0] == 1 && version[1] == 0)
  400. {
  401. char header_len_le16[2];
  402. uint16_t header_len = uint16_t(header.length());
  403. header_len_le16[0] = char((header_len >> 0) & 0xff);
  404. header_len_le16[1] = char((header_len >> 8) & 0xff);
  405. out.write(reinterpret_cast<char*>(header_len_le16), 2);
  406. }
  407. else
  408. {
  409. char header_len_le32[4];
  410. uint32_t header_len = uint32_t(header.length());
  411. header_len_le32[0] = char((header_len >> 0) & 0xff);
  412. header_len_le32[1] = char((header_len >> 8) & 0xff);
  413. header_len_le32[2] = char((header_len >> 16) & 0xff);
  414. header_len_le32[3] = char((header_len >> 24) & 0xff);
  415. out.write(reinterpret_cast<char*>(header_len_le32), 4);
  416. }
  417. out << header;
  418. }
  419. inline std::string read_header_1_0(std::istream& istream)
  420. {
  421. // read header length and convert from little endian
  422. char header_len_le16[2];
  423. istream.read(header_len_le16, 2);
  424. uint16_t header_length = uint16_t(header_len_le16[0] << 0) | uint16_t(header_len_le16[1] << 8);
  425. if ((magic_string_length + 2 + 2 + header_length) % 16 != 0)
  426. {
  427. // TODO: display warning
  428. }
  429. std::unique_ptr<char[]> buf(new char[header_length]);
  430. istream.read(buf.get(), header_length);
  431. std::string header(buf.get(), header_length);
  432. return header;
  433. }
  434. inline std::string read_header_2_0(std::istream& istream)
  435. {
  436. // read header length and convert from little endian
  437. char header_len_le32[4];
  438. istream.read(header_len_le32, 4);
  439. uint32_t header_length = uint32_t(header_len_le32[0] << 0) | uint32_t(header_len_le32[1] << 8)
  440. | uint32_t(header_len_le32[2] << 16) | uint32_t(header_len_le32[3] << 24);
  441. if ((magic_string_length + 2 + 4 + header_length) % 16 != 0)
  442. {
  443. // TODO: display warning
  444. }
  445. std::unique_ptr<char[]> buf(new char[header_length]);
  446. istream.read(buf.get(), header_length);
  447. std::string header(buf.get(), header_length);
  448. return header;
  449. }
  450. struct npy_file
  451. {
  452. npy_file() = default;
  453. npy_file(std::vector<std::size_t>& shape, bool fortran_order, std::string typestring)
  454. : m_shape(shape)
  455. , m_fortran_order(fortran_order)
  456. , m_typestring(typestring)
  457. {
  458. // Allocate memory
  459. m_word_size = std::size_t(atoi(&typestring[2]));
  460. m_n_bytes = compute_size(shape) * m_word_size;
  461. m_buffer = std::allocator<char>{}.allocate(m_n_bytes);
  462. }
  463. ~npy_file()
  464. {
  465. if (m_buffer != nullptr)
  466. {
  467. std::allocator<char>{}.deallocate(m_buffer, m_n_bytes);
  468. }
  469. }
  470. // delete copy constructor
  471. npy_file(const npy_file&) = delete;
  472. npy_file& operator=(const npy_file&) = delete;
  473. // implement move constructor and assignment
  474. npy_file(npy_file&& rhs)
  475. : m_shape(std::move(rhs.m_shape))
  476. , m_fortran_order(std::move(rhs.m_fortran_order))
  477. , m_word_size(std::move(rhs.m_word_size))
  478. , m_n_bytes(std::move(rhs.m_n_bytes))
  479. , m_typestring(std::move(rhs.m_typestring))
  480. , m_buffer(rhs.m_buffer)
  481. {
  482. rhs.m_buffer = nullptr;
  483. }
  484. npy_file& operator=(npy_file&& rhs)
  485. {
  486. if (this != &rhs)
  487. {
  488. m_shape = std::move(rhs.m_shape);
  489. m_fortran_order = std::move(rhs.m_fortran_order);
  490. m_word_size = std::move(rhs.m_word_size);
  491. m_n_bytes = std::move(rhs.m_n_bytes);
  492. m_typestring = std::move(rhs.m_typestring);
  493. m_buffer = rhs.m_buffer;
  494. rhs.m_buffer = nullptr;
  495. }
  496. return *this;
  497. }
  498. template <class T, layout_type L>
  499. auto cast_impl(bool check_type)
  500. {
  501. if (m_buffer == nullptr)
  502. {
  503. XTENSOR_THROW(std::runtime_error, "This npy_file has already been cast.");
  504. }
  505. T* ptr = reinterpret_cast<T*>(&m_buffer[0]);
  506. std::vector<std::size_t> strides(m_shape.size());
  507. std::size_t sz = compute_size(m_shape);
  508. // check if the typestring matches the given one
  509. if (check_type && m_typestring != detail::build_typestring<T>())
  510. {
  511. XTENSOR_THROW(
  512. std::runtime_error,
  513. "Cast error: formats not matching "s + m_typestring + " vs "s
  514. + detail::build_typestring<T>()
  515. );
  516. }
  517. if ((L == layout_type::column_major && !m_fortran_order)
  518. || (L == layout_type::row_major && m_fortran_order))
  519. {
  520. XTENSOR_THROW(
  521. std::runtime_error,
  522. "Cast error: layout mismatch between npy file and requested layout."
  523. );
  524. }
  525. compute_strides(
  526. m_shape,
  527. m_fortran_order ? layout_type::column_major : layout_type::row_major,
  528. strides
  529. );
  530. std::vector<std::size_t> shape(m_shape);
  531. return std::make_tuple(ptr, sz, std::move(shape), std::move(strides));
  532. }
  533. template <class T, layout_type L = layout_type::dynamic>
  534. auto cast(bool check_type = true) &&
  535. {
  536. auto cast_elems = cast_impl<T, L>(check_type);
  537. m_buffer = nullptr;
  538. return adapt(
  539. std::move(std::get<0>(cast_elems)),
  540. std::get<1>(cast_elems),
  541. acquire_ownership(),
  542. std::get<2>(cast_elems),
  543. std::get<3>(cast_elems)
  544. );
  545. }
  546. template <class T, layout_type L = layout_type::dynamic>
  547. auto cast(bool check_type = true) const&
  548. {
  549. auto cast_elems = cast_impl<T, L>(check_type);
  550. return adapt(
  551. std::get<0>(cast_elems),
  552. std::get<1>(cast_elems),
  553. no_ownership(),
  554. std::get<2>(cast_elems),
  555. std::get<3>(cast_elems)
  556. );
  557. }
  558. template <class T, layout_type L = layout_type::dynamic>
  559. auto cast(bool check_type = true) &
  560. {
  561. auto cast_elems = cast_impl<T, L>(check_type);
  562. return adapt(
  563. std::get<0>(cast_elems),
  564. std::get<1>(cast_elems),
  565. no_ownership(),
  566. std::get<2>(cast_elems),
  567. std::get<3>(cast_elems)
  568. );
  569. }
  570. char* ptr()
  571. {
  572. return m_buffer;
  573. }
  574. std::size_t n_bytes()
  575. {
  576. return m_n_bytes;
  577. }
  578. std::vector<std::size_t> m_shape;
  579. bool m_fortran_order;
  580. std::size_t m_word_size;
  581. std::size_t m_n_bytes;
  582. std::string m_typestring;
  583. char* m_buffer;
  584. };
  585. inline npy_file load_npy_file(std::istream& stream)
  586. {
  587. // check magic bytes an version number
  588. unsigned char v_major, v_minor;
  589. detail::read_magic(stream, &v_major, &v_minor);
  590. std::string header;
  591. if (v_major == 1 && v_minor == 0)
  592. {
  593. header = detail::read_header_1_0(stream);
  594. }
  595. else if (v_major == 2 && v_minor == 0)
  596. {
  597. header = detail::read_header_2_0(stream);
  598. }
  599. else
  600. {
  601. XTENSOR_THROW(std::runtime_error, "unsupported file format version");
  602. }
  603. // parse header
  604. bool fortran_order;
  605. std::string typestr;
  606. std::vector<std::size_t> shape;
  607. detail::parse_header(header, typestr, &fortran_order, shape);
  608. npy_file result(shape, fortran_order, typestr);
  609. // read the data
  610. stream.read(result.ptr(), std::streamsize((result.n_bytes())));
  611. return result;
  612. }
  613. template <class O, class E>
  614. inline void dump_npy_stream(O& stream, const xexpression<E>& e)
  615. {
  616. using value_type = typename E::value_type;
  617. const E& ex = e.derived_cast();
  618. auto&& eval_ex = eval(ex);
  619. bool fortran_order = false;
  620. if (eval_ex.layout() == layout_type::column_major && eval_ex.dimension() > 1)
  621. {
  622. fortran_order = true;
  623. }
  624. std::string typestring = detail::build_typestring<value_type>();
  625. auto shape = eval_ex.shape();
  626. detail::write_header(stream, typestring, fortran_order, shape);
  627. std::size_t size = compute_size(shape);
  628. stream.write(
  629. reinterpret_cast<const char*>(eval_ex.data()),
  630. std::streamsize((sizeof(value_type) * size))
  631. );
  632. }
  633. } // namespace detail
  634. /**
  635. * Save xexpression to NumPy npy format
  636. *
  637. * @param filename The filename or path to dump the data
  638. * @param e the xexpression
  639. */
  640. template <typename E>
  641. inline void dump_npy(const std::string& filename, const xexpression<E>& e)
  642. {
  643. std::ofstream stream(filename, std::ofstream::binary);
  644. if (!stream)
  645. {
  646. XTENSOR_THROW(std::runtime_error, "IO Error: failed to open file: "s + filename);
  647. }
  648. detail::dump_npy_stream(stream, e);
  649. }
  650. /**
  651. * Save xexpression to NumPy npy format in a string
  652. *
  653. * @param e the xexpression
  654. */
  655. template <typename E>
  656. inline std::string dump_npy(const xexpression<E>& e)
  657. {
  658. std::stringstream stream;
  659. detail::dump_npy_stream(stream, e);
  660. return stream.str();
  661. }
  662. /**
  663. * Loads a npy file (the NumPy storage format)
  664. *
  665. * @param stream An input stream from which to load the file
  666. * @tparam T select the type of the npy file (note: currently there is
  667. * no dynamic casting if types do not match)
  668. * @tparam L select layout_type::column_major if you stored data in
  669. * Fortran format
  670. * @return xarray with contents from npy file
  671. */
  672. template <typename T, layout_type L = layout_type::dynamic>
  673. inline auto load_npy(std::istream& stream)
  674. {
  675. detail::npy_file file = detail::load_npy_file(stream);
  676. return std::move(file).cast<T, L>();
  677. }
  678. /**
  679. * Loads a npy file (the NumPy storage format)
  680. *
  681. * @param filename The filename or path to the file
  682. * @tparam T select the type of the npy file (note: currently there is
  683. * no dynamic casting if types do not match)
  684. * @tparam L select layout_type::column_major if you stored data in
  685. * Fortran format
  686. * @return xarray with contents from npy file
  687. */
  688. template <typename T, layout_type L = layout_type::dynamic>
  689. inline auto load_npy(const std::string& filename)
  690. {
  691. std::ifstream stream(filename, std::ifstream::binary);
  692. if (!stream)
  693. {
  694. XTENSOR_THROW(std::runtime_error, "io error: failed to open a file.");
  695. }
  696. return load_npy<T, L>(stream);
  697. }
  698. } // namespace xt
  699. #endif