From a747e082443c54f94bbfd374ba78ef90faf22854 Mon Sep 17 00:00:00 2001 From: tjkessler Date: Sun, 12 Jan 2025 16:39:14 -0500 Subject: [PATCH] Add option to omit readout layer(s) --- graphchem/nn/gcn.py | 45 +++++++++++++++++++++--------------- tests/test_gcn.py | 56 +++++++++++++++++++++++++++++++++++++++------ 2 files changed, 75 insertions(+), 26 deletions(-) diff --git a/graphchem/nn/gcn.py b/graphchem/nn/gcn.py index b09e378..fc8db80 100644 --- a/graphchem/nn/gcn.py +++ b/graphchem/nn/gcn.py @@ -76,18 +76,23 @@ def __init__(self, atom_vocab_size: int, bond_vocab_size: int, nn.Linear(2 * embedding_dim, embedding_dim) )) - self.readout = nn.ModuleList() - self.readout.append(nn.Sequential( - nn.Linear(embedding_dim, readout_dim) - )) - if n_readout > 1: - for _ in range(n_readout - 1): - self.readout.append(nn.Sequential( - nn.Linear(readout_dim, readout_dim) - )) - self.readout.append(nn.Sequential( - nn.Linear(readout_dim, output_dim) - )) + if n_readout > 0: + + self.readout = nn.ModuleList() + self.readout.append(nn.Sequential( + nn.Linear(embedding_dim, readout_dim) + )) + if n_readout > 1: + for _ in range(n_readout - 1): + self.readout.append(nn.Sequential( + nn.Linear(readout_dim, readout_dim) + )) + self.readout.append(nn.Sequential( + nn.Linear(readout_dim, output_dim) + )) + + else: + self.readout = None def forward( self, @@ -138,14 +143,16 @@ def forward( out = global_add_pool(out_atom, batch) - for layer in self.readout[:-1]: + if self.readout is not None: - out = layer(out) - out = F.softplus(out) - out = F.dropout( - out, p=self._p_dropout, training=self.training - ) + for layer in self.readout[:-1]: + + out = layer(out) + out = F.softplus(out) + out = F.dropout( + out, p=self._p_dropout, training=self.training + ) - out = self.readout[-1](out) + out = self.readout[-1](out) return (out, out_atom, out_bond) diff --git a/tests/test_gcn.py b/tests/test_gcn.py index 6125ff0..389ee29 100644 --- a/tests/test_gcn.py +++ b/tests/test_gcn.py @@ -6,7 +6,7 @@ # Test instantiation of the MoleculeGCN class with various parameters -def test_moleculgcn_instantiation(): +def test_moleculegcn_instantiation(): model = MoleculeGCN( atom_vocab_size=10, @@ -21,21 +21,24 @@ def test_moleculgcn_instantiation(): assert model._p_dropout == 0.2 assert model._n_messages == 3 + assert model.readout is not None # Test the forward pass of the MoleculeGCN model -def test_moleculgcn_forward_pass(): +def test_moleculegcn_forward_pass(): atom_vocab_size = 10 bond_vocab_size = 5 + embedding_dim = 64 + output_dim = 2 n_atoms = 3 n_bonds = 5 model = MoleculeGCN( atom_vocab_size=atom_vocab_size, bond_vocab_size=bond_vocab_size, - output_dim=2, - embedding_dim=64, + output_dim=output_dim, + embedding_dim=embedding_dim, n_messages=3, n_readout=3, readout_dim=128, @@ -54,6 +57,45 @@ def test_moleculgcn_forward_pass(): out, out_atom, out_bond = model.forward(data) - assert out.shape == (1, 2) - assert out_atom.shape == (3, 64) - assert out_bond.shape == (5, 64) + assert out.shape == (1, output_dim) + assert out_atom.shape == (n_atoms, embedding_dim) + assert out_bond.shape == (n_bonds, embedding_dim) + + +# Test the forward pass of the MoleculeGCN model without readout layers +def test_moleculegcn_no_readout(): + + atom_vocab_size = 10 + bond_vocab_size = 5 + embedding_dim = 64 + n_atoms = 3 + n_bonds = 5 + + model = MoleculeGCN( + atom_vocab_size=atom_vocab_size, + bond_vocab_size=bond_vocab_size, + output_dim=None, + embedding_dim=embedding_dim, + n_messages=3, + n_readout=0, + readout_dim=128, + p_dropout=0.2 + ) + + assert model.readout is None + + n_atoms = 3 + + # Create a mock input for the forward pass + data = Data( + x=torch.randint(0, atom_vocab_size, (n_atoms,)), + edge_index=torch.randint(0, n_atoms, (2, n_bonds)), + edge_attr=torch.randint(0, bond_vocab_size, (n_bonds,)), + batch=torch.tensor([0, 0, 0], dtype=torch.long) + ) + + out, out_atom, out_bond = model.forward(data) + + assert out.shape == (1, embedding_dim) + assert out_atom.shape == (n_atoms, embedding_dim) + assert out_bond.shape == (n_bonds, embedding_dim)