diff --git a/graphchem/nn/gcn.py b/graphchem/nn/gcn.py index b09e378..99b60c1 100644 --- a/graphchem/nn/gcn.py +++ b/graphchem/nn/gcn.py @@ -18,6 +18,8 @@ class MoleculeGCN(nn.Module): Probability of an element to be zeroed in dropout layers. _n_messages : int Number of message passing steps. + act_fn : callable + Activation function, e.g., `torch.nn.functional.softplus` emb_atom : nn.Embedding Embedding layer for atoms. emb_bond : nn.Embedding @@ -36,7 +38,8 @@ def __init__(self, atom_vocab_size: int, bond_vocab_size: int, n_readout: Optional[int] = 2, readout_dim: Optional[int] = 64, p_dropout: Optional[float] = 0.0, - aggr: Optional[str] = "add"): + aggr: Optional[str] = "add", + act_fn: Optional[callable] = F.softplus): """ Initialize the MoleculeGCN object. @@ -60,34 +63,59 @@ def __init__(self, atom_vocab_size: int, bond_vocab_size: int, Dropout probability for the dropout layers. aggr : str, optional (default="add") Aggregation scheme to use in the GeneralConv layer. + act_fn : callable, optional (default=`torch.nn.functional.softplus`) + Activation function, e.g., `torch.nn.functional.softplus`, + `torch.nn.functional.sigmoid`, `torch.nn.functional.relu`, etc. """ super().__init__() + # Store attributes self._p_dropout = p_dropout self._n_messages = n_messages + self.act_fn = act_fn + # Embedding layer for atoms self.emb_atom = nn.Embedding(atom_vocab_size, embedding_dim) + + # Embedding layer for bonds self.emb_bond = nn.Embedding(bond_vocab_size, embedding_dim) + # General convolution layer for atoms with specified aggregation method self.atom_conv = GeneralConv( embedding_dim, embedding_dim, embedding_dim, aggr=aggr ) + + # Edge convolution layer for bonds using a linear transformation self.bond_conv = EdgeConv(nn.Sequential( nn.Linear(2 * embedding_dim, embedding_dim) )) - self.readout = nn.ModuleList() - self.readout.append(nn.Sequential( - nn.Linear(embedding_dim, readout_dim) - )) - if n_readout > 1: - for _ in range(n_readout - 1): - self.readout.append(nn.Sequential( - nn.Linear(readout_dim, readout_dim) - )) - self.readout.append(nn.Sequential( - nn.Linear(readout_dim, output_dim) - )) + # Initialize the readout network if readout layers are specified + if n_readout > 0: + + # Create a list to hold the readout network modules + self.readout = nn.ModuleList() + + # First layer of the readout network + self.readout.append(nn.Sequential( + nn.Linear(embedding_dim, readout_dim) + )) + + # Additional hidden layers for the readout network if needed + if n_readout > 1: + for _ in range(n_readout - 1): + self.readout.append(nn.Sequential( + nn.Linear(readout_dim, readout_dim) + )) + + # Final layer of the readout network to produce output dimensions + self.readout.append(nn.Sequential( + nn.Linear(readout_dim, output_dim) + )) + + # No readout network if n_readout is 0 + else: + self.readout = None def forward( self, @@ -111,41 +139,64 @@ def forward( out_bond : torch.Tensor Bond-level representations after message passing. """ + # Extract node features, edge attributes, edge indices, and batch + # vector from data x, edge_attr, edge_index, batch = data.x, data.edge_attr, \ data.edge_index, data.batch + + # If no node features are provided, initialize with ones if data.num_node_features == 0: x = torch.ones(data.num_nodes, 1) + # Embed and activate atom features out_atom = self.emb_atom(x) - out_atom = F.softplus(out_atom) + out_atom = self.act_fn(out_atom) + # Embed and activate bond features out_bond = self.emb_bond(edge_attr) - out_bond = F.softplus(out_bond) + out_bond = self.act_fn(out_bond) + # Perform message passing for the specified number of steps for _ in range(self._n_messages): + # Update bond representations using edge convolution out_bond = self.bond_conv(out_bond, edge_index) - out_bond = F.softplus(out_bond) + out_bond = self.act_fn(out_bond) + + # Apply dropout out_bond = F.dropout( out_bond, p=self._p_dropout, training=self.training ) + # Update atom representations using general convolution out_atom = self.atom_conv(out_atom, edge_index, out_bond) - out_atom = F.softplus(out_atom) + out_atom = self.act_fn(out_atom) + + # Apply dropout out_atom = F.dropout( out_atom, p=self._p_dropout, training=self.training ) + # Aggregate atom representations across batches with global add pooling out = global_add_pool(out_atom, batch) - for layer in self.readout[:-1]: + # Process aggregated atom representation through the readout network + if self.readout is not None: - out = layer(out) - out = F.softplus(out) - out = F.dropout( - out, p=self._p_dropout, training=self.training - ) + # Iterate over all but the last layer of the readout network + for layer in self.readout[:-1]: + + # Pass through layer and activate + out = layer(out) + out = self.act_fn(out) + + # Apply dropout + out = F.dropout( + out, p=self._p_dropout, training=self.training + ) - out = self.readout[-1](out) + # Final layer of the readout network to produce output dimensions + out = self.readout[-1](out) + # Return final prediction, atom representations, bond representations return (out, out_atom, out_bond) diff --git a/tests/test_gcn.py b/tests/test_gcn.py index 6125ff0..9b2b84a 100644 --- a/tests/test_gcn.py +++ b/tests/test_gcn.py @@ -6,7 +6,7 @@ # Test instantiation of the MoleculeGCN class with various parameters -def test_moleculgcn_instantiation(): +def test_moleculegcn_instantiation(): model = MoleculeGCN( atom_vocab_size=10, @@ -21,21 +21,24 @@ def test_moleculgcn_instantiation(): assert model._p_dropout == 0.2 assert model._n_messages == 3 + assert model.readout is not None # Test the forward pass of the MoleculeGCN model -def test_moleculgcn_forward_pass(): +def test_moleculegcn_forward_pass(): atom_vocab_size = 10 bond_vocab_size = 5 + embedding_dim = 64 + output_dim = 2 n_atoms = 3 n_bonds = 5 model = MoleculeGCN( atom_vocab_size=atom_vocab_size, bond_vocab_size=bond_vocab_size, - output_dim=2, - embedding_dim=64, + output_dim=output_dim, + embedding_dim=embedding_dim, n_messages=3, n_readout=3, readout_dim=128, @@ -54,6 +57,95 @@ def test_moleculgcn_forward_pass(): out, out_atom, out_bond = model.forward(data) - assert out.shape == (1, 2) - assert out_atom.shape == (3, 64) - assert out_bond.shape == (5, 64) + assert out.shape == (1, output_dim) + assert out_atom.shape == (n_atoms, embedding_dim) + assert out_bond.shape == (n_bonds, embedding_dim) + + +# Test the forward pass of the MoleculeGCN model without readout layers +def test_moleculegcn_no_readout(): + + atom_vocab_size = 10 + bond_vocab_size = 5 + embedding_dim = 64 + n_atoms = 3 + n_bonds = 5 + + model = MoleculeGCN( + atom_vocab_size=atom_vocab_size, + bond_vocab_size=bond_vocab_size, + output_dim=None, + embedding_dim=embedding_dim, + n_messages=3, + n_readout=0, + readout_dim=128, + p_dropout=0.2 + ) + + assert model.readout is None + + n_atoms = 3 + + # Create a mock input for the forward pass + data = Data( + x=torch.randint(0, atom_vocab_size, (n_atoms,)), + edge_index=torch.randint(0, n_atoms, (2, n_bonds)), + edge_attr=torch.randint(0, bond_vocab_size, (n_bonds,)), + batch=torch.tensor([0, 0, 0], dtype=torch.long) + ) + + out, out_atom, out_bond = model.forward(data) + + assert out.shape == (1, embedding_dim) + assert out_atom.shape == (n_atoms, embedding_dim) + assert out_bond.shape == (n_bonds, embedding_dim) + + +# Test the forward pass of the MoleculeGCN model with different activation fns +def test_moleculegcn_act_fns(): + + functions = [ + torch.nn.functional.softplus, + torch.nn.functional.sigmoid, + torch.nn.functional.relu, + torch.nn.functional.tanh, + torch.nn.functional.leaky_relu + ] + + for fn in functions: + + atom_vocab_size = 10 + bond_vocab_size = 5 + embedding_dim = 64 + n_atoms = 3 + n_bonds = 5 + + model = MoleculeGCN( + atom_vocab_size=atom_vocab_size, + bond_vocab_size=bond_vocab_size, + output_dim=None, + embedding_dim=embedding_dim, + n_messages=3, + n_readout=0, + readout_dim=128, + p_dropout=0.2, + act_fn=fn + ) + + assert model.readout is None + + n_atoms = 3 + + # Create a mock input for the forward pass + data = Data( + x=torch.randint(0, atom_vocab_size, (n_atoms,)), + edge_index=torch.randint(0, n_atoms, (2, n_bonds)), + edge_attr=torch.randint(0, bond_vocab_size, (n_bonds,)), + batch=torch.tensor([0, 0, 0], dtype=torch.long) + ) + + out, out_atom, out_bond = model.forward(data) + + assert out.shape == (1, embedding_dim) + assert out_atom.shape == (n_atoms, embedding_dim) + assert out_bond.shape == (n_bonds, embedding_dim)