Skip to content

Latest commit

 

History

History
61 lines (50 loc) · 2.28 KB

readme.md

File metadata and controls

61 lines (50 loc) · 2.28 KB

graph-networks

Minimalist implementation of Relational inductive biases, deep learning, and graph networks in PyTorch.
This codebase implements the Graph Network (GN) block with all the following components:

  • Node, edge and global models: ./models.py
  • Node, edge and global aggregations: ./aggregators.py

The GN block is then built upon thoses in ./graphnet.py

The implementation follows the original pseudocode flow in order to follow along (clarity being more important than efficiency here):

algo.png

To compare with the following code:

def forward(self, E, V, u, r, s):
    E_prime = torch.empty((self.Ne, self.e_dim))
    for k in range(self.Ne):
        e_k, v_rk, v_sk = E[k], V[r[k]], V[s[k]]
        # 1. Compute updated edge attributes
        e_prime_k = self.edge_model(e_k, v_rk, v_sk, u)  
        E_prime[k] = e_prime_k

    V_prime = torch.empty((self.Nn, self.n_dim))
    for i in range(self.Nn):
        if any(r == i):
            E_prime_i = torch.stack([E_prime[k] for k in range(self.Ne) if r[k] == i], dim=0)
            # 2. Aggregate edge attributes per node
            e_prime_bar_i = self.edge_to_node_agg(E_prime_i)  
            # 3. Compute updated node attributes
            v_prime_i = self.node_model(e_prime_bar_i, V[i], u)  
            V_prime[i] = v_prime_i

    # 4. Aggregate edge attributes globally
    e_prime_bar = self.edge_to_global_agg(E_prime)  
    # 5. Aggregate node attributes globally
    v_prime_bar = self.node_to_global_agg(V_prime)  
    # 6. Compute updated global attribute
    u_prime = self.global_model(e_prime_bar, v_prime_bar, u)  

    return E_prime, V_prime, u_prime

Example

A node classifier is trained on the Zachary's karate club dataset using a GN block and a decoder.
In order to launch the training, run the main script:

python main.py

This assumes you have pytorch as well as pyg installed.

epoch   0 ~ loss=1.38 ~ acc=35.3%
epoch   5 ~ loss=0.83 ~ acc=73.5%
epoch  10 ~ loss=0.09 ~ acc=100.0%
epoch  15 ~ loss=0.00 ~ acc=100.0%
epoch  20 ~ loss=0.00 ~ acc=100.0%
epoch  25 ~ loss=0.00 ~ acc=100.0%
epoch  30 ~ loss=0.00 ~ acc=100.0%