Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MoleculeGCN updates #18

Merged
merged 2 commits into from
Jan 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 75 additions & 24 deletions graphchem/nn/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class MoleculeGCN(nn.Module):
Probability of an element to be zeroed in dropout layers.
_n_messages : int
Number of message passing steps.
act_fn : callable
Activation function, e.g., `torch.nn.functional.softplus`
emb_atom : nn.Embedding
Embedding layer for atoms.
emb_bond : nn.Embedding
Expand All @@ -36,7 +38,8 @@ def __init__(self, atom_vocab_size: int, bond_vocab_size: int,
n_readout: Optional[int] = 2,
readout_dim: Optional[int] = 64,
p_dropout: Optional[float] = 0.0,
aggr: Optional[str] = "add"):
aggr: Optional[str] = "add",
act_fn: Optional[callable] = F.softplus):
"""
Initialize the MoleculeGCN object.

Expand All @@ -60,34 +63,59 @@ def __init__(self, atom_vocab_size: int, bond_vocab_size: int,
Dropout probability for the dropout layers.
aggr : str, optional (default="add")
Aggregation scheme to use in the GeneralConv layer.
act_fn : callable, optional (default=`torch.nn.functional.softplus`)
Activation function, e.g., `torch.nn.functional.softplus`,
`torch.nn.functional.sigmoid`, `torch.nn.functional.relu`, etc.
"""
super().__init__()

# Store attributes
self._p_dropout = p_dropout
self._n_messages = n_messages
self.act_fn = act_fn

# Embedding layer for atoms
self.emb_atom = nn.Embedding(atom_vocab_size, embedding_dim)

# Embedding layer for bonds
self.emb_bond = nn.Embedding(bond_vocab_size, embedding_dim)

# General convolution layer for atoms with specified aggregation method
self.atom_conv = GeneralConv(
embedding_dim, embedding_dim, embedding_dim, aggr=aggr
)

# Edge convolution layer for bonds using a linear transformation
self.bond_conv = EdgeConv(nn.Sequential(
nn.Linear(2 * embedding_dim, embedding_dim)
))

self.readout = nn.ModuleList()
self.readout.append(nn.Sequential(
nn.Linear(embedding_dim, readout_dim)
))
if n_readout > 1:
for _ in range(n_readout - 1):
self.readout.append(nn.Sequential(
nn.Linear(readout_dim, readout_dim)
))
self.readout.append(nn.Sequential(
nn.Linear(readout_dim, output_dim)
))
# Initialize the readout network if readout layers are specified
if n_readout > 0:

# Create a list to hold the readout network modules
self.readout = nn.ModuleList()

# First layer of the readout network
self.readout.append(nn.Sequential(
nn.Linear(embedding_dim, readout_dim)
))

# Additional hidden layers for the readout network if needed
if n_readout > 1:
for _ in range(n_readout - 1):
self.readout.append(nn.Sequential(
nn.Linear(readout_dim, readout_dim)
))

# Final layer of the readout network to produce output dimensions
self.readout.append(nn.Sequential(
nn.Linear(readout_dim, output_dim)
))

# No readout network if n_readout is 0
else:
self.readout = None

