11#ifndef XTENSOR_NPY_HPP
12#define XTENSOR_NPY_HPP
28#include <xtl/xplatform.hpp>
29#include <xtl/xsequence.hpp>
31#include "../containers/xadapt.hpp"
32#include "../containers/xarray.hpp"
33#include "../core/xeval.hpp"
34#include "../core/xstrides.hpp"
35#include "../core/xtensor_config.hpp"
39 using namespace std::string_literals;
44 const char magic_string[] =
"\x93NUMPY";
45 const std::size_t magic_string_length =
sizeof(magic_string) - 1;
48 inline void write_magic(O& ostream,
unsigned char v_major = 1,
unsigned char v_minor = 0)
50 ostream.write(magic_string, magic_string_length);
51 ostream.put(
char(v_major));
52 ostream.put(
char(v_minor));
55 inline void read_magic(std::istream& istream,
unsigned char* v_major,
unsigned char* v_minor)
57 std::unique_ptr<char[]> buf(
new char[magic_string_length + 2]);
58 istream.read(buf.get(), magic_string_length + 2);
62 XTENSOR_THROW(std::runtime_error,
"io error: failed reading file");
65 for (std::size_t i = 0; i < magic_string_length; i++)
67 if (buf[i] != magic_string[i])
69 XTENSOR_THROW(std::runtime_error,
"this file do not have a valid npy format.");
73 *v_major =
static_cast<unsigned char>(buf[magic_string_length]);
74 *v_minor =
static_cast<unsigned char>(buf[magic_string_length + 1]);
78 inline char map_type()
80 if (std::is_same<T, float>::value)
84 if (std::is_same<T, double>::value)
88 if (std::is_same<T, long double>::value)
93 if (std::is_same<T, char>::value)
97 if (std::is_same<T, signed char>::value)
101 if (std::is_same<T, short>::value)
105 if (std::is_same<T, int>::value)
109 if (std::is_same<T, long>::value)
113 if (std::is_same<T, long long>::value)
118 if (std::is_same<T, unsigned char>::value)
122 if (std::is_same<T, unsigned short>::value)
126 if (std::is_same<T, unsigned int>::value)
130 if (std::is_same<T, unsigned long>::value)
134 if (std::is_same<T, unsigned long long>::value)
139 if (std::is_same<T, bool>::value)
144 if (std::is_same<T, std::complex<float>>::value)
148 if (std::is_same<T, std::complex<double>>::value)
152 if (std::is_same<T, std::complex<long double>>::value)
157 XTENSOR_THROW(std::runtime_error,
"Type not known.");
161 inline char get_endianess()
163 constexpr char little_endian_char =
'<';
164 constexpr char big_endian_char =
'>';
165 constexpr char no_endian_char =
'|';
167 if (
sizeof(T) <=
sizeof(
char))
169 return no_endian_char;
172 switch (xtl::endianness())
174 case xtl::endian::little_endian:
175 return little_endian_char;
176 case xtl::endian::big_endian:
177 return big_endian_char;
179 return no_endian_char;
184 inline std::string build_typestring()
186 std::stringstream ss;
187 ss << get_endianess<T>() << map_type<T>() <<
sizeof(T);
192 inline void parse_typestring(std::string typestring)
194 std::regex re(
"'([<>|])([ifucb])(\\d+)'");
197 std::regex_match(typestring, sm, re);
200 XTENSOR_THROW(std::runtime_error,
"invalid typestring");
205 inline std::string unwrap_s(std::string s,
char delim_front,
char delim_back)
207 if ((s.back() == delim_back) && (s.front() == delim_front))
209 return s.substr(1, s.length() - 2);
213 XTENSOR_THROW(std::runtime_error,
"unable to unwrap");
217 inline std::string get_value_from_map(std::string mapstr)
219 std::size_t sep_pos = mapstr.find_first_of(
":");
220 if (sep_pos == std::string::npos)
225 return mapstr.substr(sep_pos + 1);
228 inline void pop_char(std::string& s,
char c)
237 parse_header(std::string header, std::string& descr,
bool* fortran_order, std::vector<std::size_t>& shape)
272 if (header.back() !=
'\n')
274 XTENSOR_THROW(std::runtime_error,
"invalid header");
279 header.erase(std::remove(header.begin(), header.end(),
' '), header.end());
282 header = unwrap_s(header,
'{',
'}');
285 std::size_t keypos_descr = header.find(
"'descr'");
286 std::size_t keypos_fortran = header.find(
"'fortran_order'");
287 std::size_t keypos_shape = header.find(
"'shape'");
290 if (keypos_descr == std::string::npos)
292 XTENSOR_THROW(std::runtime_error,
"missing 'descr' key");
294 if (keypos_fortran == std::string::npos)
296 XTENSOR_THROW(std::runtime_error,
"missing 'fortran_order' key");
298 if (keypos_shape == std::string::npos)
300 XTENSOR_THROW(std::runtime_error,
"missing 'shape' key");
307 if (keypos_descr >= keypos_fortran || keypos_fortran >= keypos_shape)
309 XTENSOR_THROW(std::runtime_error,
"header keys in wrong order");
313 std::string keyvalue_descr;
314 keyvalue_descr = header.substr(keypos_descr, keypos_fortran - keypos_descr);
315 pop_char(keyvalue_descr,
',');
317 std::string keyvalue_fortran;
318 keyvalue_fortran = header.substr(keypos_fortran, keypos_shape - keypos_fortran);
319 pop_char(keyvalue_fortran,
',');
321 std::string keyvalue_shape;
322 keyvalue_shape = header.substr(keypos_shape, std::string::npos);
323 pop_char(keyvalue_shape,
',');
326 std::string descr_s = get_value_from_map(keyvalue_descr);
327 std::string fortran_s = get_value_from_map(keyvalue_fortran);
328 std::string shape_s = get_value_from_map(keyvalue_shape);
330 parse_typestring(descr_s);
331 descr = unwrap_s(descr_s,
'\'',
'\'');
334 if (fortran_s ==
"True")
336 *fortran_order =
true;
338 else if (fortran_s ==
"False")
340 *fortran_order =
false;
344 XTENSOR_THROW(std::runtime_error,
"invalid fortran_order value");
351 shape_s = unwrap_s(shape_s,
'(',
')');
357 std::size_t pos_next = shape_s.find_first_of(
',', pos);
360 if (pos_next != std::string::npos)
362 dim_s = shape_s.substr(pos, pos_next - pos);
366 dim_s = shape_s.substr(pos);
369 if (dim_s.length() == 0)
371 if (pos_next != std::string::npos)
373 XTENSOR_THROW(std::runtime_error,
"invalid shape");
378 std::stringstream ss;
382 shape.push_back(tmp);
385 if (pos_next != std::string::npos)
396 template <
class O,
class S>
397 inline void write_header(O& out,
const std::string& descr,
bool fortran_order,
const S& shape)
399 std::ostringstream ss_header;
400 std::string s_fortran_order;
403 s_fortran_order =
"True";
407 s_fortran_order =
"False";
411 std::ostringstream ss_shape;
413 for (
auto shape_it = std::begin(shape); shape_it != std::end(shape); ++shape_it)
415 ss_shape << *shape_it <<
", ";
417 s_shape = ss_shape.str();
418 if (std::size(shape) > 1)
420 s_shape = s_shape.erase(s_shape.size() - 2);
422 else if (std::size(shape) == 1)
424 s_shape = s_shape.erase(s_shape.size() - 1);
428 ss_header <<
"{'descr': '" << descr <<
"', 'fortran_order': " << s_fortran_order
429 <<
", 'shape': " << s_shape <<
", }";
431 std::size_t header_len_pre = ss_header.str().length() + 1;
432 std::size_t metadata_len = magic_string_length + 2 + 2 + header_len_pre;
434 unsigned char version[2] = {1, 0};
435 if (metadata_len >= 255 * 255)
437 metadata_len = magic_string_length + 2 + 4 + header_len_pre;
441 std::size_t padding_len = 64 - (metadata_len % 64);
442 std::string padding(padding_len,
' ');
443 ss_header << padding;
444 ss_header << std::endl;
446 std::string header = ss_header.str();
449 write_magic(out, version[0], version[1]);
452 if (version[0] == 1 && version[1] == 0)
454 char header_len_le16[2];
455 uint16_t header_len = uint16_t(header.length());
457 header_len_le16[0] = char((header_len >> 0) & 0xff);
458 header_len_le16[1] = char((header_len >> 8) & 0xff);
459 out.write(
reinterpret_cast<char*
>(header_len_le16), 2);
463 char header_len_le32[4];
464 uint32_t header_len = uint32_t(header.length());
466 header_len_le32[0] = char((header_len >> 0) & 0xff);
467 header_len_le32[1] = char((header_len >> 8) & 0xff);
468 header_len_le32[2] = char((header_len >> 16) & 0xff);
469 header_len_le32[3] = char((header_len >> 24) & 0xff);
470 out.write(
reinterpret_cast<char*
>(header_len_le32), 4);
476 inline std::string read_header_1_0(std::istream& istream)
479 char header_len_le16[2];
480 istream.read(header_len_le16, 2);
482 uint16_t header_b0 =
static_cast<uint16_t
>(
static_cast<uint8_t
>(header_len_le16[0]));
483 uint16_t header_b1 =
static_cast<uint16_t
>(
static_cast<uint8_t
>(header_len_le16[1])) << 8;
484 uint16_t header_length = header_b0 | header_b1;
486 if ((magic_string_length + 2 + 2 + header_length) % 16 != 0)
491 std::unique_ptr<char[]> buf(
new char[header_length]);
492 istream.read(buf.get(), header_length);
493 std::string header(buf.get(), header_length);
498 inline std::string read_header_2_0(std::istream& istream)
501 char header_len_le32[4];
502 istream.read(header_len_le32, 4);
504 uint32_t header_b0 =
static_cast<uint32_t
>(
static_cast<uint8_t
>(header_len_le32[0]));
505 uint32_t header_b1 =
static_cast<uint32_t
>(
static_cast<uint8_t
>(header_len_le32[1])) << 8;
506 uint32_t header_b2 =
static_cast<uint32_t
>(
static_cast<uint8_t
>(header_len_le32[2])) << 16;
507 uint32_t header_b3 =
static_cast<uint32_t
>(
static_cast<uint8_t
>(header_len_le32[3])) << 24;
508 uint32_t header_length = header_b0 | header_b1 | header_b2 | header_b3;
510 if ((magic_string_length + 2 + 4 + header_length) % 16 != 0)
515 std::unique_ptr<char[]> buf(
new char[header_length]);
516 istream.read(buf.get(), header_length);
517 std::string header(buf.get(), header_length);
524 npy_file() =
default;
526 npy_file(std::vector<std::size_t>& shape,
bool fortran_order, std::string typestring)
528 , m_fortran_order(fortran_order)
529 , m_typestring(typestring)
532 m_word_size = std::size_t(atoi(&typestring[2]));
533 m_n_bytes = compute_size(shape) * m_word_size;
534 m_buffer = std::allocator<char>{}.allocate(m_n_bytes);
539 if (m_buffer !=
nullptr)
541 std::allocator<char>{}.deallocate(m_buffer, m_n_bytes);
546 npy_file(
const npy_file&) =
delete;
547 npy_file& operator=(
const npy_file&) =
delete;
550 npy_file(npy_file&& rhs)
551 : m_shape(std::move(rhs.m_shape))
552 , m_fortran_order(std::move(rhs.m_fortran_order))
553 , m_word_size(std::move(rhs.m_word_size))
554 , m_n_bytes(std::move(rhs.m_n_bytes))
555 , m_typestring(std::move(rhs.m_typestring))
556 , m_buffer(rhs.m_buffer)
558 rhs.m_buffer =
nullptr;
561 npy_file& operator=(npy_file&& rhs)
565 m_shape = std::move(rhs.m_shape);
566 m_fortran_order = std::move(rhs.m_fortran_order);
567 m_word_size = std::move(rhs.m_word_size);
568 m_n_bytes = std::move(rhs.m_n_bytes);
569 m_typestring = std::move(rhs.m_typestring);
570 m_buffer = rhs.m_buffer;
571 rhs.m_buffer =
nullptr;
576 template <
class T, layout_type L>
577 auto cast_impl(
bool check_type)
579 if (m_buffer ==
nullptr)
581 XTENSOR_THROW(std::runtime_error,
"This npy_file has already been cast.");
583 T* ptr =
reinterpret_cast<T*
>(&m_buffer[0]);
584 std::vector<std::size_t>
strides(m_shape.size());
585 std::size_t sz = compute_size(m_shape);
588 if (check_type && m_typestring != detail::build_typestring<T>())
592 "Cast error: formats not matching "s + m_typestring +
" vs "s
593 + detail::build_typestring<T>()
602 "Cast error: layout mismatch between npy file and requested layout."
611 std::vector<std::size_t> shape(m_shape);
613 return std::make_tuple(ptr, sz, std::move(shape), std::move(
strides));
616 template <
class T, layout_type L = layout_type::dynamic>
617 auto cast(
bool check_type =
true) &&
619 auto cast_elems = cast_impl<T, L>(check_type);
622 std::move(std::get<0>(cast_elems)),
623 std::get<1>(cast_elems),
625 std::get<2>(cast_elems),
626 std::get<3>(cast_elems)
630 template <
class T, layout_type L = layout_type::dynamic>
631 auto cast(
bool check_type =
true) const&
633 auto cast_elems = cast_impl<T, L>(check_type);
635 std::get<0>(cast_elems),
636 std::get<1>(cast_elems),
638 std::get<2>(cast_elems),
639 std::get<3>(cast_elems)
643 template <
class T, layout_type L = layout_type::dynamic>
644 auto cast(
bool check_type =
true) &
646 auto cast_elems = cast_impl<T, L>(check_type);
648 std::get<0>(cast_elems),
649 std::get<1>(cast_elems),
651 std::get<2>(cast_elems),
652 std::get<3>(cast_elems)
661 std::size_t n_bytes()
666 std::vector<std::size_t> m_shape;
667 bool m_fortran_order;
668 std::size_t m_word_size;
669 std::size_t m_n_bytes;
670 std::string m_typestring;
674 inline npy_file load_npy_file(std::istream& stream)
677 unsigned char v_major, v_minor;
678 detail::read_magic(stream, &v_major, &v_minor);
682 if (v_major == 1 && v_minor == 0)
684 header = detail::read_header_1_0(stream);
686 else if (v_major == 2 && v_minor == 0)
688 header = detail::read_header_2_0(stream);
692 XTENSOR_THROW(std::runtime_error,
"unsupported file format version");
699 std::vector<std::size_t> shape;
700 detail::parse_header(header, typestr, &fortran_order, shape);
702 npy_file result(shape, fortran_order, typestr);
704 stream.read(result.ptr(), std::streamsize((result.n_bytes())));
708 template <
class O,
class E>
709 inline void dump_npy_stream(O& stream,
const xexpression<E>& e)
711 using value_type =
typename E::value_type;
712 const E& ex = e.derived_cast();
713 auto&& eval_ex =
eval(ex);
714 bool fortran_order =
false;
717 fortran_order =
true;
720 std::string typestring = detail::build_typestring<value_type>();
722 auto shape = eval_ex.shape();
723 detail::write_header(stream, typestring, fortran_order, shape);
725 std::size_t size = compute_size(shape);
727 reinterpret_cast<const char*
>(eval_ex.data()),
728 std::streamsize((
sizeof(value_type) * size))
739 template <
typename E>
742 std::ofstream stream(filename, std::ofstream::binary);
745 XTENSOR_THROW(std::runtime_error,
"IO Error: failed to open file: "s + filename);
748 detail::dump_npy_stream(stream, e);
756 template <
typename E>
759 std::stringstream stream;
760 detail::dump_npy_stream(stream, e);
774 template <
typename T, layout_type L = layout_type::dynamic>
777 detail::npy_file file = detail::load_npy_file(stream);
778 return std::move(file).cast<T, L>();
791 template <
typename T, layout_type L = layout_type::dynamic>
794 std::ifstream stream(filename, std::ifstream::binary);
797 XTENSOR_THROW(std::runtime_error,
"io error: failed to open a file.");
Base class for xexpressions.
auto adapt(C &&container, const SC &shape, layout_type l=L)
Constructs:
auto eval(T &&t) -> std::enable_if_t< detail::is_container< std::decay_t< T > >::value, T && >
Force evaluation of xexpression.
std::size_t compute_strides(const shape_type &shape, layout_type l, strides_type &strides)
Compute the strides given the shape and the layout of an array.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
standard mathematical functions for xexpressions
auto load_npy(std::istream &stream)
Loads a npy file (the NumPy storage format)
void dump_npy(const std::string &filename, const xexpression< E > &e)
Save xexpression to NumPy npy format.