Skip to content

Commit

Permalink
Fixed issue with device during evaluation
Browse files Browse the repository at this point in the history
- Fixed model in kafé configuration
- Added unit test for kafé
  • Loading branch information
makgyver committed Oct 5, 2024
1 parent df5f4aa commit e7559f4
Show file tree
Hide file tree
Showing 13 changed files with 47 additions and 15 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ fluke_fl.egg-info/
tests/*.ipynb
runs/
repr_results/
checkpoint*/
checkpoint*/
chk_*
2 changes: 1 addition & 1 deletion configs/kafe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ hyperparameters:
server:
weighted: true
bandwidth: 1.0
model: CNN_Mnist
model: MNIST_2NN
name: fluke.algorithms.kafe.Kafe
2 changes: 1 addition & 1 deletion fluke/algorithms/fedhp.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def fit(self, override_local_epochs: int = 0) -> float:
def evaluate(self, evaluator: Evaluator, test_set: FastDataLoader) -> dict[str, float]:
if test_set is not None and self.initial_prototypes is not None:
model = FedHPModel(self.model)
return evaluator.evaluate(self._last_round, model, test_set)
return evaluator.evaluate(self._last_round, model, test_set, device=self.device)
return {}

def finalize(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions fluke/algorithms/fednh.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def fit(self, override_local_epochs: int = 0) -> float:
def evaluate(self, evaluator: Evaluator, test_set: FastDataLoader) -> dict[str, float]:
if test_set is not None and self.model is not None:
model = ArgMaxModule(self.model)
return evaluator.evaluate(self._last_round, model, test_set)
return evaluator.evaluate(self._last_round, model, test_set, device=self.device)
return {}

def finalize(self) -> None:
Expand Down Expand Up @@ -225,7 +225,7 @@ def aggregate(self, eligible: Iterable[PFLClient]) -> None:
def evaluate(self, evaluator: Evaluator, test_set: FastDataLoader) -> dict[str, float]:
if self.test_set is not None:
model = ArgMaxModule(self.model)
return evaluator.evaluate(self.rounds + 1, model, self.test_set)
return evaluator.evaluate(self.rounds + 1, model, self.test_set, device=self.device)
return {}


Expand Down
2 changes: 1 addition & 1 deletion fluke/algorithms/fedproto.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def fit(self, override_local_epochs: int = 0) -> float:
def evaluate(self, evaluator: Evaluator, test_set: FastDataLoader) -> dict[str, float]:
if test_set is not None and self.prototypes[0] is not None:
model = FedProtoModel(self.model, self.prototypes, self.device)
return evaluator.evaluate(self._last_round, model, test_set)
return evaluator.evaluate(self._last_round, model, test_set, device=self.device)
return {}

def finalize(self) -> None:
Expand Down
3 changes: 2 additions & 1 deletion fluke/algorithms/fedrod.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ def evaluate(self, evaluator: Evaluator, test_set: FastDataLoader) -> dict[str,
if test_set is not None and self.model is not None and self.inner_model is not None:
return evaluator.evaluate(self._last_round,
RODModel(self.model, self.inner_model),
test_set)
test_set,
device=self.device)
return {}


Expand Down
9 changes: 6 additions & 3 deletions fluke/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def evaluate(self,
the results.
"""
if test_set is not None and self.model is not None:
return evaluator.evaluate(self._last_round, self.model, test_set)
return evaluator.evaluate(self._last_round, self.model, test_set, device=self.device)
return {}

def finalize(self) -> None:
Expand Down Expand Up @@ -319,7 +319,7 @@ def __str__(self) -> str:
hpstr = ", ".join([f"{h}={str(v)}" for h, v in self.hyper_params.items()])
hpstr = ", " + hpstr if hpstr else ""
return f"{self.__class__.__name__}[{self._index}](optim={self.optimizer_cfg}, " + \
f"batch_size={self.train_set._batch_size}{hpstr})"
f"batch_size={self.train_set._batch_size}{hpstr})"

def __repr__(self) -> str:
return str(self)
Expand Down Expand Up @@ -380,7 +380,10 @@ def evaluate(self, evaluator: Evaluator, test_set: FastDataLoader) -> dict[str,
the results.
"""
if test_set is not None and self.personalized_model is not None:
return evaluator.evaluate(self._last_round, self.personalized_model, test_set)
return evaluator.evaluate(self._last_round,
self.personalized_model,
test_set,
device=self.device)
return {}

def state_dict(self) -> dict:
Expand Down
3 changes: 2 additions & 1 deletion fluke/data/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def __init__(self, root, train=True, transform=None, target_transform=None, down

data_file = (self.training_file if self.train
else self.test_file)
self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file),
weights_only=False)

def __getitem__(self, index):
"""Get images and target for data loader.
Expand Down
4 changes: 2 additions & 2 deletions fluke/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def centralized(alg_cfg: str = typer.Argument(..., help='Config file for the alg
optimizer.step()
scheduler.step()

epoch_eval = evaluator.evaluate(e+1, model, test_loader, criterion)
epoch_eval = evaluator.evaluate(e+1, model, test_loader, criterion, device=device)
history.append(epoch_eval)
for k, v in epoch_eval.items():
log.add_scalar(k, v, e+1)
Expand Down Expand Up @@ -198,7 +198,7 @@ def clients_only(alg_cfg: str = typer.Argument(..., help='Config file for the al
optimizer.step()
scheduler.step()

client_eval = evaluator.evaluate(e+1, model, test_loader, criterion)
client_eval = evaluator.evaluate(e+1, model, test_loader, criterion, device=device)
running_evals[i].append(client_eval)

log.pretty_log(client_eval, title=f"Client [{i}] Performance")
Expand Down
2 changes: 1 addition & 1 deletion fluke/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def evaluate(self, evaluator: Evaluator, test_set: FastDataLoader) -> dict[str,
the results.
"""
if test_set is not None:
return evaluator.evaluate(self.rounds + 1, self.model, test_set)
return evaluator.evaluate(self.rounds + 1, self.model, test_set, device=self.device)
return {}

def finalize(self) -> None:
Expand Down
17 changes: 17 additions & 0 deletions tests/configs/alg/kafe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
hyperparameters:
client:
batch_size: 64
local_epochs: 5
loss: CrossEntropyLoss
optimizer:
lr: 0.1
momentum: 0.5
# weight_decay: 0.0001
scheduler:
gamma: 1
step_size: 1
server:
weighted: true
bandwidth: 1.0
model: MNIST_2NN
name: fluke.algorithms.kafe.Kafe
4 changes: 4 additions & 0 deletions tests/configs/exp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ exp:
seed: 42
logger:
name: local
eval:
locals: true
pre_fit: true
post_fit: true
protocol:
eligible_perc: 1
n_clients: 100
Expand Down
7 changes: 6 additions & 1 deletion tests/test_alg.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,10 @@ def test_fedsgd():
# "./tests/configs/alg/fedsgd.yaml", oncpu=False)


def test_kafe():
kafe, log = _test_algo("./tests/configs/exp.yaml", "./tests/configs/alg/kafe.yaml")


def test_lgfedavg():
lgfedavg, log = _test_algo("./tests/configs/exp.yaml", "./tests/configs/alg/lg_fedavg.yaml")
# lgfedavg, log = _test_algo("./tests/configs/exp.yaml",
Expand Down Expand Up @@ -451,7 +455,7 @@ def test_superfed():
# test_fedrep()
# test_lgfedavg()
# test_moon()
test_fedbn()
# test_fedbn()
# test_pfedme() # TO BE CHECKED
# test_scaffold()
# test_superfed()
Expand All @@ -465,3 +469,4 @@ def test_superfed():
# test_fedavgm()
# test_fedhp()
# test_fednh()
test_kafe()

0 comments on commit e7559f4

Please sign in to comment.