21#include "xexpression.hpp"
23#include "xstrided_view.hpp"
29 inline std::ostream& operator<<(std::ostream& out,
const xexpression<E>& e);
35 namespace print_options
57 inline void set_line_width(
int line_width)
59 print_options().line_width = line_width;
68 inline void set_threshold(
int threshold)
70 print_options().threshold = threshold;
80 inline void set_edge_items(
int edge_items)
82 print_options().edge_items = edge_items;
90 inline void set_precision(
int precision)
92 print_options().precision = precision;
95#define DEFINE_LOCAL_PRINT_OPTION(NAME) \
107 static int id = std::ios_base::xalloc(); \
120 inline std::ostream& operator<<(std::ostream& out, const NAME& n) \
122 out.iword(NAME::id()) = n.value(); \
189 template <
class E,
class F>
190 std::ostream& xoutput(
201 using size_type =
typename E::size_type;
204 if (
view.dimension() == 0)
206 printer.print_next(
out);
218 for (;
i != size_type(
view.shape()[0] - 1); ++
i)
226 else if (
view.dimension() > 1)
242 slices.push_back(
static_cast<int>(
i));
251 else if (
view.dimension() > 1)
260 slices.push_back(
static_cast<int>(
i));
267 template <
class F,
class E>
270 using size_type =
typename E::size_type;
272 if (
view.dimension() == 0)
279 for (;
i !=
static_cast<size_type
>(
view.shape()[0] - 1); ++
i)
283 i =
static_cast<size_type
>(
view.shape()[0]) -
lim;
285 slices.push_back(
static_cast<int>(
i));
286 recurser_run(
fn,
e, slices,
lim);
289 slices.push_back(
static_cast<int>(
i));
290 recurser_run(
fn,
e, slices,
lim);
295 template <
class T,
class E =
void>
299 struct printer<T, std::
enable_if_t<std::is_floating_point<typename T::value_type>::value>>
301 using value_type = std::decay_t<typename T::value_type>;
302 using cache_type = std::vector<value_type>;
303 using cache_iterator =
typename cache_type::const_iterator;
305 explicit printer(std::streamsize
precision)
312 m_precision = m_required_precision < m_precision ? m_required_precision : m_precision;
313 m_it = m_cache.cbegin();
317 m_width = m_precision + 7;
318 if (m_large_exponent)
327 if (std::floor(m_max) != 0)
329 decimals += std::streamsize(std::log10(std::floor(m_max)));
332 m_width = 2 +
decimals + m_precision;
334 if (!m_required_precision)
340 std::ostream& print_next(std::ostream&
out)
344 std::stringstream
buf;
347 buf.precision(m_precision);
349 if (!m_required_precision && !std::isinf(*m_it) && !std::isnan(*m_it))
353 std::string
res =
buf.str();
364 if (!m_large_exponent)
366 out << std::scientific;
372 std::stringstream
buf;
374 buf << std::scientific;
375 buf.precision(m_precision);
377 std::string
res =
buf.str();
379 if (
res[
res.size() - 4] ==
'e')
382 res.insert(
res.size() - 2,
"0");
391 void update(
const value_type&
val)
393 if (
val != 0 && !std::isinf(
val) && !std::isnan(
val))
395 if (!m_scientific || !m_large_exponent)
401 m_required_precision = m_precision;
404 m_large_exponent =
true;
408 if (math::abs(
val) > m_max)
410 m_max = math::abs(
val);
412 if (m_required_precision < m_precision)
414 while (std::floor(
val * std::pow(10, m_required_precision))
415 !=
val * std::pow(10, m_required_precision))
417 m_required_precision++;
421 m_cache.push_back(
val);
424 std::streamsize width()
431 bool m_large_exponent =
false;
432 bool m_scientific =
false;
433 std::streamsize m_width = 9;
434 std::streamsize m_precision;
435 std::streamsize m_required_precision = 0;
436 value_type m_max = 0;
446 xtl::is_integral<typename T::value_type>::value && !std::is_same<typename T::value_type, bool>::value>>
448 using value_type = std::decay_t<typename T::value_type>;
449 using cache_type = std::vector<value_type>;
450 using cache_iterator =
typename cache_type::const_iterator;
452 explicit printer(std::streamsize)
458 m_it = m_cache.cbegin();
459 m_width = 1 + std::streamsize((m_max > 0) ? std::log10(m_max) : 0) + m_sign;
462 std::ostream& print_next(std::ostream&
out)
472 void update(
const value_type&
val)
474 if (math::abs(
val) > m_max)
476 m_max = math::abs(
val);
478 if (xtl::is_signed<value_type>::value &&
val < 0)
482 m_cache.push_back(
val);
485 std::streamsize width()
492 std::streamsize m_width;
494 value_type m_max = 0;
501 struct printer<T, std::
enable_if_t<std::is_same<typename T::value_type, bool>::value>>
503 using value_type =
bool;
504 using cache_type = std::vector<bool>;
505 using cache_iterator =
typename cache_type::const_iterator;
507 explicit printer(std::streamsize)
513 m_it = m_cache.cbegin();
516 std::ostream& print_next(std::ostream&
out)
532 void update(
const value_type&
val)
534 m_cache.push_back(
val);
537 std::streamsize width()
544 std::streamsize m_width = 5;
551 struct printer<T, std::
enable_if_t<xtl::is_complex<typename T::value_type>::value>>
553 using value_type = std::decay_t<typename T::value_type>;
554 using cache_type = std::vector<bool>;
555 using cache_iterator =
typename cache_type::const_iterator;
557 explicit printer(std::streamsize
precision)
567 m_it = m_signs.cbegin();
570 std::ostream& print_next(std::ostream&
out)
572 real_printer.print_next(
out);
581 std::stringstream
buf;
582 imag_printer.print_next(
buf);
583 std::string
s =
buf.str();
589 std::size_t idx =
s.find_last_not_of(
" ");
590 s.insert(idx + 1,
"i");
596 void update(
const value_type&
val)
598 real_printer.update(
val.real());
599 imag_printer.update(std::abs(
val.imag()));
600 m_signs.push_back(std::signbit(
val.imag()));
603 std::streamsize width()
605 return real_printer.width() + imag_printer.width() + 2;
619 !xtl::is_fundamental<typename T::value_type>::value && !xtl::is_complex<typename T::value_type>::value>>
621 using const_reference =
typename T::const_reference;
622 using value_type = std::decay_t<typename T::value_type>;
623 using cache_type = std::vector<std::string>;
624 using cache_iterator =
typename cache_type::const_iterator;
626 explicit printer(std::streamsize)
632 m_it = m_cache.cbegin();
639 std::ostream& print_next(std::ostream&
out)
647 void update(const_reference
val)
649 std::stringstream
buf;
651 std::string
s =
buf.str();
652 if (
int(
s.size()) > m_width)
654 m_width = std::streamsize(
s.size());
656 m_cache.push_back(
s);
659 std::streamsize width()
666 std::streamsize m_width = 0;
672 struct custom_formatter
674 using value_type = std::decay_t<typename E::value_type>;
677 custom_formatter(
F&&
func)
682 std::string operator()(
const value_type&
val)
const
689 std::function<std::string(
const value_type&)> m_func;
701 res.edge_items =
static_cast<int>(
out.iword(edge_items::id()));
702 res.line_width =
static_cast<int>(
out.iword(line_width::id()));
703 res.threshold =
static_cast<int>(
out.iword(threshold::id()));
704 res.precision =
static_cast<int>(
out.iword(precision::id()));
708 res.edge_items = print_options::print_options().edge_items;
712 out.iword(edge_items::id()) =
long(0);
716 res.line_width = print_options::print_options().line_width;
720 out.iword(line_width::id()) = long(0);
724 res.threshold = print_options::print_options().threshold;
728 out.iword(threshold::id()) = long(0);
732 res.precision = print_options::print_options().precision;
736 out.iword(precision::id()) = long(0);
742 template <
class E,
class F>
743 std::ostream& pretty_print(
const xexpression<E>& e, F&& func, std::ostream& out = std::cout)
745 xfunction<detail::custom_formatter<E>, const_xclosure_t<E>> print_fun(
746 detail::custom_formatter<E>(std::forward<F>(func)),
749 return pretty_print(print_fun, out);
759 explicit fmtflags_guard(S& stream)
761 , m_flags(stream.flags())
767 m_stream.flags(m_flags);
773 std::ios_base::fmtflags m_flags;
778 std::ostream& pretty_print(
const xexpression<E>& e, std::ostream& out = std::cout)
780 detail::fmtflags_guard<std::ostream> guard(out);
782 const E& d = e.derived_cast();
785 std::size_t sz = compute_size(d.shape());
787 auto po = get_print_options(out);
789 if (sz >
static_cast<std::size_t
>(po.threshold))
791 lim =
static_cast<std::size_t
>(po.edge_items);
799 auto temp_precision = out.precision();
800 auto precision = temp_precision;
801 if (po.precision != -1)
803 out.precision(
static_cast<std::streamsize
>(po.precision));
804 precision =
static_cast<std::streamsize
>(po.precision);
807 detail::printer<E> p(precision);
810 detail::recurser_run(p, d, sv, lim);
813 xoutput(out, d, sv, p, 1, p.width(), lim,
static_cast<std::size_t
>(po.line_width));
815 out.precision(temp_precision);
821 inline std::ostream& operator<<(std::ostream& out,
const xexpression<E>& e)
823 return pretty_print(e, out);
io manipulator used to set the number of egde items if the summarization is triggered.
io manipulator used to set the width of the lines when printing an expression.
io manipulator used to set the precision of the floating point values when printing an expression.
io manipulator used to set the threshold after which summarization is triggered.
standard mathematical functions for xexpressions
std::vector< xstrided_slice< std::ptrdiff_t > > xstrided_slice_vector
vector of slices used to build a xstrided_view
auto strided_view(E &&e, S &&shape, X &&stride, std::size_t offset=0, layout_type layout=L) noexcept
Construct a strided view from an xexpression, shape, strides and offset.
auto view(E &&e, S &&... slices)
Constructs and returns a view on the specified xexpression.