Skip to content

Commit

Permalink
Feature/SK-811 | Example using Differential Privacy (#698)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattiasakesson authored Sep 12, 2024
1 parent 1ed8250 commit ed5ae94
Show file tree
Hide file tree
Showing 7 changed files with 498 additions and 0 deletions.
56 changes: 56 additions & 0 deletions examples/mnist-pytorch-DPSGD/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
FEDn Project: Federated Differential Privacy MNIST (Opacus + PyTorch)
----------------------------------------------------------------------

This example FEDn Project demonstrates how Differential Privacy can be integrated to enhance the confidentiality of client data.
We have expanded our baseline MNIST-PyTorch example by incorporating the Opacus framework, which is specifically designed for PyTorch models.



Prerequisites
-------------

- `Python >=3.8, <=3.12 <https://www.python.org/downloads>`__
- `A project in FEDn Studio <https://fedn.scaleoutsystems.com/signup>`__

Edit Differential Privacy budget
--------------------------
- The **Differential Privacy budget** (`FINAL_EPSILON`, `DELTA`) is configured in the `compute` package at `client/train.py` (lines 35 and 39).
- If `HARDLIMIT` (line 40) is set to `True`, the `FINAL_EPSILON` will not exceed its specified limit.
- If `HARDLIMIT` is set to `False`, the expected `FINAL_EPSILON` will be around its specified value given the server runs `GLOBAL_ROUNDS` variable (line 36).

Creating the compute package and seed model
-------------------------------------------

Install fedn:

.. code-block::
pip install fedn
Clone this repository, then locate into this directory:

.. code-block::
git clone https://github.com/scaleoutsystems/fedn.git
cd fedn/examples/mnist-pytorch-DPSGD
Create the compute package:

.. code-block::
fedn package create --path client
This creates a file 'package.tgz' in the project folder.

Next, generate the seed model:

.. code-block::
fedn run build --path client
This will create a model file 'seed.npz' in the root of the project. This step will take a few minutes, depending on hardware and internet connection (builds a virtualenv).

Running the project on FEDn
----------------------------

To learn how to set up your FEDn Studio project and connect clients, take the quickstart tutorial: https://fedn.readthedocs.io/en/stable/quickstart.html.
99 changes: 99 additions & 0 deletions examples/mnist-pytorch-DPSGD/client/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import os
from math import floor

import torch
import torchvision

dir_path = os.path.dirname(os.path.realpath(__file__))
abs_path = os.path.abspath(dir_path)


def get_data(out_dir="data"):
# Make dir if necessary
if not os.path.exists(out_dir):
os.mkdir(out_dir)

# Only download if not already downloaded
if not os.path.exists(f"{out_dir}/train"):
torchvision.datasets.MNIST(root=f"{out_dir}/train", transform=torchvision.transforms.ToTensor, train=True, download=True)
if not os.path.exists(f"{out_dir}/test"):
torchvision.datasets.MNIST(root=f"{out_dir}/test", transform=torchvision.transforms.ToTensor, train=False, download=True)


def load_data(data_path, is_train=True):
"""Load data from disk.
:param data_path: Path to data file.
:type data_path: str
:param is_train: Whether to load training or test data.
:type is_train: bool
:return: Tuple of data and labels.
:rtype: tuple
"""
print("data_path is None: ", data_path is None)
if data_path is None:
data_path = os.environ.get("FEDN_DATA_PATH", abs_path + "/data/clients/1/mnist.pt")

print("data path: ", data_path)
data = torch.load(data_path)

if is_train:
X = data["x_train"]
y = data["y_train"]
else:
X = data["x_test"]
y = data["y_test"]

# Normalize
X = X / 255

return X, y


def splitset(dataset, parts):
n = dataset.shape[0]
local_n = floor(n / parts)
result = []
for i in range(parts):
result.append(dataset[i * local_n : (i + 1) * local_n])
return result


def split(out_dir="data"):
n_splits = int(os.environ.get("FEDN_NUM_DATA_SPLITS", 2))

# Make dir
if not os.path.exists(f"{out_dir}/clients"):
os.mkdir(f"{out_dir}/clients")

# Load and convert to dict
train_data = torchvision.datasets.MNIST(root=f"{out_dir}/train", transform=torchvision.transforms.ToTensor, train=True)
test_data = torchvision.datasets.MNIST(root=f"{out_dir}/test", transform=torchvision.transforms.ToTensor, train=False)
data = {
"x_train": splitset(train_data.data, n_splits),
"y_train": splitset(train_data.targets, n_splits),
"x_test": splitset(test_data.data, n_splits),
"y_test": splitset(test_data.targets, n_splits),
}

# Make splits
for i in range(n_splits):
subdir = f"{out_dir}/clients/{str(i+1)}"
if not os.path.exists(subdir):
os.mkdir(subdir)
torch.save(
{
"x_train": data["x_train"][i],
"y_train": data["y_train"][i],
"x_test": data["x_test"][i],
"y_test": data["y_test"][i],
},
f"{subdir}/mnist.pt",
)


if __name__ == "__main__":
# Prepare data if not already done
if not os.path.exists(abs_path + "/data/clients/1"):
get_data()
split()
12 changes: 12 additions & 0 deletions examples/mnist-pytorch-DPSGD/client/fedn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
python_env: python_env.yaml
entry_points:
build:
command: python model.py
startup:
command: python data.py
train:
command: python train.py
validate:
command: python validate.py
predict:
command: python predict.py
76 changes: 76 additions & 0 deletions examples/mnist-pytorch-DPSGD/client/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import collections

import torch

from fedn.utils.helpers.helpers import get_helper

HELPER_MODULE = "numpyhelper"
helper = get_helper(HELPER_MODULE)


def compile_model():
"""Compile the pytorch model.
:return: The compiled model.
:rtype: torch.nn.Module
"""

class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = torch.nn.Linear(784, 64)
self.fc2 = torch.nn.Linear(64, 32)
self.fc3 = torch.nn.Linear(32, 10)

def forward(self, x):
x = torch.nn.functional.relu(self.fc1(x.reshape(x.size(0), 784)))
x = torch.nn.functional.dropout(x, p=0.5, training=self.training)
x = torch.nn.functional.relu(self.fc2(x))
x = torch.nn.functional.log_softmax(self.fc3(x), dim=1)
return x

return Net()


def save_parameters(model, out_path):
"""Save model paramters to file.
:param model: The model to serialize.
:type model: torch.nn.Module
:param out_path: The path to save to.
:type out_path: str
"""
parameters_np = [val.cpu().numpy() for _, val in model.state_dict().items()]
helper.save(parameters_np, out_path)


def load_parameters(model_path):
"""Load model parameters from file and populate model.
param model_path: The path to load from.
:type model_path: str
:return: The loaded model.
:rtype: torch.nn.Module
"""
model = compile_model()
parameters_np = helper.load(model_path)

params_dict = zip(model.state_dict().keys(), parameters_np)
state_dict = collections.OrderedDict({key: torch.tensor(x) for key, x in params_dict})
model.load_state_dict(state_dict, strict=True)
return model


def init_seed(out_path="seed.npz"):
"""Initialize seed model and save it to file.
:param out_path: The path to save the seed model to.
:type out_path: str
"""
# Init and save
model = compile_model()
save_parameters(model, out_path)


if __name__ == "__main__":
init_seed("../seed.npz")
15 changes: 15 additions & 0 deletions examples/mnist-pytorch-DPSGD/client/python_env.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
name: mnist-pytorch
build_dependencies:
- pip
- setuptools
- wheel
dependencies:
- fedn
- torch==2.4.1; (sys_platform == "darwin" and platform_machine == "arm64") or (sys_platform == "win" or sys_platform == "linux")
# PyTorch macOS x86 builds deprecation
- torch==2.2.2; sys_platform == "darwin" and platform_machine == "x86_64"
- torchvision==0.19.1; (sys_platform == "darwin" and platform_machine == "arm64") or (sys_platform == "win" or sys_platform == "linux")
- torchvision==0.17.2; sys_platform == "darwin" and platform_machine == "x86_64"
- numpy==2.0.2; (sys_platform == "darwin" and platform_machine == "arm64") or (sys_platform == "win" or sys_platform == "linux")
- numpy==1.26.4; sys_platform == "darwin" and platform_machine == "x86_64"
- opacus
Loading

0 comments on commit ed5ae94

Please sign in to comment.