From caadd61199047bc03048862782d83caa5cf532ad Mon Sep 17 00:00:00 2001 From: private-mechanism <45845531+private-mechanism@users.noreply.github.com> Date: Fri, 28 Jun 2024 16:00:06 +0800 Subject: [PATCH] Byzantine dynamic defense --- .../attack/byzantine_attacks/fang_attack.py | 179 +++++++++++ .../attack/byzantine_attacks/she_attack.py | 282 ++++++++++++++++++ federatedscope/attack/trainer/__init__.py | 5 +- .../attack/trainer/gaussian_attack_trainer.py | 4 +- .../attack/trainer/label_flipping_trainer.py | 78 +++++ .../trainer/malicious_SGA_attack_trainer.py | 46 +++ federatedscope/core/aggregators/__init__.py | 16 +- .../core/aggregators/bulyan_aggregator.py | 107 +++++++ .../aggregators/clients_avg_aggregator.py | 7 +- .../aggregators/dynamic_defense_aggregator.py | 168 +++++++++++ .../core/aggregators/krum_aggregator.py | 18 +- .../core/aggregators/median_aggregator.py | 52 ++++ .../aggregators/normbounding_aggregator.py | 64 ++++ .../aggregators/trimmedmean_aggregator.py | 58 ++++ .../core/auxiliaries/aggregator_builder.py | 31 +- federatedscope/core/configs/cfg_aggregator.py | 58 +++- federatedscope/core/workers/client.py | 28 +- federatedscope/core/workers/server.py | 127 +++++--- .../she_attack_convnet2_femnist.yaml | 45 +++ tests/test_she_attack.py | 110 +++++++ 20 files changed, 1395 insertions(+), 88 deletions(-) create mode 100644 federatedscope/attack/byzantine_attacks/fang_attack.py create mode 100644 federatedscope/attack/byzantine_attacks/she_attack.py create mode 100644 federatedscope/attack/trainer/label_flipping_trainer.py create mode 100644 federatedscope/attack/trainer/malicious_SGA_attack_trainer.py create mode 100644 federatedscope/core/aggregators/bulyan_aggregator.py create mode 100644 federatedscope/core/aggregators/dynamic_defense_aggregator.py create mode 100644 federatedscope/core/aggregators/median_aggregator.py create mode 100644 federatedscope/core/aggregators/normbounding_aggregator.py create mode 100644 federatedscope/core/aggregators/trimmedmean_aggregator.py create mode 100644 scripts/attack_exp_scripts/byzantine_attacks/she_attack_convnet2_femnist.yaml create mode 100644 tests/test_she_attack.py diff --git a/federatedscope/attack/byzantine_attacks/fang_attack.py b/federatedscope/attack/byzantine_attacks/fang_attack.py new file mode 100644 index 000000000..5ca9e95fe --- /dev/null +++ b/federatedscope/attack/byzantine_attacks/fang_attack.py @@ -0,0 +1,179 @@ +import logging +import copy +import torch +import random +import numpy as np +from federatedscope.core.aggregators import ClientsAvgAggregator +from federatedscope.core.aggregators.krum_aggregator import KrumAggregator +from federatedscope.core.aggregators.median_aggregator import MedianAggregator +from federatedscope.core.aggregators.trimmedmean_aggregator import \ + TrimmedmeanAggregator +from federatedscope.core.aggregators.bulyan_aggregator import \ + BulyanAggregator + +class Fang_adaptive_attacks(): + """ + Define the Fang_adaptive_attacks with can attack a specific aggregation rule with pertinence. + """ + + def __init__(self, model=None, device='cpu',config=None): + self.str2defense = {'krum': KrumAggregator(model,device,config)._para_avg_with_krum, + 'median': MedianAggregator(model,device,config)._aggre_with_median, + 'trmean': TrimmedmeanAggregator(model,device,config)._aggre_with_trimmedmean, + 'bulyan': BulyanAggregator(model,device,config)._aggre_with_bulyan} + self.byzantine_node_num = config.aggregator.byzantine_node_num + self.config = config + self.model = model + + + def compute_lambda(self, model_re, all_updates, model_byzantine): + """ + Calculate an upper bound of the scaling factor used in Fang_krum attack. + """ + distances = [] + n_benign, d = all_updates.shape + for update in all_updates: + distance = torch.norm((all_updates - update), dim=1) + distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0) + distances[distances == 0] = 10000 + distances = torch.sort(distances, dim=1)[0] + scores = torch.sum(distances[:, :n_benign - 2 - len(model_byzantine)], dim=1) + min_score = torch.min(scores) + term_1 = min_score / ((n_benign - len(model_byzantine) - 1) * torch.sqrt(torch.Tensor([d]))[0]) + max_wre_dist = torch.max(torch.norm((all_updates - model_re), dim=1)) / (torch.sqrt(torch.Tensor([d]))[0]) + return term_1 + max_wre_dist + + def fang_krum(self, models): + """ + Please refer to "Local model poisoning attacks to byzantine-robust federated learning" + https://www.usenix.org/system/files/sec20-fang.pdf + """ + all_updates = torch.stack([each_model[0][1] for each_model in models]) + model_re = torch.mean(all_updates, 0) + deviation = torch.sign(model_re) + model_byzantine = list() + mal_id = {} + index_ = [] + for index, model in enumerate(models): + if model[1] < self.byzantine_node_num + 1: + mal_id[model[1]] = model[0][0] + model_byzantine.append(model[0]) + index_.append(index) + for i in range(len(index_) - 1, -1, -1): + models.pop(index_[i]) + agg_update = copy.deepcopy(model_re) + lamda = self.compute_lambda(agg_update, all_updates, model_byzantine) + threshold = 1e-5 + mal_update = [] + while lamda > threshold: + mal_update = agg_update - lamda * deviation + for id in mal_id: + mal_update = self._insert_the_bn(mal_update, mal_id[id][1]) + models.append(((mal_id[id][0], mal_update), id)) + agg_num = self.config.aggregator.fang_krum_agg_num + _, krum_candidate = self.str2defense['krum'](models, agg_num) + for i in range(len(model_byzantine)): + models.pop(-1) + if krum_candidate[0] < self.byzantine_node_num + 1: + for id in mal_id: + models.append(((mal_id[id][0], mal_update), id)) + return models + else: + mal_update = [] + lamda *= 0.5 + if not len(mal_update): + mal_update = agg_update - lamda * deviation + for id in mal_id: + models.append(((mal_id[id][0], mal_update), id)) + return models + + ######## fang attack median ################################# + + def fang_median(self, models): + """ + Please refer to "Local model poisoning attacks to byzantine-robust federated learning" + https://www.usenix.org/system/files/sec20-fang.pdf + """ + model_byzantine = list() + mal_id = {} + index_ = [] + for index, model in enumerate(models): + if model[1] < self.byzantine_node_num + 1: + mal_id[model[1]] = model[0][0] + model_byzantine.append(model[0]) + index_.append(index) + for i in range(len(index_) - 1, -1, -1): + models.pop(index_[i]) + if len(model_byzantine) == 0: + return models + else: + temp = torch.stack([each_model[0][1] for each_model in models], 0) + model_re = torch.mean(temp, 0) + if len(model_byzantine) == 1: + model_std = torch.zeros_like(model_re) + else: + model_std = torch.std(temp, 0) + deviation = torch.sign(model_re) + max_vector_low = model_re + 3 * model_std + max_vector_hig = model_re + 4 * model_std + min_vector_low = model_re - 4 * model_std + min_vector_hig = model_re - 3 * model_std + max_range = torch.cat((max_vector_low[:, None], max_vector_hig[:, None]), dim=1) + min_range = torch.cat((min_vector_low[:, None], min_vector_hig[:, None]), dim=1) + rand = torch.from_numpy(np.random.uniform(0, 1, [len(deviation), len(model_byzantine)])).type( + torch.FloatTensor) # 【800,4】 + max_rand = torch.stack([max_range[:, 0]] * rand.shape[1]).T + rand * torch.stack( + [max_range[:, 1] - max_range[:, 0]] * rand.shape[1]).T + min_rand = torch.stack([min_range[:, 0]] * rand.shape[1]).T + rand * torch.stack( + [min_range[:, 1] - min_range[:, 0]] * rand.shape[1]).T + former = torch.stack([(deviation < 0).type(torch.FloatTensor)] * max_rand.shape[1]).T * max_rand + latter = torch.stack([(deviation > 0).type(torch.FloatTensor)] * min_rand.shape[1]).T * min_rand + mal_vec = (former + latter).T + tuple_ = torch.chunk(mal_vec, len(model_byzantine), dim=0) + i = 0 + for id in mal_id: + fake_update = tuple_[i].reshape(model_re.shape) + models.append(((mal_id[id][0], fake_update), id)) + i += 1 + return models + + + ##### methods to transform the model update and tensor #### + def _flatten_updates(self, model): + model_update=[] + init_model = self.model.state_dict() + for key in init_model: + model_update.append(model[key].view(-1)) + return torch.cat(model_update, dim = 0) + + def _flatten_updates_without_bn(self, model): + model_update=[] + init_model = self.model.state_dict() + for key in init_model: + if 'bn' not in key: + model_update.append(model[key].view(-1)) + return torch.cat(model_update, dim = 0) + + def _reconstruct_updates(self, flatten_updates): + start_idx = 0 + init_model = self.model.state_dict() + reconstructed_model = copy.deepcopy(init_model) + for key in init_model: + reconstructed_model[key] = flatten_updates[start_idx:start_idx+len(init_model[key].view(-1))].reshape(init_model[key].shape) + start_idx=start_idx+len(init_model[key].view(-1)) + return reconstructed_model + + def _extract_the_bn(self, model): + temp_model = copy.deepcopy(self.model.state_dict()) + model = self._reconstruct_updates(model) + bn_dict={} + for key in temp_model: + if 'bn' in key: + bn_dict[key] = model[key] + return bn_dict + + def _insert_the_bn(self, model_tensor, dict): + model = self._reconstruct_updates(model_tensor) + for key in dict: + model[key] = dict[key] + return self._flatten_updates(model) diff --git a/federatedscope/attack/byzantine_attacks/she_attack.py b/federatedscope/attack/byzantine_attacks/she_attack.py new file mode 100644 index 000000000..fd2fe2ec5 --- /dev/null +++ b/federatedscope/attack/byzantine_attacks/she_attack.py @@ -0,0 +1,282 @@ +import logging +import copy +import torch +import random +import numpy as np +from federatedscope.core.aggregators import ClientsAvgAggregator +from federatedscope.core.aggregators.krum_aggregator import KrumAggregator +from federatedscope.core.aggregators.median_aggregator import MedianAggregator +from federatedscope.core.aggregators.trimmedmean_aggregator import \ + TrimmedmeanAggregator +from federatedscope.core.aggregators.bulyan_aggregator import \ + BulyanAggregator + +logger = logging.getLogger(__name__) + +class She_adaptive_attacks(): + """ + Define the She_adaptive_attacks which can attack a specific aggregation rule with pertinence. + """ + + def __init__(self, model=None, device='cpu',config=None): + self.str2defense = {'krum': KrumAggregator(model,device,config)._para_avg_with_krum, + 'median': MedianAggregator(model,device,config)._aggre_with_median, + 'trmean': TrimmedmeanAggregator(model,device,config)._aggre_with_trimmedmean, + 'bulyan': BulyanAggregator(model,device,config)._aggre_with_bulyan} + self.byzantine_node_num = config.aggregator.byzantine_node_num + self.config = config + self.model = model + + + + def she_krum(self, models, dev_type = 'sign'): + """ + Please refer to "Manipulating the Byzantine: Optimizing Model Poisoning \ + Attacks and Defensesfor Federated Learning" + https://par.nsf.gov/servlets/purl/10286354 + """ + all_updates = torch.stack([each_model[0][1] for each_model in models]) + model_re = torch.mean(all_updates, 0) + model_byzantine = list() + mal_id = {} + index_ = [] + for index, model in enumerate(models): + if model[1] < self.byzantine_node_num + 1: + mal_id[model[1]] = (model[0][0],\ + self._extract_the_bn(model[0][1])) + model_byzantine.append(model[0]) + index_.append(index) + for i in range(len(index_) - 1, -1, -1): + models.pop(index_[i]) + if len(model_byzantine) == 0: + return models + else: + all_updates = torch.stack([each_model[0][1] for each_model in models]) + if dev_type == 'sign': + deviation = torch.sign(model_re) + elif dev_type == 'unit_vec': + deviation = model_re / torch.norm(model_re) + elif dev_type == 'std': + deviation = torch.std(all_updates, 0) + deviation = deviation + threshold_diff = 1e-5 + lamda = 20.0 + lamda_fail = lamda + lamda_succ = 0. + while np.abs(lamda_succ - lamda) > threshold_diff: + mal_update = (model_re - lamda * deviation) + for id in mal_id: + mal_update = self._insert_the_bn(mal_update, mal_id[id][1]) + models.append(((mal_id[id][0], mal_update), id)) + models = [((model[0][0], self._reconstruct_updates(model[0][1])), \ + model[1]) for model in models] + agg_res, mal_candidiate = self.str2defense['krum'](models) + models = [((model[0][0], self._flatten_updates(model[0][1])), model[1]) for model in models] + for i in range(len(model_byzantine)): + models.pop(-1) + pre_num = 0 + num = 0 + for id in mal_candidiate: + if id < len(model_byzantine) + 1: + num += 1 + if self.config.aggregator.BFT_args.krum_agg_num == 1: + if num == 1: + lamda_succ = lamda + lamda = lamda + lamda_fail / 2 + else: + lamda = lamda - lamda_fail / 2 + lamda_fail = lamda_fail / 2 + else: + if num > 0 and num >= pre_num: + pre_num = num + lamda_succ = lamda + lamda = lamda + lamda_fail / 2 + else: + lamda = lamda - lamda_fail / 2 + lamda_fail = lamda_fail / 2 + mal_update = (model_re - lamda_succ * deviation) + i=0 + for id in mal_id: + mal_update = self._insert_the_bn(mal_update, mal_id[id][1]) + models.append(((mal_id[id][0], mal_update),id)) + i+=1 + logger.info(f'the model length is {len(models)}') + # if num == self.config.aggregator.BFT_args.krum_agg_num: + # lamda_succ = lamda + # lamda = lamda + lamda_fail / 2 + # else: + # lamda = lamda - lamda_fail / 2 + # lamda_fail = lamda_fail / 2 + # mal_update = model_re - lamda_succ * deviation + # for id in mal_id: + # models.append(((mal_id[id], mal_update), id)) + return models + + + def she_median(self, models, dev_type= 'sign'): + """ + Please refer to "Manipulating the Byzantine: Optimizing Model Poisoning \ + Attacks and Defensesfor Federated Learning" + https://par.nsf.gov/servlets/purl/10286354 + """ + all_updates = torch.stack([each_model[0][1] for each_model in models]) + model_re = torch.mean(all_updates, 0) + threshold = 5.0 + threshold_diff = 1e-5 + model_byzantine = list() + mal_id = {} + index_ = [] + for index, model in enumerate(models): + if model[1] < self.byzantine_node_num + 1: + mal_id[model[1]] = (model[0][0],self._extract_the_bn(model[0][1])) + model_byzantine.append(model[0]) + index_.append(index) + for i in range(len(index_) - 1, -1, -1): + models.pop(index_[i]) + + if len(model_byzantine) == 0: + return models + else: + all_updates = torch.stack([each_model[0][1] for each_model in models]) + if dev_type == 'sign': + deviation = torch.sign(model_re) + elif dev_type == 'unit_vec': + deviation = model_re / torch.norm(model_re) + elif dev_type == 'std': + deviation = torch.std(all_updates, 0) + lamda = threshold + threshold_diff = threshold_diff + prev_loss = -1 + lamda_fail = lamda + lamda_succ = 0. + while np.abs(lamda_succ - lamda) > threshold_diff: + mal_update = model_re - lamda * deviation + for id in mal_id: + mal_update = self._insert_the_bn(mal_update, mal_id[id][1]) + models.append(((mal_id[id], mal_update), id)) + mal_updates = torch.stack([each_model[0][1] for each_model in models]) + for i in range(len(model_byzantine)): + models.pop(-1) + agg_grads = torch.median(mal_updates, 0)[0] + loss = torch.norm(agg_grads - model_re) + if prev_loss < loss: + lamda_succ = lamda + lamda = lamda + lamda_fail / 2 + else: + lamda = lamda - lamda_fail / 2 + lamda_fail = lamda_fail / 2 + prev_loss = loss + mal_update = model_re - lamda_succ * deviation + i = 0 + for id in mal_id: + mal_update = self._insert_the_bn(mal_update, mal_id[id][1]) + models.append(((mal_id[id], mal_update), id)) + i += 1 + logger.info(f'the model length is {len(models)}') + return models + + + def she_trimmedmean(self, models, dev_type= 'sign'): + """ + Please refer to "Manipulating the Byzantine: Optimizing Model Poisoning \ + Attacks and Defensesfor Federated Learning" + https://par.nsf.gov/servlets/purl/10286354 + """ + all_updates = torch.stack([each_model[0][1] for each_model in models]) + model_re = torch.mean(all_updates, 0) + threshold = 5.0 + threshold_diff = 1e-1 + model_byzantine = list() + mal_id = dict() + index_ = list() + for index, model in enumerate(models): + if model[1] < self.byzantine_node_num + 1: + mal_id[model[1]] = (model[0][0],self._extract_the_bn(model[0][1])) + model_byzantine.append(model[0]) + index_.append(index) + for i in range(len(index_) - 1, -1, -1): + models.pop(index_[i]) + + if len(model_byzantine) == 0: + return models + else: + all_updates = torch.stack([each_model[0][1] for each_model in models]) + if dev_type == 'sign': + deviation = torch.sign(model_re) + elif dev_type == 'unit_vec': + deviation = model_re / torch.norm(model_re) + elif dev_type == 'std': + deviation = torch.std(all_updates, 0) + lamda = threshold + threshold_diff = threshold_diff + prev_loss = -1 + lamda_fail = lamda + lamda_succ = 0 + while np.abs(lamda_succ - lamda) > threshold_diff: + mal_update = model_re - lamda * deviation + for id in mal_id: + mal_update = self._insert_the_bn(mal_update, mal_id[id][1]) + models.append(((mal_id[id][0], mal_update), id)) + for i in range(len(model_byzantine)): + models.pop(-1) + models = [((model[0][0], self._reconstruct_updates(model[0][1])), \ + model[1]) for model in models] + agg_grads,_ = self.str2defense['trmean'](models) + models = [((model[0][0], self._flatten_updates(model[0][1])), \ + model[1]) for model in models] + agg_grads = self._flatten_updates(agg_grads) + loss = torch.norm(agg_grads - model_re) + if prev_loss < loss: + lamda_succ = lamda + lamda = lamda + lamda_fail / 2 + else: + lamda = lamda - lamda_fail / 2 + lamda_fail = lamda_fail / 2 + prev_loss = loss + mal_update = (model_re - lamda_succ * deviation) + for id in mal_id: + mal_update = self._insert_the_bn(mal_update, mal_id[id][1]) + models.append(((mal_id[id], mal_update), id)) + return models + + + ##### methods to transform the model update and tensor #### + def _flatten_updates(self, model): + model_update=[] + init_model = self.model.state_dict() + for key in init_model: + model_update.append(model[key].view(-1)) + return torch.cat(model_update, dim = 0) + + def _flatten_updates_without_bn(self, model): + model_update=[] + init_model = self.model.state_dict() + for key in init_model: + if 'bn' not in key: + model_update.append(model[key].view(-1)) + return torch.cat(model_update, dim = 0) + + def _reconstruct_updates(self, flatten_updates): + start_idx = 0 + init_model = self.model.state_dict() + reconstructed_model = copy.deepcopy(init_model) + for key in init_model: + reconstructed_model[key] = flatten_updates[start_idx:start_idx\ + +len(init_model[key].view(-1))].reshape(init_model[key].shape) + start_idx=start_idx+len(init_model[key].view(-1)) + return reconstructed_model + + def _extract_the_bn(self, model): + temp_model = copy.deepcopy(self.model.state_dict()) + model = self._reconstruct_updates(model) + bn_dict={} + for key in temp_model: + if 'bn' in key: + bn_dict[key] = model[key] + return bn_dict + + def _insert_the_bn(self, model_tensor, dict): + model = self._reconstruct_updates(model_tensor) + for key in dict: + model[key] = dict[key] + return self._flatten_updates(model) diff --git a/federatedscope/attack/trainer/__init__.py b/federatedscope/attack/trainer/__init__.py index 9760e25d4..b90001fe5 100644 --- a/federatedscope/attack/trainer/__init__.py +++ b/federatedscope/attack/trainer/__init__.py @@ -4,6 +4,8 @@ from federatedscope.attack.trainer.backdoor_trainer import * from federatedscope.attack.trainer.benign_trainer import * from federatedscope.attack.trainer.gaussian_attack_trainer import * +from federatedscope.attack.trainer.label_flipping_trainer import * +from federatedscope.attack.trainer.malicious_SGA_attack_trainer import * __all__ = [ 'wrap_GANTrainer', 'hood_on_fit_start_generator', @@ -13,5 +15,6 @@ 'hook_on_fit_start_count_round', 'hook_on_batch_start_replace_data_batch', 'hook_on_batch_backward_invert_gradient', 'hook_on_fit_start_loss_on_target_data', 'wrap_backdoorTrainer', - 'wrap_benignTrainer', 'wrap_GaussianAttackTrainer' + 'wrap_benignTrainer', 'wrap_GaussianAttackTrainer','wrap_LabelFlippingTrainer', + 'wrap_SGAAttackTrainer' ] diff --git a/federatedscope/attack/trainer/gaussian_attack_trainer.py b/federatedscope/attack/trainer/gaussian_attack_trainer.py index 91467657d..57067e004 100644 --- a/federatedscope/attack/trainer/gaussian_attack_trainer.py +++ b/federatedscope/attack/trainer/gaussian_attack_trainer.py @@ -28,7 +28,6 @@ def wrap_GaussianAttackTrainer( def hook_on_batch_backward_generate_gaussian_noise_gradient(ctx): - ctx.optimizer.zero_grad() ctx.loss_task.backward() grad_values = list() @@ -37,7 +36,8 @@ def hook_on_batch_backward_generate_gaussian_noise_gradient(ctx): grad_values.append(param.grad.detach().cpu().view(-1)) grad_values = torch.cat(grad_values) - mean_for_gaussian_noise = torch.mean(grad_values) + 0.1 + # mean_for_gaussian_noise = torch.mean(grad_values) + 0.1 + mean_for_gaussian_noise = 0 std_for_gaussian_noise = torch.std(grad_values) for name, param in ctx.model.named_parameters(): diff --git a/federatedscope/attack/trainer/label_flipping_trainer.py b/federatedscope/attack/trainer/label_flipping_trainer.py new file mode 100644 index 000000000..eba277f7c --- /dev/null +++ b/federatedscope/attack/trainer/label_flipping_trainer.py @@ -0,0 +1,78 @@ +import logging +from typing import Type +import random + +import torch +from federatedscope.core.trainers.enums import MODE, LIFECYCLE +from federatedscope.core.trainers.context import CtxVar + +from federatedscope.core.trainers import GeneralTorchTrainer + +logger = logging.getLogger(__name__) + + +def wrap_LabelFlippingTrainer( + base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]: + ''' + wrap the label flipping trainer + + Args: + base_trainer: Type: core.trainers.GeneralTorchTrainer + :returns: + The wrapped trainer; Type: core.trainers.GeneralTorchTrainer + ''' + + base_trainer.replace_hook_in_train( + new_hook=_hook_on_batch_forward_with_flipped_labels, + target_trigger='on_batch_forward', + target_hook_name='_hook_on_batch_forward') + + return base_trainer + + +# def hook_on_batch_backward_generate_gaussian_noise_gradient(ctx): +# ctx.optimizer.zero_grad() +# ctx.loss_task.backward() + +# grad_values = list() +# for name, param in ctx.model.named_parameters(): +# if 'bn' not in name: +# grad_values.append(param.grad.detach().cpu().view(-1)) + +# grad_values = torch.cat(grad_values) +# # mean_for_gaussian_noise = torch.mean(grad_values) + 0.1 +# mean_for_gaussian_noise = 0 +# std_for_gaussian_noise = torch.std(grad_values) + +# for name, param in ctx.model.named_parameters(): +# if 'bn' not in name: +# generated_grad = torch.normal(mean=mean_for_gaussian_noise, +# std=std_for_gaussian_noise, +# size=param.grad.shape) +# param.grad = generated_grad.to(param.grad.device) + +# ctx.optimizer.step() + +def _hook_on_batch_forward_with_flipped_labels(ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.y_true`` Move to `ctx.device` + ``ctx.y_prob`` Forward propagation get y_prob + ``ctx.loss_batch`` Calculate the loss + ``ctx.batch_size`` Get the batch_size + ================================== =========================== + """ + x, label = [_.to(ctx.device) for _ in ctx.data_batch] + fake_label = ctx.cfg.model.out_channels - 1 - label + # fake_label = torch.tensor(torch.randint_like(label, 0, int(ctx.cfg.model.out_channels-1))) + pred = ctx.model(x) + if len(fake_label.size()) == 0: + fake_label = fake_label.unsqueeze(0) + ctx.y_true = CtxVar(fake_label, LIFECYCLE.BATCH) + ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH) + ctx.loss_batch = CtxVar(ctx.criterion(pred, fake_label), LIFECYCLE.BATCH) + ctx.batch_size = CtxVar(len(fake_label), LIFECYCLE.BATCH) \ No newline at end of file diff --git a/federatedscope/attack/trainer/malicious_SGA_attack_trainer.py b/federatedscope/attack/trainer/malicious_SGA_attack_trainer.py new file mode 100644 index 000000000..d738db6b5 --- /dev/null +++ b/federatedscope/attack/trainer/malicious_SGA_attack_trainer.py @@ -0,0 +1,46 @@ +import logging +from typing import Type + +import torch + +from federatedscope.core.trainers import GeneralTorchTrainer + +logger = logging.getLogger(__name__) + + +def wrap_SGAAttackTrainer( + base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]: + ''' + wrap the gaussian attack trainer + + Args: + base_trainer: Type: core.trainers.GeneralTorchTrainer + :returns: + The wrapped trainer; Type: core.trainers.GeneralTorchTrainer + ''' + + base_trainer.replace_hook_in_train( + new_hook=hook_on_batch_backward_generate_inverse_gradient, + target_trigger='on_batch_backward', + target_hook_name='_hook_on_batch_backward') + + return base_trainer + + +def hook_on_batch_backward_generate_inverse_gradient(ctx): + ctx.optimizer.zero_grad() + ctx.loss_task.backward() + # grad_values = list() + # for name, param in ctx.model.named_parameters(): + # if 'bn' not in name: + # grad_values.append(param.grad.detach().cpu().view(-1)) + + # grad_values = torch.cat(grad_values) + # mean_for_gaussian_noise = torch.mean(grad_values) + 0.1 + # mean_for_gaussian_noise = 0 + # std_for_gaussian_noise = torch.std(grad_values) + # for name, param in ctx.model.named_parameters(): + # if 'bn' not in name: + # generated_grad = param.grad + # param.grad = generated_grad.to(param.grad.device) + ctx.optimizer.step() diff --git a/federatedscope/core/aggregators/__init__.py b/federatedscope/core/aggregators/__init__.py index b0bc7e95a..45b2f559f 100644 --- a/federatedscope/core/aggregators/__init__.py +++ b/federatedscope/core/aggregators/__init__.py @@ -8,6 +8,15 @@ import ServerClientsInterpolateAggregator from federatedscope.core.aggregators.fedopt_aggregator import FedOptAggregator from federatedscope.core.aggregators.krum_aggregator import KrumAggregator +from federatedscope.core.aggregators.median_aggregator import MedianAggregator +from federatedscope.core.aggregators.trimmedmean_aggregator import \ + TrimmedmeanAggregator +from federatedscope.core.aggregators.bulyan_aggregator import \ + BulyanAggregator +from federatedscope.core.aggregators.normbounding_aggregator import \ + NormboundingAggregator +from federatedscope.core.aggregators.dynamic_defense_aggregator import \ + Weighted_sampled_robustAggregator __all__ = [ 'Aggregator', @@ -18,4 +27,9 @@ 'ServerClientsInterpolateAggregator', 'FedOptAggregator', 'KrumAggregator', -] + 'MedianAggregator', + 'TrimmedmeanAggregator', + 'BulyanAggregator', + 'NormboundingAggregator', + 'Weighted_sampled_robustAggregator' +] \ No newline at end of file diff --git a/federatedscope/core/aggregators/bulyan_aggregator.py b/federatedscope/core/aggregators/bulyan_aggregator.py new file mode 100644 index 000000000..756058222 --- /dev/null +++ b/federatedscope/core/aggregators/bulyan_aggregator.py @@ -0,0 +1,107 @@ +import copy +import torch +from federatedscope.core.aggregators import ClientsAvgAggregator + + +class BulyanAggregator(ClientsAvgAggregator): + """ + Implementation of Bulyan refers to `The Hidden Vulnerability + of Distributed Learning in Byzantium` + [Mhamdi et al., 2018] + (http://proceedings.mlr.press/v80/mhamdi18a/mhamdi18a.pdf) + It combines the MultiKrum aggregator and the treamedmean aggregator + """ + def __init__(self, model=None, device='cpu', config=None): + super(BulyanAggregator, self).__init__(model, device, config) + self.byzantine_node_num = config.aggregator.byzantine_node_num + self.sample_client_rate = config.federate.sample_client_rate + assert 4 * self.byzantine_node_num + 3 <= config.federate.client_num + + def aggregate(self, agg_info): + """ + To preform aggregation with Median aggregation rule + Arguments: + agg_info (dict): the feedbacks from clients + :returns: the aggregated results + :rtype: dict + """ + models = agg_info["client_feedback"] + models = [((model[0][0], self._flatten_updates(model[0][1])), \ + model[1]) for model in models] + avg_model = self._aggre_with_bulyan(models) + updated_model = copy.deepcopy(avg_model) + init_model = self.model.state_dict() + for key in avg_model: + updated_model[key] = init_model[key] + avg_model[key] + return updated_model + + def _calculate_distance(self, model_a, model_b): + """ + Calculate the Euclidean distance between two given model para delta + """ + distance = 0.0 + + for key in model_a: + if isinstance(model_a[key], torch.Tensor): + model_a[key] = model_a[key].float() + model_b[key] = model_b[key].float() + else: + model_a[key] = torch.FloatTensor(model_a[key]) + model_b[key] = torch.FloatTensor(model_b[key]) + + distance += torch.dist(model_a[key], model_b[key], p=2) + return distance + + def _calculate_score(self, models): + """ + Calculate Krum scores + """ + model_num = len(models) + closest_num = model_num - self.byzantine_node_num - 2 + + distance_matrix = torch.zeros(model_num, model_num) + for index_a in range(model_num): + for index_b in range(index_a, model_num): + if index_a == index_b: + distance_matrix[index_a, index_b] = float('inf') + else: + distance_matrix[index_a, index_b] = distance_matrix[ + index_b, index_a] = self._calculate_distance( + models[index_a], models[index_b]) + + sorted_distance = torch.sort(distance_matrix)[0] + krum_scores = torch.sum(sorted_distance[:, :closest_num], axis=-1) + return krum_scores + + def _aggre_with_bulyan(self, models): + ''' + Apply MultiKrum to select \theta (\theta <= client_num- + 2*self.byzantine_node_num) local models + ''' + init_model = self.model.state_dict() + global_update = copy.deepcopy(init_model) + models_para = [each_model[0][1] for each_model in models] + krum_scores = self._calculate_score(models_para) + index_order = torch.sort(krum_scores)[1].numpy() + reliable_models = list() + for number, index in enumerate(index_order): + if number < len(models) - int( + 2 * self.sample_client_rate * self.byzantine_node_num): + reliable_models.append(models[index]) + ''' + Sort parameter for each coordinate of the rest \theta reliable + local models, and find \gamma (gamma<\theta-2*self.byzantine_num) + parameters closest to the median to perform averaging + ''' + exluded_num = int(self.sample_client_rate * self.byzantine_node_num) + gamma = len(reliable_models) - 2 * exluded_num + for key in init_model: + temp = torch.stack( + [each_model[0][1][key] for each_model in reliable_models], 0) + pos_largest, _ = torch.topk(temp, exluded_num, 0) + neg_smallest, _ = torch.topk(-temp, exluded_num, 0) + new_stacked = torch.cat([temp, -pos_largest, + neg_smallest]).sum(0).float() + new_stacked /= gamma + global_update[key] = new_stacked + return global_update, _ \ No newline at end of file diff --git a/federatedscope/core/aggregators/clients_avg_aggregator.py b/federatedscope/core/aggregators/clients_avg_aggregator.py index 53aa953fa..3775174ff 100644 --- a/federatedscope/core/aggregators/clients_avg_aggregator.py +++ b/federatedscope/core/aggregators/clients_avg_aggregator.py @@ -2,6 +2,7 @@ import torch from federatedscope.core.aggregators import Aggregator from federatedscope.core.auxiliaries.utils import param2tensor +import copy class ClientsAvgAggregator(Aggregator): @@ -19,10 +20,8 @@ def __init__(self, model=None, device='cpu', config=None): def aggregate(self, agg_info): """ To preform aggregation - Arguments: agg_info (dict): the feedbacks from clients - Returns: dict: the aggregated results """ @@ -61,13 +60,15 @@ def _para_weighted_avg(self, models, recover_fun=None): """ Calculates the weighted average of models. """ + models = [model[0] for model in models] training_set_size = 0 for i in range(len(models)): sample_size, _ = models[i] training_set_size += sample_size - sample_size, avg_model = models[0] + sample_size, avg_model = copy.deepcopy(models[0]) for key in avg_model: + # if 'bn' not in key: for i in range(len(models)): local_sample_size, local_model = models[i] diff --git a/federatedscope/core/aggregators/dynamic_defense_aggregator.py b/federatedscope/core/aggregators/dynamic_defense_aggregator.py new file mode 100644 index 000000000..e8645e08a --- /dev/null +++ b/federatedscope/core/aggregators/dynamic_defense_aggregator.py @@ -0,0 +1,168 @@ +import logging +import copy +import torch +import random +import numpy as np +from federatedscope.core.aggregators import ClientsAvgAggregator +from federatedscope.core.aggregators.krum_aggregator import KrumAggregator +from federatedscope.core.aggregators.median_aggregator import MedianAggregator +from federatedscope.core.aggregators.trimmedmean_aggregator import \ + TrimmedmeanAggregator +from federatedscope.core.aggregators.bulyan_aggregator import \ + BulyanAggregator +from federatedscope.attack.byzantine_attacks.fang_attack import \ + Fang_adaptive_attacks +from federatedscope.attack.byzantine_attacks.she_attack import \ + She_adaptive_attacks + + +logger = logging.getLogger(__name__) + +torch.cuda.empty_cache() + + +class Weighted_sampled_robustAggregator(ClientsAvgAggregator): + """ + randomly sample a robust aggregator in each round of the FL course + """ + def __init__(self, model=None, device='cpu', config=None): + super(Weighted_sampled_robustAggregator, self).__init__(model, device, config) + self.byzantine_node_num = config.aggregator.byzantine_node_num + self.client_sampled_ratio = config.federate.sample_client_rate + self.excluded_ratio = config.aggregator.BFT_args.trimmedmean_excluded_ratio + self.candidate=config.aggregator.BFT_args.dynamic_candidate + self.config = config + self.str2attack = {'she_krum': She_adaptive_attacks(model,device,config).she_krum, + 'she_median': She_adaptive_attacks(model,device,config).she_median, + 'she_trimmedmean': She_adaptive_attacks(model,device,config).she_trimmedmean, + 'she_bulyan': She_adaptive_attacks(model,device,config).she_krum, + 'fang_krum': Fang_adaptive_attacks(model,device,config).fang_krum, + 'fang_median': Fang_adaptive_attacks(model,device,config).fang_median, + 'fang_trimmedmean': Fang_adaptive_attacks(model,device,config).fang_median, + 'fang_bulyan': Fang_adaptive_attacks(model,device,config).fang_krum} + self.str2defense = {'krum': KrumAggregator(model,device,config)._para_avg_with_krum, + 'median': MedianAggregator(model,device,config)._aggre_with_median, + 'trmean': TrimmedmeanAggregator(model,device,config)._aggre_with_trimmedmean, + 'bulyan': BulyanAggregator(model,device,config)._aggre_with_bulyan} + + if 'krum' in self.candidate: + assert 2 * self.byzantine_node_num + 2 < config.federate.client_num + if 'trmean' in self.candidate: + assert 2 * self.excluded_ratio < 1 + if 'bulyan' in self.candidate: + assert 4 * self.byzantine_node_num + 3 <= config.federate.client_num + + + + def aggregate(self, agg_info): + """ + To preform aggregation with a rule randomly sampled from a candidate set. + + Arguments: + agg_info (dict): the feedbacks from clients + :returns: the aggregated results + :rtype: dict + """ + models = agg_info["client_feedback"] + + ## simulate the Byzantine attacks + if self.config.aggregator.BFT_args.attack == True: + models = [((model[0][0], self._flatten_updates(model[0][1])), \ + model[1]) for model in models] + attack_method = self.config.aggregator.BFT_args.attack_method + logger.info(f'the attack {attack_method} is launching') + models = self.str2attack[attack_method](models) + models = [((model[0][0], self._reconstruct_updates(model[0][1])), \ + model[1]) for model in models] + + init_model = self.model.state_dict() + + if self.config.aggregator.BFT_args.dynamic_weighted == False: + ## uniformly sampling from the candidate set + self.rule_cur=random.choice(self.candidate) + logger.info(f'the sampled rule is {self.rule_cur}') + avg_model, _ = self.str2defense[self.rule_cur](models) + + else: + ## weighted sampling from the candidate set. each weight is determined by the angle between\ + # the local update and the global update + global_delta = agg_info["global_delta"] + global_delta_without_bn = self._flatten_updates_without_bn(global_delta) + avg_model_ = copy.deepcopy(init_model) + temp = [] + for rule in self.candidate: + logger.info(f'weighted sampling: aggregate with the rule of {rule}') + models_temp = copy.deepcopy(models) + avg_model_, _ = self.str2defense[rule](models_temp) + temp.append(copy.deepcopy(avg_model_)) + #############compute the angles and then sampling according to the prob + prob = [] + slice_global_delta = global_delta_without_bn + temp_ = [] + temp_list = copy.deepcopy(temp) + for model in temp_list: + temp_.append(self._flatten_updates_without_bn(model)) + TS = [torch.dot(tmp_delta,slice_global_delta)/(torch.linalg.norm(tmp_delta)* \ + torch.linalg.norm(slice_global_delta)) for tmp_delta in temp_] + for ele in TS: + if ele < 0.1: + ele = 0 + else: + ele = ele.cpu() + prob.append(ele) + if np.sum(prob) == 0: + prob = [1 for ele in TS] + index = random.choices([i for i in range(len(prob))], weights=prob,k=1) + logger.info(f'the sampling weights is {prob} \ + and the sampled rule is {self.candidate[int(index[0])]}') + avg_model = temp[int(index[0])] + + updated_model = copy.deepcopy(init_model) + for key in init_model: + updated_model[key] = init_model[key] + avg_model[key].cpu() + torch.cuda.empty_cache() + return updated_model + + + + + ##### methods to transform the model update and tensor #### + def _flatten_updates(self, model): + model_update=[] + init_model = self.model.state_dict() + for key in init_model: + model_update.append(model[key].view(-1)) + return torch.cat(model_update, dim = 0) + + def _flatten_updates_without_bn(self, model): + model_update=[] + init_model = self.model.state_dict() + for key in init_model: + if 'bn' not in key: + model_update.append(model[key].view(-1)) + return torch.cat(model_update, dim = 0) + + def _reconstruct_updates(self, flatten_updates): + start_idx = 0 + init_model = self.model.state_dict() + reconstructed_model = copy.deepcopy(init_model) + for key in init_model: + reconstructed_model[key] = flatten_updates[start_idx:start_idx+ \ + len(init_model[key].view(-1))].reshape(init_model[key].shape) + start_idx=start_idx+len(init_model[key].view(-1)) + return reconstructed_model + + def _extract_the_bn(self, model): + temp_model = copy.deepcopy(self.model.state_dict()) + model = self._reconstruct_updates(model) + bn_dict={} + for key in temp_model: + if 'bn' in key: + bn_dict[key] = model[key] + return bn_dict + + def _insert_the_bn(self, model_tensor, dict): + model = self._reconstruct_updates(model_tensor) + for key in dict: + model[key] = dict[key] + return self._flatten_updates(model) diff --git a/federatedscope/core/aggregators/krum_aggregator.py b/federatedscope/core/aggregators/krum_aggregator.py index e6206178a..8bbaff239 100644 --- a/federatedscope/core/aggregators/krum_aggregator.py +++ b/federatedscope/core/aggregators/krum_aggregator.py @@ -14,7 +14,7 @@ class KrumAggregator(ClientsAvgAggregator): def __init__(self, model=None, device='cpu', config=None): super(KrumAggregator, self).__init__(model, device, config) self.byzantine_node_num = config.aggregator.byzantine_node_num - self.krum_agg_num = config.aggregator.krum.agg_num + self.krum_agg_num = config.aggregator.BFT_args.krum_agg_num assert 2 * self.byzantine_node_num + 2 < config.federate.client_num, \ "it should be satisfied that 2*byzantine_node_num + 2 < client_num" @@ -28,7 +28,9 @@ def aggregate(self, agg_info): :rtype: dict """ models = agg_info["client_feedback"] - avg_model = self._para_avg_with_krum(models, agg_num=self.krum_agg_num) + models = [((model[0][0], self._flatten_updates(model[0][1])), \ + model[1]) for model in models] + avg_model, _= self._para_avg_with_krum(models, agg_num=self.krum_agg_num) # When using Krum/multi-Krum aggregation, the return feedback is model # delta rather than the model param @@ -76,15 +78,15 @@ def _calculate_score(self, models): krum_scores = torch.sum(sorted_distance[:, :closest_num], axis=-1) return krum_scores - def _para_avg_with_krum(self, models, agg_num=1): - + def _para_avg_with_krum(self, models): + this_round_id = [] # each_model: (sample_size, model_para) - models_para = [each_model[1] for each_model in models] + models_para = [each_model[0][1] for each_model in models] krum_scores = self._calculate_score(models_para) index_order = torch.sort(krum_scores)[1].numpy() reliable_models = list() for number, index in enumerate(index_order): - if number < agg_num: + if number < self.krum_agg_num: reliable_models.append(models[index]) - - return self._para_weighted_avg(models=reliable_models) + this_round_id.append(models[index][-1]) + return self._para_weighted_avg(models=reliable_models),this_round_id diff --git a/federatedscope/core/aggregators/median_aggregator.py b/federatedscope/core/aggregators/median_aggregator.py new file mode 100644 index 000000000..dcd293e70 --- /dev/null +++ b/federatedscope/core/aggregators/median_aggregator.py @@ -0,0 +1,52 @@ +import copy +import torch +import numpy as np +from federatedscope.core.aggregators import ClientsAvgAggregator +import logging + +logger = logging.getLogger(__name__) + + +class MedianAggregator(ClientsAvgAggregator): + """ + Implementation of median refers to `Byzantine-robust distributed + learning: Towards optimal statistical rates` + [Yin et al., 2018] + (http://proceedings.mlr.press/v80/yin18a/yin18a.pdf) + It computes the coordinate-wise median of recieved updates from clients + The code is adapted from https://github.com/bladesteam/blades + """ + def __init__(self, model=None, device='cpu', config=None): + super(MedianAggregator, self).__init__(model, device, config) + self.byzantine_node_num = config.aggregator.byzantine_node_num + assert 2 * self.byzantine_node_num + 2 < config.federate.client_num, \ + "it should be satisfied that 2*byzantine_node_num + 2 < client_num" + + def aggregate(self, agg_info): + """ + To preform aggregation with Median aggregation rule + Arguments: + agg_info (dict): the feedbacks from clients + :returns: the aggregated results + :rtype: dict + """ + models = agg_info["client_feedback"] + models = [((model[0][0], self._flatten_updates(model[0][1])), \ + model[1]) for model in models] + avg_model,_ = self._aggre_with_median(models) + updated_model = copy.deepcopy(avg_model) + init_model = self.model.state_dict() + for key in avg_model: + updated_model[key] = init_model[key] + avg_model[key] + return updated_model + + def _aggre_with_median(self, models): + init_model = self.model.state_dict() + global_update = copy.deepcopy(init_model) + for key in init_model: + temp = torch.stack([each_model[0][1][key] for each_model in models], + 0) + temp_pos, _ = torch.median(temp, dim=0) + temp_neg, _ = torch.median(-temp, dim=0) + global_update[key] = (temp_pos - temp_neg) / 2 + return global_update, _ \ No newline at end of file diff --git a/federatedscope/core/aggregators/normbounding_aggregator.py b/federatedscope/core/aggregators/normbounding_aggregator.py new file mode 100644 index 000000000..a42439a1d --- /dev/null +++ b/federatedscope/core/aggregators/normbounding_aggregator.py @@ -0,0 +1,64 @@ +import logging +import copy +import torch +import numpy as np +from federatedscope.core.aggregators import ClientsAvgAggregator + +logger = logging.getLogger(__name__) + + +class NormboundingAggregator(ClientsAvgAggregator): + """ + The server clips each update to reduce the negative impact \ + of malicious updates. + """ + def __init__(self, model=None, device='cpu', config=None): + super(NormboundingAggregator, self).__init__(model, device, config) + self.norm_bound = config.aggregator.BFT_args.normbounding_norm_bound + + def aggregate(self, agg_info): + """ + To preform aggregation with normbounding aggregation rule + Arguments: + agg_info (dict): the feedbacks from clients + :returns: the aggregated results + :rtype: dict + """ + models = agg_info["client_feedback"] + avg_model = self._aggre_with_normbounding(models) + updated_model = copy.deepcopy(avg_model) + init_model = self.model.state_dict() + for key in avg_model: + updated_model[key] = init_model[key] + avg_model[key] + return updated_model + + def _aggre_with_normbounding(self, models): + models_temp = [] + for each_model in models: + param = self._flatten_updates(each_model[1]) + if torch.norm(param, p=2) > self.norm_bound: + scaling_rate = self.norm_bound / torch.norm(param, p=2) + scaled_param = scaling_rate * param + models_temp.append( + (each_model[0], self._reconstruct_updates(scaled_param))) + else: + models_temp.append(each_model) + return self._para_weighted_avg(models_temp) + + def _flatten_updates(self, model): + model_update = [] + init_model = self.model.state_dict() + for key in init_model: + model_update.append(model[key].view(-1)) + return torch.cat(model_update, dim=0) + + def _reconstruct_updates(self, flatten_updates): + start_idx = 0 + init_model = self.model.state_dict() + reconstructed_model = copy.deepcopy(init_model) + for key in init_model: + reconstructed_model[key] = flatten_updates[ + start_idx:start_idx + len(init_model[key].view(-1))].reshape( + init_model[key].shape) + start_idx = start_idx + len(init_model[key].view(-1)) + return reconstructed_model \ No newline at end of file diff --git a/federatedscope/core/aggregators/trimmedmean_aggregator.py b/federatedscope/core/aggregators/trimmedmean_aggregator.py new file mode 100644 index 000000000..58e2be17e --- /dev/null +++ b/federatedscope/core/aggregators/trimmedmean_aggregator.py @@ -0,0 +1,58 @@ +import copy +import torch +import numpy as np +from federatedscope.core.aggregators import ClientsAvgAggregator +import logging + +logger = logging.getLogger(__name__) + + +class TrimmedmeanAggregator(ClientsAvgAggregator): + """ + Implementation of median refer to `Byzantine-robust distributed + learning: Towards optimal statistical rates` + [Yin et al., 2018] + (http://proceedings.mlr.press/v80/yin18a/yin18a.pdf) + The code is adapted from https://github.com/bladesteam/blades + """ + def __init__(self, model=None, device='cpu', config=None): + super(TrimmedmeanAggregator, self).__init__(model, device, config) + self.excluded_ratio = \ + config.aggregator.BFT_args.trimmedmean_excluded_ratio + self.byzantine_node_num = config.aggregator.byzantine_node_num + assert 2 * self.byzantine_node_num + 2 < config.federate.client_num, \ + "it should be satisfied that 2*byzantine_node_num + 2 < client_num" + assert self.excluded_ratio < 0.5 + + def aggregate(self, agg_info): + """ + To preform aggregation with trimmedmean aggregation rule + Arguments: + agg_info (dict): the feedbacks from clients + :returns: the aggregated results + :rtype: dict + """ + models = agg_info["client_feedback"] + models = [((model[0][0], self._flatten_updates(model[0][1])), \ + model[1]) for model in models] + avg_model = self._aggre_with_trimmedmean(models) + updated_model = copy.deepcopy(avg_model) + init_model = self.model.state_dict() + for key in avg_model: + updated_model[key] = init_model[key] + avg_model[key] + return updated_model + + def _aggre_with_trimmedmean(self, models): + init_model = self.model.state_dict() + global_update = copy.deepcopy(init_model) + excluded_num = int(len(models) * self.excluded_ratio) + for key in init_model: + temp = torch.stack([each_model[0][1][key] for each_model in models], + 0) + pos_largest, _ = torch.topk(temp, excluded_num, 0) + neg_smallest, _ = torch.topk(-temp, excluded_num, 0) + new_stacked = torch.cat([temp, -pos_largest, + neg_smallest]).sum(0).float() + new_stacked /= len(temp) - 2 * excluded_num + global_update[key] = new_stacked + return global_update, _ \ No newline at end of file diff --git a/federatedscope/core/auxiliaries/aggregator_builder.py b/federatedscope/core/auxiliaries/aggregator_builder.py index 9e506a5d3..d60f6c9de 100644 --- a/federatedscope/core/auxiliaries/aggregator_builder.py +++ b/federatedscope/core/auxiliaries/aggregator_builder.py @@ -58,7 +58,19 @@ def get_aggregator(method, model=None, device=None, online=False, config=None): from federatedscope.core.aggregators import ClientsAvgAggregator, \ OnlineClientsAvgAggregator, ServerClientsInterpolateAggregator, \ FedOptAggregator, NoCommunicationAggregator, \ - AsynClientsAvgAggregator, KrumAggregator + AsynClientsAvgAggregator, KrumAggregator, \ + MedianAggregator, TrimmedmeanAggregator, \ + BulyanAggregator, NormboundingAggregator,Weighted_sampled_robustAggregator + + STR2AGG = { + 'fedavg': ClientsAvgAggregator, + 'krum': KrumAggregator, + 'median': MedianAggregator, + 'bulyan': BulyanAggregator, + 'trimmedmean': TrimmedmeanAggregator, + 'normbounding': NormboundingAggregator, + 'dynamic':Weighted_sampled_robustAggregator + } if method.lower() in constants.AGGREGATOR_TYPE: aggregator_type = constants.AGGREGATOR_TYPE[method.lower()] @@ -87,12 +99,17 @@ def get_aggregator(method, model=None, device=None, online=False, config=None): return AsynClientsAvgAggregator(model=model, device=device, config=config) - elif config.aggregator.krum.use: - return KrumAggregator(model=model, device=device, config=config) else: - return ClientsAvgAggregator(model=model, - device=device, - config=config) + if config.aggregator.robust_rule not in STR2AGG: + logger.warning( + f'The specified {config.aggregator.robust_rule} aggregtion\ + rule has not been supported, the vanilla fedavg algorithm \ + will be used instead.') + return STR2AGG.get(config.aggregator.robust_rule, + ClientsAvgAggregator)(model=model, + device=device, + config=config) + elif aggregator_type == 'server_clients_interpolation': return ServerClientsInterpolateAggregator( model=model, @@ -105,4 +122,4 @@ def get_aggregator(method, model=None, device=None, online=False, config=None): config=config) else: raise NotImplementedError( - "Aggregator {} is not implemented.".format(aggregator_type)) + "Aggregator {} is not implemented.".format(aggregator_type)) \ No newline at end of file diff --git a/federatedscope/core/configs/cfg_aggregator.py b/federatedscope/core/configs/cfg_aggregator.py index adb4b2ee7..63c355bf3 100644 --- a/federatedscope/core/configs/cfg_aggregator.py +++ b/federatedscope/core/configs/cfg_aggregator.py @@ -11,11 +11,58 @@ def extend_aggregator_cfg(cfg): # ---------------------------------------------------------------------- # cfg.aggregator = CN() cfg.aggregator.byzantine_node_num = 0 + cfg.aggregator.client_sampled_ratio = 0.2 - # For krum/multi-krum Algos - cfg.aggregator.krum = CN() - cfg.aggregator.krum.use = False - cfg.aggregator.krum.agg_num = 1 + cfg.aggregator.robust_rule = 'fedavg' + cfg.aggregator.byzantine_node_num = 0 + cfg.aggregator.BFT_args = CN(new_allowed=True) + + # # For fedavg Algos + # cfg.aggregator.fedavg = CN() + # cfg.aggregator.fedavg.use = False + + # # For krum/multi-krum Algos + # cfg.aggregator.krum = CN() + # cfg.aggregator.krum.use = False + # cfg.aggregator.krum.agg_num = 1 + + # # For median Algos + # cfg.aggregator.median = CN() + # cfg.aggregator.median.use = False + + # # For trimmed_mean Algos + # cfg.aggregator.trimmedmean = CN() + # cfg.aggregator.trimmedmean.use = False + # cfg.aggregator.trimmedmean.excluded_ratio=0.1 + + # # For bulyan Algos + # cfg.aggregator.bulyan = CN() + # cfg.aggregator.bulyan.use = False + + + # For sampled robust aggregation Algos + # cfg.aggregator.sampled_robust_aggregator = CN() + # cfg.aggregator.sampled_robust_aggregator.use = False + # cfg.aggregator.sampled_robust_aggregator.krum_agg_num=1 + # cfg.aggregator.sampled_robust_aggregator.trimmedmean_excluded_ratio=0.1 + # cfg.aggregator.sampled_robust_aggregator.fltrust_global_learningrate=0.01 + # cfg.aggregator.sampled_robust_aggregator.candidate=['krum'] + + # For weighted sampled robust aggregation Algos + # cfg.aggregator.BFT_args.dynamic_uniform + # cfg.aggregator.weighted_sampled_robustaggregator = CN() + # cfg.aggregator.weighted_sampled_robustaggregator.use = False + # cfg.aggregator.weighted_sampled_robustaggregator.krum_agg_num=1 + # cfg.aggregator.weighted_sampled_robustaggregator.trimmedmean_excluded_ratio=0.1 + # cfg.aggregator.weighted_sampled_robustaggregator.fltrust_global_learningrate=0.01 + # cfg.aggregator.weighted_sampled_robustaggregator.candidate=['krum'] + # cfg.aggregator.weighted_sampled_robustaggregator.uniform=False + + + # For normbounding Algos + cfg.aggregator.normbounding = CN() + cfg.aggregator.normbounding.use = False + cfg.aggregator.normbounding.tau = 10.0 # For ATC method cfg.aggregator.num_agg_groups = 1 @@ -26,10 +73,9 @@ def extend_aggregator_cfg(cfg): # --------------- register corresponding check function ---------- cfg.register_cfg_check_fun(assert_aggregator_cfg) - def assert_aggregator_cfg(cfg): - if cfg.aggregator.byzantine_node_num == 0 and cfg.aggregator.krum.use: + if cfg.aggregator.byzantine_node_num == 0 and cfg.aggregator.robust_rule == 'krum': logging.warning('Although krum aggregtion rule is applied, we found ' 'that cfg.aggregator.byzantine_node_num == 0') diff --git a/federatedscope/core/workers/client.py b/federatedscope/core/workers/client.py index e17ff228d..78924e93a 100644 --- a/federatedscope/core/workers/client.py +++ b/federatedscope/core/workers/client.py @@ -5,7 +5,7 @@ from federatedscope.core.message import Message from federatedscope.core.communication import StandaloneCommManager, \ - StandaloneDDPCommManager, gRPCCommManager + gRPCCommManager from federatedscope.core.monitors.early_stopper import EarlyStopper from federatedscope.core.auxiliaries.trainer_builder import get_trainer from federatedscope.core.secret_sharing import AdditiveSecretSharing @@ -14,7 +14,6 @@ from federatedscope.core.workers.base_client import BaseClient logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) class Client(BaseClient): @@ -106,7 +105,6 @@ def __init__(self, config=self._cfg, is_attacker=self.is_attacker, monitor=self._monitor) - self.device = device # For client-side evaluation self.best_results = dict() @@ -151,12 +149,8 @@ def __init__(self, self.server_id = server_id if self.mode == 'standalone': comm_queue = kwargs['shared_comm_queue'] - if self._cfg.federate.process_num <= 1: - self.comm_manager = StandaloneCommManager( - comm_queue=comm_queue, monitor=self._monitor) - else: - self.comm_manager = StandaloneDDPCommManager( - comm_queue=comm_queue, monitor=self._monitor) + self.comm_manager = StandaloneCommManager(comm_queue=comm_queue, + monitor=self._monitor) self.local_address = None elif self.mode == 'distributed': host = kwargs['host'] @@ -196,6 +190,8 @@ def _calculate_model_delta(self, init_model, updated_model): for model_index in range(len(init_model)): model_delta = copy.deepcopy(init_model[model_index]) for key in init_model[model_index].keys(): + if key not in updated_model[model_index]: + continue model_delta[key] = updated_model[model_index][ key] - init_model[model_index][key] model_deltas.append(model_delta) @@ -229,13 +225,6 @@ def run(self): if msg.msg_type == 'finish': break - def run_standalone(self): - """ - Run in standalone mode - """ - self.join_in() - self.run() - def callback_funcs_for_model_para(self, message: Message): """ The handling function for receiving model parameters, \ @@ -295,9 +284,6 @@ def callback_funcs_for_model_para(self, message: Message): # ensure all the model params (which might be updated by other # clients in the previous local training process) are overwritten # and synchronized with the received model - if self._cfg.federate.process_num > 1: - for k, v in content.items(): - content[k] = v.to(self.device) self.trainer.update(content, strict=self._cfg.federate.share_local_model) self.state = round @@ -385,7 +371,9 @@ def callback_funcs_for_model_para(self, message: Message): self.msg_buffer['train'][self.state] = [(sample_size, content_frame)] else: - if self._cfg.asyn.use or self._cfg.aggregator.krum.use: + if self._cfg.asyn.use or self._cfg.aggregator.robust_rule in \ + ['krum', 'normbounding', 'median', 'trimmedmean', + 'bulyan','dynamic']: # Return the model delta when using asynchronous training # protocol, because the staled updated might be discounted # and cause that the sum of the aggregated weights might diff --git a/federatedscope/core/workers/server.py b/federatedscope/core/workers/server.py index 6d86d54d2..d1083772d 100644 --- a/federatedscope/core/workers/server.py +++ b/federatedscope/core/workers/server.py @@ -2,15 +2,15 @@ import copy import os import sys +import matplotlib.pyplot as plt import numpy as np import pickle -import time from federatedscope.core.monitors.early_stopper import EarlyStopper from federatedscope.core.message import Message from federatedscope.core.communication import StandaloneCommManager, \ - StandaloneDDPCommManager, gRPCCommManager + gRPCCommManager from federatedscope.core.auxiliaries.aggregator_builder import get_aggregator from federatedscope.core.auxiliaries.sampler_builder import get_sampler from federatedscope.core.auxiliaries.utils import merge_dict_of_results, \ @@ -20,7 +20,6 @@ from federatedscope.core.workers.base_server import BaseServer logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) class Server(BaseServer): @@ -90,8 +89,7 @@ def __init__(self, self._cfg.early_stop.improve_indicator_mode, self._monitor.the_larger_the_better) - if self._cfg.federate.share_local_model \ - and not self._cfg.federate.process_num > 1: + if self._cfg.federate.share_local_model: # put the model to the specified device model.to(device) # Build aggregator @@ -130,18 +128,20 @@ def __init__(self, shared_party_num=int(self._cfg.federate.sample_client_num) ).fixedpoint2float if self._cfg.federate.use_ss else None - if self._cfg.federate.make_global_eval: + if self._cfg.federate.make_global_eval \ + or self.aggregator.robust_rule == 'dynamic'\ + or self._cfg.attack.attack_method == 'pga_attack': # set up a trainer for conducting evaluation in server - assert self.models is not None + assert self.model is not None assert self.data is not None self.trainer = get_trainer( - model=self.models[0], + model=self.model, data=self.data, device=self.device, config=self._cfg, - only_for_eval=True, monitor=self._monitor - ) # the trainer is only used for global evaluation + ) + # the trainer is only used for global evaluation self.trainers = [self.trainer] if self.model_num > 1: # By default, the evaluation is conducted by calling @@ -151,6 +151,11 @@ def __init__(self, for _ in range(self.model_num - 1) ]) + if self._cfg.attack.attack_method == 'pga_attack' or\ + self._cfg.aggregator.robust_rule == 'dynamic': + from federatedscope.attack.trainer import wrap_SGAAttackTrainer + self.mal_sga_trainer = wrap_SGAAttackTrainer(copy.deepcopy(self.trainer)) + # Initialize the number of joined-in clients self._client_num = client_num self._total_round_num = total_round_num @@ -198,16 +203,9 @@ def __init__(self, self.msg_buffer = {'train': dict(), 'eval': dict()} self.staled_msg_buffer = list() if self.mode == 'standalone': - comm_queue = kwargs.get('shared_comm_queue', None) - if self._cfg.federate.process_num > 1: - id2comm = kwargs.get('id2comm', None) - self.comm_manager = StandaloneDDPCommManager( - comm_queue=comm_queue, - monitor=self._monitor, - id2comm=id2comm) - else: - self.comm_manager = StandaloneCommManager( - comm_queue=comm_queue, monitor=self._monitor) + comm_queue = kwargs['shared_comm_queue'] + self.comm_manager = StandaloneCommManager(comm_queue=comm_queue, + monitor=self._monitor) elif self.mode == 'distributed': host = kwargs['host'] port = kwargs['port'] @@ -332,6 +330,7 @@ def check_and_move_on(self, if not check_eval_result: # Receiving enough feedback in the training process aggregated_num = self._perform_federated_aggregation() + self.state += 1 if self.state % self._cfg.eval.freq == 0 and self.state != \ self.total_round_num: @@ -352,16 +351,14 @@ def check_and_move_on(self, # Start a new training round self._start_new_training_round(aggregated_num) else: + # plot the histogram of the aggregated clients # Final Evaluate logger.info('Server: Training is finished! Starting ' 'evaluation.') self.eval() - else: # Receiving enough feedback in the evaluation process self._merge_and_format_eval_results() - if self.state >= self.total_round_num: - self.is_finish = True else: move_on_flag = False @@ -373,7 +370,6 @@ def check_and_save(self): To save the results and save model after each evaluation, and check \ whether to early stop. """ - # early stopping if "Results_weighted_avg" in self.history_results and \ self._cfg.eval.best_res_update_round_wise_key in \ @@ -410,6 +406,20 @@ def check_and_save(self): if not self._cfg.federate.make_global_eval: self.save_client_eval_results() self.terminate(msg_type='finish') + # id_freq = list() + # for i in range(self._cfg.federate.client_num): + # i_freq=0 + # for li in self.aggregation_id_list: + # if i+1 in li: + # i_freq += 1 + # id_freq.append(i_freq) + # with open(os.path.join(self._cfg.outdir, "eval_results.log"),"a") as outfile: + # outfile.write(str(self.aggregation_id_list) + "\n") + # plt.bar([i+1 for i in range(self._cfg.federate.client_num)],id_freq, color='blue') + # path= os.path.join(self._cfg.outdir, "client_freq.jpg") + # plt.savefig(path) + # with open(os.path.join(self._cfg.outdir, "eval_results.log"),"a") as outfile: + # outfile.write(str(self.aggregation_rule) + "\n") # Clean the clients evaluation msg buffer if not self._cfg.federate.make_global_eval: @@ -433,13 +443,12 @@ def _perform_federated_aggregation(self): for client_id in train_msg_buffer.keys(): if self.model_num == 1: - msg_list.append(train_msg_buffer[client_id]) + msg_list.append((train_msg_buffer[client_id],client_id)) else: train_data_size, model_para_multiple = \ train_msg_buffer[client_id] msg_list.append( - (train_data_size, model_para_multiple[model_idx])) - + (train_data_size, model_para_multiple[model_idx],client_id)) # The staleness of the messages in train_msg_buffer # should be 0 staleness.append((client_id, 0)) @@ -456,7 +465,7 @@ def _perform_federated_aggregation(self): staleness.append((client_id, self.state - state)) # Trigger the monitor here (for training) - self._monitor.calc_model_metric(self.models[0].state_dict(), + self._monitor.calc_model_metric(self.model.state_dict(), msg_list, rnd=self.state) @@ -465,16 +474,24 @@ def _perform_federated_aggregation(self): agg_info = { 'client_feedback': msg_list, 'recover_fun': self.recover_fun, - 'staleness': staleness, + 'staleness': staleness } # logger.info(f'The staleness is {staleness}') + if self._cfg.aggregator.robust_rule == 'dynamic': + if self._cfg.aggregator.BFT_args.dynamic_weighted == True: + agg_info['global_delta'] = self._mal_global_trainer() + else: + agg_info['global_delta'] = 0 + + if self._cfg.attack.attack_method.lower() == 'pga_attack': + agg_info['malicious_delta'] = self._mal_global_trainer() result = aggregator.aggregate(agg_info) # Due to lazy load, we merge two state dict merged_param = merge_param_dict(model.state_dict().copy(), result) model.load_state_dict(merged_param, strict=False) - return aggregated_num + def _start_new_training_round(self, aggregated_num=0): """ The behaviors for starting a new training round @@ -664,7 +681,7 @@ def broadcast_model_para(self, model_para = [{} if skip_broadcast else model.state_dict() for model in self.models] else: - model_para = {} if skip_broadcast else self.models[0].state_dict() + model_para = {} if skip_broadcast else self.model.state_dict() # We define the evaluation happens at the end of an epoch rnd = self.state - 1 if msg_type == 'evaluate' else self.state @@ -689,7 +706,6 @@ def broadcast_client_address(self): To broadcast the communication addresses of clients (used for \ additive secret sharing) """ - self.comm_manager.send( Message(msg_type='address', sender=self.ID, @@ -781,7 +797,7 @@ def trigger_for_start(self): else: if self._cfg.backend == 'torch': model_size = sys.getsizeof(pickle.dumps( - self.models[0])) / 1024.0 * 8. + self.model)) / 1024.0 * 8. else: # TODO: calculate model size for TF Model model_size = 1.0 @@ -815,9 +831,6 @@ def trigger_for_start(self): logger.info( '----------- Starting training (Round #{:d}) -------------'. format(self.state)) - print( - time.strftime('%Y-%m-%d %H:%M:%S', - time.localtime(time.time()))) def trigger_for_feat_engr(self, trigger_train_func, @@ -851,7 +864,7 @@ def terminate(self, msg_type='finish'): if self.model_num > 1: model_para = [model.state_dict() for model in self.models] else: - model_para = self.models[0].state_dict() + model_para = self.model.state_dict() self._monitor.finish_fl() @@ -866,9 +879,8 @@ def terminate(self, msg_type='finish'): def eval(self): """ To conduct evaluation. When ``cfg.federate.make_global_eval=True``, \ - a global evaluation is conducted by the server. + a global evaluation is conducted by the server.w """ - if self._cfg.federate.make_global_eval: # By default, the evaluation is conducted one-by-one for all # internal models; @@ -901,6 +913,42 @@ def eval(self): self.broadcast_model_para(msg_type='evaluate', filter_unseen_clients=False) + def _calculate_model_delta(self, init_model, updated_model): + model_delta = copy.deepcopy(self.model.state_dict()) + tmp_ini = copy.deepcopy(init_model) + for key in model_delta: + if key not in updated_model: + model_delta[key] = init_model[key] - tmp_ini[key] + else: + model_delta[key] = updated_model[key] - init_model[key] + return model_delta + + def _global_trainer(self): + """ + The function is applied to conduct the global model training on a root dataset. \ + For now, this function is simply used to implement the Fltrust aggregator, which \ + is an advanced byzantine robust aggregator. + """ + temp_model = copy.deepcopy(self.model.state_dict()) + _, model_para_all, _ = self.trainer.train(target_data_split_name='train') + global_delta = self._calculate_model_delta(init_model=temp_model, updated_model=model_para_all) + for key in temp_model: + self.model.state_dict()[key] = temp_model[key] + return global_delta + + def _mal_global_trainer(self): + """ + The function is applied to conduct the model finetuning on a root dataset. \ + For now, this function is simply used to simulate the PGA attacker, which \ + should have been done by the malicious clients. + """ + temp_model = copy.deepcopy(self.model.state_dict()) + _, model_para_all, _ = self.mal_sga_trainer.train(target_data_split_name='val') + global_delta = self._calculate_model_delta(init_model=temp_model, updated_model=model_para_all) + self.model.load_state_dict(temp_model) + return global_delta + + def callback_funcs_model_para(self, message: Message): """ The handling function for receiving model parameters, which triggers \ @@ -945,7 +993,6 @@ def callback_funcs_model_para(self, message: Message): 'after_receiving': self.broadcast_model_para(msg_type='model_para', sample_client_num=1) - return move_on_flag def callback_funcs_for_join_in(self, message: Message): diff --git a/scripts/attack_exp_scripts/byzantine_attacks/she_attack_convnet2_femnist.yaml b/scripts/attack_exp_scripts/byzantine_attacks/she_attack_convnet2_femnist.yaml new file mode 100644 index 000000000..b032820fd --- /dev/null +++ b/scripts/attack_exp_scripts/byzantine_attacks/she_attack_convnet2_femnist.yaml @@ -0,0 +1,45 @@ +use_gpu: True +early_stop: + patience: 0 +seed: 12345 +federate: + mode: standalone + total_round_num: 50 + client_num: 200 + sample_client_rate: 0.2 + merge_test_data: True + make_global_eval: True + merge_val_data: True +data: + root: data/ + type: femnist + splits: [0.6,0.2,0.2] + subsample: 0.25 + transform: [['ToTensor'], ['Normalize', {'mean': [0.9637], 'std': [0.1592]}]] +dataloader: + batch_size: 10 +model: + type: convnet2 + hidden: 2048 + out_channels: 62 + dropout: 0.0 +train: + local_update_steps: 1 + batch_or_epoch: epoch + optimizer: + lr: 0.01 + weight_decay: 0.0 +grad: + grad_clip: 5.0 +criterion: + type: CrossEntropyLoss +trainer: + type: cvtrainer +eval: + freq: 10 + metrics: ['acc', 'correct'] + best_res_update_round_wise_key: test_acc +outdir: exp/ +expname: 0109_cao_attack_krum_femnist/ + + diff --git a/tests/test_she_attack.py b/tests/test_she_attack.py new file mode 100644 index 000000000..a48bb8871 --- /dev/null +++ b/tests/test_she_attack.py @@ -0,0 +1,110 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from federatedscope.core.auxiliaries.data_builder import get_data +from federatedscope.core.auxiliaries.utils import setup_seed +from federatedscope.core.auxiliaries.logging import update_logger +from federatedscope.core.configs.config import global_cfg +from federatedscope.core.auxiliaries.runner_builder import get_runner +from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls + + +class sampled_aggr_AlgoTest(unittest.TestCase): + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + def set_config_sample0(self, cfg): + backup_cfg = cfg.clone() + import torch + cfg.merge_from_file('scripts/attack_exp_scripts/byzantine_attacks/she_attack_convnet2_femnist.yaml') + cfg.device = 1 + cfg.federate.client_num = 50 + # attack + cfg.aggregator.byzantine_node_num = 5 + cfg.aggregator.BFT_args.attack = True + cfg.aggregator.BFT_args.attack_method = 'she_krum' + + # defense + cfg.aggregator.robust_rule = 'dynamic' + cfg.aggregator.BFT_args.dynamic_weighted = False + cfg.aggregator.BFT_args.dynamic_candidate=['krum'] + cfg.aggregator.BFT_args.krum_agg_num = 5 + cfg.aggregator.BFT_args.trimmedmean_excluded_ratio = 0.2 + cfg.eval.freq = 1 + cfg.outdir = 'test_attack/' + cfg.expname = 'she_attack_femnist/' + cfg.expname_tag = 'she_krum_krum' + return backup_cfg + + def set_config_sample1(self, cfg): + backup_cfg = cfg.clone() + import torch + cfg.merge_from_file('scripts/attack_exp_scripts/byzantine_attacks/she_attack_convnet2_femnist.yaml') + cfg.device = 1 + cfg.federate.client_num = 50 + # attack + cfg.aggregator.byzantine_node_num = 5 + cfg.aggregator.BFT_args.attack = True + cfg.aggregator.BFT_args.attack_method = 'she_krum' + + # defense + cfg.aggregator.robust_rule = 'dynamic' + cfg.aggregator.BFT_args.dynamic_weighted = False + cfg.aggregator.BFT_args.dynamic_candidate=['krum','median','bulyan','trmean'] + cfg.aggregator.BFT_args.krum_agg_num = 5 + cfg.aggregator.BFT_args.trimmedmean_excluded_ratio = 0.2 + cfg.eval.freq = 1 + cfg.outdir = 'test_attack/' + cfg.expname = 'she_attack_femnist/' + cfg.expname_tag = 'she_krum_dynamic' + return backup_cfg + + + def test_0_sample(self): + init_cfg = global_cfg.clone() + backup_cfg = self.set_config_sample0(init_cfg) + setup_seed(init_cfg.seed) + update_logger(init_cfg, True) + + data, modified_cfg = get_data(init_cfg.clone()) + init_cfg.merge_from_other_cfg(modified_cfg) + self.assertIsNotNone(data) + + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) + self.assertIsNotNone(Fed_runner) + test_best_results = Fed_runner.run() + print(test_best_results) + init_cfg.merge_from_other_cfg(backup_cfg) + self.assertLess( + test_best_results['client_summarized_weighted_avg']['test_acc'], + 0.1) + init_cfg.merge_from_other_cfg(backup_cfg) + + def test_1_sample(self): + init_cfg = global_cfg.clone() + backup_cfg = self.set_config_sample1(init_cfg) + setup_seed(init_cfg.seed) + update_logger(init_cfg, True) + + data, modified_cfg = get_data(init_cfg.clone()) + init_cfg.merge_from_other_cfg(modified_cfg) + self.assertIsNotNone(data) + + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) + self.assertIsNotNone(Fed_runner) + test_best_results = Fed_runner.run() + print(test_best_results) + init_cfg.merge_from_other_cfg(backup_cfg) + self.assertLess( + test_best_results['client_summarized_weighted_avg']['test_acc'], + 0.1) + init_cfg.merge_from_other_cfg(backup_cfg) + +if __name__ == '__main__': + unittest.main()