Skip to content

Commit 014c4ba

Browse files
committed
hetero neighbor sampling
1 parent 9532032 commit 014c4ba

File tree

7 files changed

+460
-11
lines changed

7 files changed

+460
-11
lines changed

csrc/cpu/hgt_sample_cpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
102102
const int64_t num_hops) {
103103

104104
// Create a mapping to convert single string relations to edge type triplets:
105-
std::unordered_map<rel_t, edge_t> to_edge_type;
105+
unordered_map<rel_t, edge_t> to_edge_type;
106106
for (const auto &kv : colptr_dict) {
107107
const auto &rel_type = kv.key();
108108
to_edge_type[rel_type] = split(rel_type);

csrc/cpu/neighbor_sample_cpu.cpp

Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
#include "neighbor_sample_cpu.h"
2+
3+
#include "utils.h"
4+
5+
using namespace std;
6+
7+
namespace {
8+
9+
template <bool replace, bool directed>
10+
tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
11+
sample(const torch::Tensor &colptr, const torch::Tensor &row,
12+
const torch::Tensor &input_node, const vector<int64_t> num_neighbors) {
13+
14+
// Initialize some data structures for the sampling process:
15+
vector<int64_t> samples;
16+
unordered_map<int64_t, int64_t> to_local_node;
17+
18+
auto *colptr_data = colptr.data_ptr<int64_t>();
19+
auto *row_data = row.data_ptr<int64_t>();
20+
auto *input_node_data = input_node.data_ptr<int64_t>();
21+
22+
for (int64_t i = 0; i < input_node.numel(); i++) {
23+
const auto &v = input_node_data[i];
24+
samples.push_back(v);
25+
to_local_node.insert({v, i});
26+
}
27+
28+
vector<int64_t> rows, cols, edges;
29+
30+
int64_t begin = 0, end = samples.size();
31+
for (int64_t ell = 0; ell < (int64_t)num_neighbors.size(); ell++) {
32+
const auto &num_samples = num_neighbors[ell];
33+
for (int64_t i = begin; i < end; i++) {
34+
const auto &w = samples[i];
35+
const auto &col_start = colptr_data[w];
36+
const auto &col_end = colptr_data[w + 1];
37+
const auto col_count = col_end - col_start;
38+
39+
if (col_count == 0)
40+
continue;
41+
42+
if (replace) {
43+
for (int64_t j = 0; j < num_samples; j++) {
44+
const int64_t offset = col_start + rand() % col_count;
45+
const int64_t &v = row_data[offset];
46+
const auto res = to_local_node.insert({v, samples.size()});
47+
if (res.second)
48+
samples.push_back(v);
49+
if (directed) {
50+
cols.push_back(i);
51+
rows.push_back(res.first->second);
52+
edges.push_back(offset);
53+
}
54+
}
55+
} else if (num_samples >= col_count) {
56+
for (int64_t offset = col_start; offset < col_end; offset++) {
57+
const int64_t &v = row_data[offset];
58+
const auto res = to_local_node.insert({v, samples.size()});
59+
if (res.second)
60+
samples.push_back(v);
61+
if (directed) {
62+
cols.push_back(i);
63+
rows.push_back(res.first->second);
64+
edges.push_back(offset);
65+
}
66+
}
67+
} else {
68+
unordered_set<int64_t> rnd_indices;
69+
for (int64_t j = col_count - num_samples; j < col_count; j++) {
70+
int64_t rnd = rand() % j;
71+
if (!rnd_indices.insert(rnd).second) {
72+
rnd = j;
73+
rnd_indices.insert(j);
74+
}
75+
const int64_t offset = col_start + rnd;
76+
const int64_t &v = row_data[offset];
77+
const auto res = to_local_node.insert({v, samples.size()});
78+
if (res.second)
79+
samples.push_back(v);
80+
if (directed) {
81+
cols.push_back(i);
82+
rows.push_back(res.first->second);
83+
edges.push_back(offset);
84+
}
85+
}
86+
}
87+
}
88+
begin = end, end = samples.size();
89+
}
90+
91+
if (!directed) {
92+
unordered_map<int64_t, int64_t>::iterator iter;
93+
for (int64_t i = 0; i < (int64_t)samples.size(); i++) {
94+
const auto &w = samples[i];
95+
const auto &col_start = colptr_data[w];
96+
const auto &col_end = colptr_data[w + 1];
97+
for (int64_t offset = col_start; offset < col_end; offset++) {
98+
const auto &v = row_data[offset];
99+
iter = to_local_node.find(v);
100+
if (iter != to_local_node.end()) {
101+
rows.push_back(iter->second);
102+
cols.push_back(i);
103+
edges.push_back(offset);
104+
}
105+
}
106+
}
107+
}
108+
109+
return make_tuple(from_vector<int64_t>(samples), from_vector<int64_t>(rows),
110+
from_vector<int64_t>(cols), from_vector<int64_t>(edges));
111+
}
112+
113+
template <bool replace, bool directed>
114+
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
115+
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
116+
hetero_sample(const std::vector<node_t> &node_types,
117+
const std::vector<edge_t> &edge_types,
118+
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
119+
const c10::Dict<rel_t, torch::Tensor> &row_dict,
120+
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
121+
const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict,
122+
const int64_t num_hops) {
123+
124+
// Create a mapping to convert single string relations to edge type triplets:
125+
unordered_map<rel_t, edge_t> to_edge_type;
126+
for (const auto &k : edge_types)
127+
to_edge_type[get<0>(k) + "__" + get<1>(k) + "__" + get<2>(k)] = k;
128+
129+
// Initialize some data structures for the sampling process:
130+
unordered_map<node_t, vector<int64_t>> samples_dict;
131+
unordered_map<node_t, unordered_map<int64_t, int64_t>> to_local_node_dict;
132+
for (const auto &k : node_types) {
133+
samples_dict[k];
134+
to_local_node_dict[k];
135+
}
136+
137+
unordered_map<rel_t, vector<int64_t>> rows_dict, cols_dict, edges_dict;
138+
for (const auto &kv : colptr_dict) {
139+
const auto &rel_type = kv.key();
140+
rows_dict[rel_type];
141+
cols_dict[rel_type];
142+
edges_dict[rel_type];
143+
}
144+
145+
// Add the input nodes to the output nodes:
146+
for (const auto &kv : input_node_dict) {
147+
const auto &node_type = kv.key();
148+
const auto &input_node = kv.value();
149+
const auto *input_node_data = input_node.data_ptr<int64_t>();
150+
151+
auto &samples = samples_dict.at(node_type);
152+
auto &to_local_node = to_local_node_dict.at(node_type);
153+
for (int64_t i = 0; i < input_node.numel(); i++) {
154+
const auto &v = input_node_data[i];
155+
samples.push_back(v);
156+
to_local_node.insert({v, i});
157+
}
158+
}
159+
160+
unordered_map<node_t, pair<int64_t, int64_t>> slice_dict;
161+
for (const auto &kv : samples_dict)
162+
slice_dict[kv.first] = {0, kv.second.size()};
163+
164+
for (int64_t ell = 0; ell < num_hops; ell++) {
165+
for (const auto &kv : num_neighbors_dict) {
166+
const auto &rel_type = kv.key();
167+
const auto &edge_type = to_edge_type[rel_type];
168+
const auto &src_node_type = get<0>(edge_type);
169+
const auto &dst_node_type = get<2>(edge_type);
170+
const auto &num_samples = kv.value()[ell];
171+
const auto &dst_samples = samples_dict.at(dst_node_type);
172+
auto &src_samples = samples_dict.at(src_node_type);
173+
auto &to_local_src_node = to_local_node_dict.at(src_node_type);
174+
175+
const auto *colptr_data = colptr_dict.at(rel_type).data_ptr<int64_t>();
176+
const auto *row_data = row_dict.at(rel_type).data_ptr<int64_t>();
177+
178+
auto &rows = rows_dict.at(rel_type);
179+
auto &cols = cols_dict.at(rel_type);
180+
auto &edges = edges_dict.at(rel_type);
181+
182+
const auto &begin = slice_dict.at(dst_node_type).first;
183+
const auto &end = slice_dict.at(dst_node_type).second;
184+
for (int64_t i = begin; i < end; i++) {
185+
const auto &w = dst_samples[i];
186+
const auto &col_start = colptr_data[w];
187+
const auto &col_end = colptr_data[w + 1];
188+
const auto col_count = col_end - col_start;
189+
190+
if (col_count == 0)
191+
continue;
192+
193+
if (replace) {
194+
for (int64_t j = 0; j < num_samples; j++) {
195+
const int64_t offset = col_start + rand() % col_count;
196+
const int64_t &v = row_data[offset];
197+
const auto res = to_local_src_node.insert({v, src_samples.size()});
198+
if (res.second)
199+
src_samples.push_back(v);
200+
if (directed) {
201+
cols.push_back(i);
202+
rows.push_back(res.first->second);
203+
edges.push_back(offset);
204+
}
205+
}
206+
} else if (num_samples >= col_count) {
207+
for (int64_t offset = col_start; offset < col_end; offset++) {
208+
const int64_t &v = row_data[offset];
209+
const auto res = to_local_src_node.insert({v, src_samples.size()});
210+
if (res.second)
211+
src_samples.push_back(v);
212+
if (directed) {
213+
cols.push_back(i);
214+
rows.push_back(res.first->second);
215+
edges.push_back(offset);
216+
}
217+
}
218+
} else {
219+
unordered_set<int64_t> rnd_indices;
220+
for (int64_t j = col_count - num_samples; j < col_count; j++) {
221+
int64_t rnd = rand() % j;
222+
if (!rnd_indices.insert(rnd).second) {
223+
rnd = j;
224+
rnd_indices.insert(j);
225+
}
226+
const int64_t offset = col_start + rnd;
227+
const int64_t &v = row_data[offset];
228+
const auto res = to_local_src_node.insert({v, src_samples.size()});
229+
if (res.second)
230+
src_samples.push_back(v);
231+
if (directed) {
232+
cols.push_back(i);
233+
rows.push_back(res.first->second);
234+
edges.push_back(offset);
235+
}
236+
}
237+
}
238+
}
239+
}
240+
241+
for (const auto &kv : samples_dict) {
242+
slice_dict[kv.first] = {slice_dict.at(kv.first).second, kv.second.size()};
243+
}
244+
}
245+
246+
if (!directed) { // Construct the subgraph among the sampled nodes:
247+
unordered_map<int64_t, int64_t>::iterator iter;
248+
for (const auto &kv : colptr_dict) {
249+
const auto &rel_type = kv.key();
250+
const auto &edge_type = to_edge_type[rel_type];
251+
const auto &src_node_type = get<0>(edge_type);
252+
const auto &dst_node_type = get<2>(edge_type);
253+
const auto &dst_samples = samples_dict.at(dst_node_type);
254+
auto &to_local_src_node = to_local_node_dict.at(src_node_type);
255+
256+
const auto *colptr_data = kv.value().data_ptr<int64_t>();
257+
const auto *row_data = row_dict.at(rel_type).data_ptr<int64_t>();
258+
259+
auto &rows = rows_dict.at(rel_type);
260+
auto &cols = cols_dict.at(rel_type);
261+
auto &edges = edges_dict.at(rel_type);
262+
263+
for (int64_t i = 0; i < (int64_t)dst_samples.size(); i++) {
264+
const auto &w = dst_samples[i];
265+
const auto &col_start = colptr_data[w];
266+
const auto &col_end = colptr_data[w + 1];
267+
for (int64_t offset = col_start; offset < col_end; offset++) {
268+
const auto &v = row_data[offset];
269+
iter = to_local_src_node.find(v);
270+
if (iter != to_local_src_node.end()) {
271+
rows.push_back(iter->second);
272+
cols.push_back(i);
273+
edges.push_back(offset);
274+
}
275+
}
276+
}
277+
}
278+
}
279+
280+
return make_tuple(from_vector<node_t, int64_t>(samples_dict),
281+
from_vector<rel_t, int64_t>(rows_dict),
282+
from_vector<rel_t, int64_t>(cols_dict),
283+
from_vector<rel_t, int64_t>(edges_dict));
284+
}
285+
286+
} // namespace
287+
288+
tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
289+
neighbor_sample_cpu(const torch::Tensor &colptr, const torch::Tensor &row,
290+
const torch::Tensor &input_node,
291+
const vector<int64_t> num_neighbors, const bool replace,
292+
const bool directed) {
293+
294+
if (replace && directed) {
295+
return sample<true, true>(colptr, row, input_node, num_neighbors);
296+
} else if (replace && !directed) {
297+
return sample<true, false>(colptr, row, input_node, num_neighbors);
298+
} else if (!replace && directed) {
299+
return sample<false, true>(colptr, row, input_node, num_neighbors);
300+
} else {
301+
return sample<false, false>(colptr, row, input_node, num_neighbors);
302+
}
303+
}
304+
305+
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
306+
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
307+
hetero_neighbor_sample_cpu(
308+
const std::vector<node_t> &node_types,
309+
const std::vector<edge_t> &edge_types,
310+
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
311+
const c10::Dict<rel_t, torch::Tensor> &row_dict,
312+
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
313+
const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict,
314+
const int64_t num_hops, const bool replace, const bool directed) {
315+
316+
if (replace && directed) {
317+
return hetero_sample<true, true>(node_types, edge_types, colptr_dict,
318+
row_dict, input_node_dict,
319+
num_neighbors_dict, num_hops);
320+
} else if (replace && !directed) {
321+
return hetero_sample<true, false>(node_types, edge_types, colptr_dict,
322+
row_dict, input_node_dict,
323+
num_neighbors_dict, num_hops);
324+
} else if (!replace && directed) {
325+
return hetero_sample<false, true>(node_types, edge_types, colptr_dict,
326+
row_dict, input_node_dict,
327+
num_neighbors_dict, num_hops);
328+
} else {
329+
return hetero_sample<false, false>(node_types, edge_types, colptr_dict,
330+
row_dict, input_node_dict,
331+
num_neighbors_dict, num_hops);
332+
}
333+
}

csrc/cpu/neighbor_sample_cpu.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#pragma once
2+
3+
#include <torch/extension.h>
4+
5+
typedef std::string node_t;
6+
typedef std::tuple<std::string, std::string, std::string> edge_t;
7+
typedef std::string rel_t;
8+
9+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
10+
neighbor_sample_cpu(const torch::Tensor &colptr, const torch::Tensor &row,
11+
const torch::Tensor &input_node,
12+
const std::vector<int64_t> num_neighbors,
13+
const bool replace, const bool directed);
14+
15+
std::tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
16+
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
17+
hetero_neighbor_sample_cpu(
18+
const std::vector<node_t> &node_types,
19+
const std::vector<edge_t> &edge_types,
20+
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
21+
const c10::Dict<rel_t, torch::Tensor> &row_dict,
22+
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
23+
const c10::Dict<rel_t, std::vector<int64_t>> &num_neighbors_dict,
24+
const int64_t num_hops, const bool replace, const bool directed);

0 commit comments

Comments
 (0)