xtensor
 
Loading...
Searching...
No Matches
xio.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_IO_HPP
11#define XTENSOR_IO_HPP
12
13#include <complex>
14#include <cstddef>
15#include <iomanip>
16#include <iostream>
17#include <sstream>
18#include <string>
19
20#include "../core/xexpression.hpp"
21#include "../core/xmath.hpp"
22#include "../views/xstrided_view.hpp"
23#include "xtl/xmasked_value_meta.hpp"
24
25namespace xt
26{
27
28 template <class E>
29 inline std::ostream& operator<<(std::ostream& out, const xexpression<E>& e);
30
31 /*****************
32 * print options *
33 *****************/
34
35 namespace print_options
36 {
38 {
39 int edge_items = 3;
40 int line_width = 75;
41 int threshold = 1000;
42 int precision = -1; // default precision
43 };
44
45 inline print_options_impl& print_options()
46 {
47 static print_options_impl po;
48 return po;
49 }
50
57 inline void set_line_width(int line_width)
58 {
59 print_options().line_width = line_width;
60 }
61
68 inline void set_threshold(int threshold)
69 {
70 print_options().threshold = threshold;
71 }
72
80 inline void set_edge_items(int edge_items)
81 {
82 print_options().edge_items = edge_items;
83 }
84
90 inline void set_precision(int precision)
91 {
92 print_options().precision = precision;
93 }
94
95#define DEFINE_LOCAL_PRINT_OPTION(NAME) \
96 class NAME \
97 { \
98 public: \
99 \
100 NAME(int value) \
101 : m_value(value) \
102 { \
103 id(); \
104 } \
105 static int id() \
106 { \
107 static int id = std::ios_base::xalloc(); \
108 return id; \
109 } \
110 int value() const \
111 { \
112 return m_value; \
113 } \
114 \
115 private: \
116 \
117 int m_value; \
118 }; \
119 \
120 inline std::ostream& operator<<(std::ostream& out, const NAME& n) \
121 { \
122 out.iword(NAME::id()) = n.value(); \
123 return out; \
124 }
125
138 DEFINE_LOCAL_PRINT_OPTION(line_width)
139
140
152 DEFINE_LOCAL_PRINT_OPTION(threshold)
153
166 DEFINE_LOCAL_PRINT_OPTION(edge_items)
167
180 DEFINE_LOCAL_PRINT_OPTION(precision)
181 }
182
183 /**************************************
184 * xexpression ostream implementation *
185 **************************************/
186
187 namespace detail
188 {
189 template <class E, class F>
190 std::ostream& xoutput(
191 std::ostream& out,
192 const E& e,
193 xstrided_slice_vector& slices,
194 F& printer,
195 std::size_t blanks,
196 std::streamsize element_width,
197 std::size_t edgeitems,
198 std::size_t line_width
199 )
200 {
201 using size_type = typename E::size_type;
202
203 const auto view = xt::strided_view(e, slices);
204 if (view.dimension() == 0)
205 {
206 printer.print_next(out);
207 }
208 else
209 {
210 std::string indents(blanks, ' ');
211
212 size_type i = 0;
213 size_type elems_on_line = 0;
214 const size_type ewp2 = static_cast<size_type>(element_width) + size_type(2);
215 const size_type line_lim = static_cast<size_type>(std::floor(line_width / ewp2));
216
217 out << '{';
218 for (; i != size_type(view.shape()[0] - 1); ++i)
219 {
220 if (edgeitems && size_type(view.shape()[0]) > (edgeitems * 2) && i == edgeitems)
221 {
222 if (view.dimension() == 1 && line_lim != 0 && elems_on_line >= line_lim)
223 {
224 out << " ...,";
225 }
226 else if (view.dimension() > 1)
227 {
228 elems_on_line = 0;
229 out << "...," << std::endl << indents;
230 }
231 else
232 {
233 out << "..., ";
234 }
235 i = size_type(view.shape()[0]) - edgeitems;
236 }
237 if (view.dimension() == 1 && line_lim != 0 && elems_on_line >= line_lim)
238 {
239 out << std::endl << indents;
240 elems_on_line = 0;
241 }
242 slices.push_back(static_cast<int>(i));
243 xoutput(out, e, slices, printer, blanks + 1, element_width, edgeitems, line_width) << ',';
244 slices.pop_back();
245 elems_on_line++;
246
247 if ((view.dimension() == 1) && !(line_lim != 0 && elems_on_line >= line_lim))
248 {
249 out << ' ';
250 }
251 else if (view.dimension() > 1)
252 {
253 out << std::endl << indents;
254 }
255 }
256 if (view.dimension() == 1 && line_lim != 0 && elems_on_line >= line_lim)
257 {
258 out << std::endl << indents;
259 }
260 slices.push_back(static_cast<int>(i));
261 xoutput(out, e, slices, printer, blanks + 1, element_width, edgeitems, line_width) << '}';
262 slices.pop_back();
263 }
264 return out;
265 }
266
267 template <class F, class E>
268 void recurser_run(F& fn, const E& e, xstrided_slice_vector& slices, std::size_t lim = 0)
269 {
270 using size_type = typename E::size_type;
271 const auto view = strided_view(e, slices);
272 if (view.dimension() == 0)
273 {
274 fn.update(view());
275 }
276 else
277 {
278 size_type i = 0;
279 for (; i != static_cast<size_type>(view.shape()[0] - 1); ++i)
280 {
281 if (lim && size_type(view.shape()[0]) > (lim * 2) && i == lim)
282 {
283 i = static_cast<size_type>(view.shape()[0]) - lim;
284 }
285 slices.push_back(static_cast<int>(i));
286 recurser_run(fn, e, slices, lim);
287 slices.pop_back();
288 }
289 slices.push_back(static_cast<int>(i));
290 recurser_run(fn, e, slices, lim);
291 slices.pop_back();
292 }
293 }
294
295 template <class T, class E = void>
296 struct printer;
297
298 template <class T>
299 struct printer<T, std::enable_if_t<std::is_floating_point<typename T::value_type>::value>>
300 {
301 using value_type = std::decay_t<typename T::value_type>;
302 using cache_type = std::vector<value_type>;
303 using cache_iterator = typename cache_type::const_iterator;
304
305 explicit printer(std::streamsize precision)
306 : m_precision(precision)
307 {
308 }
309
310 void init()
311 {
312 m_precision = m_required_precision < m_precision ? m_required_precision : m_precision;
313 m_it = m_cache.cbegin();
314 if (m_scientific)
315 {
316 // 3 = sign, number and dot and 4 = "e+00"
317 m_width = m_precision + 7;
318 if (m_large_exponent)
319 {
320 // = e+000 (additional number)
321 m_width += 1;
322 }
323 }
324 else
325 {
326 std::streamsize decimals = 1; // print a leading 0
327 if (std::floor(m_max) != 0)
328 {
329 decimals += std::streamsize(std::log10(std::floor(m_max)));
330 }
331 // 2 => sign and dot
332 m_width = 2 + decimals + m_precision;
333 }
334 if (!m_required_precision)
335 {
336 --m_width;
337 }
338 }
339
340 std::ostream& print_next(std::ostream& out)
341 {
342 if (!m_scientific)
343 {
344 std::stringstream buf;
345 buf.width(m_width);
346 buf << std::fixed;
347 buf.precision(m_precision);
348 buf << (*m_it);
349 if (!m_required_precision && !std::isinf(*m_it) && !std::isnan(*m_it))
350 {
351 buf << '.';
352 }
353 std::string res = buf.str();
354 auto sit = res.rbegin();
355 while (*sit == '0')
356 {
357 *sit = ' ';
358 ++sit;
359 }
360 out << res;
361 }
362 else
363 {
364 if (!m_large_exponent)
365 {
366 out << std::scientific;
367 out.width(m_width);
368 out << (*m_it);
369 }
370 else
371 {
372 std::stringstream buf;
373 buf.width(m_width);
374 buf << std::scientific;
375 buf.precision(m_precision);
376 buf << (*m_it);
377 std::string res = buf.str();
378
379 if (res[res.size() - 4] == 'e')
380 {
381 res.erase(0, 1);
382 res.insert(res.size() - 2, "0");
383 }
384 out << res;
385 }
386 }
387 ++m_it;
388 return out;
389 }
390
391 void update(const value_type& val)
392 {
393 if (val != 0 && !std::isinf(val) && !std::isnan(val))
394 {
395 if (!m_scientific || !m_large_exponent)
396 {
397 int exponent = 1 + int(std::log10(math::abs(val)));
398 if (exponent <= -5 || exponent > 7)
399 {
400 m_scientific = true;
401 m_required_precision = m_precision;
402 if (exponent <= -100 || exponent >= 100)
403 {
404 m_large_exponent = true;
405 }
406 }
407 }
408 if (math::abs(val) > m_max)
409 {
410 m_max = math::abs(val);
411 }
412 if (m_required_precision < m_precision)
413 {
414 while (std::floor(val * std::pow(10, m_required_precision))
415 != val * std::pow(10, m_required_precision))
416 {
417 m_required_precision++;
418 }
419 }
420 }
421 m_cache.push_back(val);
422 }
423
424 std::streamsize width()
425 {
426 return m_width;
427 }
428
429 private:
430
431 bool m_large_exponent = false;
432 bool m_scientific = false;
433 std::streamsize m_width = 9;
434 std::streamsize m_precision;
435 std::streamsize m_required_precision = 0;
436 value_type m_max = 0;
437
438 cache_type m_cache;
439 cache_iterator m_it;
440 };
441
442 template <class T>
443 struct printer<
444 T,
445 std::enable_if_t<
446 xtl::is_integral<typename T::value_type>::value && !std::is_same<typename T::value_type, bool>::value>>
447 {
448 using value_type = std::decay_t<typename T::value_type>;
449 using cache_type = std::vector<value_type>;
450 using cache_iterator = typename cache_type::const_iterator;
451
452 explicit printer(std::streamsize)
453 {
454 }
455
456 void init()
457 {
458 m_it = m_cache.cbegin();
459 m_width = 1 + std::streamsize((m_max > 0) ? std::log10(m_max) : 0) + m_sign;
460 }
461
462 std::ostream& print_next(std::ostream& out)
463 {
464 // + enables printing of chars etc. as numbers
465 // TODO should chars be printed as numbers?
466 out.width(m_width);
467 out << +(*m_it);
468 ++m_it;
469 return out;
470 }
471
472 void update(const value_type& val)
473 {
474 if (math::abs(val) > m_max)
475 {
476 m_max = math::abs(val);
477 }
478 if (xtl::is_signed<value_type>::value && val < 0)
479 {
480 m_sign = true;
481 }
482 m_cache.push_back(val);
483 }
484
485 std::streamsize width()
486 {
487 return m_width;
488 }
489
490 private:
491
492 std::streamsize m_width;
493 bool m_sign = false;
494 value_type m_max = 0;
495
496 cache_type m_cache;
497 cache_iterator m_it;
498 };
499
500 template <class T>
501 struct printer<T, std::enable_if_t<std::is_same<typename T::value_type, bool>::value>>
502 {
503 using value_type = bool;
504 using cache_type = std::vector<bool>;
505 using cache_iterator = typename cache_type::const_iterator;
506
507 explicit printer(std::streamsize)
508 {
509 }
510
511 void init()
512 {
513 m_it = m_cache.cbegin();
514 }
515
516 std::ostream& print_next(std::ostream& out)
517 {
518 if (*m_it)
519 {
520 out << " true";
521 }
522 else
523 {
524 out << "false";
525 }
526 // TODO: the following std::setw(5) isn't working correctly on OSX.
527 // out << std::boolalpha << std::setw(m_width) << (*m_it);
528 ++m_it;
529 return out;
530 }
531
532 void update(const value_type& val)
533 {
534 m_cache.push_back(val);
535 }
536
537 std::streamsize width()
538 {
539 return m_width;
540 }
541
542 private:
543
544 std::streamsize m_width = 5;
545
546 cache_type m_cache;
547 cache_iterator m_it;
548 };
549
550 template <class T>
551 struct printer<T, std::enable_if_t<xtl::is_complex<typename T::value_type>::value>>
552 {
553 using value_type = std::decay_t<typename T::value_type>;
554 using cache_type = std::vector<bool>;
555 using cache_iterator = typename cache_type::const_iterator;
556
557 explicit printer(std::streamsize precision)
558 : real_printer(precision)
559 , imag_printer(precision)
560 {
561 }
562
563 void init()
564 {
565 real_printer.init();
566 imag_printer.init();
567 m_it = m_signs.cbegin();
568 }
569
570 std::ostream& print_next(std::ostream& out)
571 {
572 real_printer.print_next(out);
573 if (*m_it)
574 {
575 out << "-";
576 }
577 else
578 {
579 out << "+";
580 }
581 std::stringstream buf;
582 imag_printer.print_next(buf);
583 std::string s = buf.str();
584 if (s[0] == ' ')
585 {
586 s.erase(0, 1); // erase space for +/-
587 }
588 // insert j at end of number
589 std::size_t idx = s.find_last_not_of(" ");
590 s.insert(idx + 1, "i");
591 out << s;
592 ++m_it;
593 return out;
594 }
595
596 void update(const value_type& val)
597 {
598 real_printer.update(val.real());
599 imag_printer.update(std::abs(val.imag()));
600 m_signs.push_back(std::signbit(val.imag()));
601 }
602
603 std::streamsize width()
604 {
605 return real_printer.width() + imag_printer.width() + 2;
606 }
607
608 private:
609
610 printer<value_type> real_printer, imag_printer;
611 cache_type m_signs;
612 cache_iterator m_it;
613 };
614
615 template <class T>
616 struct printer<
617 T,
618 std::enable_if_t<
619 !xtl::is_fundamental<typename T::value_type>::value && !xtl::is_complex<typename T::value_type>::value>>
620 {
621 using const_reference = typename T::const_reference;
622 using value_type = std::decay_t<typename T::value_type>;
623 using cache_type = std::vector<std::string>;
624 using cache_iterator = typename cache_type::const_iterator;
625
626 explicit printer(std::streamsize)
627 {
628 }
629
630 void init()
631 {
632 m_it = m_cache.cbegin();
633 if (m_width > 20)
634 {
635 m_width = 0;
636 }
637 }
638
639 std::ostream& print_next(std::ostream& out)
640 {
641 out.width(m_width);
642 out << *m_it;
643 ++m_it;
644 return out;
645 }
646
647 void update(const_reference val)
648 {
649 std::stringstream buf;
650 if constexpr (xtl::is_xmasked_value<value_type>::value)
651 {
652 buf << +val;
653 }
654 else
655 {
656 buf << val;
657 }
658 std::string s = buf.str();
659 if (int(s.size()) > m_width)
660 {
661 m_width = std::streamsize(s.size());
662 }
663 m_cache.push_back(s);
664 }
665
666 std::streamsize width()
667 {
668 return m_width;
669 }
670
671 private:
672
673 std::streamsize m_width = 0;
674 cache_type m_cache;
675 cache_iterator m_it;
676 };
677
678 template <class E>
679 struct custom_formatter
680 {
681 using value_type = std::decay_t<typename E::value_type>;
682
683 template <class F>
684 custom_formatter(F&& func)
685 : m_func(func)
686 {
687 }
688
689 std::string operator()(const value_type& val) const
690 {
691 return m_func(val);
692 }
693
694 private:
695
696 std::function<std::string(const value_type&)> m_func;
697 };
698 }
699
700 inline print_options::print_options_impl get_print_options(std::ostream& out)
701 {
707
708 res.edge_items = static_cast<int>(out.iword(edge_items::id()));
709 res.line_width = static_cast<int>(out.iword(line_width::id()));
710 res.threshold = static_cast<int>(out.iword(threshold::id()));
711 res.precision = static_cast<int>(out.iword(precision::id()));
712
713 if (!res.edge_items)
714 {
715 res.edge_items = print_options::print_options().edge_items;
716 }
717 else
718 {
719 out.iword(edge_items::id()) = long(0);
720 }
721 if (!res.line_width)
722 {
723 res.line_width = print_options::print_options().line_width;
724 }
725 else
726 {
727 out.iword(line_width::id()) = long(0);
728 }
729 if (!res.threshold)
730 {
731 res.threshold = print_options::print_options().threshold;
732 }
733 else
734 {
735 out.iword(threshold::id()) = long(0);
736 }
737 if (!res.precision)
738 {
739 res.precision = print_options::print_options().precision;
740 }
741 else
742 {
743 out.iword(precision::id()) = long(0);
744 }
745
746 return res;
747 }
748
749 template <class E, class F>
750 std::ostream& pretty_print(const xexpression<E>& e, F&& func, std::ostream& out = std::cout)
751 {
752 xfunction<detail::custom_formatter<E>, const_xclosure_t<E>> print_fun(
753 detail::custom_formatter<E>(std::forward<F>(func)),
754 e
755 );
756 return pretty_print(print_fun, out);
757 }
758
759 namespace detail
760 {
761 template <class S>
762 class fmtflags_guard
763 {
764 public:
765
766 explicit fmtflags_guard(S& stream)
767 : m_stream(stream)
768 , m_flags(stream.flags())
769 {
770 }
771
772 ~fmtflags_guard()
773 {
774 m_stream.flags(m_flags);
775 }
776
777 private:
778
779 S& m_stream;
780 std::ios_base::fmtflags m_flags;
781 };
782 }
783
784 template <class E>
785 std::ostream& pretty_print(const xexpression<E>& e, std::ostream& out = std::cout)
786 {
787 detail::fmtflags_guard<std::ostream> guard(out);
788
789 const E& d = e.derived_cast();
790
791 std::size_t lim = 0;
792 std::size_t sz = compute_size(d.shape());
793
794 auto po = get_print_options(out);
795
796 if (sz > static_cast<std::size_t>(po.threshold))
797 {
798 lim = static_cast<std::size_t>(po.edge_items);
799 }
800 if (sz == 0)
801 {
802 out << "{}";
803 return out;
804 }
805
806 auto temp_precision = out.precision();
807 auto precision = temp_precision;
808 if (po.precision != -1)
809 {
810 out.precision(static_cast<std::streamsize>(po.precision));
811 precision = static_cast<std::streamsize>(po.precision);
812 }
813
814 detail::printer<E> p(precision);
815
817 detail::recurser_run(p, d, sv, lim);
818 p.init();
819 sv.clear();
820 xoutput(out, d, sv, p, 1, p.width(), lim, static_cast<std::size_t>(po.line_width));
821
822 out.precision(temp_precision); // restore precision
823
824 return out;
825 }
826
827 template <class E>
828 inline std::ostream& operator<<(std::ostream& out, const xexpression<E>& e)
829 {
830 return pretty_print(e, out);
831 }
832}
833#endif
834
835// Backward compatibility: include xmime.hpp in xio.hpp by default.
836
837#if defined(__CLING__) || defined(__CLANG_REPL__)
838#include "xmime.hpp"
839#endif
io manipulator used to set the number of egde items if the summarization is triggered.
Definition xio.hpp:166
io manipulator used to set the width of the lines when printing an expression.
Definition xio.hpp:138
io manipulator used to set the precision of the floating point values when printing an expression.
Definition xio.hpp:180
io manipulator used to set the threshold after which summarization is triggered.
Definition xio.hpp:152
Base class for xexpressions.
auto operator<<(E1 &&e1, E2 &&e2) noexcept -> detail::shift_return_type_t< detail::left_shift, E1, E2 >
Bitwise left shift.
standard mathematical functions for xexpressions
std::vector< xstrided_slice< std::ptrdiff_t > > xstrided_slice_vector
vector of slices used to build a xstrided_view
auto strided_view(E &&e, S &&shape, X &&stride, std::size_t offset=0, layout_type layout=L) noexcept
Construct a strided view from an xexpression, shape, strides and offset.
auto view(E &&e, S &&... slices)
Constructs and returns a view on the specified xexpression.
Definition xview.hpp:1824