-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathneighbor_sampling.py
49 lines (39 loc) · 1.86 KB
/
neighbor_sampling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torch_geometric as pyg
from torch_geometric.loader.utils import filter_data
from torch_geometric.nn import SAGEConv
import tch_geometric as thg
device = 'cpu'
samples_per_node = 4
num_neighbors = [4, 3]
dataset = pyg.datasets.FakeDataset()
data = dataset[0]
col_ptrs, row_indices, perm = thg.loader.to_csc(data)
# Standard sampling
start = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.long)
samples, rows, cols, edge_index, layer_offsets = thg.native.neighbor_sampling_homogenous(
col_ptrs, row_indices, start.repeat(samples_per_node), num_neighbors
)
batch = filter_data(data, samples, rows, cols, edge_index, perm)
layer = SAGEConv((-1, -1), 32)
output = layer(x=batch.x, edge_index=batch.edge_index)
# Weighted sampling
start = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.long)
weights = torch.rand(row_indices.shape, dtype=torch.double)
samples, rows, cols, edge_index, layer_offsets = thg.native.neighbor_sampling_homogenous(
col_ptrs, row_indices, start.repeat(samples_per_node), num_neighbors, sampler=thg.loader.WeightedEdgeSampler(weights)
)
batch = filter_data(data, samples, rows, cols, edge_index, perm)
layer = SAGEConv((-1, -1), 32)
output = layer(x=batch.x, edge_index=batch.edge_index)
# Temporal Filtering
start = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.long)
initial_timestamps = torch.randint(size=start.shape, low=0, high=5, dtype=torch.long)
timestamps = torch.randint(size=row_indices.shape, low=0, high=5, dtype=torch.long)
samples, rows, cols, edge_index, layer_offsets = thg.native.neighbor_sampling_homogenous(
col_ptrs, row_indices, start.repeat(samples_per_node), num_neighbors,
filter=(thg.loader.TemporalEdgeFilter((0, 3), timestamps), initial_timestamps)
)
batch = filter_data(data, samples, rows, cols, edge_index, perm)
layer = SAGEConv((-1, -1), 32)
output = layer(x=batch.x, edge_index=batch.edge_index)