diff --git a/include/xtensor/io/xcsv.hpp b/include/xtensor/io/xcsv.hpp index 2b758bd9e..080ccaf54 100644 --- a/include/xtensor/io/xcsv.hpp +++ b/include/xtensor/io/xcsv.hpp @@ -10,7 +10,6 @@ #ifndef XTENSOR_CSV_HPP #define XTENSOR_CSV_HPP -#include #include #include #include @@ -211,30 +210,40 @@ namespace xt { using size_type = typename E::size_type; const E& ex = e.derived_cast(); - if (ex.dimension() != 2) + if (ex.dimension() == 1) { - XTENSOR_THROW(std::runtime_error, "Only 2-D expressions can be serialized to CSV"); - } - size_type nbrows = ex.shape()[0], nbcols = ex.shape()[1]; - auto st = ex.stepper_begin(ex.shape()); - for (size_type r = 0; r != nbrows; ++r) - { - for (size_type c = 0; c != nbcols; ++c) + const size_type n = ex.shape()[0]; + for (size_type i = 0; i != n; ++i) { - stream << *st; - if (c != nbcols - 1) + stream << ex(i); + if (i != n - 1) { - st.step(1); stream << ','; } - else + } + stream << std::endl; + } + else if (ex.dimension() == 2) + { + const size_type nbrows = ex.shape()[0]; + const size_type nbcols = ex.shape()[1]; + for (size_type r = 0; r != nbrows; ++r) + { + for (size_type c = 0; c != nbcols; ++c) { - st.reset(1); - st.step(0); - stream << std::endl; + stream << ex(r, c); + if (c != nbcols - 1) + { + stream << ','; + } } + stream << std::endl; } } + else + { + XTENSOR_THROW(std::runtime_error, "Only 1-D and 2-D expressions can be serialized to CSV"); + } } struct xcsv_config diff --git a/test/test_xcsv.cpp b/test/test_xcsv.cpp index 584d17345..767718fe1 100644 --- a/test/test_xcsv.cpp +++ b/test/test_xcsv.cpp @@ -7,10 +7,8 @@ * The full license is in the file LICENSE, distributed with this software. * ****************************************************************************/ -#include #include -#include "xtensor/core/xmath.hpp" #include "xtensor/io/xcsv.hpp" #include "xtensor/io/xio.hpp" @@ -18,6 +16,19 @@ namespace xt { + TEST(xcsv, load_1D) + { + const std::string source = "1, 2, 3, 4"; + + std::stringstream source_stream(source); + + const xtensor res = load_csv(source_stream); + + const xtensor exp{{1, 2, 3, 4}}; + + ASSERT_TRUE(all(equal(res, exp))); + } + TEST(xcsv, load_double) { std::string source = "1.0, 2.0, 3.0, 4.0\n" @@ -49,6 +60,16 @@ namespace xt ASSERT_TRUE(all(equal(res, exp))); } + TEST(xcsv, dump_1D) + { + xtensor data{{1.0, 2.0, 3.0, 4.0}}; + + std::stringstream res; + + dump_csv(res, data); + ASSERT_EQ("1,2,3,4\n", res.str()); + } + TEST(xcsv, dump_double) { xtensor data{{1.0, 2.0, 3.0, 4.0}, {10.0, 12.0, 15.0, 18.0}};