xtensor
 
Loading...
Searching...
No Matches
xio.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_IO_HPP
11#define XTENSOR_IO_HPP
12
13#include <complex>
14#include <cstddef>
15#include <iomanip>
16#include <iostream>
17#include <sstream>
18#include <string>
19
20#include "../core/xexpression.hpp"
21#include "../core/xmath.hpp"
22#include "../views/xstrided_view.hpp"
23
24namespace xt
25{
26
27 template <class E>
28 inline std::ostream& operator<<(std::ostream& out, const xexpression<E>& e);
29
30 /*****************
31 * print options *
32 *****************/
33
34 namespace print_options
35 {
37 {
38 int edge_items = 3;
39 int line_width = 75;
40 int threshold = 1000;
41 int precision = -1; // default precision
42 };
43
44 inline print_options_impl& print_options()
45 {
46 static print_options_impl po;
47 return po;
48 }
49
56 inline void set_line_width(int line_width)
57 {
58 print_options().line_width = line_width;
59 }
60
67 inline void set_threshold(int threshold)
68 {
69 print_options().threshold = threshold;
70 }
71
79 inline void set_edge_items(int edge_items)
80 {
81 print_options().edge_items = edge_items;
82 }
83
89 inline void set_precision(int precision)
90 {
91 print_options().precision = precision;
92 }
93
94#define DEFINE_LOCAL_PRINT_OPTION(NAME) \
95 class NAME \
96 { \
97 public: \
98 \
99 NAME(int value) \
100 : m_value(value) \
101 { \
102 id(); \
103 } \
104 static int id() \
105 { \
106 static int id = std::ios_base::xalloc(); \
107 return id; \
108 } \
109 int value() const \
110 { \
111 return m_value; \
112 } \
113 \
114 private: \
115 \
116 int m_value; \
117 }; \
118 \
119 inline std::ostream& operator<<(std::ostream& out, const NAME& n) \
120 { \
121 out.iword(NAME::id()) = n.value(); \
122 return out; \
123 }
124
137 DEFINE_LOCAL_PRINT_OPTION(line_width)
138
139
151 DEFINE_LOCAL_PRINT_OPTION(threshold)
152
165 DEFINE_LOCAL_PRINT_OPTION(edge_items)
166
179 DEFINE_LOCAL_PRINT_OPTION(precision)
180 }
181
182 /**************************************
183 * xexpression ostream implementation *
184 **************************************/
185
186 namespace detail
187 {
188 template <class E, class F>
189 std::ostream& xoutput(
190 std::ostream& out,
191 const E& e,
192 xstrided_slice_vector& slices,
193 F& printer,
194 std::size_t blanks,
195 std::streamsize element_width,
196 std::size_t edgeitems,
197 std::size_t line_width
198 )
199 {
200 using size_type = typename E::size_type;
201
202 const auto view = xt::strided_view(e, slices);
203 if (view.dimension() == 0)
204 {
205 printer.print_next(out);
206 }
207 else
208 {
209 std::string indents(blanks, ' ');
210
211 size_type i = 0;
212 size_type elems_on_line = 0;
213 const size_type ewp2 = static_cast<size_type>(element_width) + size_type(2);
214 const size_type line_lim = static_cast<size_type>(std::floor(line_width / ewp2));
215
216 out << '{';
217 for (; i != size_type(view.shape()[0] - 1); ++i)
218 {
219 if (edgeitems && size_type(view.shape()[0]) > (edgeitems * 2) && i == edgeitems)
220 {
221 if (view.dimension() == 1 && line_lim != 0 && elems_on_line >= line_lim)
222 {
223 out << " ...,";
224 }
225 else if (view.dimension() > 1)
226 {
227 elems_on_line = 0;
228 out << "...," << std::endl << indents;
229 }
230 else
231 {
232 out << "..., ";
233 }
234 i = size_type(view.shape()[0]) - edgeitems;
235 }
236 if (view.dimension() == 1 && line_lim != 0 && elems_on_line >= line_lim)
237 {
238 out << std::endl << indents;
239 elems_on_line = 0;
240 }
241 slices.push_back(static_cast<int>(i));
242 xoutput(out, e, slices, printer, blanks + 1, element_width, edgeitems, line_width) << ',';
243 slices.pop_back();
244 elems_on_line++;
245
246 if ((view.dimension() == 1) && !(line_lim != 0 && elems_on_line >= line_lim))
247 {
248 out << ' ';
249 }
250 else if (view.dimension() > 1)
251 {
252 out << std::endl << indents;
253 }
254 }
255 if (view.dimension() == 1 && line_lim != 0 && elems_on_line >= line_lim)
256 {
257 out << std::endl << indents;
258 }
259 slices.push_back(static_cast<int>(i));
260 xoutput(out, e, slices, printer, blanks + 1, element_width, edgeitems, line_width) << '}';
261 slices.pop_back();
262 }
263 return out;
264 }
265
266 template <class F, class E>
267 void recurser_run(F& fn, const E& e, xstrided_slice_vector& slices, std::size_t lim = 0)
268 {
269 using size_type = typename E::size_type;
270 const auto view = strided_view(e, slices);
271 if (view.dimension() == 0)
272 {
273 fn.update(view());
274 }
275 else
276 {
277 size_type i = 0;
278 for (; i != static_cast<size_type>(view.shape()[0] - 1); ++i)
279 {
280 if (lim && size_type(view.shape()[0]) > (lim * 2) && i == lim)
281 {
282 i = static_cast<size_type>(view.shape()[0]) - lim;
283 }
284 slices.push_back(static_cast<int>(i));
285 recurser_run(fn, e, slices, lim);
286 slices.pop_back();
287 }
288 slices.push_back(static_cast<int>(i));
289 recurser_run(fn, e, slices, lim);
290 slices.pop_back();
291 }
292 }
293
294 template <class T, class E = void>
295 struct printer;
296
297 template <class T>
298 struct printer<T, std::enable_if_t<std::is_floating_point<typename T::value_type>::value>>
299 {
300 using value_type = std::decay_t<typename T::value_type>;
301 using cache_type = std::vector<value_type>;
302 using cache_iterator = typename cache_type::const_iterator;
303
304 explicit printer(std::streamsize precision)
305 : m_precision(precision)
306 {
307 }
308
309 void init()
310 {
311 m_precision = m_required_precision < m_precision ? m_required_precision : m_precision;
312 m_it = m_cache.cbegin();
313 if (m_scientific)
314 {
315 // 3 = sign, number and dot and 4 = "e+00"
316 m_width = m_precision + 7;
317 if (m_large_exponent)
318 {
319 // = e+000 (additional number)
320 m_width += 1;
321 }
322 }
323 else
324 {
325 std::streamsize decimals = 1; // print a leading 0
326 if (std::floor(m_max) != 0)
327 {
328 decimals += std::streamsize(std::log10(std::floor(m_max)));
329 }
330 // 2 => sign and dot
331 m_width = 2 + decimals + m_precision;
332 }
333 if (!m_required_precision)
334 {
335 --m_width;
336 }
337 }
338
339 std::ostream& print_next(std::ostream& out)
340 {
341 if (!m_scientific)
342 {
343 std::stringstream buf;
344 buf.width(m_width);
345 buf << std::fixed;
346 buf.precision(m_precision);
347 buf << (*m_it);
348 if (!m_required_precision && !std::isinf(*m_it) && !std::isnan(*m_it))
349 {
350 buf << '.';
351 }
352 std::string res = buf.str();
353 auto sit = res.rbegin();
354 while (*sit == '0')
355 {
356 *sit = ' ';
357 ++sit;
358 }
359 out << res;
360 }
361 else
362 {
363 if (!m_large_exponent)
364 {
365 out << std::scientific;
366 out.width(m_width);
367 out << (*m_it);
368 }
369 else
370 {
371 std::stringstream buf;
372 buf.width(m_width);
373 buf << std::scientific;
374 buf.precision(m_precision);
375 buf << (*m_it);
376 std::string res = buf.str();
377
378 if (res[res.size() - 4] == 'e')
379 {
380 res.erase(0, 1);
381 res.insert(res.size() - 2, "0");
382 }
383 out << res;
384 }
385 }
386 ++m_it;
387 return out;
388 }
389
390 void update(const value_type& val)
391 {
392 if (val != 0 && !std::isinf(val) && !std::isnan(val))
393 {
394 if (!m_scientific || !m_large_exponent)
395 {
396 int exponent = 1 + int(std::log10(math::abs(val)));
397 if (exponent <= -5 || exponent > 7)
398 {
399 m_scientific = true;
400 m_required_precision = m_precision;
401 if (exponent <= -100 || exponent >= 100)
402 {
403 m_large_exponent = true;
404 }
405 }
406 }
407 if (math::abs(val) > m_max)
408 {
409 m_max = math::abs(val);
410 }
411 if (m_required_precision < m_precision)
412 {
413 while (std::floor(val * std::pow(10, m_required_precision))
414 != val * std::pow(10, m_required_precision))
415 {
416 m_required_precision++;
417 }
418 }
419 }
420 m_cache.push_back(val);
421 }
422
423 std::streamsize width()
424 {
425 return m_width;
426 }
427
428 private:
429
430 bool m_large_exponent = false;
431 bool m_scientific = false;
432 std::streamsize m_width = 9;
433 std::streamsize m_precision;
434 std::streamsize m_required_precision = 0;
435 value_type m_max = 0;
436
437 cache_type m_cache;
438 cache_iterator m_it;
439 };
440
441 template <class T>
442 struct printer<
443 T,
444 std::enable_if_t<
445 xtl::is_integral<typename T::value_type>::value && !std::is_same<typename T::value_type, bool>::value>>
446 {
447 using value_type = std::decay_t<typename T::value_type>;
448 using cache_type = std::vector<value_type>;
449 using cache_iterator = typename cache_type::const_iterator;
450
451 explicit printer(std::streamsize)
452 {
453 }
454
455 void init()
456 {
457 m_it = m_cache.cbegin();
458 m_width = 1 + std::streamsize((m_max > 0) ? std::log10(m_max) : 0) + m_sign;
459 }
460
461 std::ostream& print_next(std::ostream& out)
462 {
463 // + enables printing of chars etc. as numbers
464 // TODO should chars be printed as numbers?
465 out.width(m_width);
466 out << +(*m_it);
467 ++m_it;
468 return out;
469 }
470
471 void update(const value_type& val)
472 {
473 if (math::abs(val) > m_max)
474 {
475 m_max = math::abs(val);
476 }
477 if (xtl::is_signed<value_type>::value && val < 0)
478 {
479 m_sign = true;
480 }
481 m_cache.push_back(val);
482 }
483
484 std::streamsize width()
485 {
486 return m_width;
487 }
488
489 private:
490
491 std::streamsize m_width;
492 bool m_sign = false;
493 value_type m_max = 0;
494
495 cache_type m_cache;
496 cache_iterator m_it;
497 };
498
499 template <class T>
500 struct printer<T, std::enable_if_t<std::is_same<typename T::value_type, bool>::value>>
501 {
502 using value_type = bool;
503 using cache_type = std::vector<bool>;
504 using cache_iterator = typename cache_type::const_iterator;
505
506 explicit printer(std::streamsize)
507 {
508 }
509
510 void init()
511 {
512 m_it = m_cache.cbegin();
513 }
514
515 std::ostream& print_next(std::ostream& out)
516 {
517 if (*m_it)
518 {
519 out << " true";
520 }
521 else
522 {
523 out << "false";
524 }
525 // TODO: the following std::setw(5) isn't working correctly on OSX.
526 // out << std::boolalpha << std::setw(m_width) << (*m_it);
527 ++m_it;
528 return out;
529 }
530
531 void update(const value_type& val)
532 {
533 m_cache.push_back(val);
534 }
535
536 std::streamsize width()
537 {
538 return m_width;
539 }
540
541 private:
542
543 std::streamsize m_width = 5;
544
545 cache_type m_cache;
546 cache_iterator m_it;
547 };
548
549 template <class T>
550 struct printer<T, std::enable_if_t<xtl::is_complex<typename T::value_type>::value>>
551 {
552 using value_type = std::decay_t<typename T::value_type>;
553 using cache_type = std::vector<bool>;
554 using cache_iterator = typename cache_type::const_iterator;
555
556 explicit printer(std::streamsize precision)
557 : real_printer(precision)
558 , imag_printer(precision)
559 {
560 }
561
562 void init()
563 {
564 real_printer.init();
565 imag_printer.init();
566 m_it = m_signs.cbegin();
567 }
568
569 std::ostream& print_next(std::ostream& out)
570 {
571 real_printer.print_next(out);
572 if (*m_it)
573 {
574 out << "-";
575 }
576 else
577 {
578 out << "+";
579 }
580 std::stringstream buf;
581 imag_printer.print_next(buf);
582 std::string s = buf.str();
583 if (s[0] == ' ')
584 {
585 s.erase(0, 1); // erase space for +/-
586 }
587 // insert j at end of number
588 std::size_t idx = s.find_last_not_of(" ");
589 s.insert(idx + 1, "i");
590 out << s;
591 ++m_it;
592 return out;
593 }
594
595 void update(const value_type& val)
596 {
597 real_printer.update(val.real());
598 imag_printer.update(std::abs(val.imag()));
599 m_signs.push_back(std::signbit(val.imag()));
600 }
601
602 std::streamsize width()
603 {
604 return real_printer.width() + imag_printer.width() + 2;
605 }
606
607 private:
608
609 printer<value_type> real_printer, imag_printer;
610 cache_type m_signs;
611 cache_iterator m_it;
612 };
613
614 template <class T>
615 struct printer<
616 T,
617 std::enable_if_t<
618 !xtl::is_fundamental<typename T::value_type>::value && !xtl::is_complex<typename T::value_type>::value>>
619 {
620 using const_reference = typename T::const_reference;
621 using value_type = std::decay_t<typename T::value_type>;
622 using cache_type = std::vector<std::string>;
623 using cache_iterator = typename cache_type::const_iterator;
624
625 explicit printer(std::streamsize)
626 {
627 }
628
629 void init()
630 {
631 m_it = m_cache.cbegin();
632 if (m_width > 20)
633 {
634 m_width = 0;
635 }
636 }
637
638 std::ostream& print_next(std::ostream& out)
639 {
640 out.width(m_width);
641 out << *m_it;
642 ++m_it;
643 return out;
644 }
645
646 void update(const_reference val)
647 {
648 std::stringstream buf;
649 buf << val;
650 std::string s = buf.str();
651 if (int(s.size()) > m_width)
652 {
653 m_width = std::streamsize(s.size());
654 }
655 m_cache.push_back(s);
656 }
657
658 std::streamsize width()
659 {
660 return m_width;
661 }
662
663 private:
664
665 std::streamsize m_width = 0;
666 cache_type m_cache;
667 cache_iterator m_it;
668 };
669
670 template <class E>
671 struct custom_formatter
672 {
673 using value_type = std::decay_t<typename E::value_type>;
674
675 template <class F>
676 custom_formatter(F&& func)
677 : m_func(func)
678 {
679 }
680
681 std::string operator()(const value_type& val) const
682 {
683 return m_func(val);
684 }
685
686 private:
687
688 std::function<std::string(const value_type&)> m_func;
689 };
690 }
691
692 inline print_options::print_options_impl get_print_options(std::ostream& out)
693 {
699
700 res.edge_items = static_cast<int>(out.iword(edge_items::id()));
701 res.line_width = static_cast<int>(out.iword(line_width::id()));
702 res.threshold = static_cast<int>(out.iword(threshold::id()));
703 res.precision = static_cast<int>(out.iword(precision::id()));
704
705 if (!res.edge_items)
706 {
707 res.edge_items = print_options::print_options().edge_items;
708 }
709 else
710 {
711 out.iword(edge_items::id()) = long(0);
712 }
713 if (!res.line_width)
714 {
715 res.line_width = print_options::print_options().line_width;
716 }
717 else
718 {
719 out.iword(line_width::id()) = long(0);
720 }
721 if (!res.threshold)
722 {
723 res.threshold = print_options::print_options().threshold;
724 }
725 else
726 {
727 out.iword(threshold::id()) = long(0);
728 }
729 if (!res.precision)
730 {
731 res.precision = print_options::print_options().precision;
732 }
733 else
734 {
735 out.iword(precision::id()) = long(0);
736 }
737
738 return res;
739 }
740
741 template <class E, class F>
742 std::ostream& pretty_print(const xexpression<E>& e, F&& func, std::ostream& out = std::cout)
743 {
744 xfunction<detail::custom_formatter<E>, const_xclosure_t<E>> print_fun(
745 detail::custom_formatter<E>(std::forward<F>(func)),
746 e
747 );
748 return pretty_print(print_fun, out);
749 }
750
751 namespace detail
752 {
753 template <class S>
754 class fmtflags_guard
755 {
756 public:
757
758 explicit fmtflags_guard(S& stream)
759 : m_stream(stream)
760 , m_flags(stream.flags())
761 {
762 }
763
764 ~fmtflags_guard()
765 {
766 m_stream.flags(m_flags);
767 }
768
769 private:
770
771 S& m_stream;
772 std::ios_base::fmtflags m_flags;
773 };
774 }
775
776 template <class E>
777 std::ostream& pretty_print(const xexpression<E>& e, std::ostream& out = std::cout)
778 {
779 detail::fmtflags_guard<std::ostream> guard(out);
780
781 const E& d = e.derived_cast();
782
783 std::size_t lim = 0;
784 std::size_t sz = compute_size(d.shape());
785
786 auto po = get_print_options(out);
787
788 if (sz > static_cast<std::size_t>(po.threshold))
789 {
790 lim = static_cast<std::size_t>(po.edge_items);
791 }
792 if (sz == 0)
793 {
794 out << "{}";
795 return out;
796 }
797
798 auto temp_precision = out.precision();
799 auto precision = temp_precision;
800 if (po.precision != -1)
801 {
802 out.precision(static_cast<std::streamsize>(po.precision));
803 precision = static_cast<std::streamsize>(po.precision);
804 }
805
806 detail::printer<E> p(precision);
807
809 detail::recurser_run(p, d, sv, lim);
810 p.init();
811 sv.clear();
812 xoutput(out, d, sv, p, 1, p.width(), lim, static_cast<std::size_t>(po.line_width));
813
814 out.precision(temp_precision); // restore precision
815
816 return out;
817 }
818
819 template <class E>
820 inline std::ostream& operator<<(std::ostream& out, const xexpression<E>& e)
821 {
822 return pretty_print(e, out);
823 }
824}
825#endif
826
827// Backward compatibility: include xmime.hpp in xio.hpp by default.
828
829#if defined(__CLING__) || defined(__CLANG_REPL__)
830#include "xmime.hpp"
831#endif
io manipulator used to set the number of egde items if the summarization is triggered.
Definition xio.hpp:165
io manipulator used to set the width of the lines when printing an expression.
Definition xio.hpp:137
io manipulator used to set the precision of the floating point values when printing an expression.
Definition xio.hpp:179
io manipulator used to set the threshold after which summarization is triggered.
Definition xio.hpp:151
Base class for xexpressions.
auto operator<<(E1 &&e1, E2 &&e2) noexcept -> detail::shift_return_type_t< detail::left_shift, E1, E2 >
Bitwise left shift.
standard mathematical functions for xexpressions
std::vector< xstrided_slice< std::ptrdiff_t > > xstrided_slice_vector
vector of slices used to build a xstrided_view
auto strided_view(E &&e, S &&shape, X &&stride, std::size_t offset=0, layout_type layout=L) noexcept
Construct a strided view from an xexpression, shape, strides and offset.
auto view(E &&e, S &&... slices)
Constructs and returns a view on the specified xexpression.
Definition xview.hpp:1824