Skip to content

Commit 3cbb9be

Browse files
committed
Handle symmetric with binsparse read/write
1 parent ea97e94 commit 3cbb9be

File tree

4 files changed

+102
-30
lines changed

4 files changed

+102
-30
lines changed

examples/convert_binsparse.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ void convert_to_binsparse(std::string input_file, std::string output_file,
1414
T, I, binsparse::__detail::csr_matrix_owning<T, I>>(input_file);
1515
binsparse::csr_matrix<T, I> matrix{
1616
x.values().data(), x.colind().data(), x.rowptr().data(),
17-
std::get<0>(x.shape()), std::get<1>(x.shape()), I(x.size())};
17+
std::get<0>(x.shape()), std::get<1>(x.shape()), I(x.size()),
18+
x.structure()};
1819
binsparse::write_csr_matrix(output_file, matrix, user_keys);
1920
std::cout << "Writing to binsparse file " << output_file << " using "
2021
<< format << " format...\n";
@@ -23,7 +24,8 @@ void convert_to_binsparse(std::string input_file, std::string output_file,
2324
T, I, binsparse::__detail::coo_matrix_owning<T, I>>(input_file);
2425
binsparse::coo_matrix<T, I> matrix{
2526
x.values().data(), x.rowind().data(), x.colind().data(),
26-
std::get<0>(x.shape()), std::get<1>(x.shape()), I(x.size())};
27+
std::get<0>(x.shape()), std::get<1>(x.shape()), I(x.size()),
28+
x.structure()};
2729
binsparse::write_coo_matrix(output_file, matrix, user_keys);
2830
std::cout << "Writing to binsparse file " << output_file << " using "
2931
<< format << " format...\n";

