xtensor
 
Loading...
Searching...
No Matches
index_mapper.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_INDEX_MAPPER_HPP
11#define XTENSOR_INDEX_MAPPER_HPP
12
13#include "xview.hpp"
14
15namespace xt
16{
17
18 template <class UndefinedView>
20
25 enum class access_t
26 {
29 };
30
55 template <class UnderlyingContainer, class... Slices>
56 class index_mapper<xt::xview<UnderlyingContainer, Slices...>>
57 {
58 public:
59
61 using view_type = xt::xview<UnderlyingContainer, Slices...>;
62
64 using reference = typename xt::xview<UnderlyingContainer, Slices...>::reference;
65
67 using const_reference = typename xt::xview<UnderlyingContainer, Slices...>::const_reference;
68
70 static constexpr size_t n_slices = sizeof...(Slices);
71
73 static constexpr size_t nb_integral_slices = (std::is_integral_v<Slices> + ...);
74
76 static constexpr size_t nb_new_axis_slices = (xt::detail::is_newaxis_v<Slices> + ...);
77
82 template <std::integral... Indices>
83 static constexpr size_t n_indices_full_v = size_t(sizeof...(Indices) + nb_integral_slices);
84
92 template <std::integral... Indices>
93 reference map(UnderlyingContainer& container, const view_type& view, const Indices... indices) const;
94
102 template <std::integral... Indices>
104 cmap(const UnderlyingContainer& container, const view_type& view, const Indices... indices) const;
105
113 template <std::integral... Indices>
114 reference map_at(UnderlyingContainer& container, const view_type& view, const Indices... indices) const;
115
123 template <std::integral... Indices>
125 cmap_at(const UnderlyingContainer& container, const view_type& view, const Indices... indices) const;
126
128 size_t dimension(const UnderlyingContainer& container) const;
129
130 private:
131
133 template <bool IS_CONST>
134 using conditional_reference = std::conditional_t<IS_CONST, const_reference, reference>;
135
137 template <size_t I>
138 using slice_type = std::tuple_element_t<I, std::tuple<Slices...>>;
139
141 template <size_t I>
142 static consteval bool is_slice_integral();
143
145 template <size_t I>
146 static consteval bool is_slice_new_axis();
147
155 template <size_t first, size_t bound, size_t... indices>
156 struct indices_sequence_helper
157 {
158 // we add the current axis
159 using not_new_axis_type = typename indices_sequence_helper<first + 1, bound, indices..., first>::type;
160
161 // we skip the current axis
162 using new_axis_type = typename indices_sequence_helper<first + 1, bound, indices...>::type;
163
164 // NOTE: is_slice_new_axis works even if first >= sizeof...(Slices)
165 using type = std::conditional_t<is_slice_new_axis<first>(), new_axis_type, not_new_axis_type>;
166 };
167
169 template <size_t bound, size_t... indices>
170 struct indices_sequence_helper<bound, bound, indices...>
171 {
172 using type = std::index_sequence<indices...>;
173 };
174
176 template <size_t bound>
177 using indices_sequence = indices_sequence_helper<0, bound>::type;
178
194 template <size_t I, std::integral Index>
195 size_t map_ith_index(const view_type& view, const Index i) const;
196
210 template <bool IS_CONST, access_t ACCESS, std::integral FirstIndice, std::integral... OtherIndices>
211 conditional_reference<IS_CONST> map_main(
212 std::bool_constant<IS_CONST> /* is_const */,
213 std::conditional_t<IS_CONST, const UnderlyingContainer&, UnderlyingContainer&> container,
214 std::integral_constant<access_t, ACCESS> /* access */,
215 const view_type& view,
216 const FirstIndice firstIndice,
217 const OtherIndices... otherIndices
218 ) const;
219
230 template <bool IS_CONST, access_t ACCESS>
231 conditional_reference<IS_CONST> map_main(
232 std::bool_constant<IS_CONST> /* is_const */,
233 std::conditional_t<IS_CONST, const UnderlyingContainer&, UnderlyingContainer&> container,
234 std::integral_constant<access_t, ACCESS> /* access */,
235 const view_type& view
236 ) const;
237
252 template <bool IS_CONST, access_t ACCESS, size_t n_indices, size_t... Is>
253 conditional_reference<IS_CONST> map_all_indices(
254 std::bool_constant<IS_CONST> /* is_const */,
255 std::conditional_t<IS_CONST, const UnderlyingContainer&, UnderlyingContainer&> container,
256 std::integral_constant<access_t, ACCESS> /* access */,
257 const view_type& view,
258 std::index_sequence<Is...> /* is_seq */,
259 const std::array<size_t, n_indices>& indices
260 ) const;
261
263 template <std::integral... Indices>
264 std::array<size_t, n_indices_full_v<Indices...>> get_indices_full(const Indices... indices) const;
265 };
266
267 /*******************************
268 * index_mapper implementation *
269 *******************************/
270
271 template <class UnderlyingContainer, class... Slices>
272 template <size_t I>
273 consteval bool index_mapper<xt::xview<UnderlyingContainer, Slices...>>::is_slice_integral()
274 {
275 if constexpr (I < sizeof...(Slices))
276 {
277 return std::is_integral_v<slice_type<I>>;
278 }
279 else
280 {
281 return false;
282 }
283 }
284
285 template <class UnderlyingContainer, class... Slices>
286 template <size_t I>
287 consteval bool index_mapper<xt::xview<UnderlyingContainer, Slices...>>::is_slice_new_axis()
288 {
289 if constexpr (I < sizeof...(Slices))
290 {
291 return xt::detail::is_newaxis_v<slice_type<I>>;
292 }
293 else
294 {
295 return false;
296 }
297 }
298
299 template <class UnderlyingContainer, class... Slices>
300 template <std::integral... Indices>
301 auto
302 index_mapper<xt::xview<UnderlyingContainer, Slices...>>::get_indices_full(const Indices... indices) const
303 -> std::array<size_t, n_indices_full_v<Indices...>>
304 {
305 constexpr size_t n_indices_full = n_indices_full_v<Indices...>;
306
307 std::array<size_t, sizeof...(indices)> args{size_t(indices)...};
308 std::array<size_t, n_indices_full> args_full;
309
310 const auto fill_args_full = [&args_full, &args]<size_t... Is>(std::index_sequence<Is...>)
311 {
312 auto it = std::cbegin(args);
313
314 ((args_full[Is] = (is_slice_integral<Is>()) ? size_t(0) : *it++), ...);
315 };
316
317 fill_args_full(std::make_index_sequence<n_indices_full>{});
318
319 return args_full;
320 }
321
322 template <class UnderlyingContainer, class... Slices>
323 template <std::integral... Indices>
324 auto index_mapper<xt::xview<UnderlyingContainer, Slices...>>::map(
325 UnderlyingContainer& container,
326 const view_type& view,
327 const Indices... indices
328 ) const -> reference
329 {
330 return map_main(
331 std::false_type{},
332 container,
333 std::integral_constant<access_t, access_t::UNSAFE>{},
334 view,
335 indices...
336 );
337 }
338
339 template <class UnderlyingContainer, class... Slices>
340 template <std::integral... Indices>
341 auto index_mapper<xt::xview<UnderlyingContainer, Slices...>>::cmap(
342 const UnderlyingContainer& container,
343 const view_type& view,
344 const Indices... indices
345 ) const -> const_reference
346 {
347 return map_main(
348 std::true_type{},
349 container,
350 std::integral_constant<access_t, access_t::UNSAFE>{},
351 view,
352 indices...
353 );
354 }
355
356 template <class UnderlyingContainer, class... Slices>
357 template <std::integral... Indices>
358 auto index_mapper<xt::xview<UnderlyingContainer, Slices...>>::map_at(
359 UnderlyingContainer& container,
360 const view_type& view,
361 const Indices... indices
362 ) const -> reference
363 {
364 return map_main(
365 std::false_type{},
366 container,
367 std::integral_constant<access_t, access_t::SAFE>{},
368 view,
369 indices...
370 );
371 }
372
373 template <class UnderlyingContainer, class... Slices>
374 template <std::integral... Indices>
375 auto index_mapper<xt::xview<UnderlyingContainer, Slices...>>::cmap_at(
376 const UnderlyingContainer& container,
377 const view_type& view,
378 const Indices... indices
379 ) const -> const_reference
380 {
381 return map_main(
382 std::true_type{},
383 container,
384 std::integral_constant<access_t, access_t::SAFE>{},
385 view,
386 indices...
387 );
388 }
389
390 template <class UnderlyingContainer, class... Slices>
391 template <bool IS_CONST, access_t ACCESS, std::integral FirstIndice, std::integral... OtherIndices>
392 auto index_mapper<xt::xview<UnderlyingContainer, Slices...>>::map_main(
393 std::bool_constant<IS_CONST> is_const,
394 std::conditional_t<IS_CONST, const UnderlyingContainer&, UnderlyingContainer&> container,
395 std::integral_constant<access_t, ACCESS> access,
396 const view_type& view,
397 const FirstIndice firstIndice,
398 const OtherIndices... otherIndices
399 ) const -> conditional_reference<IS_CONST>
400 {
401 constexpr size_t n_indices_full = n_indices_full_v<FirstIndice, OtherIndices...>;
402
403 constexpr size_t underlying_n_dimensions = xt::static_dimension<
404 typename std::decay_t<UnderlyingContainer>::shape_type>::value;
405
406 // If there is too many indices, we need to drop the first ones.
407 // If the number of dimensions of the underlying container is known at compile time we can drop them
408 // at compile time Else a runtime-test is requires, which, breaks vectorization.
409 // I don't know if we can do it in another way.
410
411 if constexpr (underlying_n_dimensions != size_t(-1))
412 {
413 // the number of dimensions of the underlying container is known at compile time.
414 constexpr size_t n_dimensions = underlying_n_dimensions - nb_integral_slices + nb_new_axis_slices;
415
416 // we can perform compile time checks
417 if constexpr (1 + sizeof...(OtherIndices) > n_dimensions)
418 {
419 return map_main(is_const, container, access, view, otherIndices...);
420 }
421 else
422 {
423 return map_all_indices(
424 is_const,
425 container,
426 access,
427 view,
428 indices_sequence<n_indices_full>{},
429 get_indices_full(firstIndice, otherIndices...)
430 );
431 }
432 }
433 else
434 {
435 // we need execution time checks
436 if (1 + sizeof...(OtherIndices) > dimension(container))
437 {
438 return map_main(is_const, container, access, view, otherIndices...);
439 }
440 else
441 {
442 return map_all_indices(
443 is_const,
444 container,
445 access,
446 view,
447 indices_sequence<n_indices_full>{},
448 get_indices_full(firstIndice, otherIndices...)
449 );
450 }
451 }
452 }
453
454 template <class UnderlyingContainer, class... Slices>
455 template <bool IS_CONST, access_t ACCESS>
456 auto index_mapper<xt::xview<UnderlyingContainer, Slices...>>::map_main(
457 std::bool_constant<IS_CONST> is_const,
458 std::conditional_t<IS_CONST, const UnderlyingContainer&, UnderlyingContainer&> container,
459 std::integral_constant<access_t, ACCESS> access,
460 const view_type& view
461 ) const -> conditional_reference<IS_CONST>
462 {
463 constexpr size_t n_indices_full = n_indices_full_v<>;
464
465 return map_all_indices(
466 is_const,
467 container,
468 access,
469 view,
470 indices_sequence<n_indices_full>{},
471 get_indices_full()
472 );
473 }
474
475 template <class UnderlyingContainer, class... Slices>
476 template <bool IS_CONST, access_t ACCESS, size_t n_indices, size_t... Is>
477 auto index_mapper<xt::xview<UnderlyingContainer, Slices...>>::map_all_indices(
478 std::bool_constant<IS_CONST> /* is_const */,
479 std::conditional_t<IS_CONST, const UnderlyingContainer&, UnderlyingContainer&> container,
480 std::integral_constant<access_t, ACCESS> /* access */,
481 const view_type& view,
482 std::index_sequence<Is...> /* is_seq */,
483 const std::array<size_t, n_indices>& indices
484 ) const -> conditional_reference<IS_CONST>
485 {
486 if constexpr (ACCESS == access_t::SAFE)
487 {
488 return container.at(map_ith_index<Is>(view, indices[Is])...);
489 }
490 else
491 {
492 return container(map_ith_index<Is>(view, indices[Is])...);
493 }
494 }
495
496 template <class UnderlyingContainer, class... Slices>
497 template <size_t I, std::integral Index>
498 auto
499 index_mapper<xt::xview<UnderlyingContainer, Slices...>>::map_ith_index(const view_type& view, const Index i) const
500 -> size_t
501 {
502 if constexpr (I < sizeof...(Slices))
503 {
504 // if the slice is explicitly specified, use it
505 using current_slice = std::tuple_element_t<I, std::tuple<Slices...>>;
506
507 static_assert(not xt::detail::is_newaxis_v<current_slice>);
508
509 const auto& slice = std::get<I>(view.slices());
510
511 if constexpr (std::is_integral_v<current_slice>)
512 {
513 assert(i == 0);
514 return size_t(slice);
515 }
516 else
517 {
518 assert(i < slice.size());
519 return size_t(slice(i));
520 }
521 }
522 else
523 {
524 // else assume xt::all
525 return i;
526 }
527 }
528
529 template <class UnderlyingContainer, class... Slices>
530 auto index_mapper<xt::xview<UnderlyingContainer, Slices...>>::dimension(const UnderlyingContainer& container
531 ) const -> size_t
532 {
533 return container.dimension() - nb_integral_slices + nb_new_axis_slices;
534 }
535
536} // namespace xt
537
538#endif // XTENSOR_INDEX_MAPPER_HPP
xt::xview< UnderlyingContainer, Slices... > view_type
The view type this mapper works with.
static constexpr size_t nb_integral_slices
Number of slices that are integral constants (fixed indices)
typename xt::xview< UnderlyingContainer, Slices... >::const_reference const_reference
Const reference type of the underlying view.
const_reference cmap_at(const UnderlyingContainer &container, const view_type &view, const Indices... indices) const
Map view indices to container const_reference using SAFE access.
typename xt::xview< UnderlyingContainer, Slices... >::reference reference
Reference type of the underlying view.
size_t dimension(const UnderlyingContainer &container) const
Return the dimensionality of the view.
const_reference cmap(const UnderlyingContainer &container, const view_type &view, const Indices... indices) const
Map view indices to container const_reference using UNSAFE access.
static constexpr size_t nb_new_axis_slices
Number of slices that are xt::newaxis (insert a dimension)
static constexpr size_t n_indices_full_v
Compute how many indices are needed to address the underlying container when given N indices in the v...
static constexpr size_t n_slices
Total number of explicitly passed slices in the view.
reference map_at(UnderlyingContainer &container, const view_type &view, const Indices... indices) const
Map view indices to container reference using SAFE access.
reference map(UnderlyingContainer &container, const view_type &view, const Indices... indices) const
Map view indices to container reference using UNSAFE access.
Multidimensional view with tensor semantic.
Definition xview.hpp:365
standard mathematical functions for xexpressions
access_t
Defines the access policy for the underlying container.
@ UNSAFE
Use operator() accessor (no bounds checking).
@ SAFE
Use .at() accessor (bounds checked).
auto view(E &&e, S &&... slices)
Constructs and returns a view on the specified xexpression.
Definition xview.hpp:1823
A helper class for mapping indices between views and their underlying containers.