Skip to content

Commit

Permalink
Add option to specify activation function; add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
tjkessler committed Jan 12, 2025
1 parent a747e08 commit d1c3c31
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 6 deletions.
56 changes: 50 additions & 6 deletions graphchem/nn/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class MoleculeGCN(nn.Module):
Probability of an element to be zeroed in dropout layers.
_n_messages : int
Number of message passing steps.
act_fn : callable
Activation function, e.g., `torch.nn.functional.softplus`
emb_atom : nn.Embedding
Embedding layer for atoms.
emb_bond : nn.Embedding
Expand All @@ -36,7 +38,8 @@ def __init__(self, atom_vocab_size: int, bond_vocab_size: int,
n_readout: Optional[int] = 2,
readout_dim: Optional[int] = 64,
p_dropout: Optional[float] = 0.0,
aggr: Optional[str] = "add"):
aggr: Optional[str] = "add",
act_fn: Optional[callable] = F.softplus):
"""
Initialize the MoleculeGCN object.
Expand All @@ -60,37 +63,57 @@ def __init__(self, atom_vocab_size: int, bond_vocab_size: int,
Dropout probability for the dropout layers.
aggr : str, optional (default="add")
Aggregation scheme to use in the GeneralConv layer.
act_fn : callable, optional (default=`torch.nn.functional.softplus`)
Activation function, e.g., `torch.nn.functional.softplus`,
`torch.nn.functional.sigmoid`, `torch.nn.functional.relu`, etc.
"""
super().__init__()

# Store attributes
self._p_dropout = p_dropout
self._n_messages = n_messages
self.act_fn = act_fn

# Embedding layer for atoms
self.emb_atom = nn.Embedding(atom_vocab_size, embedding_dim)

# Embedding layer for bonds
self.emb_bond = nn.Embedding(bond_vocab_size, embedding_dim)

# General convolution layer for atoms with specified aggregation method
self.atom_conv = GeneralConv(
embedding_dim, embedding_dim, embedding_dim, aggr=aggr
)

# Edge convolution layer for bonds using a linear transformation
self.bond_conv = EdgeConv(nn.Sequential(
nn.Linear(2 * embedding_dim, embedding_dim)
))

# Initialize the readout network if readout layers are specified
if n_readout > 0:

# Create a list to hold the readout network modules
self.readout = nn.ModuleList()

# First layer of the readout network
self.readout.append(nn.Sequential(
nn.Linear(embedding_dim, readout_dim)
))

# Additional hidden layers for the readout network if needed
if n_readout > 1:
for _ in range(n_readout - 1):
self.readout.append(nn.Sequential(
nn.Linear(readout_dim, readout_dim)
))

# Final layer of the readout network to produce output dimensions
self.readout.append(nn.Sequential(
nn.Linear(readout_dim, output_dim)
))

# No readout network if n_readout is 0
else:
self.readout = None

Expand All @@ -116,43 +139,64 @@ def forward(
out_bond : torch.Tensor
Bond-level representations after message passing.
"""
# Extract node features, edge attributes, edge indices, and batch
# vector from data
x, edge_attr, edge_index, batch = data.x, data.edge_attr, \
data.edge_index, data.batch

# If no node features are provided, initialize with ones
if data.num_node_features == 0:
x = torch.ones(data.num_nodes, 1)

# Embed and activate atom features
out_atom = self.emb_atom(x)
out_atom = F.softplus(out_atom)
out_atom = self.act_fn(out_atom)

# Embed and activate bond features
out_bond = self.emb_bond(edge_attr)
out_bond = F.softplus(out_bond)
out_bond = self.act_fn(out_bond)

# Perform message passing for the specified number of steps
for _ in range(self._n_messages):

# Update bond representations using edge convolution
out_bond = self.bond_conv(out_bond, edge_index)
out_bond = F.softplus(out_bond)
out_bond = self.act_fn(out_bond)

# Apply dropout
out_bond = F.dropout(
out_bond, p=self._p_dropout, training=self.training
)

# Update atom representations using general convolution
out_atom = self.atom_conv(out_atom, edge_index, out_bond)
out_atom = F.softplus(out_atom)
out_atom = self.act_fn(out_atom)

# Apply dropout
out_atom = F.dropout(
out_atom, p=self._p_dropout, training=self.training
)

# Aggregate atom representations across batches with global add pooling
out = global_add_pool(out_atom, batch)

# Process aggregated atom representation through the readout network
if self.readout is not None:

# Iterate over all but the last layer of the readout network
for layer in self.readout[:-1]:

# Pass through layer and activate
out = layer(out)
out = F.softplus(out)
out = self.act_fn(out)

# Apply dropout
out = F.dropout(
out, p=self._p_dropout, training=self.training
)

# Final layer of the readout network to produce output dimensions
out = self.readout[-1](out)

# Return final prediction, atom representations, bond representations
return (out, out_atom, out_bond)
50 changes: 50 additions & 0 deletions tests/test_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,53 @@ def test_moleculegcn_no_readout():
assert out.shape == (1, embedding_dim)
assert out_atom.shape == (n_atoms, embedding_dim)
assert out_bond.shape == (n_bonds, embedding_dim)


# Test the forward pass of the MoleculeGCN model with different activation fns
def test_moleculegcn_act_fns():

functions = [
torch.nn.functional.softplus,
torch.nn.functional.sigmoid,
torch.nn.functional.relu,
torch.nn.functional.tanh,
torch.nn.functional.leaky_relu
]

for fn in functions:

atom_vocab_size = 10
bond_vocab_size = 5
embedding_dim = 64
n_atoms = 3
n_bonds = 5

model = MoleculeGCN(
atom_vocab_size=atom_vocab_size,
bond_vocab_size=bond_vocab_size,
output_dim=None,
embedding_dim=embedding_dim,
n_messages=3,
n_readout=0,
readout_dim=128,
p_dropout=0.2,
act_fn=fn
)

assert model.readout is None

n_atoms = 3

# Create a mock input for the forward pass
data = Data(
x=torch.randint(0, atom_vocab_size, (n_atoms,)),
edge_index=torch.randint(0, n_atoms, (2, n_bonds)),
edge_attr=torch.randint(0, bond_vocab_size, (n_bonds,)),
batch=torch.tensor([0, 0, 0], dtype=torch.long)
)

out, out_atom, out_bond = model.forward(data)

assert out.shape == (1, embedding_dim)
assert out_atom.shape == (n_atoms, embedding_dim)
assert out_bond.shape == (n_bonds, embedding_dim)

0 comments on commit d1c3c31

Please sign in to comment.