xtensor
Loading...
Searching...
No Matches
xexpression.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_EXPRESSION_HPP
11#define XTENSOR_EXPRESSION_HPP
12
13#include <cstddef>
14#include <type_traits>
15#include <vector>
16
17#include <xtl/xclosure.hpp>
18#include <xtl/xmeta_utils.hpp>
19#include <xtl/xtype_traits.hpp>
20
21#include "xlayout.hpp"
22#include "xshape.hpp"
23#include "xtensor_forward.hpp"
24#include "xutils.hpp"
25
26namespace xt
27{
28
29 /***************************
30 * xexpression declaration *
31 ***************************/
32
45 template <class D>
47 {
48 public:
49
50 using derived_type = D;
51
55
57
60
63
66 };
67
68 /************************************
69 * xsharable_expression declaration *
70 ************************************/
71
72 template <class E>
74
75 template <class E>
77
78 namespace detail
79 {
80 template <class E>
82 }
83
84 template <class D>
86 {
87 protected:
88
90 ~xsharable_expression() = default;
91
93 xsharable_expression& operator=(const xsharable_expression&) = default;
94
96 xsharable_expression& operator=(xsharable_expression&&) = default;
97
98 private:
99
100 std::shared_ptr<D> p_shared;
101
102 friend xshared_expression<D> detail::make_xshared_impl<D>(xsharable_expression<D>&&);
103 };
104
105 /******************************
106 * xexpression implementation *
107 ******************************/
108
116 template <class D>
117 inline auto xexpression<D>::derived_cast() & noexcept -> derived_type&
118 {
119 return *static_cast<derived_type*>(this);
120 }
121
125 template <class D>
127 {
128 return *static_cast<const derived_type*>(this);
129 }
130
134 template <class D>
135 inline auto xexpression<D>::derived_cast() && noexcept -> derived_type
136 {
137 return *static_cast<derived_type*>(this);
138 }
139
141
142 /***************************************
143 * xsharable_expression implementation *
144 ***************************************/
145
146 template <class D>
148 : p_shared(nullptr)
149 {
150 }
151
162 namespace detail
163 {
164 template <template <class> class B, class E>
165 struct is_crtp_base_of_impl : std::is_base_of<B<E>, E>
166 {
167 };
168
169 template <template <class> class B, class E, template <class> class F>
170 struct is_crtp_base_of_impl<B, F<E>>
171 : xtl::disjunction<std::is_base_of<B<E>, F<E>>, std::is_base_of<B<F<E>>, F<E>>>
172 {
173 };
174 }
175
176 template <template <class> class B, class E>
177 using is_crtp_base_of = detail::is_crtp_base_of_impl<B, std::decay_t<E>>;
178
179 template <class E>
180 using is_xexpression = is_crtp_base_of<xexpression, E>;
181
182 template <class E, class R = void>
183 using enable_xexpression = typename std::enable_if<is_xexpression<E>::value, R>::type;
184
185 template <class E, class R = void>
186 using disable_xexpression = typename std::enable_if<!is_xexpression<E>::value, R>::type;
187
188 template <class... E>
189 using has_xexpression = xtl::disjunction<is_xexpression<E>...>;
190
191 template <class E>
192 using is_xsharable_expression = is_crtp_base_of<xsharable_expression, E>;
193
194 template <class E, class R = void>
195 using enable_xsharable_expression = typename std::enable_if<is_xsharable_expression<E>::value, R>::type;
196
197 template <class E, class R = void>
198 using disable_xsharable_expression = typename std::enable_if<!is_xsharable_expression<E>::value, R>::type;
199
200 template <class LHS, class RHS>
201 struct can_assign : std::is_assignable<LHS, RHS>
202 {
203 };
204
205 template <class LHS, class RHS, class R = void>
206 using enable_assignable_expression = typename std::enable_if<can_assign<LHS, RHS>::value, R>::type;
207
208 template <class LHS, class RHS, class R = void>
209 using enable_not_assignable_expression = typename std::enable_if<!can_assign<LHS, RHS>::value, R>::type;
210
211 /***********************
212 * evaluation_strategy *
213 ***********************/
214
215 namespace detail
216 {
217 struct option_base
218 {
219 };
220 }
221
222 namespace evaluation_strategy
223 {
224
225 struct immediate_type : xt::detail::option_base
226 {
227 };
228
229 constexpr auto immediate = std::tuple<immediate_type>{};
230
231 struct lazy_type : xt::detail::option_base
232 {
233 };
234
235 constexpr auto lazy = std::tuple<lazy_type>{};
236
237 /*
238 struct cached {};
239 */
240 }
241
242 template <class T>
243 struct is_evaluation_strategy : std::is_base_of<detail::option_base, std::decay_t<T>>
244 {
245 };
246
247 /************
248 * xclosure *
249 ************/
250
251 template <class T>
252 class xscalar;
253
254 template <class E, class EN = void>
255 struct xclosure
256 {
257 using type = xtl::closure_type_t<E>;
258 };
259
260 template <class E>
262 {
263 using type = xshared_expression<E>; // force copy
264 };
265
266 template <class E>
267 struct xclosure<E, disable_xexpression<std::decay_t<E>>>
268 {
270 };
271
272 template <class E>
273 using xclosure_t = typename xclosure<E>::type;
274
275 template <class E, class EN = void>
277 {
278 using type = xtl::const_closure_type_t<E>;
279 };
280
281 template <class E>
282 struct const_xclosure<E, disable_xexpression<std::decay_t<E>>>
283 {
285 };
286
287 template <class E>
288 struct const_xclosure<xshared_expression<E>&, std::enable_if_t<true>>
289 {
290 using type = xshared_expression<E>; // force copy
291 };
292
293 template <class E>
294 using const_xclosure_t = typename const_xclosure<E>::type;
295
296 /*************************
297 * expression tag system *
298 *************************/
299
301 {
302 };
303
305 {
306 };
307
308 namespace extension
309 {
310 template <class E, class = void_t<int>>
312 {
314 };
315
316 template <class E>
317 struct get_expression_tag_impl<E, void_t<typename std::decay_t<E>::expression_tag>>
318 {
319 using type = typename std::decay_t<E>::expression_tag;
320 };
321
322 template <class E>
326
327 template <class E>
328 using get_expression_tag_t = typename get_expression_tag<E>::type;
329
330 template <class... T>
332
333 template <>
335 {
337 };
338
339 template <class T>
341 {
342 using type = T;
343 };
344
345 template <class T>
347 {
348 using type = T;
349 };
350
351 template <class T>
353 {
354 using type = T;
355 };
356
357 template <class T>
358 struct expression_tag_and<T, xtensor_expression_tag> : expression_tag_and<xtensor_expression_tag, T>
359 {
360 };
361
362 template <>
367
368 template <class T1, class... T>
369 struct expression_tag_and<T1, T...> : expression_tag_and<T1, typename expression_tag_and<T...>::type>
370 {
371 };
372
373 template <class... T>
374 using expression_tag_and_t = typename expression_tag_and<T...>::type;
375
380 }
381
382 template <class... T>
384 {
385 using type = extension::expression_tag_and_t<
386 extension::get_expression_tag_t<std::decay_t<const_xclosure_t<T>>>...>;
387 };
388
389 template <class... T>
390 using xexpression_tag_t = typename xexpression_tag<T...>::type;
391
392 template <class E>
393 struct is_xtensor_expression : std::is_same<xexpression_tag_t<E>, xtensor_expression_tag>
394 {
395 };
396
397 template <class E>
398 struct is_xoptional_expression : std::is_same<xexpression_tag_t<E>, xoptional_expression_tag>
399 {
400 };
401
402 /********************************
403 * xoptional_comparable concept *
404 ********************************/
405
406 template <class... E>
408 : xtl::conjunction<xtl::disjunction<is_xtensor_expression<E>, is_xoptional_expression<E>>...>
409 {
410 };
411
412#define XTENSOR_FORWARD_CONST_METHOD(name) \
413 auto name() const -> decltype(std::declval<xtl::constify_t<E>>().name()) \
414 { \
415 return m_ptr->name(); \
416 }
417
418#define XTENSOR_FORWARD_METHOD(name) \
419 auto name() -> decltype(std::declval<E>().name()) \
420 { \
421 return m_ptr->name(); \
422 }
423
424#define XTENSOR_FORWARD_CONST_ITERATOR_METHOD(name) \
425 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL> \
426 auto name() const noexcept -> decltype(std::declval<xtl::constify_t<E>>().template name<L>()) \
427 { \
428 return m_ptr->template name<L>(); \
429 } \
430 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class S> \
431 auto name(const S& shape) const noexcept \
432 -> decltype(std::declval<xtl::constify_t<E>>().template name<L>(shape)) \
433 { \
434 return m_ptr->template name<L>(); \
435 }
436
437#define XTENSOR_FORWARD_ITERATOR_METHOD(name) \
438 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class S> \
439 auto name(const S& shape) noexcept -> decltype(std::declval<E>().template name<L>(shape)) \
440 { \
441 return m_ptr->template name<L>(); \
442 } \
443 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL> \
444 auto name() noexcept -> decltype(std::declval<E>().template name<L>()) \
445 { \
446 return m_ptr->template name<L>(); \
447 }
448
449 namespace detail
450 {
451 template <class E>
452 struct expr_strides_type
453 {
454 using type = typename E::strides_type;
455 };
456
457 template <class E>
458 struct expr_inner_strides_type
459 {
460 using type = typename E::inner_strides_type;
461 };
462
463 template <class E>
464 struct expr_backstrides_type
465 {
466 using type = typename E::backstrides_type;
467 };
468
469 template <class E>
470 struct expr_inner_backstrides_type
471 {
472 using type = typename E::inner_backstrides_type;
473 };
474
475 template <class E>
476 struct expr_storage_type
477 {
478 using type = typename E::storage_type;
479 };
480 }
481
507 template <class E>
508 class xshared_expression : public xexpression<xshared_expression<E>>
509 {
510 public:
511
513
514 using value_type = typename E::value_type;
515 using reference = typename E::reference;
516 using const_reference = typename E::const_reference;
517 using pointer = typename E::pointer;
518 using const_pointer = typename E::const_pointer;
519 using size_type = typename E::size_type;
520 using difference_type = typename E::difference_type;
521
522 using inner_shape_type = typename E::inner_shape_type;
523 using shape_type = typename E::shape_type;
524
525 using strides_type = xtl::mpl::
526 eval_if_t<has_strides<E>, detail::expr_strides_type<E>, get_strides_type<shape_type>>;
527 using backstrides_type = xtl::mpl::
528 eval_if_t<has_strides<E>, detail::expr_backstrides_type<E>, get_strides_type<shape_type>>;
529 using inner_strides_type = xtl::mpl::
530 eval_if_t<has_strides<E>, detail::expr_inner_strides_type<E>, get_strides_type<shape_type>>;
531 using inner_backstrides_type = xtl::mpl::
532 eval_if_t<has_strides<E>, detail::expr_inner_backstrides_type<E>, get_strides_type<shape_type>>;
533 using storage_type = xtl::mpl::eval_if_t<has_storage_type<E>, detail::expr_storage_type<E>, make_invalid_type<>>;
534
535 using stepper = typename E::stepper;
536 using const_stepper = typename E::const_stepper;
537
538 using linear_iterator = typename E::linear_iterator;
539 using const_linear_iterator = typename E::const_linear_iterator;
540
541 using bool_load_type = typename E::bool_load_type;
542
543 static constexpr layout_type static_layout = E::static_layout;
544 static constexpr bool contiguous_layout = static_layout != layout_type::dynamic;
545
546 explicit xshared_expression(const std::shared_ptr<E>& ptr);
547 long use_count() const noexcept;
548
549 template <class... Args>
550 auto operator()(Args... args) -> decltype(std::declval<E>()(args...))
551 {
552 return m_ptr->operator()(args...);
553 }
554
555 XTENSOR_FORWARD_CONST_METHOD(shape)
556 XTENSOR_FORWARD_CONST_METHOD(dimension)
557 XTENSOR_FORWARD_CONST_METHOD(size)
558 XTENSOR_FORWARD_CONST_METHOD(layout)
559 XTENSOR_FORWARD_CONST_METHOD(is_contiguous)
560
561 XTENSOR_FORWARD_ITERATOR_METHOD(begin)
562 XTENSOR_FORWARD_ITERATOR_METHOD(end)
563 XTENSOR_FORWARD_CONST_ITERATOR_METHOD(begin)
564 XTENSOR_FORWARD_CONST_ITERATOR_METHOD(end)
565 XTENSOR_FORWARD_CONST_ITERATOR_METHOD(cbegin)
566 XTENSOR_FORWARD_CONST_ITERATOR_METHOD(cend)
567
568 XTENSOR_FORWARD_ITERATOR_METHOD(rbegin)
569 XTENSOR_FORWARD_ITERATOR_METHOD(rend)
570 XTENSOR_FORWARD_CONST_ITERATOR_METHOD(rbegin)
571 XTENSOR_FORWARD_CONST_ITERATOR_METHOD(rend)
572 XTENSOR_FORWARD_CONST_ITERATOR_METHOD(crbegin)
573 XTENSOR_FORWARD_CONST_ITERATOR_METHOD(crend)
574
575 XTENSOR_FORWARD_METHOD(linear_begin)
576 XTENSOR_FORWARD_METHOD(linear_end)
577 XTENSOR_FORWARD_CONST_METHOD(linear_begin)
578 XTENSOR_FORWARD_CONST_METHOD(linear_end)
579 XTENSOR_FORWARD_CONST_METHOD(linear_cbegin)
580 XTENSOR_FORWARD_CONST_METHOD(linear_cend)
581
582 XTENSOR_FORWARD_METHOD(linear_rbegin)
583 XTENSOR_FORWARD_METHOD(linear_rend)
584 XTENSOR_FORWARD_CONST_METHOD(linear_rbegin)
585 XTENSOR_FORWARD_CONST_METHOD(linear_rend)
586 XTENSOR_FORWARD_CONST_METHOD(linear_crbegin)
587 XTENSOR_FORWARD_CONST_METHOD(linear_crend)
588
589 template <class T = E>
590 std::enable_if_t<has_strides<T>::value, const inner_strides_type&> strides() const
591 {
592 return m_ptr->strides();
593 }
594
595 template <class T = E>
596 std::enable_if_t<has_strides<T>::value, const inner_strides_type&> backstrides() const
597 {
598 return m_ptr->backstrides();
599 }
600
601 template <class T = E>
602 std::enable_if_t<has_data_interface<T>::value, pointer> data() noexcept
603 {
604 return m_ptr->data();
605 }
606
607 template <class T = E>
608 std::enable_if_t<has_data_interface<T>::value, pointer> data() const noexcept
609 {
610 return m_ptr->data();
611 }
612
613 template <class T = E>
614 std::enable_if_t<has_data_interface<T>::value, size_type> data_offset() const noexcept
615 {
616 return m_ptr->data_offset();
617 }
618
619 template <class T = E>
620 std::enable_if_t<has_data_interface<T>::value, typename T::storage_type&> storage() noexcept
621 {
622 return m_ptr->storage();
623 }
624
625 template <class T = E>
626 std::enable_if_t<has_data_interface<T>::value, const typename T::storage_type&> storage() const noexcept
627 {
628 return m_ptr->storage();
629 }
630
631 template <class It>
632 reference element(It first, It last)
633 {
634 return m_ptr->element(first, last);
635 }
636
637 template <class It>
638 const_reference element(It first, It last) const
639 {
640 return m_ptr->element(first, last);
641 }
642
643 template <class S>
644 bool broadcast_shape(S& shape, bool reuse_cache = false) const
645 {
646 return m_ptr->broadcast_shape(shape, reuse_cache);
647 }
648
649 template <class S>
650 bool has_linear_assign(const S& strides) const noexcept
651 {
652 return m_ptr->has_linear_assign(strides);
653 }
654
655 template <class S>
656 auto stepper_begin(const S& shape) noexcept -> decltype(std::declval<E>().stepper_begin(shape))
657 {
658 return m_ptr->stepper_begin(shape);
659 }
660
661 template <class S>
662 auto stepper_end(const S& shape, layout_type l) noexcept
663 -> decltype(std::declval<E>().stepper_end(shape, l))
664 {
665 return m_ptr->stepper_end(shape, l);
666 }
667
668 template <class S>
669 auto stepper_begin(const S& shape) const noexcept
670 -> decltype(std::declval<const E>().stepper_begin(shape))
671 {
672 return static_cast<const E*>(m_ptr.get())->stepper_begin(shape);
673 }
674
675 template <class S>
676 auto stepper_end(const S& shape, layout_type l) const noexcept
677 -> decltype(std::declval<const E>().stepper_end(shape, l))
678 {
679 return static_cast<const E*>(m_ptr.get())->stepper_end(shape, l);
680 }
681
682 private:
683
684 std::shared_ptr<E> m_ptr;
685 };
686
694 template <class E>
695 inline xshared_expression<E>::xshared_expression(const std::shared_ptr<E>& ptr)
696 : m_ptr(ptr)
697 {
698 }
699
704 template <class E>
706 {
707 return m_ptr.use_count();
708 }
709
710 namespace detail
711 {
712 template <class E>
713 inline xshared_expression<E> make_xshared_impl(xsharable_expression<E>&& expr)
714 {
715 if (expr.p_shared == nullptr)
716 {
717 expr.p_shared = std::make_shared<E>(std::move(expr).derived_cast());
718 }
719 return xshared_expression<E>(expr.p_shared);
720 }
721 }
722
729 template <class E>
731 {
732 static_assert(
734 "make_shared requires E to inherit from xsharable_expression"
735 );
736 return detail::make_xshared_impl(std::move(expr.derived_cast()));
737 }
738
746 template <class E>
748 {
749 return make_xshared(std::move(expr));
750 }
751
759 template <class E>
760 inline auto share(xexpression<E>&& expr)
761 {
762 return make_xshared(std::move(expr));
763 }
764
765#undef XTENSOR_FORWARD_METHOD
766
767}
768
769#endif
Base class for xexpressions.
derived_type & derived_cast() &noexcept
Returns a reference to the actual derived type of the xexpression.
const derived_type & derived_cast() const &noexcept
Returns a constant reference to the actual derived type of the xexpression.
Shared xexpressions.
xshared_expression(const std::shared_ptr< E > &ptr)
Constructor for xshared expression (note: usually the free function make_xshared is recommended).
long use_count() const noexcept
Return the number of times this expression is referenced.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
Definition xstrides.hpp:248
standard mathematical functions for xexpressions
auto share(xexpression< E > &expr)
Helper function to create shared expression from any xexpression.
layout_type
Definition xlayout.hpp:24
xshared_expression< E > make_xshared(xexpression< E > &&expr)
Helper function to create shared expression from any xexpression.