Skip to content

Commit

Permalink
Add option to omit readout layer(s)
Browse files Browse the repository at this point in the history
  • Loading branch information
tjkessler committed Jan 12, 2025
1 parent 34120c3 commit a747e08
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 26 deletions.
45 changes: 26 additions & 19 deletions graphchem/nn/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,23 @@ def __init__(self, atom_vocab_size: int, bond_vocab_size: int,
nn.Linear(2 * embedding_dim, embedding_dim)
))

self.readout = nn.ModuleList()
self.readout.append(nn.Sequential(
nn.Linear(embedding_dim, readout_dim)
))
if n_readout > 1:
for _ in range(n_readout - 1):
self.readout.append(nn.Sequential(
nn.Linear(readout_dim, readout_dim)
))
self.readout.append(nn.Sequential(
nn.Linear(readout_dim, output_dim)
))
if n_readout > 0:

self.readout = nn.ModuleList()
self.readout.append(nn.Sequential(
nn.Linear(embedding_dim, readout_dim)
))
if n_readout > 1:
for _ in range(n_readout - 1):
self.readout.append(nn.Sequential(
nn.Linear(readout_dim, readout_dim)
))
self.readout.append(nn.Sequential(
nn.Linear(readout_dim, output_dim)
))

else:
self.readout = None

def forward(
self,
Expand Down Expand Up @@ -138,14 +143,16 @@ def forward(

out = global_add_pool(out_atom, batch)

for layer in self.readout[:-1]:
if self.readout is not None:

out = layer(out)
out = F.softplus(out)
out = F.dropout(
out, p=self._p_dropout, training=self.training
)
for layer in self.readout[:-1]:

out = layer(out)
out = F.softplus(out)
out = F.dropout(
out, p=self._p_dropout, training=self.training
)

out = self.readout[-1](out)
out = self.readout[-1](out)

return (out, out_atom, out_bond)
56 changes: 49 additions & 7 deletions tests/test_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


# Test instantiation of the MoleculeGCN class with various parameters
def test_moleculgcn_instantiation():
def test_moleculegcn_instantiation():

model = MoleculeGCN(
atom_vocab_size=10,
Expand All @@ -21,21 +21,24 @@ def test_moleculgcn_instantiation():

assert model._p_dropout == 0.2
assert model._n_messages == 3
assert model.readout is not None


# Test the forward pass of the MoleculeGCN model
def test_moleculgcn_forward_pass():
def test_moleculegcn_forward_pass():

atom_vocab_size = 10
bond_vocab_size = 5
embedding_dim = 64
output_dim = 2
n_atoms = 3
n_bonds = 5

model = MoleculeGCN(
atom_vocab_size=atom_vocab_size,
bond_vocab_size=bond_vocab_size,
output_dim=2,
embedding_dim=64,
output_dim=output_dim,
embedding_dim=embedding_dim,
n_messages=3,
n_readout=3,
readout_dim=128,
Expand All @@ -54,6 +57,45 @@ def test_moleculgcn_forward_pass():

out, out_atom, out_bond = model.forward(data)

assert out.shape == (1, 2)
assert out_atom.shape == (3, 64)
assert out_bond.shape == (5, 64)
assert out.shape == (1, output_dim)
assert out_atom.shape == (n_atoms, embedding_dim)
assert out_bond.shape == (n_bonds, embedding_dim)


# Test the forward pass of the MoleculeGCN model without readout layers
def test_moleculegcn_no_readout():

atom_vocab_size = 10
bond_vocab_size = 5
embedding_dim = 64
n_atoms = 3
n_bonds = 5

model = MoleculeGCN(
atom_vocab_size=atom_vocab_size,
bond_vocab_size=bond_vocab_size,
output_dim=None,
embedding_dim=embedding_dim,
n_messages=3,
n_readout=0,
readout_dim=128,
p_dropout=0.2
)

assert model.readout is None

n_atoms = 3

# Create a mock input for the forward pass
data = Data(
x=torch.randint(0, atom_vocab_size, (n_atoms,)),
edge_index=torch.randint(0, n_atoms, (2, n_bonds)),
edge_attr=torch.randint(0, bond_vocab_size, (n_bonds,)),
batch=torch.tensor([0, 0, 0], dtype=torch.long)
)

out, out_atom, out_bond = model.forward(data)

assert out.shape == (1, embedding_dim)
assert out_atom.shape == (n_atoms, embedding_dim)
assert out_bond.shape == (n_bonds, embedding_dim)

0 comments on commit a747e08

Please sign in to comment.