@@ -39,9 +39,8 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
3939 if (col_count == 0 )
4040 continue ;
4141
42- if (replace) {
43- for (int64_t j = 0 ; j < num_samples; j++) {
44- const int64_t offset = col_start + rand () % col_count;
42+ if ((num_samples < 0 ) || (!replace && (num_samples >= col_count))) {
43+ for (int64_t offset = col_start; offset < col_end; offset++) {
4544 const int64_t &v = row_data[offset];
4645 const auto res = to_local_node.insert ({v, samples.size ()});
4746 if (res.second )
@@ -52,8 +51,9 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
5251 edges.push_back (offset);
5352 }
5453 }
55- } else if (num_samples >= col_count) {
56- for (int64_t offset = col_start; offset < col_end; offset++) {
54+ } else if (replace) {
55+ for (int64_t j = 0 ; j < num_samples; j++) {
56+ const int64_t offset = col_start + rand () % col_count;
5757 const int64_t &v = row_data[offset];
5858 const auto res = to_local_node.insert ({v, samples.size ()});
5959 if (res.second )
@@ -111,14 +111,14 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
111111}
112112
113113template <bool replace, bool directed>
114- std:: tuple<c10::Dict<node_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>,
115- c10::Dict<rel_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>>
116- hetero_sample (const std:: vector<node_t > &node_types,
117- const std:: vector<edge_t > &edge_types,
114+ tuple<c10::Dict<node_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>,
115+ c10::Dict<rel_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>>
116+ hetero_sample (const vector<node_t > &node_types,
117+ const vector<edge_t > &edge_types,
118118 const c10::Dict<rel_t , torch::Tensor> &colptr_dict,
119119 const c10::Dict<rel_t , torch::Tensor> &row_dict,
120120 const c10::Dict<node_t , torch::Tensor> &input_node_dict,
121- const c10::Dict<rel_t , std:: vector<int64_t >> &num_neighbors_dict,
121+ const c10::Dict<rel_t , vector<int64_t >> &num_neighbors_dict,
122122 const int64_t num_hops) {
123123
124124 // Create a mapping to convert single string relations to edge type triplets:
@@ -129,9 +129,9 @@ hetero_sample(const std::vector<node_t> &node_types,
129129 // Initialize some data structures for the sampling process:
130130 unordered_map<node_t , vector<int64_t >> samples_dict;
131131 unordered_map<node_t , unordered_map<int64_t , int64_t >> to_local_node_dict;
132- for (const auto &k : node_types) {
133- samples_dict[k ];
134- to_local_node_dict[k ];
132+ for (const auto &node_type : node_types) {
133+ samples_dict[node_type ];
134+ to_local_node_dict[node_type ];
135135 }
136136
137137 unordered_map<rel_t , vector<int64_t >> rows_dict, cols_dict, edges_dict;
@@ -167,7 +167,7 @@ hetero_sample(const std::vector<node_t> &node_types,
167167 const auto &edge_type = to_edge_type[rel_type];
168168 const auto &src_node_type = get<0 >(edge_type);
169169 const auto &dst_node_type = get<2 >(edge_type);
170- const auto & num_samples = kv.value ()[ell];
170+ const auto num_samples = kv.value ()[ell];
171171 const auto &dst_samples = samples_dict.at (dst_node_type);
172172 auto &src_samples = samples_dict.at (src_node_type);
173173 auto &to_local_src_node = to_local_node_dict.at (src_node_type);
@@ -190,9 +190,8 @@ hetero_sample(const std::vector<node_t> &node_types,
190190 if (col_count == 0 )
191191 continue ;
192192
193- if (replace) {
194- for (int64_t j = 0 ; j < num_samples; j++) {
195- const int64_t offset = col_start + rand () % col_count;
193+ if ((num_samples < 0 ) || (!replace && (num_samples >= col_count))) {
194+ for (int64_t offset = col_start; offset < col_end; offset++) {
196195 const int64_t &v = row_data[offset];
197196 const auto res = to_local_src_node.insert ({v, src_samples.size ()});
198197 if (res.second )
@@ -203,8 +202,9 @@ hetero_sample(const std::vector<node_t> &node_types,
203202 edges.push_back (offset);
204203 }
205204 }
206- } else if (num_samples >= col_count) {
207- for (int64_t offset = col_start; offset < col_end; offset++) {
205+ } else if (replace) {
206+ for (int64_t j = 0 ; j < num_samples; j++) {
207+ const int64_t offset = col_start + rand () % col_count;
208208 const int64_t &v = row_data[offset];
209209 const auto res = to_local_src_node.insert ({v, src_samples.size ()});
210210 if (res.second )
@@ -302,15 +302,14 @@ neighbor_sample_cpu(const torch::Tensor &colptr, const torch::Tensor &row,
302302 }
303303}
304304
305- std:: tuple<c10::Dict<node_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>,
306- c10::Dict<rel_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>>
305+ tuple<c10::Dict<node_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>,
306+ c10::Dict<rel_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>>
307307hetero_neighbor_sample_cpu (
308- const std::vector<node_t > &node_types,
309- const std::vector<edge_t > &edge_types,
308+ const vector<node_t > &node_types, const vector<edge_t > &edge_types,
310309 const c10::Dict<rel_t , torch::Tensor> &colptr_dict,
311310 const c10::Dict<rel_t , torch::Tensor> &row_dict,
312311 const c10::Dict<node_t , torch::Tensor> &input_node_dict,
313- const c10::Dict<rel_t , std:: vector<int64_t >> &num_neighbors_dict,
312+ const c10::Dict<rel_t , vector<int64_t >> &num_neighbors_dict,
314313 const int64_t num_hops, const bool replace, const bool directed) {
315314
316315 if (replace && directed) {
0 commit comments