xio.hpp 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832
  1. /***************************************************************************
  2. * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
  3. * Copyright (c) QuantStack *
  4. * *
  5. * Distributed under the terms of the BSD 3-Clause License. *
  6. * *
  7. * The full license is in the file LICENSE, distributed with this software. *
  8. ****************************************************************************/
  9. #ifndef XTENSOR_IO_HPP
  10. #define XTENSOR_IO_HPP
  11. #include <complex>
  12. #include <cstddef>
  13. #include <iomanip>
  14. #include <iostream>
  15. #include <numeric>
  16. #include <sstream>
  17. #include <string>
  18. #include "xexpression.hpp"
  19. #include "xmath.hpp"
  20. #include "xstrided_view.hpp"
  21. namespace xt
  22. {
  23. template <class E>
  24. inline std::ostream& operator<<(std::ostream& out, const xexpression<E>& e);
  25. /*****************
  26. * print options *
  27. *****************/
  28. namespace print_options
  29. {
  30. struct print_options_impl
  31. {
  32. int edge_items = 3;
  33. int line_width = 75;
  34. int threshold = 1000;
  35. int precision = -1; // default precision
  36. };
  37. inline print_options_impl& print_options()
  38. {
  39. static print_options_impl po;
  40. return po;
  41. }
  42. /**
  43. * @brief Sets the line width. After \a line_width chars,
  44. * a new line is added.
  45. *
  46. * @param line_width The line width
  47. */
  48. inline void set_line_width(int line_width)
  49. {
  50. print_options().line_width = line_width;
  51. }
  52. /**
  53. * @brief Sets the threshold after which summarization is triggered (default: 1000).
  54. *
  55. * @param threshold The number of elements in the xexpression that triggers
  56. * summarization in the output
  57. */
  58. inline void set_threshold(int threshold)
  59. {
  60. print_options().threshold = threshold;
  61. }
  62. /**
  63. * @brief Sets the number of edge items. If the summarization is
  64. * triggered, this value defines how many items of each dimension
  65. * are printed.
  66. *
  67. * @param edge_items The number of edge items
  68. */
  69. inline void set_edge_items(int edge_items)
  70. {
  71. print_options().edge_items = edge_items;
  72. }
  73. /**
  74. * @brief Sets the precision for printing floating point values.
  75. *
  76. * @param precision The number of digits for floating point output
  77. */
  78. inline void set_precision(int precision)
  79. {
  80. print_options().precision = precision;
  81. }
  82. #define DEFINE_LOCAL_PRINT_OPTION(NAME) \
  83. class NAME \
  84. { \
  85. public: \
  86. \
  87. NAME(int value) \
  88. : m_value(value) \
  89. { \
  90. id(); \
  91. } \
  92. static int id() \
  93. { \
  94. static int id = std::ios_base::xalloc(); \
  95. return id; \
  96. } \
  97. int value() const \
  98. { \
  99. return m_value; \
  100. } \
  101. \
  102. private: \
  103. \
  104. int m_value; \
  105. }; \
  106. \
  107. inline std::ostream& operator<<(std::ostream& out, const NAME& n) \
  108. { \
  109. out.iword(NAME::id()) = n.value(); \
  110. return out; \
  111. }
  112. /**
  113. * @class line_width
  114. *
  115. * io manipulator used to set the width of the lines when printing
  116. * an expression.
  117. *
  118. * @code{.cpp}
  119. * using po = xt::print_options;
  120. * xt::xarray<double> a = {{1, 2, 3}, {4, 5, 6}};
  121. * std::cout << po::line_width(100) << a << std::endl;
  122. * @endcode
  123. */
  124. DEFINE_LOCAL_PRINT_OPTION(line_width)
  125. /**
  126. * @class threshold
  127. *
  128. * io manipulator used to set the threshold after which summarization is
  129. * triggered.
  130. *
  131. * @code{.cpp}
  132. * using po = xt::print_options;
  133. * xt::xarray<double> a = xt::rand::randn<double>({2000, 500});
  134. * std::cout << po::threshold(50) << a << std::endl;
  135. * @endcode
  136. */
  137. DEFINE_LOCAL_PRINT_OPTION(threshold)
  138. /**
  139. * @class edge_items
  140. *
  141. * io manipulator used to set the number of egde items if
  142. * the summarization is triggered.
  143. *
  144. * @code{.cpp}
  145. * using po = xt::print_options;
  146. * xt::xarray<double> a = xt::rand::randn<double>({2000, 500});
  147. * std::cout << po::edge_items(5) << a << std::endl;
  148. * @endcode
  149. */
  150. DEFINE_LOCAL_PRINT_OPTION(edge_items)
  151. /**
  152. * @class precision
  153. *
  154. * io manipulator used to set the precision of the floating point values
  155. * when printing an expression.
  156. *
  157. * @code{.cpp}
  158. * using po = xt::print_options;
  159. * xt::xarray<double> a = xt::rand::randn<double>({2000, 500});
  160. * std::cout << po::precision(5) << a << std::endl;
  161. * @endcode
  162. */
  163. DEFINE_LOCAL_PRINT_OPTION(precision)
  164. }
  165. /**************************************
  166. * xexpression ostream implementation *
  167. **************************************/
  168. namespace detail
  169. {
  170. template <class E, class F>
  171. std::ostream& xoutput(
  172. std::ostream& out,
  173. const E& e,
  174. xstrided_slice_vector& slices,
  175. F& printer,
  176. std::size_t blanks,
  177. std::streamsize element_width,
  178. std::size_t edgeitems,
  179. std::size_t line_width
  180. )
  181. {
  182. using size_type = typename E::size_type;
  183. const auto view = xt::strided_view(e, slices);
  184. if (view.dimension() == 0)
  185. {
  186. printer.print_next(out);
  187. }
  188. else
  189. {
  190. std::string indents(blanks, ' ');
  191. size_type i = 0;
  192. size_type elems_on_line = 0;
  193. const size_type ewp2 = static_cast<size_type>(element_width) + size_type(2);
  194. const size_type line_lim = static_cast<size_type>(std::floor(line_width / ewp2));
  195. out << '{';
  196. for (; i != size_type(view.shape()[0] - 1); ++i)
  197. {
  198. if (edgeitems && size_type(view.shape()[0]) > (edgeitems * 2) && i == edgeitems)
  199. {
  200. if (view.dimension() == 1 && line_lim != 0 && elems_on_line >= line_lim)
  201. {
  202. out << " ...,";
  203. }
  204. else if (view.dimension() > 1)
  205. {
  206. elems_on_line = 0;
  207. out << "...," << std::endl << indents;
  208. }
  209. else
  210. {
  211. out << "..., ";
  212. }
  213. i = size_type(view.shape()[0]) - edgeitems;
  214. }
  215. if (view.dimension() == 1 && line_lim != 0 && elems_on_line >= line_lim)
  216. {
  217. out << std::endl << indents;
  218. elems_on_line = 0;
  219. }
  220. slices.push_back(static_cast<int>(i));
  221. xoutput(out, e, slices, printer, blanks + 1, element_width, edgeitems, line_width) << ',';
  222. slices.pop_back();
  223. elems_on_line++;
  224. if ((view.dimension() == 1) && !(line_lim != 0 && elems_on_line >= line_lim))
  225. {
  226. out << ' ';
  227. }
  228. else if (view.dimension() > 1)
  229. {
  230. out << std::endl << indents;
  231. }
  232. }
  233. if (view.dimension() == 1 && line_lim != 0 && elems_on_line >= line_lim)
  234. {
  235. out << std::endl << indents;
  236. }
  237. slices.push_back(static_cast<int>(i));
  238. xoutput(out, e, slices, printer, blanks + 1, element_width, edgeitems, line_width) << '}';
  239. slices.pop_back();
  240. }
  241. return out;
  242. }
  243. template <class F, class E>
  244. void recurser_run(F& fn, const E& e, xstrided_slice_vector& slices, std::size_t lim = 0)
  245. {
  246. using size_type = typename E::size_type;
  247. const auto view = strided_view(e, slices);
  248. if (view.dimension() == 0)
  249. {
  250. fn.update(view());
  251. }
  252. else
  253. {
  254. size_type i = 0;
  255. for (; i != static_cast<size_type>(view.shape()[0] - 1); ++i)
  256. {
  257. if (lim && size_type(view.shape()[0]) > (lim * 2) && i == lim)
  258. {
  259. i = static_cast<size_type>(view.shape()[0]) - lim;
  260. }
  261. slices.push_back(static_cast<int>(i));
  262. recurser_run(fn, e, slices, lim);
  263. slices.pop_back();
  264. }
  265. slices.push_back(static_cast<int>(i));
  266. recurser_run(fn, e, slices, lim);
  267. slices.pop_back();
  268. }
  269. }
  270. template <class T, class E = void>
  271. struct printer;
  272. template <class T>
  273. struct printer<T, std::enable_if_t<std::is_floating_point<typename T::value_type>::value>>
  274. {
  275. using value_type = std::decay_t<typename T::value_type>;
  276. using cache_type = std::vector<value_type>;
  277. using cache_iterator = typename cache_type::const_iterator;
  278. explicit printer(std::streamsize precision)
  279. : m_precision(precision)
  280. {
  281. }
  282. void init()
  283. {
  284. m_precision = m_required_precision < m_precision ? m_required_precision : m_precision;
  285. m_it = m_cache.cbegin();
  286. if (m_scientific)
  287. {
  288. // 3 = sign, number and dot and 4 = "e+00"
  289. m_width = m_precision + 7;
  290. if (m_large_exponent)
  291. {
  292. // = e+000 (additional number)
  293. m_width += 1;
  294. }
  295. }
  296. else
  297. {
  298. std::streamsize decimals = 1; // print a leading 0
  299. if (std::floor(m_max) != 0)
  300. {
  301. decimals += std::streamsize(std::log10(std::floor(m_max)));
  302. }
  303. // 2 => sign and dot
  304. m_width = 2 + decimals + m_precision;
  305. }
  306. if (!m_required_precision)
  307. {
  308. --m_width;
  309. }
  310. }
  311. std::ostream& print_next(std::ostream& out)
  312. {
  313. if (!m_scientific)
  314. {
  315. std::stringstream buf;
  316. buf.width(m_width);
  317. buf << std::fixed;
  318. buf.precision(m_precision);
  319. buf << (*m_it);
  320. if (!m_required_precision && !std::isinf(*m_it) && !std::isnan(*m_it))
  321. {
  322. buf << '.';
  323. }
  324. std::string res = buf.str();
  325. auto sit = res.rbegin();
  326. while (*sit == '0')
  327. {
  328. *sit = ' ';
  329. ++sit;
  330. }
  331. out << res;
  332. }
  333. else
  334. {
  335. if (!m_large_exponent)
  336. {
  337. out << std::scientific;
  338. out.width(m_width);
  339. out << (*m_it);
  340. }
  341. else
  342. {
  343. std::stringstream buf;
  344. buf.width(m_width);
  345. buf << std::scientific;
  346. buf.precision(m_precision);
  347. buf << (*m_it);
  348. std::string res = buf.str();
  349. if (res[res.size() - 4] == 'e')
  350. {
  351. res.erase(0, 1);
  352. res.insert(res.size() - 2, "0");
  353. }
  354. out << res;
  355. }
  356. }
  357. ++m_it;
  358. return out;
  359. }
  360. void update(const value_type& val)
  361. {
  362. if (val != 0 && !std::isinf(val) && !std::isnan(val))
  363. {
  364. if (!m_scientific || !m_large_exponent)
  365. {
  366. int exponent = 1 + int(std::log10(math::abs(val)));
  367. if (exponent <= -5 || exponent > 7)
  368. {
  369. m_scientific = true;
  370. m_required_precision = m_precision;
  371. if (exponent <= -100 || exponent >= 100)
  372. {
  373. m_large_exponent = true;
  374. }
  375. }
  376. }
  377. if (math::abs(val) > m_max)
  378. {
  379. m_max = math::abs(val);
  380. }
  381. if (m_required_precision < m_precision)
  382. {
  383. while (std::floor(val * std::pow(10, m_required_precision))
  384. != val * std::pow(10, m_required_precision))
  385. {
  386. m_required_precision++;
  387. }
  388. }
  389. }
  390. m_cache.push_back(val);
  391. }
  392. std::streamsize width()
  393. {
  394. return m_width;
  395. }
  396. private:
  397. bool m_large_exponent = false;
  398. bool m_scientific = false;
  399. std::streamsize m_width = 9;
  400. std::streamsize m_precision;
  401. std::streamsize m_required_precision = 0;
  402. value_type m_max = 0;
  403. cache_type m_cache;
  404. cache_iterator m_it;
  405. };
  406. template <class T>
  407. struct printer<
  408. T,
  409. std::enable_if_t<
  410. xtl::is_integral<typename T::value_type>::value && !std::is_same<typename T::value_type, bool>::value>>
  411. {
  412. using value_type = std::decay_t<typename T::value_type>;
  413. using cache_type = std::vector<value_type>;
  414. using cache_iterator = typename cache_type::const_iterator;
  415. explicit printer(std::streamsize)
  416. {
  417. }
  418. void init()
  419. {
  420. m_it = m_cache.cbegin();
  421. m_width = 1 + std::streamsize((m_max > 0) ? std::log10(m_max) : 0) + m_sign;
  422. }
  423. std::ostream& print_next(std::ostream& out)
  424. {
  425. // + enables printing of chars etc. as numbers
  426. // TODO should chars be printed as numbers?
  427. out.width(m_width);
  428. out << +(*m_it);
  429. ++m_it;
  430. return out;
  431. }
  432. void update(const value_type& val)
  433. {
  434. if (math::abs(val) > m_max)
  435. {
  436. m_max = math::abs(val);
  437. }
  438. if (xtl::is_signed<value_type>::value && val < 0)
  439. {
  440. m_sign = true;
  441. }
  442. m_cache.push_back(val);
  443. }
  444. std::streamsize width()
  445. {
  446. return m_width;
  447. }
  448. private:
  449. std::streamsize m_width;
  450. bool m_sign = false;
  451. value_type m_max = 0;
  452. cache_type m_cache;
  453. cache_iterator m_it;
  454. };
  455. template <class T>
  456. struct printer<T, std::enable_if_t<std::is_same<typename T::value_type, bool>::value>>
  457. {
  458. using value_type = bool;
  459. using cache_type = std::vector<bool>;
  460. using cache_iterator = typename cache_type::const_iterator;
  461. explicit printer(std::streamsize)
  462. {
  463. }
  464. void init()
  465. {
  466. m_it = m_cache.cbegin();
  467. }
  468. std::ostream& print_next(std::ostream& out)
  469. {
  470. if (*m_it)
  471. {
  472. out << " true";
  473. }
  474. else
  475. {
  476. out << "false";
  477. }
  478. // TODO: the following std::setw(5) isn't working correctly on OSX.
  479. // out << std::boolalpha << std::setw(m_width) << (*m_it);
  480. ++m_it;
  481. return out;
  482. }
  483. void update(const value_type& val)
  484. {
  485. m_cache.push_back(val);
  486. }
  487. std::streamsize width()
  488. {
  489. return m_width;
  490. }
  491. private:
  492. std::streamsize m_width = 5;
  493. cache_type m_cache;
  494. cache_iterator m_it;
  495. };
  496. template <class T>
  497. struct printer<T, std::enable_if_t<xtl::is_complex<typename T::value_type>::value>>
  498. {
  499. using value_type = std::decay_t<typename T::value_type>;
  500. using cache_type = std::vector<bool>;
  501. using cache_iterator = typename cache_type::const_iterator;
  502. explicit printer(std::streamsize precision)
  503. : real_printer(precision)
  504. , imag_printer(precision)
  505. {
  506. }
  507. void init()
  508. {
  509. real_printer.init();
  510. imag_printer.init();
  511. m_it = m_signs.cbegin();
  512. }
  513. std::ostream& print_next(std::ostream& out)
  514. {
  515. real_printer.print_next(out);
  516. if (*m_it)
  517. {
  518. out << "-";
  519. }
  520. else
  521. {
  522. out << "+";
  523. }
  524. std::stringstream buf;
  525. imag_printer.print_next(buf);
  526. std::string s = buf.str();
  527. if (s[0] == ' ')
  528. {
  529. s.erase(0, 1); // erase space for +/-
  530. }
  531. // insert j at end of number
  532. std::size_t idx = s.find_last_not_of(" ");
  533. s.insert(idx + 1, "i");
  534. out << s;
  535. ++m_it;
  536. return out;
  537. }
  538. void update(const value_type& val)
  539. {
  540. real_printer.update(val.real());
  541. imag_printer.update(std::abs(val.imag()));
  542. m_signs.push_back(std::signbit(val.imag()));
  543. }
  544. std::streamsize width()
  545. {
  546. return real_printer.width() + imag_printer.width() + 2;
  547. }
  548. private:
  549. printer<value_type> real_printer, imag_printer;
  550. cache_type m_signs;
  551. cache_iterator m_it;
  552. };
  553. template <class T>
  554. struct printer<
  555. T,
  556. std::enable_if_t<
  557. !xtl::is_fundamental<typename T::value_type>::value && !xtl::is_complex<typename T::value_type>::value>>
  558. {
  559. using const_reference = typename T::const_reference;
  560. using value_type = std::decay_t<typename T::value_type>;
  561. using cache_type = std::vector<std::string>;
  562. using cache_iterator = typename cache_type::const_iterator;
  563. explicit printer(std::streamsize)
  564. {
  565. }
  566. void init()
  567. {
  568. m_it = m_cache.cbegin();
  569. if (m_width > 20)
  570. {
  571. m_width = 0;
  572. }
  573. }
  574. std::ostream& print_next(std::ostream& out)
  575. {
  576. out.width(m_width);
  577. out << *m_it;
  578. ++m_it;
  579. return out;
  580. }
  581. void update(const_reference val)
  582. {
  583. std::stringstream buf;
  584. buf << val;
  585. std::string s = buf.str();
  586. if (int(s.size()) > m_width)
  587. {
  588. m_width = std::streamsize(s.size());
  589. }
  590. m_cache.push_back(s);
  591. }
  592. std::streamsize width()
  593. {
  594. return m_width;
  595. }
  596. private:
  597. std::streamsize m_width = 0;
  598. cache_type m_cache;
  599. cache_iterator m_it;
  600. };
  601. template <class E>
  602. struct custom_formatter
  603. {
  604. using value_type = std::decay_t<typename E::value_type>;
  605. template <class F>
  606. custom_formatter(F&& func)
  607. : m_func(func)
  608. {
  609. }
  610. std::string operator()(const value_type& val) const
  611. {
  612. return m_func(val);
  613. }
  614. private:
  615. std::function<std::string(const value_type&)> m_func;
  616. };
  617. }
  618. inline print_options::print_options_impl get_print_options(std::ostream& out)
  619. {
  620. print_options::print_options_impl res;
  621. using print_options::edge_items;
  622. using print_options::line_width;
  623. using print_options::precision;
  624. using print_options::threshold;
  625. res.edge_items = static_cast<int>(out.iword(edge_items::id()));
  626. res.line_width = static_cast<int>(out.iword(line_width::id()));
  627. res.threshold = static_cast<int>(out.iword(threshold::id()));
  628. res.precision = static_cast<int>(out.iword(precision::id()));
  629. if (!res.edge_items)
  630. {
  631. res.edge_items = print_options::print_options().edge_items;
  632. }
  633. else
  634. {
  635. out.iword(edge_items::id()) = long(0);
  636. }
  637. if (!res.line_width)
  638. {
  639. res.line_width = print_options::print_options().line_width;
  640. }
  641. else
  642. {
  643. out.iword(line_width::id()) = long(0);
  644. }
  645. if (!res.threshold)
  646. {
  647. res.threshold = print_options::print_options().threshold;
  648. }
  649. else
  650. {
  651. out.iword(threshold::id()) = long(0);
  652. }
  653. if (!res.precision)
  654. {
  655. res.precision = print_options::print_options().precision;
  656. }
  657. else
  658. {
  659. out.iword(precision::id()) = long(0);
  660. }
  661. return res;
  662. }
  663. template <class E, class F>
  664. std::ostream& pretty_print(const xexpression<E>& e, F&& func, std::ostream& out = std::cout)
  665. {
  666. xfunction<detail::custom_formatter<E>, const_xclosure_t<E>> print_fun(
  667. detail::custom_formatter<E>(std::forward<F>(func)),
  668. e
  669. );
  670. return pretty_print(print_fun, out);
  671. }
  672. namespace detail
  673. {
  674. template <class S>
  675. class fmtflags_guard
  676. {
  677. public:
  678. explicit fmtflags_guard(S& stream)
  679. : m_stream(stream)
  680. , m_flags(stream.flags())
  681. {
  682. }
  683. ~fmtflags_guard()
  684. {
  685. m_stream.flags(m_flags);
  686. }
  687. private:
  688. S& m_stream;
  689. std::ios_base::fmtflags m_flags;
  690. };
  691. }
  692. template <class E>
  693. std::ostream& pretty_print(const xexpression<E>& e, std::ostream& out = std::cout)
  694. {
  695. detail::fmtflags_guard<std::ostream> guard(out);
  696. const E& d = e.derived_cast();
  697. std::size_t lim = 0;
  698. std::size_t sz = compute_size(d.shape());
  699. auto po = get_print_options(out);
  700. if (sz > static_cast<std::size_t>(po.threshold))
  701. {
  702. lim = static_cast<std::size_t>(po.edge_items);
  703. }
  704. if (sz == 0)
  705. {
  706. out << "{}";
  707. return out;
  708. }
  709. auto temp_precision = out.precision();
  710. auto precision = temp_precision;
  711. if (po.precision != -1)
  712. {
  713. out.precision(static_cast<std::streamsize>(po.precision));
  714. precision = static_cast<std::streamsize>(po.precision);
  715. }
  716. detail::printer<E> p(precision);
  717. xstrided_slice_vector sv;
  718. detail::recurser_run(p, d, sv, lim);
  719. p.init();
  720. sv.clear();
  721. xoutput(out, d, sv, p, 1, p.width(), lim, static_cast<std::size_t>(po.line_width));
  722. out.precision(temp_precision); // restore precision
  723. return out;
  724. }
  725. template <class E>
  726. inline std::ostream& operator<<(std::ostream& out, const xexpression<E>& e)
  727. {
  728. return pretty_print(e, out);
  729. }
  730. }
  731. #endif
  732. // Backward compatibility: include xmime.hpp in xio.hpp by default.
  733. #ifdef __CLING__
  734. #include "xmime.hpp"
  735. #endif