This small framework can be used to scale up and train any Tensorflow or Pytrorch based graph neural network in federated manner on partitioned graphs or distributed graphs. Some graphs are too large and consumes huge amount of time if we train a neural network on it. In this case distributed learning is a potential approach but in some cases (Ex;- distributed graphs located on different datacenters) federated learning can be better becase it minimizes the communication overhead and speed up the training.
- Python 3.5 or higher (A conda or a Virtual environment is preferred)
- Tensorflow 2
- Stellargraph
- Scikit-learn
- Pandas
- Numpy
- Clone the repository
- Install dependencies as follows
pip3 install -r requirements.txt
To be trained using this framework a neural network should be wrapped with a python class that provides following methods. Few example models are in the models folder in the repo.
class Model:
def __init__(self,nodes,edges):
# define class variables here
def initialize(self,**hyper_params):
# define model initialization logic here
# **hyper_params dictionary can be used to pass any variable
return initial model weights
def set_weights(self,weights):
# set passed weights to your model
def get_weights(self):
# extract and return model weights
return model weights
def fit(self,epochs = 4):
# define training logic here
return model weights, training history
def gen_embeddings(self):
# this method is optional
return embeddings as a pandas dataframe
Repository
├── data: partitioned CORA dataset for testing
├── misc: miscellaneous
├── models: where models are stored
- upervised.py: supervised implementation of GRAPHSAGE
- supervised.py: unsupervised implementation of GRAPHSAGE
Before start clients or server set PYTHONHASHSEED environment variable to 0 for getting reproducible results.
PYTHONHASHSEED=0
Following arguments must be passed in following order
For unsupervised training
- path_weights - A location to extract and store model weights
- path_nodes - Where your graph nodes are stored
- path_edges - Where your graph edges are stored
- graph_id - ID for identify graphs
- partition_id - ID of the partition located in server that is used to initialize the weights
- num_clients - Number of clients that will be join for the federated training
- num_rounds - Number of federated rounds to be trained
- IP(optional - default localhost) - IP of the VM that fl_server is in
- PORT(optional - default 5000) - PORT that shuould be used to communicate with clients
python fl_server_unsupervised.py ./weights/ ./data/ ./data/ 4 0 2 3 localhost 5000
For supervised training
python fl_server.py ./weights/ ./data/ ./data/ 4 0 2 3 localhost 5000
For supervised + sheduled training
python fl_server_shed.py ./weights/ ./data4/ ./data4/ 3 0 1 3 localhost 5000
For unsupervised training
- path_weights - A location to extract and store model weights
- path_embeddings - A location to store node embeddings if you want to generate them
- path_nodes - Where your graph nodes are stored
- path_edges - Where your graph edges are stored
- graph_id - ID for identify graphs
- partition_id - ID of the partition located in server that is used to initialize the weights
- epochs - number of epochs to train
- IP(optional - default localhost) - IP of the VM that fl_server is in
- PORT(optional - default 5000) - PORT that fl_server is listening to
Any number of clients can be started but number of clients should be passed in to fl_server when it is started as explained above.
client 1
python fl_client_unsupervised.py ./weights/ ./embeddings/ ./data/ ./data/ 4 0 4 localhost 5000
client 2
python fl_client_unsupervised.py ./weights/ ./embeddings/ ./data/ ./data/ 4 1 4 localhost 5000
For supervised training
- path_weights - A location to extract and store model weights
- path_nodes - Where your graph nodes are stored
- path_edges - Where your graph edges are stored
- graph_id - ID for identify graphs
- partition_id - ID of the partition located in server that is used to initialize the weights
- epochs - number of epochs to train
- IP(optional - default localhost) - IP of the VM that fl_server is in
- PORT(optional - default 5000) - PORT that fl_server is listening to
client 1
python fl_client.py ./weights/ ./data/ ./data/ 4 0 10 localhost 5000
client 2
python fl_client.py ./weights/ ./data/ ./data/ 4 1 10 localhost 5000
For supervised + sheduled training
- client_id - ID for identify the client
- path_weights - A location to extract and store model weights
- path_nodes - Where your graph nodes are stored
- path_edges - Where your graph edges are stored
- graph_id - ID for identify graphs
- partition_ids - Comma seperated IDs of the partitions ordered in to the sheduling order ( Ex:- 1,2,5)
- epochs - number of epochs to train
- IP(optional - default localhost) - IP of the VM that fl_server is in
- PORT(optional - default 5000) - PORT that fl_server is listening to
client 1
python fl_client_shed.py 1 ./weights/ ./data4/ ./data4/ 3 0,1 3 localhost 5000
client 2
python fl_client_shed.py 2 ./weights/ ./data4/ ./data4/ 3 2,3 3 localhost 5000