Skip to content

Commit

Permalink
Merge pull request #14 from makgyver/sam_dev
Browse files Browse the repository at this point in the history
Fixed FedNH
  • Loading branch information
makgyver authored Oct 29, 2024
2 parents 7dbc26a + 230c104 commit 594c9c3
Showing 1 changed file with 23 additions and 15 deletions.
38 changes: 23 additions & 15 deletions fluke/algorithms/fednh.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

"""Implementation of the [FedNH23]_ algorithm.
References:
Expand All @@ -24,6 +25,8 @@
from ..utils.model import STATE_DICT_KEYS_TO_IGNORE # NOQA
from . import PersonalizedFL # NOQA

from collections import Counter, OrderedDict
import torch.nn.functional as F
__all__ = [
"ProtoNet",
"ArgMaxModule",
Expand All @@ -43,18 +46,15 @@ def __init__(self, encoder: Module, n_protos: int, normalize: bool = False):
requires_grad=False)
self.prototypes.data = torch.nn.init.orthogonal_(torch.rand(n_protos,
encoder.output_size))
self.temperature = Parameter(torch.rand(1), requires_grad=True)
self.temperature = Parameter(torch.tensor(1.0))

def forward(self, x: torch.Tensor) -> torch.Tensor:
embeddings = self.encoder(x)
if self._normalize:
embeddings_norm = torch.norm(embeddings, p=2, dim=1, keepdim=True).clamp(min=1e-12)
embeddings = torch.div(embeddings, embeddings_norm)
prototype_norm = torch.norm(self.prototypes, p=2, dim=1, keepdim=True).clamp(min=1e-12)
normalized_prototypes = torch.div(self.prototypes, prototype_norm)
logits = torch.matmul(embeddings, normalized_prototypes.T)
else:
logits = torch.matmul(embeddings, self.prototypes.T)
embeddings_norm = torch.norm(embeddings, p=2, dim=1, keepdim=True).clamp(min=1e-12)
embeddings = torch.div(embeddings, embeddings_norm)
prototype_norm = torch.norm(self.prototypes, p=2, dim=1, keepdim=True).clamp(min=1e-12)
normalized_prototypes = torch.div(self.prototypes, prototype_norm)
logits = torch.matmul(embeddings, self.prototypes.T)
logits = self.temperature * logits
return embeddings, logits

Expand Down Expand Up @@ -91,13 +91,15 @@ def __init__(self,
proto_norm=proto_norm
)
self.model = self.personalized_model
self.count_by_class = torch.bincount(self.train_set.tensors[1])

def _update_protos(self, protos: Iterable[torch.Tensor]) -> None:
for label, prts in protos.items():
if prts.shape[0] > 0:
self.model.prototypes.data[label] = torch.sum(prts, dim=0) / prts.shape[0]
self.model.prototypes.data[label] /= torch.norm(
self.model.prototypes.data[label]).clamp(min=1e-12)
self.model.prototypes.data[label] = self.model.prototypes.data[label] * prts.shape[0]
else:
self.model.prototypes.data[label] = torch.zeros_like(
self.model.prototypes.data[label])
Expand All @@ -120,6 +122,7 @@ def fit(self, override_local_epochs: int = 0) -> float:
_, logits = self.model(X)
loss = self.hyper_params.loss_fn(logits, y)
loss.backward()
torch.nn.utils.clip_grad_norm_(parameters=filter(lambda p: p.requires_grad, self.model.parameters()), max_norm=10)
self.optimizer.step()
running_loss += loss.item()
self.scheduler.step()
Expand All @@ -128,11 +131,12 @@ def fit(self, override_local_epochs: int = 0) -> float:
running_loss /= (epochs * len(self.train_set))

protos = defaultdict(list)

for label in range(self.hyper_params.n_protos):
Xlbl = self.train_set.tensors[0][self.train_set.tensors[1] == label]
protos[label] = self.model(Xlbl)[0].detach().data
protos[label] = self.model.encoder(Xlbl).detach().data

self._update_protos(protos)
self._update_protos(protos)
return running_loss

def evaluate(self, evaluator: Evaluator, test_set: FastDataLoader) -> dict[str, float]:
Expand Down Expand Up @@ -178,13 +182,18 @@ def aggregate(self, eligible: Iterable[PFLClient]) -> None:
# server_lr = self.hyper_params.lr * self.hyper_params.lr_decay ** self.round
server_lr = 1.0
weight = server_lr / len(clients_models)
cl_weight = torch.zeros(self.hyper_params.n_protos)

# To get client.count_by_class is actually illegal in fluke, but irrelevant from an implementation point of view
for client in eligible:
cl_weight += client.count_by_class

# Aggregate prototypes
prototypes = self.model.prototypes.clone()
if self.hyper_params.weighted:
for label, protos in label_protos.items():
prototypes.data[label, :] = torch.sum(
weight * torch.stack(protos), dim=0)
torch.stack(protos)/cl_weight[label], dim=0)
else:
sim_weights = []
for protos in clients_protos:
Expand All @@ -204,9 +213,7 @@ def aggregate(self, eligible: Iterable[PFLClient]) -> None:
self.hyper_params.rho * self.model.prototypes.data

# Normalize the prototypes again
self.model.prototypes.data /= torch.norm(self.model.prototypes.data,
dim=0).clamp(min=1e-12)

self.model.prototypes.data /= torch.norm(self.model.prototypes.data, dim=0).clamp(min=1e-12)
# Aggregate models = Federated Averaging
avg_model_sd = OrderedDict()
clients_sd = [client.encoder.state_dict() for client in clients_models]
Expand Down Expand Up @@ -236,3 +243,4 @@ def get_client_class(self) -> PFLClient:

def get_server_class(self) -> Server:
return FedNHServer

0 comments on commit 594c9c3

Please sign in to comment.