From d1c3c31f6617778f96a18838499b9b6597e09a23 Mon Sep 17 00:00:00 2001 From: tjkessler Date: Sun, 12 Jan 2025 17:10:57 -0500 Subject: [PATCH] Add option to specify activation function; add comments --- graphchem/nn/gcn.py | 56 ++++++++++++++++++++++++++++++++++++++++----- tests/test_gcn.py | 50 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 6 deletions(-) diff --git a/graphchem/nn/gcn.py b/graphchem/nn/gcn.py index fc8db80..99b60c1 100644 --- a/graphchem/nn/gcn.py +++ b/graphchem/nn/gcn.py @@ -18,6 +18,8 @@ class MoleculeGCN(nn.Module): Probability of an element to be zeroed in dropout layers. _n_messages : int Number of message passing steps. + act_fn : callable + Activation function, e.g., `torch.nn.functional.softplus` emb_atom : nn.Embedding Embedding layer for atoms. emb_bond : nn.Embedding @@ -36,7 +38,8 @@ def __init__(self, atom_vocab_size: int, bond_vocab_size: int, n_readout: Optional[int] = 2, readout_dim: Optional[int] = 64, p_dropout: Optional[float] = 0.0, - aggr: Optional[str] = "add"): + aggr: Optional[str] = "add", + act_fn: Optional[callable] = F.softplus): """ Initialize the MoleculeGCN object. @@ -60,37 +63,57 @@ def __init__(self, atom_vocab_size: int, bond_vocab_size: int, Dropout probability for the dropout layers. aggr : str, optional (default="add") Aggregation scheme to use in the GeneralConv layer. + act_fn : callable, optional (default=`torch.nn.functional.softplus`) + Activation function, e.g., `torch.nn.functional.softplus`, + `torch.nn.functional.sigmoid`, `torch.nn.functional.relu`, etc. """ super().__init__() + # Store attributes self._p_dropout = p_dropout self._n_messages = n_messages + self.act_fn = act_fn + # Embedding layer for atoms self.emb_atom = nn.Embedding(atom_vocab_size, embedding_dim) + + # Embedding layer for bonds self.emb_bond = nn.Embedding(bond_vocab_size, embedding_dim) + # General convolution layer for atoms with specified aggregation method self.atom_conv = GeneralConv( embedding_dim, embedding_dim, embedding_dim, aggr=aggr ) + + # Edge convolution layer for bonds using a linear transformation self.bond_conv = EdgeConv(nn.Sequential( nn.Linear(2 * embedding_dim, embedding_dim) )) + # Initialize the readout network if readout layers are specified if n_readout > 0: + # Create a list to hold the readout network modules self.readout = nn.ModuleList() + + # First layer of the readout network self.readout.append(nn.Sequential( nn.Linear(embedding_dim, readout_dim) )) + + # Additional hidden layers for the readout network if needed if n_readout > 1: for _ in range(n_readout - 1): self.readout.append(nn.Sequential( nn.Linear(readout_dim, readout_dim) )) + + # Final layer of the readout network to produce output dimensions self.readout.append(nn.Sequential( nn.Linear(readout_dim, output_dim) )) + # No readout network if n_readout is 0 else: self.readout = None @@ -116,43 +139,64 @@ def forward( out_bond : torch.Tensor Bond-level representations after message passing. """ + # Extract node features, edge attributes, edge indices, and batch + # vector from data x, edge_attr, edge_index, batch = data.x, data.edge_attr, \ data.edge_index, data.batch + + # If no node features are provided, initialize with ones if data.num_node_features == 0: x = torch.ones(data.num_nodes, 1) + # Embed and activate atom features out_atom = self.emb_atom(x) - out_atom = F.softplus(out_atom) + out_atom = self.act_fn(out_atom) + # Embed and activate bond features out_bond = self.emb_bond(edge_attr) - out_bond = F.softplus(out_bond) + out_bond = self.act_fn(out_bond) + # Perform message passing for the specified number of steps for _ in range(self._n_messages): + # Update bond representations using edge convolution out_bond = self.bond_conv(out_bond, edge_index) - out_bond = F.softplus(out_bond) + out_bond = self.act_fn(out_bond) + + # Apply dropout out_bond = F.dropout( out_bond, p=self._p_dropout, training=self.training ) + # Update atom representations using general convolution out_atom = self.atom_conv(out_atom, edge_index, out_bond) - out_atom = F.softplus(out_atom) + out_atom = self.act_fn(out_atom) + + # Apply dropout out_atom = F.dropout( out_atom, p=self._p_dropout, training=self.training ) + # Aggregate atom representations across batches with global add pooling out = global_add_pool(out_atom, batch) + # Process aggregated atom representation through the readout network if self.readout is not None: + # Iterate over all but the last layer of the readout network for layer in self.readout[:-1]: + # Pass through layer and activate out = layer(out) - out = F.softplus(out) + out = self.act_fn(out) + + # Apply dropout out = F.dropout( out, p=self._p_dropout, training=self.training ) + # Final layer of the readout network to produce output dimensions out = self.readout[-1](out) + # Return final prediction, atom representations, bond representations return (out, out_atom, out_bond) diff --git a/tests/test_gcn.py b/tests/test_gcn.py index 389ee29..9b2b84a 100644 --- a/tests/test_gcn.py +++ b/tests/test_gcn.py @@ -99,3 +99,53 @@ def test_moleculegcn_no_readout(): assert out.shape == (1, embedding_dim) assert out_atom.shape == (n_atoms, embedding_dim) assert out_bond.shape == (n_bonds, embedding_dim) + + +# Test the forward pass of the MoleculeGCN model with different activation fns +def test_moleculegcn_act_fns(): + + functions = [ + torch.nn.functional.softplus, + torch.nn.functional.sigmoid, + torch.nn.functional.relu, + torch.nn.functional.tanh, + torch.nn.functional.leaky_relu + ] + + for fn in functions: + + atom_vocab_size = 10 + bond_vocab_size = 5 + embedding_dim = 64 + n_atoms = 3 + n_bonds = 5 + + model = MoleculeGCN( + atom_vocab_size=atom_vocab_size, + bond_vocab_size=bond_vocab_size, + output_dim=None, + embedding_dim=embedding_dim, + n_messages=3, + n_readout=0, + readout_dim=128, + p_dropout=0.2, + act_fn=fn + ) + + assert model.readout is None + + n_atoms = 3 + + # Create a mock input for the forward pass + data = Data( + x=torch.randint(0, atom_vocab_size, (n_atoms,)), + edge_index=torch.randint(0, n_atoms, (2, n_bonds)), + edge_attr=torch.randint(0, bond_vocab_size, (n_bonds,)), + batch=torch.tensor([0, 0, 0], dtype=torch.long) + ) + + out, out_atom, out_bond = model.forward(data) + + assert out.shape == (1, embedding_dim) + assert out_atom.shape == (n_atoms, embedding_dim) + assert out_bond.shape == (n_bonds, embedding_dim)