diff --git a/fluke/algorithms/fednh.py b/fluke/algorithms/fednh.py index 127f181..81bb918 100644 --- a/fluke/algorithms/fednh.py +++ b/fluke/algorithms/fednh.py @@ -1,3 +1,4 @@ + """Implementation of the [FedNH23]_ algorithm. References: @@ -24,6 +25,8 @@ from ..utils.model import STATE_DICT_KEYS_TO_IGNORE # NOQA from . import PersonalizedFL # NOQA +from collections import Counter, OrderedDict +import torch.nn.functional as F __all__ = [ "ProtoNet", "ArgMaxModule", @@ -43,18 +46,15 @@ def __init__(self, encoder: Module, n_protos: int, normalize: bool = False): requires_grad=False) self.prototypes.data = torch.nn.init.orthogonal_(torch.rand(n_protos, encoder.output_size)) - self.temperature = Parameter(torch.rand(1), requires_grad=True) + self.temperature = Parameter(torch.tensor(1.0)) def forward(self, x: torch.Tensor) -> torch.Tensor: embeddings = self.encoder(x) - if self._normalize: - embeddings_norm = torch.norm(embeddings, p=2, dim=1, keepdim=True).clamp(min=1e-12) - embeddings = torch.div(embeddings, embeddings_norm) - prototype_norm = torch.norm(self.prototypes, p=2, dim=1, keepdim=True).clamp(min=1e-12) - normalized_prototypes = torch.div(self.prototypes, prototype_norm) - logits = torch.matmul(embeddings, normalized_prototypes.T) - else: - logits = torch.matmul(embeddings, self.prototypes.T) + embeddings_norm = torch.norm(embeddings, p=2, dim=1, keepdim=True).clamp(min=1e-12) + embeddings = torch.div(embeddings, embeddings_norm) + prototype_norm = torch.norm(self.prototypes, p=2, dim=1, keepdim=True).clamp(min=1e-12) + normalized_prototypes = torch.div(self.prototypes, prototype_norm) + logits = torch.matmul(embeddings, self.prototypes.T) logits = self.temperature * logits return embeddings, logits @@ -91,6 +91,7 @@ def __init__(self, proto_norm=proto_norm ) self.model = self.personalized_model + self.count_by_class = torch.bincount(self.train_set.tensors[1]) def _update_protos(self, protos: Iterable[torch.Tensor]) -> None: for label, prts in protos.items(): @@ -98,6 +99,7 @@ def _update_protos(self, protos: Iterable[torch.Tensor]) -> None: self.model.prototypes.data[label] = torch.sum(prts, dim=0) / prts.shape[0] self.model.prototypes.data[label] /= torch.norm( self.model.prototypes.data[label]).clamp(min=1e-12) + self.model.prototypes.data[label] = self.model.prototypes.data[label] * prts.shape[0] else: self.model.prototypes.data[label] = torch.zeros_like( self.model.prototypes.data[label]) @@ -120,6 +122,7 @@ def fit(self, override_local_epochs: int = 0) -> float: _, logits = self.model(X) loss = self.hyper_params.loss_fn(logits, y) loss.backward() + torch.nn.utils.clip_grad_norm_(parameters=filter(lambda p: p.requires_grad, self.model.parameters()), max_norm=10) self.optimizer.step() running_loss += loss.item() self.scheduler.step() @@ -128,11 +131,12 @@ def fit(self, override_local_epochs: int = 0) -> float: running_loss /= (epochs * len(self.train_set)) protos = defaultdict(list) + for label in range(self.hyper_params.n_protos): Xlbl = self.train_set.tensors[0][self.train_set.tensors[1] == label] - protos[label] = self.model(Xlbl)[0].detach().data + protos[label] = self.model.encoder(Xlbl).detach().data - self._update_protos(protos) + self._update_protos(protos) return running_loss def evaluate(self, evaluator: Evaluator, test_set: FastDataLoader) -> dict[str, float]: @@ -178,13 +182,18 @@ def aggregate(self, eligible: Iterable[PFLClient]) -> None: # server_lr = self.hyper_params.lr * self.hyper_params.lr_decay ** self.round server_lr = 1.0 weight = server_lr / len(clients_models) + cl_weight = torch.zeros(self.hyper_params.n_protos) + + # To get client.count_by_class is actually illegal in fluke, but irrelevant from an implementation point of view + for client in eligible: + cl_weight += client.count_by_class # Aggregate prototypes prototypes = self.model.prototypes.clone() if self.hyper_params.weighted: for label, protos in label_protos.items(): prototypes.data[label, :] = torch.sum( - weight * torch.stack(protos), dim=0) + torch.stack(protos)/cl_weight[label], dim=0) else: sim_weights = [] for protos in clients_protos: @@ -204,9 +213,7 @@ def aggregate(self, eligible: Iterable[PFLClient]) -> None: self.hyper_params.rho * self.model.prototypes.data # Normalize the prototypes again - self.model.prototypes.data /= torch.norm(self.model.prototypes.data, - dim=0).clamp(min=1e-12) - + self.model.prototypes.data /= torch.norm(self.model.prototypes.data, dim=0).clamp(min=1e-12) # Aggregate models = Federated Averaging avg_model_sd = OrderedDict() clients_sd = [client.encoder.state_dict() for client in clients_models] @@ -236,3 +243,4 @@ def get_client_class(self) -> PFLClient: def get_server_class(self) -> Server: return FedNHServer +