11#ifndef XTENSOR_NPY_HPP
12#define XTENSOR_NPY_HPP
31#include <xtl/xplatform.hpp>
32#include <xtl/xsequence.hpp>
34#include "../containers/xadapt.hpp"
35#include "../containers/xarray.hpp"
36#include "../core/xeval.hpp"
37#include "../core/xstrides.hpp"
38#include "../core/xtensor_config.hpp"
42 using namespace std::string_literals;
47 const char magic_string[] =
"\x93NUMPY";
48 const std::size_t magic_string_length =
sizeof(magic_string) - 1;
51 inline void write_magic(O& ostream,
unsigned char v_major = 1,
unsigned char v_minor = 0)
53 ostream.write(magic_string, magic_string_length);
54 ostream.put(
char(v_major));
55 ostream.put(
char(v_minor));
58 inline void read_magic(std::istream& istream,
unsigned char* v_major,
unsigned char* v_minor)
60 std::unique_ptr<char[]> buf(
new char[magic_string_length + 2]);
61 istream.read(buf.get(), magic_string_length + 2);
65 XTENSOR_THROW(std::runtime_error,
"io error: failed reading file");
68 for (std::size_t i = 0; i < magic_string_length; i++)
70 if (buf[i] != magic_string[i])
72 XTENSOR_THROW(std::runtime_error,
"this file do not have a valid npy format.");
76 *v_major =
static_cast<unsigned char>(buf[magic_string_length]);
77 *v_minor =
static_cast<unsigned char>(buf[magic_string_length + 1]);
81 inline char map_type()
83 if (std::is_same<T, float>::value)
87 if (std::is_same<T, double>::value)
91 if (std::is_same<T, long double>::value)
96 if (std::is_same<T, char>::value)
100 if (std::is_same<T, signed char>::value)
104 if (std::is_same<T, short>::value)
108 if (std::is_same<T, int>::value)
112 if (std::is_same<T, long>::value)
116 if (std::is_same<T, long long>::value)
121 if (std::is_same<T, unsigned char>::value)
125 if (std::is_same<T, unsigned short>::value)
129 if (std::is_same<T, unsigned int>::value)
133 if (std::is_same<T, unsigned long>::value)
137 if (std::is_same<T, unsigned long long>::value)
142 if (std::is_same<T, bool>::value)
147 if (std::is_same<T, std::complex<float>>::value)
151 if (std::is_same<T, std::complex<double>>::value)
155 if (std::is_same<T, std::complex<long double>>::value)
160 XTENSOR_THROW(std::runtime_error,
"Type not known.");
164 inline char get_endianess()
166 constexpr char little_endian_char =
'<';
167 constexpr char big_endian_char =
'>';
168 constexpr char no_endian_char =
'|';
170 if (
sizeof(T) <=
sizeof(
char))
172 return no_endian_char;
175 switch (xtl::endianness())
177 case xtl::endian::little_endian:
178 return little_endian_char;
179 case xtl::endian::big_endian:
180 return big_endian_char;
182 return no_endian_char;
187 inline std::string build_typestring()
189 std::stringstream ss;
190 ss << get_endianess<T>() << map_type<T>() <<
sizeof(T);
195 inline void parse_typestring(std::string typestring)
197 std::regex re(
"'([<>|])([ifucb])(\\d+)'");
200 std::regex_match(typestring, sm, re);
203 XTENSOR_THROW(std::runtime_error,
"invalid typestring");
208 inline std::string unwrap_s(std::string s,
char delim_front,
char delim_back)
210 if ((s.back() == delim_back) && (s.front() == delim_front))
212 return s.substr(1, s.length() - 2);
216 XTENSOR_THROW(std::runtime_error,
"unable to unwrap");
220 inline std::string get_value_from_map(std::string mapstr)
222 std::size_t sep_pos = mapstr.find_first_of(
":");
223 if (sep_pos == std::string::npos)
228 return mapstr.substr(sep_pos + 1);
231 inline void pop_char(std::string& s,
char c)
240 parse_header(std::string header, std::string& descr,
bool* fortran_order, std::vector<std::size_t>& shape)
275 if (header.back() !=
'\n')
277 XTENSOR_THROW(std::runtime_error,
"invalid header");
282 header.erase(std::remove(header.begin(), header.end(),
' '), header.end());
285 header = unwrap_s(header,
'{',
'}');
288 std::size_t keypos_descr = header.find(
"'descr'");
289 std::size_t keypos_fortran = header.find(
"'fortran_order'");
290 std::size_t keypos_shape = header.find(
"'shape'");
293 if (keypos_descr == std::string::npos)
295 XTENSOR_THROW(std::runtime_error,
"missing 'descr' key");
297 if (keypos_fortran == std::string::npos)
299 XTENSOR_THROW(std::runtime_error,
"missing 'fortran_order' key");
301 if (keypos_shape == std::string::npos)
303 XTENSOR_THROW(std::runtime_error,
"missing 'shape' key");
310 if (keypos_descr >= keypos_fortran || keypos_fortran >= keypos_shape)
312 XTENSOR_THROW(std::runtime_error,
"header keys in wrong order");
316 std::string keyvalue_descr;
317 keyvalue_descr = header.substr(keypos_descr, keypos_fortran - keypos_descr);
318 pop_char(keyvalue_descr,
',');
320 std::string keyvalue_fortran;
321 keyvalue_fortran = header.substr(keypos_fortran, keypos_shape - keypos_fortran);
322 pop_char(keyvalue_fortran,
',');
324 std::string keyvalue_shape;
325 keyvalue_shape = header.substr(keypos_shape, std::string::npos);
326 pop_char(keyvalue_shape,
',');
329 std::string descr_s = get_value_from_map(keyvalue_descr);
330 std::string fortran_s = get_value_from_map(keyvalue_fortran);
331 std::string shape_s = get_value_from_map(keyvalue_shape);
333 parse_typestring(descr_s);
334 descr = unwrap_s(descr_s,
'\'',
'\'');
337 if (fortran_s ==
"True")
339 *fortran_order =
true;
341 else if (fortran_s ==
"False")
343 *fortran_order =
false;
347 XTENSOR_THROW(std::runtime_error,
"invalid fortran_order value");
354 shape_s = unwrap_s(shape_s,
'(',
')');
360 std::size_t pos_next = shape_s.find_first_of(
',', pos);
363 if (pos_next != std::string::npos)
365 dim_s = shape_s.substr(pos, pos_next - pos);
369 dim_s = shape_s.substr(pos);
372 if (dim_s.length() == 0)
374 if (pos_next != std::string::npos)
376 XTENSOR_THROW(std::runtime_error,
"invalid shape");
381 std::stringstream ss;
385 shape.push_back(tmp);
388 if (pos_next != std::string::npos)
399 template <
class O,
class S>
400 inline void write_header(O& out,
const std::string& descr,
bool fortran_order,
const S& shape)
402 std::ostringstream ss_header;
403 std::string s_fortran_order;
406 s_fortran_order =
"True";
410 s_fortran_order =
"False";
414 std::ostringstream ss_shape;
416 for (
auto shape_it = std::begin(shape); shape_it != std::end(shape); ++shape_it)
418 ss_shape << *shape_it <<
", ";
420 s_shape = ss_shape.str();
421 if (std::size(shape) > 1)
423 s_shape = s_shape.erase(s_shape.size() - 2);
425 else if (std::size(shape) == 1)
427 s_shape = s_shape.erase(s_shape.size() - 1);
431 ss_header <<
"{'descr': '" << descr <<
"', 'fortran_order': " << s_fortran_order
432 <<
", 'shape': " << s_shape <<
", }";
434 std::size_t header_len_pre = ss_header.str().length() + 1;
435 std::size_t metadata_len = magic_string_length + 2 + 2 + header_len_pre;
437 unsigned char version[2] = {1, 0};
438 if (metadata_len >= 255 * 255)
440 metadata_len = magic_string_length + 2 + 4 + header_len_pre;
444 std::size_t padding_len = 64 - (metadata_len % 64);
445 std::string padding(padding_len,
' ');
446 ss_header << padding;
447 ss_header << std::endl;
449 std::string header = ss_header.str();
452 write_magic(out, version[0], version[1]);
455 if (version[0] == 1 && version[1] == 0)
457 char header_len_le16[2];
458 uint16_t header_len = uint16_t(header.length());
460 header_len_le16[0] = char((header_len >> 0) & 0xff);
461 header_len_le16[1] = char((header_len >> 8) & 0xff);
462 out.write(
reinterpret_cast<char*
>(header_len_le16), 2);
466 char header_len_le32[4];
467 uint32_t header_len = uint32_t(header.length());
469 header_len_le32[0] = char((header_len >> 0) & 0xff);
470 header_len_le32[1] = char((header_len >> 8) & 0xff);
471 header_len_le32[2] = char((header_len >> 16) & 0xff);
472 header_len_le32[3] = char((header_len >> 24) & 0xff);
473 out.write(
reinterpret_cast<char*
>(header_len_le32), 4);
479 inline std::string read_header_1_0(std::istream& istream)
482 char header_len_le16[2];
483 istream.read(header_len_le16, 2);
485 uint16_t header_length = uint16_t(header_len_le16[0] << 0) | uint16_t(header_len_le16[1] << 8);
487 if ((magic_string_length + 2 + 2 + header_length) % 16 != 0)
492 std::unique_ptr<char[]> buf(
new char[header_length]);
493 istream.read(buf.get(), header_length);
494 std::string header(buf.get(), header_length);
499 inline std::string read_header_2_0(std::istream& istream)
502 char header_len_le32[4];
503 istream.read(header_len_le32, 4);
505 uint32_t header_length = uint32_t(header_len_le32[0] << 0) | uint32_t(header_len_le32[1] << 8)
506 | uint32_t(header_len_le32[2] << 16) | uint32_t(header_len_le32[3] << 24);
508 if ((magic_string_length + 2 + 4 + header_length) % 16 != 0)
513 std::unique_ptr<char[]> buf(
new char[header_length]);
514 istream.read(buf.get(), header_length);
515 std::string header(buf.get(), header_length);
522 npy_file() =
default;
524 npy_file(std::vector<std::size_t>& shape,
bool fortran_order, std::string typestring)
526 , m_fortran_order(fortran_order)
527 , m_typestring(typestring)
530 m_word_size = std::size_t(atoi(&typestring[2]));
531 m_n_bytes = compute_size(shape) * m_word_size;
532 m_buffer = std::allocator<char>{}.allocate(m_n_bytes);
537 if (m_buffer !=
nullptr)
539 std::allocator<char>{}.deallocate(m_buffer, m_n_bytes);
544 npy_file(
const npy_file&) =
delete;
545 npy_file& operator=(
const npy_file&) =
delete;
548 npy_file(npy_file&& rhs)
549 : m_shape(std::move(rhs.m_shape))
550 , m_fortran_order(std::move(rhs.m_fortran_order))
551 , m_word_size(std::move(rhs.m_word_size))
552 , m_n_bytes(std::move(rhs.m_n_bytes))
553 , m_typestring(std::move(rhs.m_typestring))
554 , m_buffer(rhs.m_buffer)
556 rhs.m_buffer =
nullptr;
559 npy_file& operator=(npy_file&& rhs)
563 m_shape = std::move(rhs.m_shape);
564 m_fortran_order = std::move(rhs.m_fortran_order);
565 m_word_size = std::move(rhs.m_word_size);
566 m_n_bytes = std::move(rhs.m_n_bytes);
567 m_typestring = std::move(rhs.m_typestring);
568 m_buffer = rhs.m_buffer;
569 rhs.m_buffer =
nullptr;
574 template <
class T, layout_type L>
575 auto cast_impl(
bool check_type)
577 if (m_buffer ==
nullptr)
579 XTENSOR_THROW(std::runtime_error,
"This npy_file has already been cast.");
581 T* ptr =
reinterpret_cast<T*
>(&m_buffer[0]);
582 std::vector<std::size_t>
strides(m_shape.size());
583 std::size_t sz = compute_size(m_shape);
586 if (check_type && m_typestring != detail::build_typestring<T>())
590 "Cast error: formats not matching "s + m_typestring +
" vs "s
591 + detail::build_typestring<T>()
600 "Cast error: layout mismatch between npy file and requested layout."
609 std::vector<std::size_t> shape(m_shape);
611 return std::make_tuple(ptr, sz, std::move(shape), std::move(
strides));
614 template <
class T, layout_type L = layout_type::dynamic>
615 auto cast(
bool check_type =
true) &&
617 auto cast_elems = cast_impl<T, L>(check_type);
620 std::move(std::get<0>(cast_elems)),
621 std::get<1>(cast_elems),
623 std::get<2>(cast_elems),
624 std::get<3>(cast_elems)
628 template <
class T, layout_type L = layout_type::dynamic>
629 auto cast(
bool check_type =
true) const&
631 auto cast_elems = cast_impl<T, L>(check_type);
633 std::get<0>(cast_elems),
634 std::get<1>(cast_elems),
636 std::get<2>(cast_elems),
637 std::get<3>(cast_elems)
641 template <
class T, layout_type L = layout_type::dynamic>
642 auto cast(
bool check_type =
true) &
644 auto cast_elems = cast_impl<T, L>(check_type);
646 std::get<0>(cast_elems),
647 std::get<1>(cast_elems),
649 std::get<2>(cast_elems),
650 std::get<3>(cast_elems)
659 std::size_t n_bytes()
664 std::vector<std::size_t> m_shape;
665 bool m_fortran_order;
666 std::size_t m_word_size;
667 std::size_t m_n_bytes;
668 std::string m_typestring;
672 inline npy_file load_npy_file(std::istream& stream)
675 unsigned char v_major, v_minor;
676 detail::read_magic(stream, &v_major, &v_minor);
680 if (v_major == 1 && v_minor == 0)
682 header = detail::read_header_1_0(stream);
684 else if (v_major == 2 && v_minor == 0)
686 header = detail::read_header_2_0(stream);
690 XTENSOR_THROW(std::runtime_error,
"unsupported file format version");
697 std::vector<std::size_t> shape;
698 detail::parse_header(header, typestr, &fortran_order, shape);
700 npy_file result(shape, fortran_order, typestr);
702 stream.read(result.ptr(), std::streamsize((result.n_bytes())));
706 template <
class O,
class E>
707 inline void dump_npy_stream(O& stream,
const xexpression<E>& e)
709 using value_type =
typename E::value_type;
710 const E& ex = e.derived_cast();
711 auto&& eval_ex =
eval(ex);
712 bool fortran_order =
false;
715 fortran_order =
true;
718 std::string typestring = detail::build_typestring<value_type>();
720 auto shape = eval_ex.shape();
721 detail::write_header(stream, typestring, fortran_order, shape);
723 std::size_t size = compute_size(shape);
725 reinterpret_cast<const char*
>(eval_ex.data()),
726 std::streamsize((
sizeof(value_type) * size))
737 template <
typename E>
740 std::ofstream stream(filename, std::ofstream::binary);
743 XTENSOR_THROW(std::runtime_error,
"IO Error: failed to open file: "s + filename);
746 detail::dump_npy_stream(stream, e);
754 template <
typename E>
757 std::stringstream stream;
758 detail::dump_npy_stream(stream, e);
772 template <
typename T, layout_type L = layout_type::dynamic>
775 detail::npy_file file = detail::load_npy_file(stream);
776 return std::move(file).cast<T, L>();
789 template <
typename T, layout_type L = layout_type::dynamic>
792 std::ifstream stream(filename, std::ifstream::binary);
795 XTENSOR_THROW(std::runtime_error,
"io error: failed to open a file.");
Base class for xexpressions.
auto adapt(C &&container, const SC &shape, layout_type l=L)
Constructs:
auto eval(T &&t) -> std::enable_if_t< detail::is_container< std::decay_t< T > >::value, T && >
Force evaluation of xexpression.
std::size_t compute_strides(const shape_type &shape, layout_type l, strides_type &strides)
Compute the strides given the shape and the layout of an array.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
standard mathematical functions for xexpressions
auto load_npy(std::istream &stream)
Loads a npy file (the NumPy storage format)
void dump_npy(const std::string &filename, const xexpression< E > &e)
Save xexpression to NumPy npy format.