xtensor
 
Loading...
Searching...
No Matches
xexpression.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_EXPRESSION_HPP
11#define XTENSOR_EXPRESSION_HPP
12
13#include <type_traits>
14
15#include <xtl/xclosure.hpp>
16#include <xtl/xmeta_utils.hpp>
17#include <xtl/xtype_traits.hpp>
18
19#include "../core/xlayout.hpp"
20#include "../core/xshape.hpp"
21#include "../core/xtensor_forward.hpp"
22#include "../utils/xutils.hpp"
23
24namespace xt
25{
26
27 /***************************
28 * xexpression declaration *
29 ***************************/
30
43 template <class D>
44 class xexpression
45 {
46 public:
47
48 using derived_type = D;
49
50 derived_type& derived_cast() & noexcept;
51 const derived_type& derived_cast() const& noexcept;
52 derived_type derived_cast() && noexcept;
53
54 protected:
55
56 xexpression() = default;
57 ~xexpression() = default;
58
59 xexpression(const xexpression&) = default;
60 xexpression& operator=(const xexpression&) = default;
61
62 xexpression(xexpression&&) = default;
63 xexpression& operator=(xexpression&&) = default;
64 };
65
66 /************************************
67 * xsharable_expression declaration *
68 ************************************/
69
70 template <class E>
72
73 template <class E>
75
76 namespace detail
77 {
78 template <class E>
80 }
81
82 template <class D>
83 class xsharable_expression : public xexpression<D>
84 {
85 protected:
86
87 xsharable_expression();
88 ~xsharable_expression() = default;
89
90 xsharable_expression(const xsharable_expression&) = default;
91 xsharable_expression& operator=(const xsharable_expression&) = default;
92
93 xsharable_expression(xsharable_expression&&) = default;
94 xsharable_expression& operator=(xsharable_expression&&) = default;
95
96 private:
97
98 std::shared_ptr<D> p_shared;
99
100 friend xshared_expression<D> detail::make_xshared_impl<D>(xsharable_expression<D>&&);
101 };
102
103 /******************************
104 * xexpression implementation *
105 ******************************/
106
111
114 template <class D>
115 inline auto xexpression<D>::derived_cast() & noexcept -> derived_type&
116 {
117 return *static_cast<derived_type*>(this);
118 }
119
123 template <class D>
124 inline auto xexpression<D>::derived_cast() const& noexcept -> const derived_type&
125 {
126 return *static_cast<const derived_type*>(this);
127 }
128
132 template <class D>
133 inline auto xexpression<D>::derived_cast() && noexcept -> derived_type
134 {
135 return *static_cast<derived_type*>(this);
136 }
137
139
140 /***************************************
141 * xsharable_expression implementation *
142 ***************************************/
143
144 template <class D>
145 inline xsharable_expression<D>::xsharable_expression()
146 : p_shared(nullptr)
147 {
148 }
149
159
160 namespace detail
161 {
162 template <template <class> class B, class E>
163 struct is_crtp_base_of_impl : std::is_base_of<B<E>, E>
164 {
165 };
166
167 template <template <class> class B, class E, template <class> class F>
168 struct is_crtp_base_of_impl<B, F<E>>
169 : std::disjunction<std::is_base_of<B<E>, F<E>>, std::is_base_of<B<F<E>>, F<E>>>
170 {
171 };
172 }
173
174 template <template <class> class B, class E>
175 using is_crtp_base_of = detail::is_crtp_base_of_impl<B, std::decay_t<E>>;
176
177 template <class E>
178 using is_xexpression = is_crtp_base_of<xexpression, E>;
179
180 template <class E>
181 concept xexpression_concept = is_xexpression<E>::value;
182
183 template <class E, class R = void>
184 using enable_xexpression = typename std::enable_if<is_xexpression<E>::value, R>::type;
185
186 template <class E, class R = void>
187 using disable_xexpression = typename std::enable_if<!is_xexpression<E>::value, R>::type;
188
189 template <class... E>
190 using has_xexpression = std::disjunction<is_xexpression<E>...>;
191
192 template <class E>
193 using is_xsharable_expression = is_crtp_base_of<xsharable_expression, E>;
194
195 template <class E, class R = void>
196 using enable_xsharable_expression = typename std::enable_if<is_xsharable_expression<E>::value, R>::type;
197
198 template <class E, class R = void>
199 using disable_xsharable_expression = typename std::enable_if<!is_xsharable_expression<E>::value, R>::type;
200
201 template <class LHS, class RHS>
202 struct can_assign : std::is_assignable<LHS, RHS>
203 {
204 };
205
206 template <class LHS, class RHS, class R = void>
207 using enable_assignable_expression = typename std::enable_if<can_assign<LHS, RHS>::value, R>::type;
208
209 template <class LHS, class RHS, class R = void>
210 using enable_not_assignable_expression = typename std::enable_if<!can_assign<LHS, RHS>::value, R>::type;
211
212 /***********************
213 * evaluation_strategy *
214 ***********************/
215
216 namespace detail
217 {
218 struct option_base
219 {
220 };
221 }
222
223 namespace evaluation_strategy
224 {
225
226 struct immediate_type : xt::detail::option_base
227 {
228 };
229
230 constexpr auto immediate = std::tuple<immediate_type>{};
231
232 struct lazy_type : xt::detail::option_base
233 {
234 };
235
236 constexpr auto lazy = std::tuple<lazy_type>{};
237
238 /*
239 struct cached {};
240 */
241 }
242
243 template <class T>
244 struct is_evaluation_strategy : std::is_base_of<detail::option_base, std::decay_t<T>>
245 {
246 };
247
248 /************
249 * xclosure *
250 ************/
251
252 template <class T>
253 class xscalar;
254
255 template <class E, class EN = void>
256 struct xclosure
257 {
258 using type = xtl::closure_type_t<E>;
259 };
260
261 template <class E>
262 struct xclosure<xshared_expression<E>, std::enable_if_t<true>>
263 {
264 using type = xshared_expression<E>; // force copy
265 };
266
267 template <class E>
268 struct xclosure<E, disable_xexpression<std::decay_t<E>>>
269 {
270 using type = xscalar<xtl::closure_type_t<E>>;
271 };
272
273 template <class E>
274 using xclosure_t = typename xclosure<E>::type;
275
276 template <class E, class EN = void>
278 {
279 using type = xtl::const_closure_type_t<E>;
280 };
281
282 template <class E>
283 struct const_xclosure<E, disable_xexpression<std::decay_t<E>>>
284 {
286 };
287
288 template <class E>
289 struct const_xclosure<xshared_expression<E>&, std::enable_if_t<true>>
290 {
291 using type = xshared_expression<E>; // force copy
292 };
293
294 template <class E>
295 using const_xclosure_t = typename const_xclosure<E>::type;
296
297 /*************************
298 * expression tag system *
299 *************************/
300
302 {
303 };
304
306 {
307 };
308
309 namespace extension
310 {
311 template <class E, class = void_t<int>>
313 {
314 using type = xtensor_expression_tag;
315 };
316
317 template <class E>
318 struct get_expression_tag_impl<E, void_t<typename std::decay_t<E>::expression_tag>>
319 {
320 using type = typename std::decay_t<E>::expression_tag;
321 };
322
323 template <class E>
327
328 template <class E>
329 using get_expression_tag_t = typename get_expression_tag<E>::type;
330
331 template <class... T>
333
334 template <>
336 {
337 using type = xtensor_expression_tag;
338 };
339
340 template <class T>
342 {
343 using type = T;
344 };
345
346 template <class T>
348 {
349 using type = T;
350 };
351
352 template <class T>
354 {
355 using type = T;
356 };
357
358 template <class T>
359 struct expression_tag_and<T, xtensor_expression_tag> : expression_tag_and<xtensor_expression_tag, T>
360 {
361 };
362
363 template <>
368
369 template <class T1, class... T>
370 struct expression_tag_and<T1, T...> : expression_tag_and<T1, typename expression_tag_and<T...>::type>
371 {
372 };
373
374 template <class... T>
375 using expression_tag_and_t = typename expression_tag_and<T...>::type;
376
378 {
379 using expression_tag = xtensor_expression_tag;
380 };
381 }
382
383 template <class... T>
385 {
386 using type = extension::expression_tag_and_t<
387 extension::get_expression_tag_t<std::decay_t<const_xclosure_t<T>>>...>;
388 };
389
390 template <class... T>
391 using xexpression_tag_t = typename xexpression_tag<T...>::type;
392
393 template <class E>
394 struct is_xtensor_expression : std::is_same<xexpression_tag_t<E>, xtensor_expression_tag>
395 {
396 };
397
398 template <class E>
399 struct is_xoptional_expression : std::is_same<xexpression_tag_t<E>, xoptional_expression_tag>
400 {
401 };
402
403 /********************************
404 * xoptional_comparable concept *
405 ********************************/
406
407 template <class... E>
409 : std::conjunction<std::disjunction<is_xtensor_expression<E>, is_xoptional_expression<E>>...>
410 {
411 };
412
413#define XTENSOR_FORWARD_CONST_METHOD(name) \
414 auto name() const -> decltype(std::declval<xtl::constify_t<E>>().name()) \
415 { \
416 return m_ptr->name(); \
417 }
418
419#define XTENSOR_FORWARD_METHOD(name) \
420 auto name() -> decltype(std::declval<E>().name()) \
421 { \
422 return m_ptr->name(); \
423 }
424
425#define XTENSOR_FORWARD_CONST_ITERATOR_METHOD(name) \
426 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL> \
427 auto name() const noexcept -> decltype(std::declval<xtl::constify_t<E>>().template name<L>()) \
428 { \
429 return m_ptr->template name<L>(); \
430 } \
431 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class S> \
432 auto name(const S& shape) const noexcept \
433 -> decltype(std::declval<xtl::constify_t<E>>().template name<L>(shape)) \
434 { \
435 return m_ptr->template name<L>(); \
436 }
437
438#define XTENSOR_FORWARD_ITERATOR_METHOD(name) \
439 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class S> \
440 auto name(const S& shape) noexcept -> decltype(std::declval<E>().template name<L>(shape)) \
441 { \
442 return m_ptr->template name<L>(); \
443 } \
444 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL> \
445 auto name() noexcept -> decltype(std::declval<E>().template name<L>()) \
446 { \
447 return m_ptr->template name<L>(); \
448 }
449
450 namespace detail
451 {
452 template <class E>
453 struct expr_strides_type
454 {
455 using type = typename E::strides_type;
456 };
457
458 template <class E>
459 struct expr_inner_strides_type
460 {
461 using type = typename E::inner_strides_type;
462 };
463
464 template <class E>
465 struct expr_backstrides_type
466 {
467 using type = typename E::backstrides_type;
468 };
469
470 template <class E>
471 struct expr_inner_backstrides_type
472 {
473 using type = typename E::inner_backstrides_type;
474 };
475
476 template <class E>
477 struct expr_storage_type
478 {
479 using type = typename E::storage_type;
480 };
481 }
482
508 template <class E>
509 class xshared_expression : public xexpression<xshared_expression<E>>
510 {
511 public:
512
513 using base_class = xexpression<xshared_expression<E>>;
514
515 using value_type = typename E::value_type;
516 using reference = typename E::reference;
517 using const_reference = typename E::const_reference;
518 using pointer = typename E::pointer;
519 using const_pointer = typename E::const_pointer;
520 using size_type = typename E::size_type;
521 using difference_type = typename E::difference_type;
522
523 using inner_shape_type = typename E::inner_shape_type;
524 using shape_type = typename E::shape_type;
525
526 using strides_type = xtl::mpl::
527 eval_if_t<has_strides<E>, detail::expr_strides_type<E>, get_strides_type<shape_type>>;
528 using backstrides_type = xtl::mpl::
529 eval_if_t<has_strides<E>, detail::expr_backstrides_type<E>, get_strides_type<shape_type>>;
530 using inner_strides_type = xtl::mpl::
531 eval_if_t<has_strides<E>, detail::expr_inner_strides_type<E>, get_strides_type<shape_type>>;
532 using inner_backstrides_type = xtl::mpl::
533 eval_if_t<has_strides<E>, detail::expr_inner_backstrides_type<E>, get_strides_type<shape_type>>;
534 using storage_type = xtl::mpl::eval_if_t<has_storage_type<E>, detail::expr_storage_type<E>, make_invalid_type<>>;
535
536 using stepper = typename E::stepper;
537 using const_stepper = typename E::const_stepper;
538
539 using linear_iterator = typename E::linear_iterator;
540 using const_linear_iterator = typename E::const_linear_iterator;
541
542 using bool_load_type = typename E::bool_load_type;
543
544 static constexpr layout_type static_layout = E::static_layout;
545 static constexpr bool contiguous_layout = static_layout != layout_type::dynamic;
546
547 explicit xshared_expression(const std::shared_ptr<E>& ptr);
548 long use_count() const noexcept;
549
550 template <class... Args>
551 auto operator()(Args... args) -> decltype(std::declval<E>()(args...))
552 {
553 return m_ptr->operator()(args...);
554 }
555
556 XTENSOR_FORWARD_CONST_METHOD(shape)
557 XTENSOR_FORWARD_CONST_METHOD(dimension)
558 XTENSOR_FORWARD_CONST_METHOD(size)
559 XTENSOR_FORWARD_CONST_METHOD(layout)
560 XTENSOR_FORWARD_CONST_METHOD(is_contiguous)
561
562 XTENSOR_FORWARD_ITERATOR_METHOD(begin)
563 XTENSOR_FORWARD_ITERATOR_METHOD(end)
564 XTENSOR_FORWARD_CONST_ITERATOR_METHOD(begin)
565 XTENSOR_FORWARD_CONST_ITERATOR_METHOD(end)
566 XTENSOR_FORWARD_CONST_ITERATOR_METHOD(cbegin)
567 XTENSOR_FORWARD_CONST_ITERATOR_METHOD(cend)
568
569 XTENSOR_FORWARD_ITERATOR_METHOD(rbegin)
570 XTENSOR_FORWARD_ITERATOR_METHOD(rend)
571 XTENSOR_FORWARD_CONST_ITERATOR_METHOD(rbegin)
572 XTENSOR_FORWARD_CONST_ITERATOR_METHOD(rend)
573 XTENSOR_FORWARD_CONST_ITERATOR_METHOD(crbegin)
574 XTENSOR_FORWARD_CONST_ITERATOR_METHOD(crend)
575
576 XTENSOR_FORWARD_METHOD(linear_begin)
577 XTENSOR_FORWARD_METHOD(linear_end)
578 XTENSOR_FORWARD_CONST_METHOD(linear_begin)
579 XTENSOR_FORWARD_CONST_METHOD(linear_end)
580 XTENSOR_FORWARD_CONST_METHOD(linear_cbegin)
581 XTENSOR_FORWARD_CONST_METHOD(linear_cend)
582
583 XTENSOR_FORWARD_METHOD(linear_rbegin)
584 XTENSOR_FORWARD_METHOD(linear_rend)
585 XTENSOR_FORWARD_CONST_METHOD(linear_rbegin)
586 XTENSOR_FORWARD_CONST_METHOD(linear_rend)
587 XTENSOR_FORWARD_CONST_METHOD(linear_crbegin)
588 XTENSOR_FORWARD_CONST_METHOD(linear_crend)
589
590 template <class T = E>
591 std::enable_if_t<has_strides<T>::value, const inner_strides_type&> strides() const
592 {
593 return m_ptr->strides();
594 }
595
596 template <class T = E>
597 std::enable_if_t<has_strides<T>::value, const inner_strides_type&> backstrides() const
598 {
599 return m_ptr->backstrides();
600 }
601
602 template <class T = E>
603 std::enable_if_t<has_data_interface<T>::value, pointer> data() noexcept
604 {
605 return m_ptr->data();
606 }
607
608 template <class T = E>
609 std::enable_if_t<has_data_interface<T>::value, pointer> data() const noexcept
610 {
611 return m_ptr->data();
612 }
613
614 template <class T = E>
615 std::enable_if_t<has_data_interface<T>::value, size_type> data_offset() const noexcept
616 {
617 return m_ptr->data_offset();
618 }
619
620 template <class T = E>
621 std::enable_if_t<has_data_interface<T>::value, typename T::storage_type&> storage() noexcept
622 {
623 return m_ptr->storage();
624 }
625
626 template <class T = E>
627 std::enable_if_t<has_data_interface<T>::value, const typename T::storage_type&> storage() const noexcept
628 {
629 return m_ptr->storage();
630 }
631
632 template <class It>
633 reference element(It first, It last)
634 {
635 return m_ptr->element(first, last);
636 }
637
638 template <class It>
639 const_reference element(It first, It last) const
640 {
641 return m_ptr->element(first, last);
642 }
643
644 template <class S>
645 bool broadcast_shape(S& shape, bool reuse_cache = false) const
646 {
647 return m_ptr->broadcast_shape(shape, reuse_cache);
648 }
649
650 template <class S>
651 bool has_linear_assign(const S& strides) const noexcept
652 {
653 return m_ptr->has_linear_assign(strides);
654 }
655
656 template <class S>
657 auto stepper_begin(const S& shape) noexcept -> decltype(std::declval<E>().stepper_begin(shape))
658 {
659 return m_ptr->stepper_begin(shape);
660 }
661
662 template <class S>
663 auto stepper_end(const S& shape, layout_type l) noexcept
664 -> decltype(std::declval<E>().stepper_end(shape, l))
665 {
666 return m_ptr->stepper_end(shape, l);
667 }
668
669 template <class S>
670 auto stepper_begin(const S& shape) const noexcept
671 -> decltype(std::declval<const E>().stepper_begin(shape))
672 {
673 return static_cast<const E*>(m_ptr.get())->stepper_begin(shape);
674 }
675
676 template <class S>
677 auto stepper_end(const S& shape, layout_type l) const noexcept
678 -> decltype(std::declval<const E>().stepper_end(shape, l))
679 {
680 return static_cast<const E*>(m_ptr.get())->stepper_end(shape, l);
681 }
682
683 private:
684
685 std::shared_ptr<E> m_ptr;
686 };
687
695 template <class E>
696 inline xshared_expression<E>::xshared_expression(const std::shared_ptr<E>& ptr)
697 : m_ptr(ptr)
698 {
699 }
700
705 template <class E>
706 inline long xshared_expression<E>::use_count() const noexcept
707 {
708 return m_ptr.use_count();
709 }
710
711 namespace detail
712 {
713 template <class E>
714 inline xshared_expression<E> make_xshared_impl(xsharable_expression<E>&& expr)
715 {
716 if (expr.p_shared == nullptr)
717 {
718 expr.p_shared = std::make_shared<E>(std::move(expr).derived_cast());
719 }
720 return xshared_expression<E>(expr.p_shared);
721 }
722 }
723
730 template <class E>
732 {
733 static_assert(
734 is_xsharable_expression<E>::value,
735 "make_shared requires E to inherit from xsharable_expression"
736 );
737 return detail::make_xshared_impl(std::move(expr.derived_cast()));
738 }
739
747 template <class E>
748 inline auto share(xexpression<E>& expr)
749 {
750 return make_xshared(std::move(expr));
751 }
752
760 template <class E>
761 inline auto share(xexpression<E>&& expr)
762 {
763 return make_xshared(std::move(expr));
764 }
765
766#undef XTENSOR_FORWARD_METHOD
767
768}
769
770#endif
Base class for xexpressions.
derived_type & derived_cast() &noexcept
Returns a reference to the actual derived type of the xexpression.
const derived_type & derived_cast() const &noexcept
Returns a constant reference to the actual derived type of the xexpression.
Shared xexpressions.
xshared_expression(const std::shared_ptr< E > &ptr)
Constructor for xshared expression (note: usually the free function make_xshared is recommended).
long use_count() const noexcept
Return the number of times this expression is referenced.
standard mathematical functions for xexpressions
auto share(xexpression< E > &expr)
Helper function to create shared expression from any xexpression.
layout_type
Definition xlayout.hpp:24
xshared_expression< E > make_xshared(xexpression< E > &&expr)
Helper function to create shared expression from any xexpression.