11#ifndef XTENSOR_NPY_HPP
12#define XTENSOR_NPY_HPP
31#include <xtl/xplatform.hpp>
32#include <xtl/xsequence.hpp>
34#include "xtensor/xadapt.hpp"
35#include "xtensor/xarray.hpp"
36#include "xtensor/xeval.hpp"
37#include "xtensor/xstrides.hpp"
39#include "xtensor_config.hpp"
43 using namespace std::string_literals;
48 const char magic_string[] =
"\x93NUMPY";
49 const std::size_t magic_string_length =
sizeof(magic_string) - 1;
52 inline void write_magic(O& ostream,
unsigned char v_major = 1,
unsigned char v_minor = 0)
54 ostream.write(magic_string, magic_string_length);
55 ostream.put(
char(v_major));
56 ostream.put(
char(v_minor));
59 inline void read_magic(std::istream& istream,
unsigned char* v_major,
unsigned char* v_minor)
61 std::unique_ptr<char[]> buf(
new char[magic_string_length + 2]);
62 istream.read(buf.get(), magic_string_length + 2);
66 XTENSOR_THROW(std::runtime_error,
"io error: failed reading file");
69 for (std::size_t i = 0; i < magic_string_length; i++)
71 if (buf[i] != magic_string[i])
73 XTENSOR_THROW(std::runtime_error,
"this file do not have a valid npy format.");
77 *v_major =
static_cast<unsigned char>(buf[magic_string_length]);
78 *v_minor =
static_cast<unsigned char>(buf[magic_string_length + 1]);
82 inline char map_type()
84 if (std::is_same<T, float>::value)
88 if (std::is_same<T, double>::value)
92 if (std::is_same<T, long double>::value)
97 if (std::is_same<T, char>::value)
101 if (std::is_same<T, signed char>::value)
105 if (std::is_same<T, short>::value)
109 if (std::is_same<T, int>::value)
113 if (std::is_same<T, long>::value)
117 if (std::is_same<T, long long>::value)
122 if (std::is_same<T, unsigned char>::value)
126 if (std::is_same<T, unsigned short>::value)
130 if (std::is_same<T, unsigned int>::value)
134 if (std::is_same<T, unsigned long>::value)
138 if (std::is_same<T, unsigned long long>::value)
143 if (std::is_same<T, bool>::value)
148 if (std::is_same<T, std::complex<float>>::value)
152 if (std::is_same<T, std::complex<double>>::value)
156 if (std::is_same<T, std::complex<long double>>::value)
161 XTENSOR_THROW(std::runtime_error,
"Type not known.");
165 inline char get_endianess()
167 constexpr char little_endian_char =
'<';
168 constexpr char big_endian_char =
'>';
169 constexpr char no_endian_char =
'|';
171 if (
sizeof(T) <=
sizeof(
char))
173 return no_endian_char;
176 switch (xtl::endianness())
178 case xtl::endian::little_endian:
179 return little_endian_char;
180 case xtl::endian::big_endian:
181 return big_endian_char;
183 return no_endian_char;
188 inline std::string build_typestring()
190 std::stringstream ss;
191 ss << get_endianess<T>() << map_type<T>() <<
sizeof(T);
196 inline void parse_typestring(std::string typestring)
198 std::regex re(
"'([<>|])([ifucb])(\\d+)'");
201 std::regex_match(typestring, sm, re);
204 XTENSOR_THROW(std::runtime_error,
"invalid typestring");
209 inline std::string unwrap_s(std::string s,
char delim_front,
char delim_back)
211 if ((s.back() == delim_back) && (s.front() == delim_front))
213 return s.substr(1, s.length() - 2);
217 XTENSOR_THROW(std::runtime_error,
"unable to unwrap");
221 inline std::string get_value_from_map(std::string mapstr)
223 std::size_t sep_pos = mapstr.find_first_of(
":");
224 if (sep_pos == std::string::npos)
229 return mapstr.substr(sep_pos + 1);
232 inline void pop_char(std::string& s,
char c)
241 parse_header(std::string header, std::string& descr,
bool* fortran_order, std::vector<std::size_t>& shape)
276 if (header.back() !=
'\n')
278 XTENSOR_THROW(std::runtime_error,
"invalid header");
283 header.erase(std::remove(header.begin(), header.end(),
' '), header.end());
286 header = unwrap_s(header,
'{',
'}');
289 std::size_t keypos_descr = header.find(
"'descr'");
290 std::size_t keypos_fortran = header.find(
"'fortran_order'");
291 std::size_t keypos_shape = header.find(
"'shape'");
294 if (keypos_descr == std::string::npos)
296 XTENSOR_THROW(std::runtime_error,
"missing 'descr' key");
298 if (keypos_fortran == std::string::npos)
300 XTENSOR_THROW(std::runtime_error,
"missing 'fortran_order' key");
302 if (keypos_shape == std::string::npos)
304 XTENSOR_THROW(std::runtime_error,
"missing 'shape' key");
311 if (keypos_descr >= keypos_fortran || keypos_fortran >= keypos_shape)
313 XTENSOR_THROW(std::runtime_error,
"header keys in wrong order");
317 std::string keyvalue_descr;
318 keyvalue_descr = header.substr(keypos_descr, keypos_fortran - keypos_descr);
319 pop_char(keyvalue_descr,
',');
321 std::string keyvalue_fortran;
322 keyvalue_fortran = header.substr(keypos_fortran, keypos_shape - keypos_fortran);
323 pop_char(keyvalue_fortran,
',');
325 std::string keyvalue_shape;
326 keyvalue_shape = header.substr(keypos_shape, std::string::npos);
327 pop_char(keyvalue_shape,
',');
330 std::string descr_s = get_value_from_map(keyvalue_descr);
331 std::string fortran_s = get_value_from_map(keyvalue_fortran);
332 std::string shape_s = get_value_from_map(keyvalue_shape);
334 parse_typestring(descr_s);
335 descr = unwrap_s(descr_s,
'\'',
'\'');
338 if (fortran_s ==
"True")
340 *fortran_order =
true;
342 else if (fortran_s ==
"False")
344 *fortran_order =
false;
348 XTENSOR_THROW(std::runtime_error,
"invalid fortran_order value");
355 shape_s = unwrap_s(shape_s,
'(',
')');
361 std::size_t pos_next = shape_s.find_first_of(
',', pos);
364 if (pos_next != std::string::npos)
366 dim_s = shape_s.substr(pos, pos_next - pos);
370 dim_s = shape_s.substr(pos);
373 if (dim_s.length() == 0)
375 if (pos_next != std::string::npos)
377 XTENSOR_THROW(std::runtime_error,
"invalid shape");
382 std::stringstream ss;
386 shape.push_back(tmp);
389 if (pos_next != std::string::npos)
400 template <
class O,
class S>
401 inline void write_header(O& out,
const std::string& descr,
bool fortran_order,
const S& shape)
403 std::ostringstream ss_header;
404 std::string s_fortran_order;
407 s_fortran_order =
"True";
411 s_fortran_order =
"False";
415 std::ostringstream ss_shape;
417 for (
auto shape_it = std::begin(shape); shape_it != std::end(shape); ++shape_it)
419 ss_shape << *shape_it <<
", ";
421 s_shape = ss_shape.str();
422 if (xtl::sequence_size(shape) > 1)
424 s_shape = s_shape.erase(s_shape.size() - 2);
426 else if (xtl::sequence_size(shape) == 1)
428 s_shape = s_shape.erase(s_shape.size() - 1);
432 ss_header <<
"{'descr': '" << descr <<
"', 'fortran_order': " << s_fortran_order
433 <<
", 'shape': " << s_shape <<
", }";
435 std::size_t header_len_pre = ss_header.str().length() + 1;
436 std::size_t metadata_len = magic_string_length + 2 + 2 + header_len_pre;
438 unsigned char version[2] = {1, 0};
439 if (metadata_len >= 255 * 255)
441 metadata_len = magic_string_length + 2 + 4 + header_len_pre;
445 std::size_t padding_len = 64 - (metadata_len % 64);
446 std::string padding(padding_len,
' ');
447 ss_header << padding;
448 ss_header << std::endl;
450 std::string header = ss_header.str();
453 write_magic(out, version[0], version[1]);
456 if (version[0] == 1 && version[1] == 0)
458 char header_len_le16[2];
459 uint16_t header_len = uint16_t(header.length());
461 header_len_le16[0] = char((header_len >> 0) & 0xff);
462 header_len_le16[1] = char((header_len >> 8) & 0xff);
463 out.write(
reinterpret_cast<char*
>(header_len_le16), 2);
467 char header_len_le32[4];
468 uint32_t header_len = uint32_t(header.length());
470 header_len_le32[0] = char((header_len >> 0) & 0xff);
471 header_len_le32[1] = char((header_len >> 8) & 0xff);
472 header_len_le32[2] = char((header_len >> 16) & 0xff);
473 header_len_le32[3] = char((header_len >> 24) & 0xff);
474 out.write(
reinterpret_cast<char*
>(header_len_le32), 4);
480 inline std::string read_header_1_0(std::istream& istream)
483 char header_len_le16[2];
484 istream.read(header_len_le16, 2);
486 uint16_t header_length = uint16_t(header_len_le16[0] << 0) | uint16_t(header_len_le16[1] << 8);
488 if ((magic_string_length + 2 + 2 + header_length) % 16 != 0)
493 std::unique_ptr<char[]> buf(
new char[header_length]);
494 istream.read(buf.get(), header_length);
495 std::string header(buf.get(), header_length);
500 inline std::string read_header_2_0(std::istream& istream)
503 char header_len_le32[4];
504 istream.read(header_len_le32, 4);
506 uint32_t header_length = uint32_t(header_len_le32[0] << 0) | uint32_t(header_len_le32[1] << 8)
507 | uint32_t(header_len_le32[2] << 16) | uint32_t(header_len_le32[3] << 24);
509 if ((magic_string_length + 2 + 4 + header_length) % 16 != 0)
514 std::unique_ptr<char[]> buf(
new char[header_length]);
515 istream.read(buf.get(), header_length);
516 std::string header(buf.get(), header_length);
523 npy_file() =
default;
525 npy_file(std::vector<std::size_t>& shape,
bool fortran_order, std::string typestring)
527 , m_fortran_order(fortran_order)
528 , m_typestring(typestring)
531 m_word_size = std::size_t(atoi(&typestring[2]));
532 m_n_bytes = compute_size(shape) * m_word_size;
533 m_buffer = std::allocator<char>{}.allocate(m_n_bytes);
538 if (m_buffer !=
nullptr)
540 std::allocator<char>{}.deallocate(m_buffer, m_n_bytes);
545 npy_file(
const npy_file&) =
delete;
546 npy_file& operator=(
const npy_file&) =
delete;
549 npy_file(npy_file&& rhs)
550 : m_shape(std::move(rhs.m_shape))
551 , m_fortran_order(std::move(rhs.m_fortran_order))
552 , m_word_size(std::move(rhs.m_word_size))
553 , m_n_bytes(std::move(rhs.m_n_bytes))
554 , m_typestring(std::move(rhs.m_typestring))
555 , m_buffer(rhs.m_buffer)
557 rhs.m_buffer =
nullptr;
560 npy_file& operator=(npy_file&& rhs)
564 m_shape = std::move(rhs.m_shape);
565 m_fortran_order = std::move(rhs.m_fortran_order);
566 m_word_size = std::move(rhs.m_word_size);
567 m_n_bytes = std::move(rhs.m_n_bytes);
568 m_typestring = std::move(rhs.m_typestring);
569 m_buffer = rhs.m_buffer;
570 rhs.m_buffer =
nullptr;
575 template <
class T, layout_type L>
576 auto cast_impl(
bool check_type)
578 if (m_buffer ==
nullptr)
580 XTENSOR_THROW(std::runtime_error,
"This npy_file has already been cast.");
582 T* ptr =
reinterpret_cast<T*
>(&m_buffer[0]);
583 std::vector<std::size_t>
strides(m_shape.size());
584 std::size_t sz = compute_size(m_shape);
587 if (check_type && m_typestring != detail::build_typestring<T>())
591 "Cast error: formats not matching "s + m_typestring +
" vs "s
592 + detail::build_typestring<T>()
601 "Cast error: layout mismatch between npy file and requested layout."
610 std::vector<std::size_t> shape(m_shape);
612 return std::make_tuple(ptr, sz, std::move(shape), std::move(
strides));
615 template <
class T, layout_type L = layout_type::dynamic>
616 auto cast(
bool check_type =
true) &&
618 auto cast_elems = cast_impl<T, L>(check_type);
621 std::move(std::get<0>(cast_elems)),
622 std::get<1>(cast_elems),
624 std::get<2>(cast_elems),
625 std::get<3>(cast_elems)
629 template <
class T, layout_type L = layout_type::dynamic>
630 auto cast(
bool check_type =
true) const&
632 auto cast_elems = cast_impl<T, L>(check_type);
634 std::get<0>(cast_elems),
635 std::get<1>(cast_elems),
637 std::get<2>(cast_elems),
638 std::get<3>(cast_elems)
642 template <
class T, layout_type L = layout_type::dynamic>
643 auto cast(
bool check_type =
true) &
645 auto cast_elems = cast_impl<T, L>(check_type);
647 std::get<0>(cast_elems),
648 std::get<1>(cast_elems),
650 std::get<2>(cast_elems),
651 std::get<3>(cast_elems)
660 std::size_t n_bytes()
665 std::vector<std::size_t> m_shape;
666 bool m_fortran_order;
667 std::size_t m_word_size;
668 std::size_t m_n_bytes;
669 std::string m_typestring;
673 inline npy_file load_npy_file(std::istream& stream)
676 unsigned char v_major, v_minor;
677 detail::read_magic(stream, &v_major, &v_minor);
681 if (v_major == 1 && v_minor == 0)
683 header = detail::read_header_1_0(stream);
685 else if (v_major == 2 && v_minor == 0)
687 header = detail::read_header_2_0(stream);
691 XTENSOR_THROW(std::runtime_error,
"unsupported file format version");
698 std::vector<std::size_t> shape;
699 detail::parse_header(header, typestr, &fortran_order, shape);
701 npy_file result(shape, fortran_order, typestr);
703 stream.read(result.ptr(), std::streamsize((result.n_bytes())));
707 template <
class O,
class E>
708 inline void dump_npy_stream(O& stream,
const xexpression<E>& e)
710 using value_type =
typename E::value_type;
711 const E& ex = e.derived_cast();
712 auto&& eval_ex =
eval(ex);
713 bool fortran_order =
false;
716 fortran_order =
true;
719 std::string typestring = detail::build_typestring<value_type>();
721 auto shape = eval_ex.shape();
722 detail::write_header(stream, typestring, fortran_order, shape);
724 std::size_t size = compute_size(shape);
726 reinterpret_cast<const char*
>(eval_ex.data()),
727 std::streamsize((
sizeof(value_type) * size))
738 template <
typename E>
744 XTENSOR_THROW(std::runtime_error,
"IO Error: failed to open file: "s +
filename);
747 detail::dump_npy_stream(
stream,
e);
755 template <
typename E>
759 detail::dump_npy_stream(
stream,
e);
773 template <
typename T, layout_type L = layout_type::dynamic>
776 detail::npy_file
file = detail::load_npy_file(
stream);
777 return std::move(
file).cast<T,
L>();
790 template <
typename T, layout_type L = layout_type::dynamic>
796 XTENSOR_THROW(std::runtime_error,
"io error: failed to open a file.");
auto cast(E &&e) noexcept -> detail::xfunction_type_t< typename detail::cast< R >::functor, E >
Element-wise static_cast.
auto adapt(C &&container, const SC &shape, layout_type l=L)
Constructs:
auto eval(T &&t) -> std::enable_if_t< detail::is_container< std::decay_t< T > >::value, T && >
Force evaluation of xexpression.
std::size_t compute_strides(const shape_type &shape, layout_type l, strides_type &strides)
Compute the strides given the shape and the layout of an array.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
standard mathematical functions for xexpressions
auto load_npy(std::istream &stream)
Loads a npy file (the NumPy storage format)
void dump_npy(const std::string &filename, const xexpression< E > &e)
Save xexpression to NumPy npy format.