include/binsparse/binsparse.hpp

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ void write_dense_matrix(std::string fname, dense_matrix<T, I, Order> m,
3434
j["binsparse"]["nnz"] = m.m * m.n;
3535
j["binsparse"]["data_types"]["values"] = type_info<T>::label();
3636

37+
if (m.structure != general) {
38+
j["binsparse"]["structure"] =
39+
__detail::get_structure_name(m.structure).value();
40+
}
41+
3742
for (auto&& v : user_keys.items()) {
3843
j[v.key()] = v.value();
3944
}
@@ -67,7 +72,13 @@ auto read_dense_matrix(std::string fname, Allocator&& alloc = Allocator{}) {
6772

6873
auto values = hdf5_tools::read_dataset<T>(f, "values", alloc);
6974

70-
return dense_matrix<T, I, Order>{values.data(), nrows, ncols};
75+
structure_t structure = general;
76+
77+
if (binsparse_metadata.contains("structure")) {
78+
structure = __detail::parse_structure(binsparse_metadata["structure"]);
79+
}
80+
81+
return dense_matrix<T, I, Order>{values.data(), nrows, ncols, structure};
7182
}
7283

7384
// CSR Format
@@ -96,6 +107,11 @@ void write_csr_matrix(std::string fname, csr_matrix<T, I> m,
96107
j["binsparse"]["data_types"]["indices_1"] = type_info<I>::label();
97108
j["binsparse"]["data_types"]["values"] = type_info<T>::label();
98109

110+
if (m.structure != general) {
111+
j["binsparse"]["structure"] =
112+
__detail::get_structure_name(m.structure).value();
113+
}
114+
99115
for (auto&& v : user_keys.items()) {
100116
j[v.key()] = v.value();
101117
}
@@ -130,8 +146,14 @@ csr_matrix<T, I> read_csr_matrix(std::string fname, Allocator&& alloc) {
130146
auto colind = hdf5_tools::read_dataset<I>(f, "indices_1", i_alloc);
131147
auto row_ptr = hdf5_tools::read_dataset<I>(f, "pointers_to_1", i_alloc);
132148

133-
return csr_matrix<T, I>{values.data(), colind.data(), row_ptr.data(),
134-
nrows, ncols, nnz};
149+
structure_t structure = general;
150+
151+
if (binsparse_metadata.contains("structure")) {
152+
structure = __detail::parse_structure(binsparse_metadata["structure"]);
153+
}
154+
155+
return csr_matrix<T, I>{values.data(), colind.data(), row_ptr.data(), nrows,
156+
ncols, nnz, structure};
135157
}
136158

137159
template <typename T, typename I>
@@ -165,6 +187,11 @@ void write_csc_matrix(std::string fname, csc_matrix<T, I> m,
165187
j["binsparse"]["data_types"]["indices_1"] = type_info<I>::label();
166188
j["binsparse"]["data_types"]["values"] = type_info<T>::label();
167189

190+
if (m.structure != general) {
191+
j["binsparse"]["structure"] =
192+
__detail::get_structure_name(m.structure).value();
193+
}
194+
168195
for (auto&& v : user_keys.items()) {
169196
j[v.key()] = v.value();
170197
}
@@ -199,8 +226,14 @@ csc_matrix<T, I> read_csc_matrix(std::string fname, Allocator&& alloc) {
199226
auto rowind = hdf5_tools::read_dataset<I>(f, "indices_1", i_alloc);
200227
auto col_ptr = hdf5_tools::read_dataset<I>(f, "pointers_to_1", i_alloc);
201228

202-
return csc_matrix<T, I>{values.data(), rowind.data(), col_ptr.data(),
203-
nrows, ncols, nnz};
229+
structure_t structure = general;
230+
231+
if (binsparse_metadata.contains("structure")) {
232+
structure = __detail::parse_structure(binsparse_metadata["structure"]);
233+
}
234+
235+
return csc_matrix<T, I>{values.data(), rowind.data(), col_ptr.data(), nrows,
236+
ncols, nnz, structure};
204237
}
205238

206239
template <typename T, typename I>
@@ -234,6 +267,11 @@ void write_coo_matrix(std::string fname, coo_matrix<T, I> m,
234267
j["binsparse"]["data_types"]["indices_1"] = type_info<I>::label();
235268
j["binsparse"]["data_types"]["values"] = type_info<T>::label();
236269

270+
if (m.structure != general) {
271+
j["binsparse"]["structure"] =
272+
__detail::get_structure_name(m.structure).value();
273+
}
274+
237275
for (auto&& v : user_keys.items()) {
238276
j[v.key()] = v.value();
239277
}
@@ -270,8 +308,14 @@ coo_matrix<T, I> read_coo_matrix(std::string fname, Allocator&& alloc) {
270308
auto rows = hdf5_tools::read_dataset<I>(f, "indices_0", i_alloc);
271309
auto cols = hdf5_tools::read_dataset<I>(f, "indices_1", i_alloc);
272310

273-
return coo_matrix<T, I>{values.data(), rows.data(), cols.data(),
274-
nrows, ncols, nnz};
311+
structure_t structure = general;
312+
313+
if (binsparse_metadata.contains("structure")) {
314+
structure = __detail::parse_structure(binsparse_metadata["structure"]);
315+
}
316+
317+
return coo_matrix<T, I>{values.data(), rows.data(), cols.data(), nrows,
318+
ncols, nnz, structure};
275319
}
276320

277321
template <typename T, typename I>

include/binsparse/detail.hpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,32 @@ inline std::string unalias_format(const std::string& format) {
2727
}
2828
}
2929

30+
inline std::optional<std::string> get_structure_name(structure_t structure) {
31+
if (structure == general) {
32+
return {};
33+
} else if (structure == symmetric) {
34+
return "symmetric_lower";
35+
} else if (structure == skew_symmetric) {
36+
return "skew_symmetric_lower";
37+
} else if (structure == hermitian) {
38+
return "hermitian";
39+
} else {
40+
throw std::runtime_error("get_structure_name: unknown structure");
41+
}
42+
}
43+
44+
inline structure_t parse_structure(const std::string& structure) {
45+
if (structure == "symmetric_lower") {
46+
return symmetric;
47+
} else if (structure == "skew_symmetric_lower") {
48+
return skew_symmetric;
49+
} else if (structure == "hermitian") {
50+
return hermitian;
51+
} else {
52+
throw std::runtime_error("parse_structure: unsupported structure");
53+
}
54+
}
55+
3056
} // namespace __detail
3157

3258
} // namespace binsparse

include/binsparse/matrix_market/matrix_market_read.hpp

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ namespace __detail {
1111
template <typename T, typename I>
1212
class csr_matrix_owning {
1313
public:
14-
csr_matrix_owning(std::tuple<I, I> shape) : shape_(shape) {}
14+
csr_matrix_owning(std::tuple<I, I> shape, structure_t structure = general)
15+
: shape_(shape), structure_(structure) {}
1516

1617
auto values() {
1718
return std::ranges::views::all(values_);
@@ -80,17 +81,23 @@ class csr_matrix_owning {
8081
return values_.size();
8182
}
8283

84+
auto structure() const {
85+
return structure_;
86+
}
87+
8388
private:
8489
std::tuple<I, I> shape_;
8590
std::vector<T> values_;
8691
std::vector<I> rowptr_;
8792
std::vector<I> colind_;
93+
structure_t structure_;
8894
};
8995

9096
template <typename T, typename I>
9197
class coo_matrix_owning {
9298
public:
93-
coo_matrix_owning(std::tuple<I, I> shape) : shape_(shape) {}
99+
coo_matrix_owning(std::tuple<I, I> shape, structure_t structure = general)
100+
: shape_(shape), structure_(structure) {}
94101

95102
auto values() {
96103
return std::ranges::views::all(values_);
@@ -144,11 +151,16 @@ class coo_matrix_owning {
144151
return values_.size();
145152
}
146153

154+
auto structure() const {
155+
return structure_;
156+
}
157+
147158
private:
148159
std::tuple<I, I> shape_;
149160
std::vector<T> values_;
150161
std::vector<I> rowind_;
151162
std::vector<I> colind_;
163+
structure_t structure_;
152164
};
153165

154166
/// Read in the Matrix Market file at location `file_path` and
@@ -170,9 +182,8 @@ inline MatrixType mmread(std::string file_path, bool one_indexed = true) {
170182
std::string buf;
171183

172184
// Make sure the file is matrix market matrix, coordinate, and check whether
173-
// it is symmetric. If the matrix is symmetric, non-diagonal elements will
174-
// be inserted in both (i, j) and (j, i). Error out if skew-symmetric or
175-
// Hermitian.
185+
// it is symmetric. If the matrix is symmetric.
186+
// Error out if skew-symmetric or Hermitian.
176187
std::getline(f, buf);
177188
std::istringstream ss(buf);
178189
std::string item;
@@ -200,11 +211,11 @@ inline MatrixType mmread(std::string file_path, bool one_indexed = true) {
200211
}
201212
// TODO: do something with real vs. integer vs. pattern?
202213
ss >> item;
203-
bool symmetric;
214+
structure_t structure;
204215
if (item == "general") {
205-
symmetric = false;
216+
structure = general;
206217
} else if (item == "symmetric") {
207-
symmetric = true;
218+
structure = symmetric;
208219
} else {
209220
throw std::runtime_error(file_path + " has an unsupported matrix type");
210221
}
@@ -224,18 +235,11 @@ inline MatrixType mmread(std::string file_path, bool one_indexed = true) {
224235
ss.str(buf);
225236
ss >> m >> n >> nnz;
226237

227-
// NOTE for symmetric matrices: `nnz` holds the number of stored values in
228-
// the matrix market file, while `matrix.nnz_` will hold the total number of
229-
// stored values (including "mirrored" symmetric values).
230-
MatrixType m_out({m, n});
238+
MatrixType m_out({m, n}, structure);
231239

232240
using coo_type = std::vector<std::tuple<std::tuple<I, I>, T>>;
233241
coo_type matrix;
234-
if (symmetric) {
235-
matrix.reserve(2 * nnz);
236-
} else {
237-
matrix.reserve(nnz);
238-
}
242+
matrix.reserve(nnz);
239243

240244
size_type c = 0;
241245
while (std::getline(f, buf)) {
@@ -260,10 +264,6 @@ inline MatrixType mmread(std::string file_path, bool one_indexed = true) {
260264

261265
matrix.push_back({{i, j}, v});
262266

263-
if (symmetric && i != j) {
264-
matrix.push_back({{j, i}, v});
265-
}
266-
267267
c++;
268268
if (c > nnz) {
269269
throw std::runtime_error("read_MatrixMarket: error reading Matrix Market "

0 commit comments

Comments
 (0)