xtensor
Loading...
Searching...
No Matches
xio.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_IO_HPP
11#define XTENSOR_IO_HPP
12
13#include <complex>
14#include <cstddef>
15#include <iomanip>
16#include <iostream>
17#include <numeric>
18#include <sstream>
19#include <string>
20
21#include "xexpression.hpp"
22#include "xmath.hpp"
23#include "xstrided_view.hpp"
24
25namespace xt
26{
27
28 template <class E>
29 inline std::ostream& operator<<(std::ostream& out, const xexpression<E>& e);
30
31 /*****************
32 * print options *
33 *****************/
34
35 namespace print_options
36 {
38 {
39 int edge_items = 3;
40 int line_width = 75;
41 int threshold = 1000;
42 int precision = -1; // default precision
43 };
44
45 inline print_options_impl& print_options()
46 {
48 return po;
49 }
50
57 inline void set_line_width(int line_width)
58 {
59 print_options().line_width = line_width;
60 }
61
68 inline void set_threshold(int threshold)
69 {
70 print_options().threshold = threshold;
71 }
72
80 inline void set_edge_items(int edge_items)
81 {
82 print_options().edge_items = edge_items;
83 }
84
90 inline void set_precision(int precision)
91 {
92 print_options().precision = precision;
93 }
94
95#define DEFINE_LOCAL_PRINT_OPTION(NAME) \
96 class NAME \
97 { \
98 public: \
99 \
100 NAME(int value) \
101 : m_value(value) \
102 { \
103 id(); \
104 } \
105 static int id() \
106 { \
107 static int id = std::ios_base::xalloc(); \
108 return id; \
109 } \
110 int value() const \
111 { \
112 return m_value; \
113 } \
114 \
115 private: \
116 \
117 int m_value; \
118 }; \
119 \
120 inline std::ostream& operator<<(std::ostream& out, const NAME& n) \
121 { \
122 out.iword(NAME::id()) = n.value(); \
123 return out; \
124 }
125
138 DEFINE_LOCAL_PRINT_OPTION(line_width)
139
140
152 DEFINE_LOCAL_PRINT_OPTION(threshold)
153
166 DEFINE_LOCAL_PRINT_OPTION(edge_items)
167
180 DEFINE_LOCAL_PRINT_OPTION(precision)
181 }
182
183 /**************************************
184 * xexpression ostream implementation *
185 **************************************/
186
187 namespace detail
188 {
189 template <class E, class F>
190 std::ostream& xoutput(
191 std::ostream& out,
192 const E& e,
193 xstrided_slice_vector& slices,
194 F& printer,
195 std::size_t blanks,
196 std::streamsize element_width,
197 std::size_t edgeitems,
198 std::size_t line_width
199 )
200 {
201 using size_type = typename E::size_type;
202
203 const auto view = xt::strided_view(e, slices);
204 if (view.dimension() == 0)
205 {
206 printer.print_next(out);
207 }
208 else
209 {
210 std::string indents(blanks, ' ');
211
212 size_type i = 0;
213 size_type elems_on_line = 0;
214 const size_type ewp2 = static_cast<size_type>(element_width) + size_type(2);
215 const size_type line_lim = static_cast<size_type>(std::floor(line_width / ewp2));
216
217 out << '{';
218 for (; i != size_type(view.shape()[0] - 1); ++i)
219 {
220 if (edgeitems && size_type(view.shape()[0]) > (edgeitems * 2) && i == edgeitems)
221 {
222 if (view.dimension() == 1 && line_lim != 0 && elems_on_line >= line_lim)
223 {
224 out << " ...,";
225 }
226 else if (view.dimension() > 1)
227 {
228 elems_on_line = 0;
229 out << "...," << std::endl << indents;
230 }
231 else
232 {
233 out << "..., ";
234 }
235 i = size_type(view.shape()[0]) - edgeitems;
236 }
237 if (view.dimension() == 1 && line_lim != 0 && elems_on_line >= line_lim)
238 {
239 out << std::endl << indents;
240 elems_on_line = 0;
241 }
242 slices.push_back(static_cast<int>(i));
243 xoutput(out, e, slices, printer, blanks + 1, element_width, edgeitems, line_width) << ',';
244 slices.pop_back();
246
247 if ((view.dimension() == 1) && !(line_lim != 0 && elems_on_line >= line_lim))
248 {
249 out << ' ';
250 }
251 else if (view.dimension() > 1)
252 {
253 out << std::endl << indents;
254 }
255 }
256 if (view.dimension() == 1 && line_lim != 0 && elems_on_line >= line_lim)
257 {
258 out << std::endl << indents;
259 }
260 slices.push_back(static_cast<int>(i));
261 xoutput(out, e, slices, printer, blanks + 1, element_width, edgeitems, line_width) << '}';
262 slices.pop_back();
263 }
264 return out;
265 }
266
267 template <class F, class E>
268 void recurser_run(F& fn, const E& e, xstrided_slice_vector& slices, std::size_t lim = 0)
269 {
270 using size_type = typename E::size_type;
271 const auto view = strided_view(e, slices);
272 if (view.dimension() == 0)
273 {
274 fn.update(view());
275 }
276 else
277 {
278 size_type i = 0;
279 for (; i != static_cast<size_type>(view.shape()[0] - 1); ++i)
280 {
281 if (lim && size_type(view.shape()[0]) > (lim * 2) && i == lim)
282 {
283 i = static_cast<size_type>(view.shape()[0]) - lim;
284 }
285 slices.push_back(static_cast<int>(i));
286 recurser_run(fn, e, slices, lim);
287 slices.pop_back();
288 }
289 slices.push_back(static_cast<int>(i));
290 recurser_run(fn, e, slices, lim);
291 slices.pop_back();
292 }
293 }
294
295 template <class T, class E = void>
296 struct printer;
297
298 template <class T>
299 struct printer<T, std::enable_if_t<std::is_floating_point<typename T::value_type>::value>>
300 {
301 using value_type = std::decay_t<typename T::value_type>;
302 using cache_type = std::vector<value_type>;
303 using cache_iterator = typename cache_type::const_iterator;
304
305 explicit printer(std::streamsize precision)
306 : m_precision(precision)
307 {
308 }
309
310 void init()
311 {
312 m_precision = m_required_precision < m_precision ? m_required_precision : m_precision;
313 m_it = m_cache.cbegin();
314 if (m_scientific)
315 {
316 // 3 = sign, number and dot and 4 = "e+00"
317 m_width = m_precision + 7;
318 if (m_large_exponent)
319 {
320 // = e+000 (additional number)
321 m_width += 1;
322 }
323 }
324 else
325 {
326 std::streamsize decimals = 1; // print a leading 0
327 if (std::floor(m_max) != 0)
328 {
329 decimals += std::streamsize(std::log10(std::floor(m_max)));
330 }
331 // 2 => sign and dot
332 m_width = 2 + decimals + m_precision;
333 }
334 if (!m_required_precision)
335 {
336 --m_width;
337 }
338 }
339
340 std::ostream& print_next(std::ostream& out)
341 {
342 if (!m_scientific)
343 {
344 std::stringstream buf;
345 buf.width(m_width);
346 buf << std::fixed;
347 buf.precision(m_precision);
348 buf << (*m_it);
349 if (!m_required_precision && !std::isinf(*m_it) && !std::isnan(*m_it))
350 {
351 buf << '.';
352 }
353 std::string res = buf.str();
354 auto sit = res.rbegin();
355 while (*sit == '0')
356 {
357 *sit = ' ';
358 ++sit;
359 }
360 out << res;
361 }
362 else
363 {
364 if (!m_large_exponent)
365 {
366 out << std::scientific;
367 out.width(m_width);
368 out << (*m_it);
369 }
370 else
371 {
372 std::stringstream buf;
373 buf.width(m_width);
374 buf << std::scientific;
375 buf.precision(m_precision);
376 buf << (*m_it);
377 std::string res = buf.str();
378
379 if (res[res.size() - 4] == 'e')
380 {
381 res.erase(0, 1);
382 res.insert(res.size() - 2, "0");
383 }
384 out << res;
385 }
386 }
387 ++m_it;
388 return out;
389 }
390
391 void update(const value_type& val)
392 {
393 if (val != 0 && !std::isinf(val) && !std::isnan(val))
394 {
395 if (!m_scientific || !m_large_exponent)
396 {
397 int exponent = 1 + int(std::log10(math::abs(val)));
399 {
400 m_scientific = true;
401 m_required_precision = m_precision;
403 {
404 m_large_exponent = true;
405 }
406 }
407 }
408 if (math::abs(val) > m_max)
409 {
410 m_max = math::abs(val);
411 }
412 if (m_required_precision < m_precision)
413 {
414 while (std::floor(val * std::pow(10, m_required_precision))
415 != val * std::pow(10, m_required_precision))
416 {
417 m_required_precision++;
418 }
419 }
420 }
421 m_cache.push_back(val);
422 }
423
424 std::streamsize width()
425 {
426 return m_width;
427 }
428
429 private:
430
431 bool m_large_exponent = false;
432 bool m_scientific = false;
433 std::streamsize m_width = 9;
434 std::streamsize m_precision;
435 std::streamsize m_required_precision = 0;
436 value_type m_max = 0;
437
438 cache_type m_cache;
439 cache_iterator m_it;
440 };
441
442 template <class T>
443 struct printer<
444 T,
445 std::enable_if_t<
446 xtl::is_integral<typename T::value_type>::value && !std::is_same<typename T::value_type, bool>::value>>
447 {
448 using value_type = std::decay_t<typename T::value_type>;
449 using cache_type = std::vector<value_type>;
450 using cache_iterator = typename cache_type::const_iterator;
451
452 explicit printer(std::streamsize)
453 {
454 }
455
456 void init()
457 {
458 m_it = m_cache.cbegin();
459 m_width = 1 + std::streamsize((m_max > 0) ? std::log10(m_max) : 0) + m_sign;
460 }
461
462 std::ostream& print_next(std::ostream& out)
463 {
464 // + enables printing of chars etc. as numbers
465 // TODO should chars be printed as numbers?
466 out.width(m_width);
467 out << +(*m_it);
468 ++m_it;
469 return out;
470 }
471
472 void update(const value_type& val)
473 {
474 if (math::abs(val) > m_max)
475 {
476 m_max = math::abs(val);
477 }
478 if (xtl::is_signed<value_type>::value && val < 0)
479 {
480 m_sign = true;
481 }
482 m_cache.push_back(val);
483 }
484
485 std::streamsize width()
486 {
487 return m_width;
488 }
489
490 private:
491
492 std::streamsize m_width;
493 bool m_sign = false;
494 value_type m_max = 0;
495
496 cache_type m_cache;
497 cache_iterator m_it;
498 };
499
500 template <class T>
501 struct printer<T, std::enable_if_t<std::is_same<typename T::value_type, bool>::value>>
502 {
503 using value_type = bool;
504 using cache_type = std::vector<bool>;
505 using cache_iterator = typename cache_type::const_iterator;
506
507 explicit printer(std::streamsize)
508 {
509 }
510
511 void init()
512 {
513 m_it = m_cache.cbegin();
514 }
515
516 std::ostream& print_next(std::ostream& out)
517 {
518 if (*m_it)
519 {
520 out << " true";
521 }
522 else
523 {
524 out << "false";
525 }
526 // TODO: the following std::setw(5) isn't working correctly on OSX.
527 // out << std::boolalpha << std::setw(m_width) << (*m_it);
528 ++m_it;
529 return out;
530 }
531
532 void update(const value_type& val)
533 {
534 m_cache.push_back(val);
535 }
536
537 std::streamsize width()
538 {
539 return m_width;
540 }
541
542 private:
543
544 std::streamsize m_width = 5;
545
546 cache_type m_cache;
547 cache_iterator m_it;
548 };
549
550 template <class T>
551 struct printer<T, std::enable_if_t<xtl::is_complex<typename T::value_type>::value>>
552 {
553 using value_type = std::decay_t<typename T::value_type>;
554 using cache_type = std::vector<bool>;
555 using cache_iterator = typename cache_type::const_iterator;
556
557 explicit printer(std::streamsize precision)
558 : real_printer(precision)
559 , imag_printer(precision)
560 {
561 }
562
563 void init()
564 {
565 real_printer.init();
566 imag_printer.init();
567 m_it = m_signs.cbegin();
568 }
569
570 std::ostream& print_next(std::ostream& out)
571 {
572 real_printer.print_next(out);
573 if (*m_it)
574 {
575 out << "-";
576 }
577 else
578 {
579 out << "+";
580 }
581 std::stringstream buf;
582 imag_printer.print_next(buf);
583 std::string s = buf.str();
584 if (s[0] == ' ')
585 {
586 s.erase(0, 1); // erase space for +/-
587 }
588 // insert j at end of number
589 std::size_t idx = s.find_last_not_of(" ");
590 s.insert(idx + 1, "i");
591 out << s;
592 ++m_it;
593 return out;
594 }
595
596 void update(const value_type& val)
597 {
598 real_printer.update(val.real());
599 imag_printer.update(std::abs(val.imag()));
600 m_signs.push_back(std::signbit(val.imag()));
601 }
602
603 std::streamsize width()
604 {
605 return real_printer.width() + imag_printer.width() + 2;
606 }
607
608 private:
609
610 printer<value_type> real_printer, imag_printer;
611 cache_type m_signs;
612 cache_iterator m_it;
613 };
614
615 template <class T>
616 struct printer<
617 T,
618 std::enable_if_t<
619 !xtl::is_fundamental<typename T::value_type>::value && !xtl::is_complex<typename T::value_type>::value>>
620 {
621 using const_reference = typename T::const_reference;
622 using value_type = std::decay_t<typename T::value_type>;
623 using cache_type = std::vector<std::string>;
624 using cache_iterator = typename cache_type::const_iterator;
625
626 explicit printer(std::streamsize)
627 {
628 }
629
630 void init()
631 {
632 m_it = m_cache.cbegin();
633 if (m_width > 20)
634 {
635 m_width = 0;
636 }
637 }
638
639 std::ostream& print_next(std::ostream& out)
640 {
641 out.width(m_width);
642 out << *m_it;
643 ++m_it;
644 return out;
645 }
646
647 void update(const_reference val)
648 {
649 std::stringstream buf;
650 buf << val;
651 std::string s = buf.str();
652 if (int(s.size()) > m_width)
653 {
654 m_width = std::streamsize(s.size());
655 }
656 m_cache.push_back(s);
657 }
658
659 std::streamsize width()
660 {
661 return m_width;
662 }
663
664 private:
665
666 std::streamsize m_width = 0;
667 cache_type m_cache;
668 cache_iterator m_it;
669 };
670
671 template <class E>
672 struct custom_formatter
673 {
674 using value_type = std::decay_t<typename E::value_type>;
675
676 template <class F>
677 custom_formatter(F&& func)
678 : m_func(func)
679 {
680 }
681
682 std::string operator()(const value_type& val) const
683 {
684 return m_func(val);
685 }
686
687 private:
688
689 std::function<std::string(const value_type&)> m_func;
690 };
691 }
692
693 inline print_options::print_options_impl get_print_options(std::ostream& out)
694 {
700
701 res.edge_items = static_cast<int>(out.iword(edge_items::id()));
702 res.line_width = static_cast<int>(out.iword(line_width::id()));
703 res.threshold = static_cast<int>(out.iword(threshold::id()));
704 res.precision = static_cast<int>(out.iword(precision::id()));
705
706 if (!res.edge_items)
707 {
708 res.edge_items = print_options::print_options().edge_items;
709 }
710 else
711 {
712 out.iword(edge_items::id()) = long(0);
713 }
714 if (!res.line_width)
715 {
716 res.line_width = print_options::print_options().line_width;
717 }
718 else
719 {
720 out.iword(line_width::id()) = long(0);
721 }
722 if (!res.threshold)
723 {
724 res.threshold = print_options::print_options().threshold;
725 }
726 else
727 {
728 out.iword(threshold::id()) = long(0);
729 }
730 if (!res.precision)
731 {
732 res.precision = print_options::print_options().precision;
733 }
734 else
735 {
736 out.iword(precision::id()) = long(0);
737 }
738
739 return res;
740 }
741
742 template <class E, class F>
743 std::ostream& pretty_print(const xexpression<E>& e, F&& func, std::ostream& out = std::cout)
744 {
745 xfunction<detail::custom_formatter<E>, const_xclosure_t<E>> print_fun(
746 detail::custom_formatter<E>(std::forward<F>(func)),
747 e
748 );
749 return pretty_print(print_fun, out);
750 }
751
752 namespace detail
753 {
754 template <class S>
755 class fmtflags_guard
756 {
757 public:
758
759 explicit fmtflags_guard(S& stream)
760 : m_stream(stream)
761 , m_flags(stream.flags())
762 {
763 }
764
765 ~fmtflags_guard()
766 {
767 m_stream.flags(m_flags);
768 }
769
770 private:
771
772 S& m_stream;
773 std::ios_base::fmtflags m_flags;
774 };
775 }
776
777 template <class E>
778 std::ostream& pretty_print(const xexpression<E>& e, std::ostream& out = std::cout)
779 {
780 detail::fmtflags_guard<std::ostream> guard(out);
781
782 const E& d = e.derived_cast();
783
784 std::size_t lim = 0;
785 std::size_t sz = compute_size(d.shape());
786
787 auto po = get_print_options(out);
788
789 if (sz > static_cast<std::size_t>(po.threshold))
790 {
791 lim = static_cast<std::size_t>(po.edge_items);
792 }
793 if (sz == 0)
794 {
795 out << "{}";
796 return out;
797 }
798
799 auto temp_precision = out.precision();
800 auto precision = temp_precision;
801 if (po.precision != -1)
802 {
803 out.precision(static_cast<std::streamsize>(po.precision));
804 precision = static_cast<std::streamsize>(po.precision);
805 }
806
807 detail::printer<E> p(precision);
808
810 detail::recurser_run(p, d, sv, lim);
811 p.init();
812 sv.clear();
813 xoutput(out, d, sv, p, 1, p.width(), lim, static_cast<std::size_t>(po.line_width));
814
815 out.precision(temp_precision); // restore precision
816
817 return out;
818 }
819
820 template <class E>
821 inline std::ostream& operator<<(std::ostream& out, const xexpression<E>& e)
822 {
823 return pretty_print(e, out);
824 }
825}
826#endif
827
828// Backward compatibility: include xmime.hpp in xio.hpp by default.
829
830#ifdef __CLING__
831#include "xmime.hpp"
832#endif
io manipulator used to set the number of egde items if the summarization is triggered.
Definition xio.hpp:166
io manipulator used to set the width of the lines when printing an expression.
Definition xio.hpp:138
io manipulator used to set the precision of the floating point values when printing an expression.
Definition xio.hpp:180
io manipulator used to set the threshold after which summarization is triggered.
Definition xio.hpp:152
standard mathematical functions for xexpressions
std::vector< xstrided_slice< std::ptrdiff_t > > xstrided_slice_vector
vector of slices used to build a xstrided_view
auto strided_view(E &&e, S &&shape, X &&stride, std::size_t offset=0, layout_type layout=L) noexcept
Construct a strided view from an xexpression, shape, strides and offset.
auto view(E &&e, S &&... slices)
Constructs and returns a view on the specified xexpression.
Definition xview.hpp:1834