-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathutils.py
101 lines (89 loc) · 3.63 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from torch_geometric.data import Data
class TemporalData(Data):
def __init__(self,
**kwargs) -> None:
super(TemporalData, self).__init__(**kwargs)
def __inc__(self, key, value):
return super().__inc__(key, value)
class DistanceDropEdge(object):
def __init__(self, max_distance: Optional[float] = None) -> None:
self.max_distance = max_distance
def __call__(self,
edge_index: torch.Tensor,
edge_attr: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.max_distance is None:
return edge_index, edge_attr
row, col = edge_index
mask = torch.norm(edge_attr, p=2, dim=-1) < self.max_distance
edge_index = torch.stack([row[mask], col[mask]], dim=0)
edge_attr = edge_attr[mask]
return edge_index, edge_attr
def init_weights(m: nn.Module) -> None:
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
fan_in = m.in_channels / m.groups
fan_out = m.out_channels / m.groups
bound = (6.0 / (fan_in + fan_out)) ** 0.5
nn.init.uniform_(m.weight, -bound, bound)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.MultiheadAttention):
if m.in_proj_weight is not None:
fan_in = m.embed_dim
fan_out = m.embed_dim
bound = (6.0 / (fan_in + fan_out)) ** 0.5
nn.init.uniform_(m.in_proj_weight, -bound, bound)
else:
nn.init.xavier_uniform_(m.q_proj_weight)
nn.init.xavier_uniform_(m.k_proj_weight)
nn.init.xavier_uniform_(m.v_proj_weight)
if m.in_proj_bias is not None:
nn.init.zeros_(m.in_proj_bias)
nn.init.xavier_uniform_(m.out_proj.weight)
if m.out_proj.bias is not None:
nn.init.zeros_(m.out_proj.bias)
if m.bias_k is not None:
nn.init.normal_(m.bias_k, mean=0.0, std=0.02)
if m.bias_v is not None:
nn.init.normal_(m.bias_v, mean=0.0, std=0.02)
elif isinstance(m, nn.LSTM):
for name, param in m.named_parameters():
if 'weight_ih' in name:
for ih in param.chunk(4, 0):
nn.init.xavier_uniform_(ih)
elif 'weight_hh' in name:
for hh in param.chunk(4, 0):
nn.init.orthogonal_(hh)
elif 'weight_hr' in name:
nn.init.xavier_uniform_(param)
elif 'bias_ih' in name:
nn.init.zeros_(param)
elif 'bias_hh' in name:
nn.init.zeros_(param)
nn.init.ones_(param.chunk(4, 0)[1])
elif isinstance(m, nn.GRU):
for name, param in m.named_parameters():
if 'weight_ih' in name:
for ih in param.chunk(3, 0):
nn.init.xavier_uniform_(ih)
elif 'weight_hh' in name:
for hh in param.chunk(3, 0):
nn.init.orthogonal_(hh)
elif 'bias_ih' in name:
nn.init.zeros_(param)
elif 'bias_hh' in name:
nn.init.zeros_(param)