-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_gcn_randomfeatures.py
92 lines (80 loc) · 3.42 KB
/
train_gcn_randomfeatures.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TODO(tsitsulin): add headers, tests, and improve style."""
from absl import app
from absl import flags
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.metrics import normalized_mutual_info_score
import tensorflow.compat.v2 as tf
from graph_embedding.dmon.models.multilayer_gcn import multilayer_gcn
from graph_embedding.dmon.synthetic_data.graph_util import construct_knn_graph
from graph_embedding.dmon.synthetic_data.overlapping_gaussians import overlapping_gaussians
tf.compat.v1.enable_v2_behavior()
FLAGS = flags.FLAGS
flags.DEFINE_integer(
'n_nodes', 1000, 'Number of nodes for the synthetic graph.', lower_bound=0)
flags.DEFINE_integer(
'n_clusters',
2,
'Number of clusters for the synthetic graph.',
lower_bound=0)
flags.DEFINE_float(
'train_size', 0.2, 'Training data proportion.', lower_bound=0)
flags.DEFINE_integer(
'n_epochs', 200, 'Number of epochs to train.', lower_bound=0)
flags.DEFINE_integer(
'n_random_features', 64, 'Number of random features.', lower_bound=0)
flags.DEFINE_float(
'learning_rate', 0.01, 'Optimizer\'s learning rate.', lower_bound=0)
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
print('Bröther may i have some self-lööps')
n_nodes = FLAGS.n_nodes
n_clusters = FLAGS.n_clusters
n_random_features = FLAGS.n_random_features
train_size = FLAGS.train_size
data_clean, data_dirty, labels = overlapping_gaussians(n_nodes, n_clusters)
data_random = np.random.normal(size=(n_nodes, n_random_features))
graph_clean = construct_knn_graph(data_clean).todense().A1.reshape(
n_nodes, n_nodes)
train_mask = np.zeros(n_nodes, dtype=bool)
train_mask[np.random.choice(
np.arange(n_nodes), int(n_nodes * train_size), replace=False)] = True
test_mask = ~train_mask
print(f'Data shape: {data_clean.shape}, graph shape: {graph_clean.shape}')
print(f'Train size: {train_mask.sum()}, test size: {test_mask.sum()}')
input_features = tf.keras.layers.Input(shape=(n_random_features,))
input_graph = tf.keras.layers.Input((n_nodes,))
model = multilayer_gcn([input_features, input_graph], [64, 32, n_clusters])
model.compile(
optimizer=tf.keras.optimizers.Adam(FLAGS.learning_rate),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
for epoch in range(FLAGS.n_epochs):
model.fit([data_random, graph_clean],
labels,
n_nodes,
shuffle=False,
sample_weight=train_mask)
clusters = model([data_random, graph_clean]).numpy().argmax(axis=1)[test_mask]
print(
'NMI:',
normalized_mutual_info_score(
labels[test_mask], clusters, average_method='arithmetic'))
print('Accuracy:', accuracy_score(labels[test_mask], clusters))
if __name__ == '__main__':
app.run(main)