xtensor
 
Loading...
Searching...
No Matches
index_mapper.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_INDEX_MAPPER_HPP
11#define XTENSOR_INDEX_MAPPER_HPP
12
13#include <cassert>
14
15#include "xview.hpp"
16
17namespace xt
18{
19
20 template <class UndefinedView>
22
27 enum class access_t
28 {
31 };
32
57 template <class UnderlyingContainer, class... Slices>
58 class index_mapper<xt::xview<UnderlyingContainer, Slices...>>
59 {
60 public:
61
63 using view_type = xt::xview<UnderlyingContainer, Slices...>;
64
66 using reference = typename xt::xview<UnderlyingContainer, Slices...>::reference;
67
69 using const_reference = typename xt::xview<UnderlyingContainer, Slices...>::const_reference;
70
72 static constexpr size_t n_slices = sizeof...(Slices);
73
75 static constexpr size_t nb_integral_slices = (std::is_integral_v<Slices> + ...);
76
78 static constexpr size_t nb_new_axis_slices = (xt::detail::is_newaxis_v<Slices> + ...);
79
84 template <std::integral... Indices>
85 static constexpr size_t n_indices_full_v = size_t(sizeof...(Indices) + nb_integral_slices);
86
94 template <std::integral... Indices>
95 reference map(UnderlyingContainer& container, const view_type& view, const Indices... indices) const;
96
104 template <std::integral... Indices>
106 cmap(const UnderlyingContainer& container, const view_type& view, const Indices... indices) const;
107
115 template <std::integral... Indices>
116 reference map_at(UnderlyingContainer& container, const view_type& view, const Indices... indices) const;
117
125 template <std::integral... Indices>
127 cmap_at(const UnderlyingContainer& container, const view_type& view, const Indices... indices) const;
128
130 size_t dimension(const UnderlyingContainer& container) const;
131
132 private:
133
135 template <bool IS_CONST>
136 using conditional_reference = std::conditional_t<IS_CONST, const_reference, reference>;
137
139 template <size_t I>
140 using slice_type = std::tuple_element_t<I, std::tuple<Slices...>>;
141
143 template <size_t I>
144 static consteval bool is_slice_integral();
145
147 template <size_t I>
148 static consteval bool is_slice_new_axis();
149
157 template <size_t first, size_t bound, size_t... indices>
158 struct indices_sequence_helper
159 {
160 // we add the current axis
161 using not_new_axis_type = typename indices_sequence_helper<first + 1, bound, indices..., first>::type;
162
163 // we skip the current axis
164 using new_axis_type = typename indices_sequence_helper<first + 1, bound, indices...>::type;
165
166 // NOTE: is_slice_new_axis works even if first >= sizeof...(Slices)
167 using type = std::conditional_t<is_slice_new_axis<first>(), new_axis_type, not_new_axis_type>;
168 };
169
171 template <size_t bound, size_t... indices>
172 struct indices_sequence_helper<bound, bound, indices...>
173 {
174 using type = std::index_sequence<indices...>;
175 };
176
178 template <size_t bound>
179 using indices_sequence = indices_sequence_helper<0, bound>::type;
180
196 template <size_t I, std::integral Index>
197 size_t map_ith_index(const view_type& view, const Index i) const;
198
212 template <bool IS_CONST, access_t ACCESS, std::integral FirstIndice, std::integral... OtherIndices>
213 conditional_reference<IS_CONST> map_main(
214 std::bool_constant<IS_CONST> /* is_const */,
215 std::conditional_t<IS_CONST, const UnderlyingContainer&, UnderlyingContainer&> container,
216 std::integral_constant<access_t, ACCESS> /* access */,
217 const view_type& view,
218 const FirstIndice firstIndice,
219 const OtherIndices... otherIndices
220 ) const;
221
232 template <bool IS_CONST, access_t ACCESS>
233 conditional_reference<IS_CONST> map_main(
234 std::bool_constant<IS_CONST> /* is_const */,
235 std::conditional_t<IS_CONST, const UnderlyingContainer&, UnderlyingContainer&> container,
236 std::integral_constant<access_t, ACCESS> /* access */,
237 const view_type& view
238 ) const;
239
254 template <bool IS_CONST, access_t ACCESS, size_t n_indices, size_t... Is>
255 conditional_reference<IS_CONST> map_all_indices(
256 std::bool_constant<IS_CONST> /* is_const */,
257 std::conditional_t<IS_CONST, const UnderlyingContainer&, UnderlyingContainer&> container,
258 std::integral_constant<access_t, ACCESS> /* access */,
259 const view_type& view,
260 std::index_sequence<Is...> /* is_seq */,
261 const std::array<size_t, n_indices>& indices
262 ) const;
263
265 template <std::integral... Indices>
266 std::array<size_t, n_indices_full_v<Indices...>> get_indices_full(const Indices... indices) const;
267 };
268
269 /*******************************
270 * index_mapper implementation *
271 *******************************/
272
273 template <class UnderlyingContainer, class... Slices>
274 template <size_t I>
275 consteval bool index_mapper<xt::xview<UnderlyingContainer, Slices...>>::is_slice_integral()
276 {
277 if constexpr (I < sizeof...(Slices))
278 {
279 return std::is_integral_v<slice_type<I>>;
280 }
281 else
282 {
283 return false;
284 }
285 }
286
287 template <class UnderlyingContainer, class... Slices>
288 template <size_t I>
289 consteval bool index_mapper<xt::xview<UnderlyingContainer, Slices...>>::is_slice_new_axis()
290 {
291 if constexpr (I < sizeof...(Slices))
292 {
293 return xt::detail::is_newaxis_v<slice_type<I>>;
294 }
295 else
296 {
297 return false;
298 }
299 }
300
301 template <class UnderlyingContainer, class... Slices>
302 template <std::integral... Indices>
303 auto
304 index_mapper<xt::xview<UnderlyingContainer, Slices...>>::get_indices_full(const Indices... indices) const
305 -> std::array<size_t, n_indices_full_v<Indices...>>
306 {
307 constexpr size_t n_indices_full = n_indices_full_v<Indices...>;
308
309 std::array<size_t, sizeof...(indices)> args{size_t(indices)...};
310 std::array<size_t, n_indices_full> args_full;
311
312 const auto fill_args_full = [&args_full, &args]<size_t... Is>(std::index_sequence<Is...>)
313 {
314 auto it = std::cbegin(args);
315
316 ((args_full[Is] = (is_slice_integral<Is>()) ? size_t(0) : *it++), ...);
317 };
318
319 fill_args_full(std::make_index_sequence<n_indices_full>{});
320
321 return args_full;
322 }
323
324 template <class UnderlyingContainer, class... Slices>
325 template <std::integral... Indices>
326 auto index_mapper<xt::xview<UnderlyingContainer, Slices...>>::map(
327 UnderlyingContainer& container,
328 const view_type& view,
329 const Indices... indices
330 ) const -> reference
331 {
332 return map_main(
333 std::false_type{},
334 container,
335 std::integral_constant<access_t, access_t::UNSAFE>{},
336 view,
337 indices...
338 );
339 }
340
341 template <class UnderlyingContainer, class... Slices>
342 template <std::integral... Indices>
343 auto index_mapper<xt::xview<UnderlyingContainer, Slices...>>::cmap(
344 const UnderlyingContainer& container,
345 const view_type& view,
346 const Indices... indices
347 ) const -> const_reference
348 {
349 return map_main(
350 std::true_type{},
351 container,
352 std::integral_constant<access_t, access_t::UNSAFE>{},
353 view,
354 indices...
355 );
356 }
357
358 template <class UnderlyingContainer, class... Slices>
359 template <std::integral... Indices>
360 auto index_mapper<xt::xview<UnderlyingContainer, Slices...>>::map_at(
361 UnderlyingContainer& container,
362 const view_type& view,
363 const Indices... indices
364 ) const -> reference
365 {
366 return map_main(
367 std::false_type{},
368 container,
369 std::integral_constant<access_t, access_t::SAFE>{},
370 view,
371 indices...
372 );
373 }
374
375 template <class UnderlyingContainer, class... Slices>
376 template <std::integral... Indices>
377 auto index_mapper<xt::xview<UnderlyingContainer, Slices...>>::cmap_at(
378 const UnderlyingContainer& container,
379 const view_type& view,
380 const Indices... indices
381 ) const -> const_reference
382 {
383 return map_main(
384 std::true_type{},
385 container,
386 std::integral_constant<access_t, access_t::SAFE>{},
387 view,
388 indices...
389 );
390 }
391
392 template <class UnderlyingContainer, class... Slices>
393 template <bool IS_CONST, access_t ACCESS, std::integral FirstIndice, std::integral... OtherIndices>
394 auto index_mapper<xt::xview<UnderlyingContainer, Slices...>>::map_main(
395 std::bool_constant<IS_CONST> is_const,
396 std::conditional_t<IS_CONST, const UnderlyingContainer&, UnderlyingContainer&> container,
397 std::integral_constant<access_t, ACCESS> access,
398 const view_type& view,
399 const FirstIndice firstIndice,
400 const OtherIndices... otherIndices
401 ) const -> conditional_reference<IS_CONST>
402 {
403 constexpr size_t n_indices_full = n_indices_full_v<FirstIndice, OtherIndices...>;
404
405 constexpr size_t underlying_n_dimensions = static_cast<size_t>(
406 xt::static_dimension<typename std::decay_t<UnderlyingContainer>::shape_type>::value
407 );
408
409 // If there is too many indices, we need to drop the first ones.
410 // If the number of dimensions of the underlying container is known at compile time we can drop them
411 // at compile time Else a runtime-test is requires, which, breaks vectorization.
412 // I don't know if we can do it in another way.
413
414 if constexpr (underlying_n_dimensions != size_t(-1))
415 {
416 // the number of dimensions of the underlying container is known at compile time.
417 constexpr size_t n_dimensions = underlying_n_dimensions - nb_integral_slices + nb_new_axis_slices;
418
419 // we can perform compile time checks
420 if constexpr (1 + sizeof...(OtherIndices) > n_dimensions)
421 {
422 return map_main(is_const, container, access, view, otherIndices...);
423 }
424 else
425 {
426 return map_all_indices(
427 is_const,
428 container,
429 access,
430 view,
431 indices_sequence<n_indices_full>{},
432 get_indices_full(firstIndice, otherIndices...)
433 );
434 }
435 }
436 else
437 {
438 // we need execution time checks
439 if (1 + sizeof...(OtherIndices) > dimension(container))
440 {
441 return map_main(is_const, container, access, view, otherIndices...);
442 }
443 else
444 {
445 return map_all_indices(
446 is_const,
447 container,
448 access,
449 view,
450 indices_sequence<n_indices_full>{},
451 get_indices_full(firstIndice, otherIndices...)
452 );
453 }
454 }
455 }
456
457 template <class UnderlyingContainer, class... Slices>
458 template <bool IS_CONST, access_t ACCESS>
459 auto index_mapper<xt::xview<UnderlyingContainer, Slices...>>::map_main(
460 std::bool_constant<IS_CONST> is_const,
461 std::conditional_t<IS_CONST, const UnderlyingContainer&, UnderlyingContainer&> container,
462 std::integral_constant<access_t, ACCESS> access,
463 const view_type& view
464 ) const -> conditional_reference<IS_CONST>
465 {
466 // Work around compilers failing to deduce nb_integral_slices as a non-type template argument inline
467 // (error: use of variable template 'n_indices_full_v' requires template arguments)
468 constexpr size_t n_indices_full = nb_integral_slices;
469
470 return map_all_indices(
471 is_const,
472 container,
473 access,
474 view,
475 indices_sequence<n_indices_full>{},
476 get_indices_full()
477 );
478 }
479
480 template <class UnderlyingContainer, class... Slices>
481 template <bool IS_CONST, access_t ACCESS, size_t n_indices, size_t... Is>
482 auto index_mapper<xt::xview<UnderlyingContainer, Slices...>>::map_all_indices(
483 std::bool_constant<IS_CONST> /* is_const */,
484 std::conditional_t<IS_CONST, const UnderlyingContainer&, UnderlyingContainer&> container,
485 std::integral_constant<access_t, ACCESS> /* access */,
486 const view_type& view,
487 std::index_sequence<Is...> /* is_seq */,
488 const std::array<size_t, n_indices>& indices
489 ) const -> conditional_reference<IS_CONST>
490 {
491 if constexpr (ACCESS == access_t::SAFE)
492 {
493 return container.at(map_ith_index<Is>(view, indices[Is])...);
494 }
495 else
496 {
497 return container(map_ith_index<Is>(view, indices[Is])...);
498 }
499 }
500
501 template <class UnderlyingContainer, class... Slices>
502 template <size_t I, std::integral Index>
503 auto
504 index_mapper<xt::xview<UnderlyingContainer, Slices...>>::map_ith_index(const view_type& view, const Index i) const
505 -> size_t
506 {
507 if constexpr (I < sizeof...(Slices))
508 {
509 // if the slice is explicitly specified, use it
510 using current_slice = std::tuple_element_t<I, std::tuple<Slices...>>;
511
512 static_assert(not xt::detail::is_newaxis_v<current_slice>);
513
514 const auto& slice = std::get<I>(view.slices());
515
516 if constexpr (std::is_integral_v<current_slice>)
517 {
518 assert(i == 0);
519 return size_t(slice);
520 }
521 else
522 {
523 using slice_size_type = typename current_slice::size_type;
524 assert(i < slice.size());
525 return size_t(slice(static_cast<slice_size_type>(i)));
526 }
527 }
528 else
529 {
530 // else assume xt::all
531 return i;
532 }
533 }
534
535 template <class UnderlyingContainer, class... Slices>
536 auto index_mapper<xt::xview<UnderlyingContainer, Slices...>>::dimension(const UnderlyingContainer& container
537 ) const -> size_t
538 {
539 return container.dimension() - nb_integral_slices + nb_new_axis_slices;
540 }
541
542} // namespace xt
543
544#endif // XTENSOR_INDEX_MAPPER_HPP
xt::xview< UnderlyingContainer, Slices... > view_type
The view type this mapper works with.
static constexpr size_t nb_integral_slices
Number of slices that are integral constants (fixed indices)
typename xt::xview< UnderlyingContainer, Slices... >::const_reference const_reference
Const reference type of the underlying view.
const_reference cmap_at(const UnderlyingContainer &container, const view_type &view, const Indices... indices) const
Map view indices to container const_reference using SAFE access.
typename xt::xview< UnderlyingContainer, Slices... >::reference reference
Reference type of the underlying view.
size_t dimension(const UnderlyingContainer &container) const
Return the dimensionality of the view.
const_reference cmap(const UnderlyingContainer &container, const view_type &view, const Indices... indices) const
Map view indices to container const_reference using UNSAFE access.
static constexpr size_t nb_new_axis_slices
Number of slices that are xt::newaxis (insert a dimension)
static constexpr size_t n_indices_full_v
Compute how many indices are needed to address the underlying container when given N indices in the v...
static constexpr size_t n_slices
Total number of explicitly passed slices in the view.
reference map_at(UnderlyingContainer &container, const view_type &view, const Indices... indices) const
Map view indices to container reference using SAFE access.
reference map(UnderlyingContainer &container, const view_type &view, const Indices... indices) const
Map view indices to container reference using UNSAFE access.
Multidimensional view with tensor semantic.
Definition xview.hpp:365
standard mathematical functions for xexpressions
access_t
Defines the access policy for the underlying container.
@ UNSAFE
Use operator() accessor (no bounds checking).
@ SAFE
Use .at() accessor (bounds checked).
auto view(E &&e, S &&... slices)
Constructs and returns a view on the specified xexpression.
Definition xview.hpp:1824
A helper class for mapping indices between views and their underlying containers.