Skip to content

Commit

Permalink
Fixed docs, tests and tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
makgyver committed Nov 1, 2024
1 parent f61e9f1 commit 1bc8408
Show file tree
Hide file tree
Showing 16 changed files with 836 additions and 107 deletions.
2 changes: 1 addition & 1 deletion docs/api_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ In this page you can find the list of modules/submodules defined in `fluke` with
:nosignatures:
DataContainer
DummyDataContainer
FastDataLoader
DataSplitter
DummyDataSplitter
```

Expand Down
2 changes: 0 additions & 2 deletions docs/examples/tutorials/fluke_custom_dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,6 @@
"\n",
" def __init__(self):\n",
" super(MyMLP, self).__init__()\n",
" self.output_size = 2\n",
"\n",
" self.fc1 = torch.nn.Linear(2, 3)\n",
" self.fc2 = torch.nn.Linear(3, 2)\n",
"\n",
Expand Down
2 changes: 0 additions & 2 deletions docs/examples/tutorials/fluke_custom_nn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@
"\n",
" def __init__(self):\n",
" super(MyMLP, self).__init__()\n",
" self.output_size = 10\n",
"\n",
" self.fc1 = torch.nn.Linear(28*28, 100)\n",
" self.fc2 = torch.nn.Linear(100, 64)\n",
" self.fc3 = torch.nn.Linear(64, 10)\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/fluke.data.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ Classes
:nosignatures:
DataContainer
DummyDataContainer
FastDataLoader
DataSplitter
DummyDataSplitter
```

<h3>
Expand Down Expand Up @@ -113,13 +113,13 @@ Classes

<h3>

{bdg-primary}`class` ``fluke.data.DummyDataSplitter``
{bdg-primary}`class` ``fluke.data.DummyDataContainer``

</h3>

