Skip to content

Commit 84b4617

Browse files
committed
bugfix
1 parent 014c4ba commit 84b4617

File tree

1 file changed

+23
-24
lines changed

1 file changed

+23
-24
lines changed

csrc/cpu/neighbor_sample_cpu.cpp

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,8 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
3939
if (col_count == 0)
4040
continue;
4141

42-
if (replace) {
43-
for (int64_t j = 0; j < num_samples; j++) {
44-
const int64_t offset = col_start + rand() % col_count;
42+
if ((num_samples < 0) || (!replace && (num_samples >= col_count))) {
43+
for (int64_t offset = col_start; offset < col_end; offset++) {
4544
const int64_t &v = row_data[offset];
4645
const auto res = to_local_node.insert({v, samples.size()});
4746
if (res.second)
@@ -52,8 +51,9 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
5251
edges.push_back(offset);
5352
}
5453
}
55-
} else if (num_samples >= col_count) {
56-
for (int64_t offset = col_start; offset < col_end; offset++) {
54+
} else if (replace) {
55+
for (int64_t j = 0; j < num_samples; j++) {
56+
const int64_t offset = col_start + rand() % col_count;
5757
const int64_t &v = row_data[offset];
5858
const auto res = to_local_node.insert({v, samples.size()});
5959
if (res.second)
@@ -111,14 +111,14 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
111111
}
112112

113113
template <bool replace, bool directed>
114-
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
115-
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
116-
hetero_sample(const std::vector<node_t> &node_types,
117-
const std::vector<edge_t> &edge_types,
114+
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
115+
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
116+
hetero_sample(const vector<node_t> &node_types,
117+
const vector<edge_t> &edge_types,
118118
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
119119
const c10::Dict<rel_t, torch::Tensor> &row_dict,
120120
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
121-
const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict,
121+
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
122122
const int64_t num_hops) {
123123

124124
// Create a mapping to convert single string relations to edge type triplets:
@@ -129,9 +129,9 @@ hetero_sample(const std::vector<node_t> &node_types,
129129
// Initialize some data structures for the sampling process:
130130
unordered_map<node_t, vector<int64_t>> samples_dict;
131131
unordered_map<node_t, unordered_map<int64_t, int64_t>> to_local_node_dict;
132-
for (const auto &k : node_types) {
133-
samples_dict[k];
134-
to_local_node_dict[k];
132+
for (const auto &node_type : node_types) {
133+
samples_dict[node_type];
134+
to_local_node_dict[node_type];
135135
}
136136

137137
unordered_map<rel_t, vector<int64_t>> rows_dict, cols_dict, edges_dict;
@@ -167,7 +167,7 @@ hetero_sample(const std::vector<node_t> &node_types,
167167
const auto &edge_type = to_edge_type[rel_type];
168168
const auto &src_node_type = get<0>(edge_type);
169169
const auto &dst_node_type = get<2>(edge_type);
170-
const auto &num_samples = kv.value()[ell];
170+
const auto num_samples = kv.value()[ell];
171171
const auto &dst_samples = samples_dict.at(dst_node_type);
172172
auto &src_samples = samples_dict.at(src_node_type);
173173
auto &to_local_src_node = to_local_node_dict.at(src_node_type);
@@ -190,9 +190,8 @@ hetero_sample(const std::vector<node_t> &node_types,
190190
if (col_count == 0)
191191
continue;
192192

193-
if (replace) {
194-
for (int64_t j = 0; j < num_samples; j++) {
195-
const int64_t offset = col_start + rand() % col_count;
193+
if ((num_samples < 0) || (!replace && (num_samples >= col_count))) {
194+
for (int64_t offset = col_start; offset < col_end; offset++) {
196195
const int64_t &v = row_data[offset];
197196
const auto res = to_local_src_node.insert({v, src_samples.size()});
198197
if (res.second)
@@ -203,8 +202,9 @@ hetero_sample(const std::vector<node_t> &node_types,
203202
edges.push_back(offset);
204203
}
205204
}
206-
} else if (num_samples >= col_count) {
207-
for (int64_t offset = col_start; offset < col_end; offset++) {
205+
} else if (replace) {
206+
for (int64_t j = 0; j < num_samples; j++) {
207+
const int64_t offset = col_start + rand() % col_count;
208208
const int64_t &v = row_data[offset];
209209
const auto res = to_local_src_node.insert({v, src_samples.size()});
210210
if (res.second)
@@ -302,15 +302,14 @@ neighbor_sample_cpu(const torch::Tensor &colptr, const torch::Tensor &row,
302302
}
303303
}
304304

305-
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
306-
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
305+
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
306+
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
307307
hetero_neighbor_sample_cpu(
308-
const std::vector<node_t> &node_types,
309-
const std::vector<edge_t> &edge_types,
308+
const vector<node_t> &node_types, const vector<edge_t> &edge_types,
310309
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
311310
const c10::Dict<rel_t, torch::Tensor> &row_dict,
312311
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
313-
const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict,
312+
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
314313
const int64_t num_hops, const bool replace, const bool directed) {
315314

316315
if (replace && directed) {

0 commit comments

Comments
 (0)