xtensor
 
Loading...
Searching...
No Matches
xexpression.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_EXPRESSION_HPP
11#define XTENSOR_EXPRESSION_HPP
12
13#include <cstddef>
14#include <type_traits>
15#include <vector>
16
17#include <xtl/xclosure.hpp>
18#include <xtl/xmeta_utils.hpp>
19#include <xtl/xtype_traits.hpp>
20
21#include "../core/xlayout.hpp"
22#include "../core/xshape.hpp"
23#include "../core/xtensor_forward.hpp"
24#include "../utils/xutils.hpp"
25
26namespace xt
27{
28
29 /***************************
30 * xexpression declaration *
31 ***************************/
32
45 template <class D>
46 class xexpression
47 {
48 public:
49
50 using derived_type = D;
51
52 derived_type& derived_cast() & noexcept;
53 const derived_type& derived_cast() const& noexcept;
54 derived_type derived_cast() && noexcept;
55
56 protected:
57
58 xexpression() = default;
59 ~xexpression() = default;
60
61 xexpression(const xexpression&) = default;
62 xexpression& operator=(const xexpression&) = default;
63
64 xexpression(xexpression&&) = default;
65 xexpression& operator=(xexpression&&) = default;
66 };
67
68 /************************************
69 * xsharable_expression declaration *
70 ************************************/
71
72 template <class E>
74
75 template <class E>
77
78 namespace detail
79 {
80 template <class E>
82 }
83
84 template <class D>
85 class xsharable_expression : public xexpression<D>
86 {
87 protected:
88
89 xsharable_expression();
90 ~xsharable_expression() = default;
91
92 xsharable_expression(const xsharable_expression&) = default;
93 xsharable_expression& operator=(const xsharable_expression&) = default;
94
95 xsharable_expression(xsharable_expression&&) = default;
96 xsharable_expression& operator=(xsharable_expression&&) = default;
97
98 private:
99
100 std::shared_ptr<D> p_shared;
101
102 friend xshared_expression<D> detail::make_xshared_impl<D>(xsharable_expression<D>&&);
103 };
104
105 /******************************
106 * xexpression implementation *
107 ******************************/
108
113
116 template <class D>
117 inline auto xexpression<D>::derived_cast() & noexcept -> derived_type&
118 {
119 return *static_cast<derived_type*>(this);
120 }
121
125 template <class D>
126 inline auto xexpression<D>::derived_cast() const& noexcept -> const derived_type&
127 {
128 return *static_cast<const derived_type*>(this);
129 }
130
134 template <class D>
135 inline auto xexpression<D>::derived_cast() && noexcept -> derived_type
136 {
137 return *static_cast<derived_type*>(this);
138 }
139
141
142 /***************************************
143 * xsharable_expression implementation *
144 ***************************************/
145
146 template <class D>
147 inline xsharable_expression<D>::xsharable_expression()
148 : p_shared(nullptr)
149 {
150 }
151
161
162 namespace detail
163 {
164 template <template <class> class B, class E>
165 struct is_crtp_base_of_impl : std::is_base_of<B<E>, E>
166 {
167 };
168
169 template <template <class> class B, class E, template <class> class F>
170 struct is_crtp_base_of_impl<B, F<E>>
171 : std::disjunction<std::is_base_of<B<E>, F<E>>, std::is_base_of<B<F<E>>, F<E>>>
172 {
173 };
174 }
175
176 template <template <class> class B, class E>
177 using is_crtp_base_of = detail::is_crtp_base_of_impl<B, std::decay_t<E>>;
178
179 template <class E>
180 using is_xexpression = is_crtp_base_of<xexpression, E>;
181
182 template <class E>
183 concept xexpression_concept = is_xexpression<E>::value;
184
185 template <class E, class R = void>
186 using enable_xexpression = typename std::enable_if<is_xexpression<E>::value, R>::type;
187
188 template <class E, class R = void>
189 using disable_xexpression = typename std::enable_if<!is_xexpression<E>::value, R>::type;
190
191 template <class... E>
192 using has_xexpression = std::disjunction<is_xexpression<E>...>;
193
194 template <class E>
195 using is_xsharable_expression = is_crtp_base_of<xsharable_expression, E>;
196
197 template <class E, class R = void>
198 using enable_xsharable_expression = typename std::enable_if<is_xsharable_expression<E>::value, R>::type;
199
200 template <class E, class R = void>
201 using disable_xsharable_expression = typename std::enable_if<!is_xsharable_expression<E>::value, R>::type;
202
203 template <class LHS, class RHS>
204 struct can_assign : std::is_assignable<LHS, RHS>
205 {
206 };
207
208 template <class LHS, class RHS, class R = void>
209 using enable_assignable_expression = typename std::enable_if<can_assign<LHS, RHS>::value, R>::type;
210
211 template <class LHS, class RHS, class R = void>
212 using enable_not_assignable_expression = typename std::enable_if<!can_assign<LHS, RHS>::value, R>::type;
213
214 /***********************
215 * evaluation_strategy *
216 ***********************/
217
218 namespace detail
219 {
220 struct option_base
221 {
222 };
223 }
224
225 namespace evaluation_strategy
226 {
227
228 struct immediate_type : xt::detail::option_base
229 {
230 };
231
232 constexpr auto immediate = std::tuple<immediate_type>{};
233
234 struct lazy_type : xt::detail::option_base
235 {
236 };
237
238 constexpr auto lazy = std::tuple<lazy_type>{};
239
240 /*
241 struct cached {};
242 */
243 }
244
245 template <class T>
246 struct is_evaluation_strategy : std::is_base_of<detail::option_base, std::decay_t<T>>
247 {
248 };
249
250 /************
251 * xclosure *
252 ************/
253
254 template <class T>
255 class xscalar;
256
257 template <class E, class EN = void>
258 struct xclosure
259 {
260 using type = xtl::closure_type_t<E>;
261 };
262
263 template <class E>
264 struct xclosure<xshared_expression<E>, std::enable_if_t<true>>
265 {
266 using type = xshared_expression<E>; // force copy
267 };
268
269 template <class E>
270 struct xclosure<E, disable_xexpression<std::decay_t<E>>>
271 {
272 using type = xscalar<xtl::closure_type_t<E>>;
273 };
274
275 template <class E>
276 using xclosure_t = typename xclosure<E>::type;
277
278 template <class E, class EN = void>
280 {
281 using type = xtl::const_closure_type_t<E>;
282 };
283
284 template <class E>
285 struct const_xclosure<E, disable_xexpression<std::decay_t<E>>>
286 {
288 };
289
290 template <class E>
291 struct const_xclosure<xshared_expression<E>&, std::enable_if_t<true>>
292 {
293 using type = xshared_expression<E>; // force copy
294 };
295
296 template <class E>
297 using const_xclosure_t = typename const_xclosure<E>::type;
298
299 /*************************
300 * expression tag system *
301 *************************/
302
304 {
305 };
306
308 {
309 };
310
311 namespace extension
312 {
313 template <class E, class = void_t<int>>
315 {
316 using type = xtensor_expression_tag;
317 };
318
319 template <class E>
320 struct get_expression_tag_impl<E, void_t<typename std::decay_t<E>::expression_tag>>
321 {
322 using type = typename std::decay_t<E>::expression_tag;
323 };
324
325 template <class E>
329
330 template <class E>
331 using get_expression_tag_t = typename get_expression_tag<E>::type;
332
333 template <class... T>
335
336 template <>
338 {
339 using type = xtensor_expression_tag;
340 };
341
342 template <class T>
344 {
345 using type = T;
346 };
347
348 template <class T>
350 {
351 using type = T;
352 };
353
354 template <class T>
356 {
357 using type = T;
358 };
359
360 template <class T>
361 struct expression_tag_and<T, xtensor_expression_tag> : expression_tag_and<xtensor_expression_tag, T>
362 {
363 };
364
365 template <>
370
371 template <class T1, class... T>
372 struct expression_tag_and<T1, T...> : expression_tag_and<T1, typename expression_tag_and<T...>::type>
373 {
374 };
375
376 template <class... T>
377 using expression_tag_and_t = typename expression_tag_and<T...>::type;
378
380 {
381 using expression_tag = xtensor_expression_tag;
382 };
383 }
384
385 template <class... T>
387 {
388 using type = extension::expression_tag_and_t<
389 extension::get_expression_tag_t<std::decay_t<const_xclosure_t<T>>>...>;
390 };
391
392 template <class... T>
393 using xexpression_tag_t = typename xexpression_tag<T...>::type;
394
395 template <class E>
396 struct is_xtensor_expression : std::is_same<xexpression_tag_t<E>, xtensor_expression_tag>
397 {
398 };
399
400 template <class E>
401 struct is_xoptional_expression : std::is_same<xexpression_tag_t<E>, xoptional_expression_tag>
402 {
403 };
404
405 /********************************
406 * xoptional_comparable concept *
407 ********************************/
408
409 template <class... E>
411 : std::conjunction<std::disjunction<is_xtensor_expression<E>, is_xoptional_expression<E>>...>
412 {
413 };
414
415#define XTENSOR_FORWARD_CONST_METHOD(name) \
416 auto name() const -> decltype(std::declval<xtl::constify_t<E>>().name()) \
417 { \
418 return m_ptr->name(); \
419 }
420
421#define XTENSOR_FORWARD_METHOD(name) \
422 auto name() -> decltype(std::declval<E>().name()) \
423 { \
424 return m_ptr->name(); \
425 }
426
427#define XTENSOR_FORWARD_CONST_ITERATOR_METHOD(name) \
428 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL> \
429 auto name() const noexcept -> decltype(std::declval<xtl::constify_t<E>>().template name<L>()) \
430 { \
431 return m_ptr->template name<L>(); \
432 } \
433 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class S> \
434 auto name(const S& shape) const noexcept \
435 -> decltype(std::declval<xtl::constify_t<E>>().template name<L>(shape)) \
436 { \
437 return m_ptr->template name<L>(); \
438 }
439
440#define XTENSOR_FORWARD_ITERATOR_METHOD(name) \
441 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class S> \
442 auto name(const S& shape) noexcept -> decltype(std::declval<E>().template name<L>(shape)) \
443 { \
444 return m_ptr->template name<L>(); \
445 } \
446 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL> \
447 auto name() noexcept -> decltype(std::declval<E>().template name<L>()) \
448 { \
449 return m_ptr->template name<L>(); \
450 }
451
452 namespace detail
453 {
454 template <class E>
455 struct expr_strides_type
456 {
457 using type = typename E::strides_type;
458 };
459
460 template <class E>
461 struct expr_inner_strides_type
462 {
463 using type = typename E::inner_strides_type;
464 };
465
466 template <class E>
467 struct expr_backstrides_type
468 {
469 using type = typename E::backstrides_type;
470 };
471
472 template <class E>
473 struct expr_inner_backstrides_type
474 {
475 using type = typename E::inner_backstrides_type;
476 };
477
478 template <class E>
479 struct expr_storage_type
480 {
481 using type = typename E::storage_type;
482 };
483 }
484
510 template <class E>
511 class xshared_expression : public xexpression<xshared_expression<E>>
512 {
513 public:
514
515 using base_class = xexpression<xshared_expression<E>>;
516
517 using value_type = typename E::value_type;
518 using reference = typename E::reference;
519 using const_reference = typename E::const_reference;
520 using pointer = typename E::pointer;
521 using const_pointer = typename E::const_pointer;
522 using size_type = typename E::size_type;
523 using difference_type = typename E::difference_type;
524
525 using inner_shape_type = typename E::inner_shape_type;
526 using shape_type = typename E::shape_type;
527
528 using strides_type = xtl::mpl::
529 eval_if_t<has_strides<E>, detail::expr_strides_type<E>, get_strides_type<shape_type>>;
530 using backstrides_type = xtl::mpl::
531 eval_if_t<has_strides<E>, detail::expr_backstrides_type<E>, get_strides_type<shape_type>>;
532 using inner_strides_type = xtl::mpl::
533 eval_if_t<has_strides<E>, detail::expr_inner_strides_type<E>, get_strides_type<shape_type>>;
534 using inner_backstrides_type = xtl::mpl::
535 eval_if_t<has_strides<E>, detail::expr_inner_backstrides_type<E>, get_strides_type<shape_type>>;
536 using storage_type = xtl::mpl::eval_if_t<has_storage_type<E>, detail::expr_storage_type<E>, make_invalid_type<>>;
537
538 using stepper = typename E::stepper;
539 using const_stepper = typename E::const_stepper;
540
541 using linear_iterator = typename E::linear_iterator;
542 using const_linear_iterator = typename E::const_linear_iterator;
543
544 using bool_load_type = typename E::bool_load_type;
545
546 static constexpr layout_type static_layout = E::static_layout;
547 static constexpr bool contiguous_layout = static_layout != layout_type::dynamic;
548
549 explicit xshared_expression(const std::shared_ptr<E>& ptr);
550 long use_count() const noexcept;
551
552 template <class... Args>
553 auto operator()(Args... args) -> decltype(std::declval<E>()(args...))
554 {
555 return m_ptr->operator()(args...);
556 }
557
558 XTENSOR_FORWARD_CONST_METHOD(shape)
559 XTENSOR_FORWARD_CONST_METHOD(dimension)
560 XTENSOR_FORWARD_CONST_METHOD(size)
561 XTENSOR_FORWARD_CONST_METHOD(layout)
562 XTENSOR_FORWARD_CONST_METHOD(is_contiguous)
563
564 XTENSOR_FORWARD_ITERATOR_METHOD(begin)
565 XTENSOR_FORWARD_ITERATOR_METHOD(end)
566 XTENSOR_FORWARD_CONST_ITERATOR_METHOD(begin)
567 XTENSOR_FORWARD_CONST_ITERATOR_METHOD(end)
568 XTENSOR_FORWARD_CONST_ITERATOR_METHOD(cbegin)
569 XTENSOR_FORWARD_CONST_ITERATOR_METHOD(cend)
570
571 XTENSOR_FORWARD_ITERATOR_METHOD(rbegin)
572 XTENSOR_FORWARD_ITERATOR_METHOD(rend)
573 XTENSOR_FORWARD_CONST_ITERATOR_METHOD(rbegin)
574 XTENSOR_FORWARD_CONST_ITERATOR_METHOD(rend)
575 XTENSOR_FORWARD_CONST_ITERATOR_METHOD(crbegin)
576 XTENSOR_FORWARD_CONST_ITERATOR_METHOD(crend)
577
578 XTENSOR_FORWARD_METHOD(linear_begin)
579 XTENSOR_FORWARD_METHOD(linear_end)
580 XTENSOR_FORWARD_CONST_METHOD(linear_begin)
581 XTENSOR_FORWARD_CONST_METHOD(linear_end)
582 XTENSOR_FORWARD_CONST_METHOD(linear_cbegin)
583 XTENSOR_FORWARD_CONST_METHOD(linear_cend)
584
585 XTENSOR_FORWARD_METHOD(linear_rbegin)
586 XTENSOR_FORWARD_METHOD(linear_rend)
587 XTENSOR_FORWARD_CONST_METHOD(linear_rbegin)
588 XTENSOR_FORWARD_CONST_METHOD(linear_rend)
589 XTENSOR_FORWARD_CONST_METHOD(linear_crbegin)
590 XTENSOR_FORWARD_CONST_METHOD(linear_crend)
591
592 template <class T = E>
593 std::enable_if_t<has_strides<T>::value, const inner_strides_type&> strides() const
594 {
595 return m_ptr->strides();
596 }
597
598 template <class T = E>
599 std::enable_if_t<has_strides<T>::value, const inner_strides_type&> backstrides() const
600 {
601 return m_ptr->backstrides();
602 }
603
604 template <class T = E>
605 std::enable_if_t<has_data_interface<T>::value, pointer> data() noexcept
606 {
607 return m_ptr->data();
608 }
609
610 template <class T = E>
611 std::enable_if_t<has_data_interface<T>::value, pointer> data() const noexcept
612 {
613 return m_ptr->data();
614 }
615
616 template <class T = E>
617 std::enable_if_t<has_data_interface<T>::value, size_type> data_offset() const noexcept
618 {
619 return m_ptr->data_offset();
620 }
621
622 template <class T = E>
623 std::enable_if_t<has_data_interface<T>::value, typename T::storage_type&> storage() noexcept
624 {
625 return m_ptr->storage();
626 }
627
628 template <class T = E>
629 std::enable_if_t<has_data_interface<T>::value, const typename T::storage_type&> storage() const noexcept
630 {
631 return m_ptr->storage();
632 }
633
634 template <class It>
635 reference element(It first, It last)
636 {
637 return m_ptr->element(first, last);
638 }
639
640 template <class It>
641 const_reference element(It first, It last) const
642 {
643 return m_ptr->element(first, last);
644 }
645
646 template <class S>
647 bool broadcast_shape(S& shape, bool reuse_cache = false) const
648 {
649 return m_ptr->broadcast_shape(shape, reuse_cache);
650 }
651
652 template <class S>
653 bool has_linear_assign(const S& strides) const noexcept
654 {
655 return m_ptr->has_linear_assign(strides);
656 }
657
658 template <class S>
659 auto stepper_begin(const S& shape) noexcept -> decltype(std::declval<E>().stepper_begin(shape))
660 {
661 return m_ptr->stepper_begin(shape);
662 }
663
664 template <class S>
665 auto stepper_end(const S& shape, layout_type l) noexcept
666 -> decltype(std::declval<E>().stepper_end(shape, l))
667 {
668 return m_ptr->stepper_end(shape, l);
669 }
670
671 template <class S>
672 auto stepper_begin(const S& shape) const noexcept
673 -> decltype(std::declval<const E>().stepper_begin(shape))
674 {
675 return static_cast<const E*>(m_ptr.get())->stepper_begin(shape);
676 }
677
678 template <class S>
679 auto stepper_end(const S& shape, layout_type l) const noexcept
680 -> decltype(std::declval<const E>().stepper_end(shape, l))
681 {
682 return static_cast<const E*>(m_ptr.get())->stepper_end(shape, l);
683 }
684
685 private:
686
687 std::shared_ptr<E> m_ptr;
688 };
689
697 template <class E>
698 inline xshared_expression<E>::xshared_expression(const std::shared_ptr<E>& ptr)
699 : m_ptr(ptr)
700 {
701 }
702
707 template <class E>
708 inline long xshared_expression<E>::use_count() const noexcept
709 {
710 return m_ptr.use_count();
711 }
712
713 namespace detail
714 {
715 template <class E>
716 inline xshared_expression<E> make_xshared_impl(xsharable_expression<E>&& expr)
717 {
718 if (expr.p_shared == nullptr)
719 {
720 expr.p_shared = std::make_shared<E>(std::move(expr).derived_cast());
721 }
722 return xshared_expression<E>(expr.p_shared);
723 }
724 }
725
732 template <class E>
734 {
735 static_assert(
736 is_xsharable_expression<E>::value,
737 "make_shared requires E to inherit from xsharable_expression"
738 );
739 return detail::make_xshared_impl(std::move(expr.derived_cast()));
740 }
741
749 template <class E>
750 inline auto share(xexpression<E>& expr)
751 {
752 return make_xshared(std::move(expr));
753 }
754
762 template <class E>
763 inline auto share(xexpression<E>&& expr)
764 {
765 return make_xshared(std::move(expr));
766 }
767
768#undef XTENSOR_FORWARD_METHOD
769
770}
771
772#endif
Base class for xexpressions.
derived_type & derived_cast() &noexcept
Returns a reference to the actual derived type of the xexpression.
const derived_type & derived_cast() const &noexcept
Returns a constant reference to the actual derived type of the xexpression.
Shared xexpressions.
xshared_expression(const std::shared_ptr< E > &ptr)
Constructor for xshared expression (note: usually the free function make_xshared is recommended).
long use_count() const noexcept
Return the number of times this expression is referenced.
standard mathematical functions for xexpressions
auto share(xexpression< E > &expr)
Helper function to create shared expression from any xexpression.
layout_type
Definition xlayout.hpp:24
xshared_expression< E > make_xshared(xexpression< E > &&expr)
Helper function to create shared expression from any xexpression.