```{eval-rst}
.. autoclass:: fluke.data.DummyDataSplitter
.. autoclass:: fluke.data.DummyDataContainer
:members: assign
:show-inheritance:
Expand Down
Empty file added tests/__init__.py
Empty file.
7 changes: 4 additions & 3 deletions tests/test_alg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
from fluke.utils.log import Log # NOQA


GlobalSettings().set_evaluator(ClassificationEval(1, 10))
GlobalSettings().set_eval_cfg(DDict(post_fit=True, pre_fit=True))


def test_centralized_fl():
hparams = DDict(
# model="fluke.nets.MNIST_2NN",
Expand Down Expand Up @@ -70,9 +74,6 @@ def test_centralized_fl():
splitter = DataSplitter(mnist, client_split=0.1)
fl = CentralizedFL(2, splitter, hparams)

GlobalSettings().set_evaluator(ClassificationEval(1, 10))
GlobalSettings().set_eval_cfg(DDict(post_fit=True, pre_fit=True))

assert isinstance(fl.server.model, MNIST_2NN)
assert isinstance(fl.clients[0].hyper_params.loss_fn, CrossEntropyLoss)

Expand Down
2 changes: 0 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def test_client():
class Model(Linear):
def __init__(self):
super().__init__(10, 2)
self.output_size = 2

# initialize weights to 0
self.weight.data.fill_(0)
Expand Down Expand Up @@ -121,7 +120,6 @@ def test_pflclient():
class Model(Linear):
def __init__(self):
super().__init__(10, 2)
self.output_size = 2

# function that taken a 10-dimensional input returns a 0 if the
# sum of the first 7 elements is less than 2.5
Expand Down
18 changes: 6 additions & 12 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
sys.path.append("..")

from fluke import DDict # NOQA
from fluke.data import (DataContainer, DataSplitter, DummyDataSplitter, # NOQA
from fluke.data import (DataContainer, DataSplitter, DummyDataContainer, # NOQA
FastDataLoader)
from fluke.data.datasets import Datasets # NOQA

Expand Down Expand Up @@ -347,17 +347,11 @@ def test_splitter():
with pytest.raises(AssertionError):
DataSplitter(data_container, **cfg.exclude("dataset"))

dummy = DummyDataSplitter((ctr, cte, ste))

assert dummy.data_container is None
assert dummy.distribution == 'iid'
assert dummy.client_split is None
assert dummy.num_classes() == 10

(ctr_, cte_), ste_ = dummy.assign(10, batch_size=10)
assert ctr_ == ctr
assert cte_ == cte
assert ste_ == ste
dummy = DummyDataContainer(ctr, cte, ste, 10)
assert dummy.num_classes == 10
assert dummy.clients_tr == ctr
assert dummy.clients_te == cte
assert dummy.server_data == ste


if __name__ == "__main__":
Expand Down
137 changes: 76 additions & 61 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
sys.path.append(".")
sys.path.append("..")

from fluke.data import DataSplitter # NOQA
from fluke.data.datasets import Datasets # NOQA
from fluke.data.support import CINIC10, MNISTM # NOQA

Expand Down Expand Up @@ -160,20 +159,22 @@ def test_femnist():
femnist = Datasets.FEMNIST("./data")
except AssertionError:
return
assert len(femnist[0]) == 3597 # Total number of clients
assert len(femnist[1]) == 3597 # Total number of clients
assert femnist[0][0].tensors[0].shape[1:] == torch.Size([1, 28, 28]) # image shape
assert len(femnist.clients_tr) == 3597 # Total number of clients
assert len(femnist.clients_te) == 3597 # Total number of clients
assert femnist.clients_tr[0].tensors[0].shape[1:] == torch.Size([1, 28, 28]) # image shape
# number of labels matches number of images in each client
assert sum([femnist[0][i].tensors[1].shape[0] == femnist[0][i].tensors[0].shape[0]
for i in range(len(femnist[0]))]) == len(femnist[0])
assert sum([femnist.clients_tr[i].tensors[1].shape[0] == femnist.clients_tr[i].tensors[0].shape[0]
for i in range(len(femnist.clients_tr))]) == len(femnist.clients_tr)

assert femnist[1][0].tensors[0].shape[1:] == torch.Size([1, 28, 28]) # image shape
assert femnist.clients_te[0].tensors[0].shape[1:] == torch.Size([1, 28, 28]) # image shape
# number of labels matches number of images in each client
assert sum([femnist[1][i].tensors[1].shape[0] == femnist[1][i].tensors[0].shape[0]
for i in range(len(femnist[1]))]) == len(femnist[1])
assert sum([femnist.clients_te[i].tensors[1].shape[0] == femnist.clients_te[i].tensors[0].shape[0]
for i in range(len(femnist.clients_te))]) == len(femnist.clients_te)

lbl_train = set.union(*[set(femnist[0][i].tensors[1].numpy()) for i in range(len(femnist[0]))])
lbl_test = set.union(*[set(femnist[1][i].tensors[1].numpy()) for i in range(len(femnist[1]))])
lbl_train = set.union(*[set(femnist.clients_tr[i].tensors[1].numpy())
for i in range(len(femnist.clients_tr))])
lbl_test = set.union(*[set(femnist.clients_te[i].tensors[1].numpy())
for i in range(len(femnist.clients_te))])
assert len(lbl_train | lbl_test) == 62 # Total number of classes


Expand All @@ -182,22 +183,22 @@ def test_femnist_dig():
femnist_dig = Datasets.FEMNIST("./data", filter="digits")
except AssertionError:
return
assert len(femnist_dig[0]) == 3597 # Total number of clients
assert len(femnist_dig[1]) == 3597 # Total number of clients
assert femnist_dig[0][0].tensors[0].shape[1:] == torch.Size([1, 28, 28]) # image shape
assert len(femnist_dig.clients_tr) == 3597 # Total number of clients
assert len(femnist_dig.clients_te) == 3597 # Total number of clients
assert femnist_dig.clients_tr[0].tensors[0].shape[1:] == torch.Size([1, 28, 28]) # image shape
# number of labels matches number of images in each client
assert sum([femnist_dig[0][i].tensors[1].shape[0] == femnist_dig[0][i].tensors[0].shape[0]
for i in range(len(femnist_dig[0]))]) == len(femnist_dig[0])
assert sum([femnist_dig.clients_tr[i].tensors[1].shape[0] == femnist_dig.clients_tr[i].tensors[0].shape[0]
for i in range(len(femnist_dig.clients_tr))]) == len(femnist_dig.clients_tr)

assert femnist_dig[1][0].tensors[0].shape[1:] == torch.Size([1, 28, 28]) # image shape
assert femnist_dig.clients_te[0].tensors[0].shape[1:] == torch.Size([1, 28, 28]) # image shape
# number of labels matches number of images in each client
assert sum([femnist_dig[1][i].tensors[1].shape[0] == femnist_dig[1][i].tensors[0].shape[0]
for i in range(len(femnist_dig[1]))]) == len(femnist_dig[1])
assert sum([femnist_dig.clients_te[i].tensors[1].shape[0] == femnist_dig.clients_te[i].tensors[0].shape[0]
for i in range(len(femnist_dig.clients_te))]) == len(femnist_dig.clients_te)

lbl_train = set.union(*[set(femnist_dig[0][i].tensors[1].numpy())
for i in range(len(femnist_dig[0]))])
lbl_test = set.union(*[set(femnist_dig[1][i].tensors[1].numpy())
for i in range(len(femnist_dig[1]))])
lbl_train = set.union(*[set(femnist_dig.clients_tr[i].tensors[1].numpy())
for i in range(len(femnist_dig.clients_tr))])
lbl_test = set.union(*[set(femnist_dig.clients_te[i].tensors[1].numpy())
for i in range(len(femnist_dig.clients_te))])
assert len(lbl_train | lbl_test) == 10


Expand All @@ -206,22 +207,22 @@ def test_femnist_upp():
femnist_u = Datasets.FEMNIST("./data", filter="uppercase")
except AssertionError:
return
assert len(femnist_u[0]) == 3597 # Total number of clients
assert len(femnist_u[1]) == 3597 # Total number of clients
assert femnist_u[0][0].tensors[0].shape[1:] == torch.Size([1, 28, 28]) # image shape
assert len(femnist_u.clients_tr) == 3597 # Total number of clients
assert len(femnist_u.clients_te) == 3597 # Total number of clients
assert femnist_u.clients_tr[0].tensors[0].shape[1:] == torch.Size([1, 28, 28]) # image shape
# number of labels matches number of images in each client
assert sum([femnist_u[0][i].tensors[1].shape[0] == femnist_u[0][i].tensors[0].shape[0]
for i in range(len(femnist_u[0]))]) == len(femnist_u[0])
assert sum([femnist_u.clients_tr[i].tensors[1].shape[0] == femnist_u.clients_tr[i].tensors[0].shape[0]
for i in range(len(femnist_u.clients_tr))]) == len(femnist_u.clients_tr)

assert femnist_u[1][0].tensors[0].shape[1:] == torch.Size([1, 28, 28]) # image shape
assert femnist_u.clients_te[0].tensors[0].shape[1:] == torch.Size([1, 28, 28]) # image shape
# number of labels matches number of images in each client
assert sum([femnist_u[1][i].tensors[1].shape[0] == femnist_u[1][i].tensors[0].shape[0]
for i in range(len(femnist_u[1]))]) == len(femnist_u[1])
assert sum([femnist_u.clients_te[i].tensors[1].shape[0] == femnist_u.clients_te[i].tensors[0].shape[0]
for i in range(len(femnist_u.clients_te))]) == len(femnist_u.clients_te)

lbl_train = set.union(*[set(femnist_u[0][i].tensors[1].numpy())
for i in range(len(femnist_u[0]))])
lbl_test = set.union(*[set(femnist_u[1][i].tensors[1].numpy())
for i in range(len(femnist_u[1]))])
lbl_train = set.union(*[set(femnist_u.clients_tr[i].tensors[1].numpy())
for i in range(len(femnist_u.clients_tr))])
lbl_test = set.union(*[set(femnist_u.clients_te[i].tensors[1].numpy())
for i in range(len(femnist_u.clients_te))])
assert len(lbl_train | lbl_test) == 26


Expand All @@ -230,22 +231,22 @@ def test_femnist_low():
femnist_l = Datasets.FEMNIST("./data", filter="lowercase")
except AssertionError:
return
assert len(femnist_l[0]) == 3597 # Total number of clients
assert len(femnist_l[1]) == 3597 # Total number of clients
assert femnist_l[0][0].tensors[0].shape[1:] == torch.Size([1, 28, 28]) # image shape
assert len(femnist_l.clients_tr) == 3597 # Total number of clients
assert len(femnist_l.clients_te) == 3597 # Total number of clients
assert femnist_l.clients_tr[0].tensors[0].shape[1:] == torch.Size([1, 28, 28]) # image shape
# number of labels matches number of images in each client
assert sum([femnist_l[0][i].tensors[1].shape[0] == femnist_l[0][i].tensors[0].shape[0]
for i in range(len(femnist_l[0]))]) == len(femnist_l[0])
assert sum([femnist_l.clients_tr[i].tensors[1].shape[0] == femnist_l.clients_tr[i].tensors[0].shape[0]
for i in range(len(femnist_l.clients_tr))]) == len(femnist_l.clients_tr)

assert femnist_l[1][0].tensors[0].shape[1:] == torch.Size([1, 28, 28]) # image shape
assert femnist_l.clients_te[0].tensors[0].shape[1:] == torch.Size([1, 28, 28]) # image shape
# number of labels matches number of images in each client
assert sum([femnist_l[1][i].tensors[1].shape[0] == femnist_l[1][i].tensors[0].shape[0]
for i in range(len(femnist_l[1]))]) == len(femnist_l[1])
assert sum([femnist_l.clients_te[i].tensors[1].shape[0] == femnist_l.clients_te[i].tensors[0].shape[0]
for i in range(len(femnist_l.clients_te))]) == len(femnist_l.clients_te)

lbl_train = set.union(*[set(femnist_l[0][i].tensors[1].numpy())
for i in range(len(femnist_l[0]))])
lbl_test = set.union(*[set(femnist_l[1][i].tensors[1].numpy())
for i in range(len(femnist_l[1]))])
lbl_train = set.union(*[set(femnist_l.clients_tr[i].tensors[1].numpy())
for i in range(len(femnist_l.clients_tr))])
lbl_test = set.union(*[set(femnist_l.clients_te[i].tensors[1].numpy())
for i in range(len(femnist_l.clients_te))])
assert len(lbl_train | lbl_test) == 26


Expand All @@ -256,15 +257,15 @@ def test_shakespeare():
except AssertionError:
return

assert len(shake[0]) == 660
assert len(shake[1]) == 660
assert shake[0][0].tensors[0].shape[1:] == torch.Size([80]) # Shakespeare text
assert shake[1][0].tensors[0].shape[1:] == torch.Size([80]) # Shakespeare text
assert len(shake.clients_tr) == 660
assert len(shake.clients_te) == 660
assert shake.clients_tr[0].tensors[0].shape[1:] == torch.Size([80]) # Shakespeare text
assert shake.clients_te[0].tensors[0].shape[1:] == torch.Size([80]) # Shakespeare text

assert sum([shake[0][i].tensors[1].shape[0] == shake[0][i].tensors[0].shape[0]
for i in range(len(shake[0]))]) == len(shake[0])
assert sum([shake[1][i].tensors[1].shape[0] == shake[1][i].tensors[0].shape[0]
for i in range(len(shake[1]))]) == len(shake[1])
assert sum([shake.clients_tr[i].tensors[1].shape[0] == shake.clients_tr[i].tensors[0].shape[0]
for i in range(len(shake.clients_tr))]) == len(shake.clients_tr)
assert sum([shake.clients_te[i].tensors[1].shape[0] == shake.clients_te[i].tensors[0].shape[0]
for i in range(len(shake.clients_te))]) == len(shake.clients_te)


# ### Fashion MNIST
Expand All @@ -274,8 +275,8 @@ def test_fashion_mnist():
assert fashion.train[1].shape == torch.Size([60000])
assert fashion.test[0].shape == torch.Size([10000, 28, 28])
assert fashion.test[1].shape == torch.Size([10000])
assert fashion.num_classes == len(
set(fashion.train[1].unique().tolist() + fashion.test[1].unique().tolist()))
assert fashion.num_classes == len(set(fashion.train[1].unique().tolist() +
fashion.test[1].unique().tolist()))
assert fashion.num_classes == 10


Expand All @@ -297,6 +298,19 @@ def test_cinic10():
assert cinic.num_classes == 10


# ### CINIC10
def test_fcube():
data_container = Datasets.FCUBE()
ctr, cte, ste = data_container.clients_tr, data_container.clients_te, data_container.server_data
assert len(ctr) == 4
assert len(cte) == 4
assert ste is not None
assert ste.size == 100
print(cte[0].size, ctr[0].size, ste.size)
assert sum([cte[i].size + ctr[i].size for i in range(4)]) == 900
assert ctr[0].num_labels == 2


if __name__ == "__main__":
# test_mnist()
# test_mnist4d()
Expand All @@ -306,13 +320,14 @@ def test_cinic10():
# test_cifar100()
# test_mnistm()
# test_tinyimagenet()
test_femnist()
test_femnist_dig()
test_femnist_upp()
test_femnist_low()
# test_femnist()
# test_femnist_dig()
# test_femnist_upp()
# test_femnist_low()
# test_shakespeare()
# test_fashion_mnist()
# test_cinic10()
test_fcube()

# 98% coverate on datasets.py
# 88% coverage on support.py
2 changes: 1 addition & 1 deletion tests/test_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_convnets():
y1 = model(x)
y2 = model.forward_head(z)
assert y1.shape == (1, 10)
assert z.shape == (1, 400)
assert z.shape == (1, 84)
assert torch.allclose(y1, y2)

model = FedBN_CNN()
Expand Down
1 change: 0 additions & 1 deletion tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def finished(self, round):
class Model(Linear):
def __init__(self):
super().__init__(10, 2)
self.output_size = 2

def target_function(x):
return 0 if x[:7].sum() < 2.5 else 1
Expand Down
15 changes: 14 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
batch_norm_to_group_norm, check_model_fit_mem,
diff_model, flatten_parameters,
get_global_model_dict, get_local_model_dict,
merge_models, mix_networks,
merge_models, mix_networks, get_activation_size,
safe_load_state_dict, set_lambda_model)


Expand Down Expand Up @@ -496,6 +496,19 @@ def test_check_mem():
assert check_model_fit_mem(net, (28 * 28,), 100, "cuda")


def test_get_activation_size():
net = MNIST_2NN()
x = torch.randn(1, 28 * 28)
assert 10 == get_activation_size(net, None)
assert 10 == get_activation_size(net, x)

net = FedBN_CNN()
x = torch.randn(1, 1, 28, 28)
with pytest.raises(ValueError):
get_activation_size(net.encoder, None)
assert 10 == get_activation_size(net, x)


def test_alllayeroutput():
net = MNIST_2NN()
all_out = AllLayerOutputModel(net)
Expand Down
Loading

0 comments on commit 1bc8408

Please sign in to comment.