def forward(
self,
Expand All @@ -111,41 +139,64 @@ def forward(
out_bond : torch.Tensor
Bond-level representations after message passing.
"""
# Extract node features, edge attributes, edge indices, and batch
# vector from data
x, edge_attr, edge_index, batch = data.x, data.edge_attr, \
data.edge_index, data.batch

# If no node features are provided, initialize with ones
if data.num_node_features == 0:
x = torch.ones(data.num_nodes, 1)

# Embed and activate atom features
out_atom = self.emb_atom(x)
out_atom = F.softplus(out_atom)
out_atom = self.act_fn(out_atom)

# Embed and activate bond features
out_bond = self.emb_bond(edge_attr)
out_bond = F.softplus(out_bond)
out_bond = self.act_fn(out_bond)

# Perform message passing for the specified number of steps
for _ in range(self._n_messages):

# Update bond representations using edge convolution
out_bond = self.bond_conv(out_bond, edge_index)
out_bond = F.softplus(out_bond)
out_bond = self.act_fn(out_bond)

# Apply dropout
out_bond = F.dropout(
out_bond, p=self._p_dropout, training=self.training
)

# Update atom representations using general convolution
out_atom = self.atom_conv(out_atom, edge_index, out_bond)
out_atom = F.softplus(out_atom)
out_atom = self.act_fn(out_atom)

# Apply dropout
out_atom = F.dropout(
out_atom, p=self._p_dropout, training=self.training
)

# Aggregate atom representations across batches with global add pooling
out = global_add_pool(out_atom, batch)

for layer in self.readout[:-1]:
# Process aggregated atom representation through the readout network
if self.readout is not None:

out = layer(out)
out = F.softplus(out)
out = F.dropout(
out, p=self._p_dropout, training=self.training
)
# Iterate over all but the last layer of the readout network
for layer in self.readout[:-1]:

# Pass through layer and activate
out = layer(out)
out = self.act_fn(out)

# Apply dropout
out = F.dropout(
out, p=self._p_dropout, training=self.training
)

out = self.readout[-1](out)
# Final layer of the readout network to produce output dimensions
out = self.readout[-1](out)

# Return final prediction, atom representations, bond representations
return (out, out_atom, out_bond)
106 changes: 99 additions & 7 deletions tests/test_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


# Test instantiation of the MoleculeGCN class with various parameters
def test_moleculgcn_instantiation():
def test_moleculegcn_instantiation():

model = MoleculeGCN(
atom_vocab_size=10,
Expand All @@ -21,21 +21,24 @@ def test_moleculgcn_instantiation():

assert model._p_dropout == 0.2
assert model._n_messages == 3
assert model.readout is not None


# Test the forward pass of the MoleculeGCN model
def test_moleculgcn_forward_pass():
def test_moleculegcn_forward_pass():

atom_vocab_size = 10
bond_vocab_size = 5
embedding_dim = 64
output_dim = 2
n_atoms = 3
n_bonds = 5

model = MoleculeGCN(
atom_vocab_size=atom_vocab_size,
bond_vocab_size=bond_vocab_size,
output_dim=2,
embedding_dim=64,
output_dim=output_dim,
embedding_dim=embedding_dim,
n_messages=3,
n_readout=3,
readout_dim=128,
Expand All @@ -54,6 +57,95 @@ def test_moleculgcn_forward_pass():

out, out_atom, out_bond = model.forward(data)

assert out.shape == (1, 2)
assert out_atom.shape == (3, 64)
assert out_bond.shape == (5, 64)
assert out.shape == (1, output_dim)
assert out_atom.shape == (n_atoms, embedding_dim)
assert out_bond.shape == (n_bonds, embedding_dim)


# Test the forward pass of the MoleculeGCN model without readout layers
def test_moleculegcn_no_readout():

atom_vocab_size = 10
bond_vocab_size = 5
embedding_dim = 64
n_atoms = 3
n_bonds = 5

model = MoleculeGCN(
atom_vocab_size=atom_vocab_size,
bond_vocab_size=bond_vocab_size,
output_dim=None,
embedding_dim=embedding_dim,
n_messages=3,
n_readout=0,
readout_dim=128,
p_dropout=0.2
)

assert model.readout is None

n_atoms = 3

# Create a mock input for the forward pass
data = Data(
x=torch.randint(0, atom_vocab_size, (n_atoms,)),
edge_index=torch.randint(0, n_atoms, (2, n_bonds)),
edge_attr=torch.randint(0, bond_vocab_size, (n_bonds,)),
batch=torch.tensor([0, 0, 0], dtype=torch.long)
)

out, out_atom, out_bond = model.forward(data)

assert out.shape == (1, embedding_dim)
assert out_atom.shape == (n_atoms, embedding_dim)
assert out_bond.shape == (n_bonds, embedding_dim)


# Test the forward pass of the MoleculeGCN model with different activation fns
def test_moleculegcn_act_fns():

functions = [
torch.nn.functional.softplus,
torch.nn.functional.sigmoid,
torch.nn.functional.relu,
torch.nn.functional.tanh,
torch.nn.functional.leaky_relu
]

for fn in functions:

atom_vocab_size = 10
bond_vocab_size = 5
embedding_dim = 64
n_atoms = 3
n_bonds = 5

model = MoleculeGCN(
atom_vocab_size=atom_vocab_size,
bond_vocab_size=bond_vocab_size,
output_dim=None,
embedding_dim=embedding_dim,
n_messages=3,
n_readout=0,
readout_dim=128,
p_dropout=0.2,
act_fn=fn
)

assert model.readout is None

n_atoms = 3

# Create a mock input for the forward pass
data = Data(
x=torch.randint(0, atom_vocab_size, (n_atoms,)),
edge_index=torch.randint(0, n_atoms, (2, n_bonds)),
edge_attr=torch.randint(0, bond_vocab_size, (n_bonds,)),
batch=torch.tensor([0, 0, 0], dtype=torch.long)
)

out, out_atom, out_bond = model.forward(data)

assert out.shape == (1, embedding_dim)
assert out_atom.shape == (n_atoms, embedding_dim)
assert out_bond.shape == (n_bonds, embedding_dim)
Loading