diff --git a/federatedscope/core/secret_sharing/__init__.py b/federatedscope/core/secret_sharing/__init__.py index 1cbf8b5d3..51e25d66a 100644 --- a/federatedscope/core/secret_sharing/__init__.py +++ b/federatedscope/core/secret_sharing/__init__.py @@ -1,2 +1,2 @@ from federatedscope.core.secret_sharing.secret_sharing import \ - AdditiveSecretSharing + AdditiveSecretSharing, MultiplicativeSecretSharing diff --git a/federatedscope/core/secret_sharing/secret_sharing.py b/federatedscope/core/secret_sharing/secret_sharing.py index 31fb99b79..26f247ea7 100644 --- a/federatedscope/core/secret_sharing/secret_sharing.py +++ b/federatedscope/core/secret_sharing/secret_sharing.py @@ -58,10 +58,10 @@ def secret_split(self, secret): secret = self.float2fixedpoint(secret) secret_seq = np.random.randint(low=0, high=self.mod_number, size=shape) - # last_seq = self.mod_funs(secret - self.mod_funs(np.sum(secret_seq, - # axis=0))) - last_seq = self.mod_funs(secret - - self.mod_funs(np.sum(secret_seq, axis=0))) + # last_seq = + # self.mod_funs(secret - self.mod_funs(np.sum(secret_seq, axis=0))) + last_seq = self.mod_funs( + secret - self.mod_funs(np.sum(secret_seq, axis=0))).astype(int) secret_seq = np.append(secret_seq, np.expand_dims(last_seq, axis=0), @@ -82,6 +82,14 @@ def secret_reconstruct(self, secret_seq): else: merge_model[key] += secret_seq[idx][key] merge_model[key] = self.fixedpoint2float(merge_model[key]) + # if merge_model is an ndarray or a list + else: + for idx in range(len(secret_seq)): + if idx == 0: + merge_model = secret_seq[idx].copy() + else: + merge_model += secret_seq[idx] + merge_model = self.fixedpoint2float(merge_model) return merge_model @@ -96,3 +104,111 @@ def _fixedpoint2float(self, x): return -1 * (self.mod_number - x) / self.epsilon else: return x / self.epsilon + + +class MultiplicativeSecretSharing(AdditiveSecretSharing): + """ + AdditiveSecretSharing class, which can split a number into frames and + recover it by summing up + """ + def __init__(self, shared_party_num, size=60): + super().__init__(shared_party_num, size) + self.maximum = 2**size + self.mod_number = 2 * self.maximum + 1 + self.epsilon = 1e8 + + def secret_split(self, secret, cls=None): + """ + To split the secret into frames according to the shared_party_num + """ + if isinstance(secret, dict): + secret_list = [dict() for _ in range(self.shared_party_num)] + for key in secret: + for idx, each in enumerate( + self.secret_split(secret[key], cls=cls)): + secret_list[idx][key] = each + return secret_list + + if isinstance(secret, list) or isinstance(secret, np.ndarray): + secret = np.asarray(secret).astype(int) + shape = [self.shared_party_num - 1] + list(secret.shape) + elif isinstance(secret, torch.Tensor): + secret = secret.numpy() + shape = [self.shared_party_num - 1] + list(secret.shape) + else: + shape = [self.shared_party_num - 1] + + if cls is None: + secret = self.float2fixedpoint(secret) + secret_seq = np.random.randint(low=0, high=self.mod_number, size=shape) + # last_seq = + # self.mod_funs(secret - self.mod_funs(np.sum(secret_seq, axis=0))) + last_seq = self.mod_funs( + secret - self.mod_funs(np.sum(secret_seq, axis=0))).astype(int) + + secret_seq = np.append(secret_seq, + np.expand_dims(last_seq, axis=0), + axis=0) + return secret_seq + + def secret_add_lists(self, args): + # args is a list + # whose last element is a list consisting of secret pieces + # TODO: add the condition that all elements in args are numbers + for i in range(len(args) - 1): + # if isinstance(args[i], int) or isinstance(args[i], np.int64): + if not isinstance(args[i], list) and not isinstance( + args[i], np.ndarray): + args[i] = [args[i]] * len(args[-1]) + return self.mod_funs(np.sum(args, axis=0)) + # TODO: in the future, when involve large numbers, numpy may overflow, + # thus, the following would work + # n = len(args[0]) + # num = len(args) + # res = [0] * n + # for i in range(n): + # for j in range(num): + # res[i] += args[j][i] + # res[i] = res[i] % self.mod_number + # return np.asarray(res) + + def secret_ndarray_star_ndarray(self, arr1, arr2): + # return a list whose i-th elements equals to + # the product of the i-th elements of arr1 and arr2 + # where arr1 and arr2 are both secret pieces + if isinstance(arr1, int) or isinstance(arr1, np.int64): + arr1 = [arr1] * len(arr2) + if isinstance(arr2, int) or isinstance(arr2, np.int64): + arr2 = [arr2] * len(arr1) + n = len(arr1) + res = [0] * n + for i in range(n): + res[i] = (arr1[i].item() * arr2[i].item()) % self.mod_number + return np.asarray(res) + + def beaver_triple(self, *args): + a = np.random.randint(0, self.mod_number, args).astype(int) + b = np.random.randint(0, self.mod_number, args).astype(int) + + a_list = [] + b_list = [] + c = [(a[i].item() * b[i].item()) % self.mod_number + for i in range(len(a))] + c_list = [] + for i in range(self.shared_party_num - 1): + a_tmp = np.random.randint(0, self.mod_number, args) + a_list.append(a_tmp) + a -= a_tmp + a = a % self.mod_number + b_tmp = np.random.randint(0, self.mod_number, args) + b_list.append(b_tmp) + b -= b_tmp + b = b % self.mod_number + c_tmp = np.random.randint(0, self.mod_number, args) + c_list.append(c_tmp) + c -= c_tmp + c = c % self.mod_number + a_list.append(a) + b_list.append(b) + c_list.append(c) + return a_list, b_list, c_list diff --git a/federatedscope/core/secret_sharing/ss_multiplicative_wrapper.py b/federatedscope/core/secret_sharing/ss_multiplicative_wrapper.py new file mode 100644 index 000000000..d2d76d05a --- /dev/null +++ b/federatedscope/core/secret_sharing/ss_multiplicative_wrapper.py @@ -0,0 +1,106 @@ +import types +import logging +from federatedscope.core.message import Message + +logger = logging.getLogger(__name__) + + +def wrap_client_for_ss_multiplicative(client): + # TODO: this only works when one of the arguments is a secret piece of + # the indicator vector which we do not make it to be a fixed point. + # For general cases, we should add a truncation step at the end. + def ss_multiplicative(self, + secret1, + secret2, + shared_party_num, + behavior=None): + self.secret1 = secret1 + self.secret2 = secret2 + self.behavior = behavior + self.shared_party_num = shared_party_num + self.pe_dict = dict() + self.pf_dict = dict() + self.res = None + if self.own_label: + self.comm_manager.send( + Message(msg_type='random_numbers', + sender=self.ID, + state=self.state, + receiver=[self.server_id], + content=(shared_party_num, len(secret2)))) + + def callback_fun_for_beaver_triplets(self, message: Message): + pa, pb, self.pc = message.content + pe = self.ss.secret_add_lists([self.secret1, -pa]) + pf = self.ss.secret_add_lists([self.secret2, -pb]) + + self.pe_dict[self.ID] = pe + self.pf_dict[self.ID] = pf + for i in range(self.shared_party_num): + if i + 1 != self.ID: + self.comm_manager.send( + Message(msg_type='part_e_and_f', + sender=self.ID, + state=self.state, + receiver=[i + 1], + content=(pe, pf))) + + def callback_func_for_part_e_and_f(self, message: Message): + pe, pf = message.content + self.pe_dict[message.sender] = pe + self.pf_dict[message.sender] = pf + if len(self.pe_dict) == self.shared_party_num: + e = self.ss.secret_add_lists([x for x in self.pe_dict.values()]) + f = self.ss.secret_add_lists([x for x in self.pf_dict.values()]) + self.pe_dict = {} + self.pf_dict = {} + t1 = self.ss.secret_ndarray_star_ndarray(f, self.secret1) + t2 = self.ss.secret_ndarray_star_ndarray(e, self.secret2) + if not self.own_label: + self.res = self.ss.secret_add_lists([t1, t2, self.pc]) + else: + t3 = self.ss.secret_ndarray_star_ndarray(e, f) + self.res = self.ss.secret_add_lists([-t3, t1, t2, self.pc]) + self.continue_next() + + def continue_next(self): + if self.behavior == 'left_child': + self.set_left_child() + elif self.behavior == 'right_child': + self.set_right_child() + elif self.behavior == 'weight': + self.set_weight() + + client.ss_multiplicative = types.MethodType(ss_multiplicative, client) + client.continue_next = types.MethodType(continue_next, client) + client.callback_fun_for_beaver_triplets = types.MethodType( + callback_fun_for_beaver_triplets, client) + client.callback_fun_for_part_e_and_f = types.MethodType( + callback_func_for_part_e_and_f, client) + + client.register_handlers('beaver_triplets', + client.callback_fun_for_beaver_triplets) + client.register_handlers('part_e_and_f', + client.callback_fun_for_part_e_and_f) + + return client + + +def wrap_server_for_ss_multiplicative(server): + def callback_func_for_random_numbers(self, message: Message): + shared_party_num, size = message.content + a_list, b_list, c_list = self.ss.beaver_triple(size) + for i in range(shared_party_num): + self.comm_manager.send( + Message(msg_type='beaver_triplets', + sender=self.ID, + receiver=[i + 1], + state=self.state, + content=(a_list[i], b_list[i], c_list[i]))) + + server.callback_func_for_random_numbers = types.MethodType( + callback_func_for_random_numbers, server) + server.register_handlers('random_numbers', + server.callback_func_for_random_numbers) + + return server diff --git a/federatedscope/vertical_fl/README.md b/federatedscope/vertical_fl/README.md index b9c4530c9..894100ec5 100644 --- a/federatedscope/vertical_fl/README.md +++ b/federatedscope/vertical_fl/README.md @@ -116,7 +116,7 @@ For label-scattering model, we provide privacy protection algorithms proposed by ``` vertical: - mode: 'label_based' + mode: 'label_scattering' protect_object: 'grad_and_hess' protect_method: 'he' key_size: ks diff --git a/federatedscope/vertical_fl/tree_based_models/baseline/xgb_feature_gathering_on_adult_by_ss_eval.yaml b/federatedscope/vertical_fl/tree_based_models/baseline/xgb_feature_gathering_on_adult_by_ss_eval.yaml new file mode 100644 index 000000000..a06f64165 --- /dev/null +++ b/federatedscope/vertical_fl/tree_based_models/baseline/xgb_feature_gathering_on_adult_by_ss_eval.yaml @@ -0,0 +1,36 @@ +use_gpu: False +device: 0 +backend: torch +federate: + mode: standalone + client_num: 2 +model: + type: xgb_tree + lambda_: 0.1 + gamma: 0 + num_of_trees: 10 + max_tree_depth: 3 +data: + root: data/ + type: adult + splits: [1.0, 0.0] +dataloader: + type: raw + batch_size: 2000 +criterion: + type: CrossEntropyLoss +trainer: + type: verticaltrainer +train: + optimizer: + # learning rate for xgb model + eta: 0.5 +vertical: + use: True + dims: [7, 14] + algo: 'xgb' + eval_protection: 'ss' + data_size_for_debug: 2000 +eval: + freq: 3 + best_res_update_round_wise_key: test_loss \ No newline at end of file diff --git a/federatedscope/vertical_fl/tree_based_models/worker/TreeClient.py b/federatedscope/vertical_fl/tree_based_models/worker/TreeClient.py index 7c4234f09..434c68c4b 100644 --- a/federatedscope/vertical_fl/tree_based_models/worker/TreeClient.py +++ b/federatedscope/vertical_fl/tree_based_models/worker/TreeClient.py @@ -6,6 +6,7 @@ from federatedscope.core.message import Message from federatedscope.vertical_fl.Paillier import \ abstract_paillier +from federatedscope.core.secret_sharing import MultiplicativeSecretSharing logger = logging.getLogger(__name__) @@ -36,6 +37,9 @@ def __init__(self, keys = abstract_paillier.generate_paillier_keypair( n_length=self._cfg.vertical.key_size) self.public_key, self.private_key = keys + elif self._cfg.vertical.eval_protection == 'ss': + self.ss = MultiplicativeSecretSharing( + shared_party_num=self.client_num) self.feature_order = None self.merged_feature_order = None diff --git a/federatedscope/vertical_fl/tree_based_models/worker/TreeServer.py b/federatedscope/vertical_fl/tree_based_models/worker/TreeServer.py index ddbc1d59a..1f0b6d928 100644 --- a/federatedscope/vertical_fl/tree_based_models/worker/TreeServer.py +++ b/federatedscope/vertical_fl/tree_based_models/worker/TreeServer.py @@ -2,6 +2,7 @@ from federatedscope.core.workers import Server from federatedscope.core.message import Message +from federatedscope.core.secret_sharing import MultiplicativeSecretSharing import logging @@ -29,6 +30,10 @@ def __init__(self, self.total_num_of_feature = self._cfg.vertical.dims[-1] self._init_data_related_var() + if self._cfg.vertical.eval_protection == 'ss': + self.ss = MultiplicativeSecretSharing( + shared_party_num=self.client_num) + def _init_data_related_var(self): pass diff --git a/federatedscope/vertical_fl/tree_based_models/worker/ss_evaluation_wrapper.py b/federatedscope/vertical_fl/tree_based_models/worker/ss_evaluation_wrapper.py new file mode 100644 index 000000000..0a7f2d3b0 --- /dev/null +++ b/federatedscope/vertical_fl/tree_based_models/worker/ss_evaluation_wrapper.py @@ -0,0 +1,306 @@ +import types +import logging +from abc import abstractmethod + +import numpy as np + +from federatedscope.vertical_fl.loss.utils import get_vertical_loss +from federatedscope.core.message import Message + +logger = logging.getLogger(__name__) + + +def wrap_client_for_ss_evaluation(client): + def eval(self, tree_num): + self.pe_dict = dict() + self.pf_dict = dict() + self.plf_dict = dict() + self.prf_dict = dict() + + self.criterion = get_vertical_loss(loss_type=self._cfg.criterion.type, + model_type=self._cfg.model.type) + if self.test_x is None: + self.test_x, self.test_y = self._fetch_test_data() + self.merged_test_result = list() + self.test_result_dict = dict() + self.test_result_piece_list = list() + + indicator_piece_list = self.ss.secret_split(np.ones( + self.test_x.shape[0]), + cls='ss_piece') + + self.model[tree_num][0].indicator = indicator_piece_list[-1] + self.tree_num = tree_num + self.node_num = 0 + for i in range(self.client_num - 1): + self.comm_manager.send( + Message(msg_type='indicator_piece', + sender=self.ID, + state=self.state, + receiver=[i + 1], + content=(tree_num, 0, indicator_piece_list[i]))) + self._test_for_node(tree_num, node_num=0) + + def _fetch_test_data(self): + test_x = self.data['test']['x'] + test_y = self.data['test']['y'] if 'y' in self.data['test'] else None + + return test_x, test_y + + def callback_func_for_indicator_piece(self, message: Message): + tree_num, node_num, indicator_piece = message.content + self.tree_num = tree_num + self.node_num = node_num + self.model[tree_num][node_num].indicator = indicator_piece + self.test_result_piece_list = list() + + def _feedback_eval_metrics(self): + test_loss = self.criterion.get_loss(self.test_y, + self.merged_test_result) + metrics = self.criterion.get_metric(self.test_y, + self.merged_test_result) + modified_metrics = dict() + for key in metrics.keys(): + if 'test' not in key: + modified_metrics['test_' + key] = metrics[key] + else: + modified_metrics[key] = metrics[key] + modified_metrics.update({ + 'test_loss': test_loss, + 'test_total': len(self.test_y) + }) + + self.comm_manager.send( + Message(msg_type='eval_metric', + sender=self.ID, + state=self.state, + receiver=[self.server_id], + content=modified_metrics)) + self.comm_manager.send( + Message(msg_type='feature_importance', + sender=self.ID, + state=self.state, + receiver=[self.server_id], + content=self.feature_importance)) + self.comm_manager.send( + Message(msg_type='ask_for_feature_importance', + sender=self.ID, + state=self.state, + receiver=[ + each + for each in list(self.comm_manager.neighbors.keys()) + if each != self.server_id + ], + content='None')) + + def _test_for_node(self, tree_num, node_num): + # All nodes have been traversed + if node_num >= 2**self.model.max_depth - 1: + part_result = self.ss.secret_add_lists(self.test_result_piece_list) + self.test_result_dict[self.ID] = part_result + self.comm_manager.send( + Message( + msg_type='ask_for_part_result', + sender=self.ID, + state=self.state, + receiver=[ + each + for each in list(self.comm_manager.neighbors.keys()) + if each != self.server_id + ], + content=(tree_num, node_num))) + # The client owns the weight + elif self.model[tree_num][node_num].weight: + if self._cfg.model.type in ['xgb_tree', 'gbdt_tree']: + eta = self._cfg.train.optimizer.eta + else: + eta = 1.0 + + weight = self.model[tree_num][node_num].weight * eta + weight_piece_list = self.ss.secret_split(weight) + weight_piece_list = weight_piece_list + self.model[tree_num][node_num].weight_piece = weight_piece_list[-1] + for i in range(self.client_num - 1): + self.comm_manager.send( + Message(msg_type='weight_piece', + sender=self.ID, + state=self.state, + receiver=[i + 1], + content=(tree_num, node_num, + weight_piece_list[i]))) + + self.ss_multiplicative(self.model[tree_num][node_num].weight_piece, + self.model[tree_num][node_num].indicator, + self.client_num, 'weight') + # Other client owns the weight, need to communicate + elif self.model[tree_num][node_num].member: + send_message = Message( + msg_type='split_request', + sender=self.ID, + state=self.state, + receiver=[self.model[tree_num][node_num].member], + content=(tree_num, node_num)) + if self.model[tree_num][node_num].member == self.ID: + self.callback_func_for_split_request(send_message) + else: + self.comm_manager.send(send_message) + + else: + self._test_for_node(tree_num, node_num + 1) + + def callback_func_for_ask_for_part_result(self, message: Message): + tree_num, node_num = message.content + sender = message.sender + part_result = self.ss.secret_add_lists(self.test_result_piece_list) + self.comm_manager.send( + Message(msg_type='part_result', + sender=self.ID, + state=self.state, + receiver=[sender], + content=(tree_num, node_num, part_result))) + + def callback_func_for_part_result(self, message: Message): + tree_num, node_num, part_result = message.content + self.test_result_dict[message.sender] = part_result + if len(self.test_result_dict) == self.client_num: + result = self.ss.secret_reconstruct( + list(self.test_result_dict.values())) + self.merged_test_result.append(result) + if ( + tree_num + 1 + ) % self._cfg.eval.freq == 0 or \ + tree_num + 1 == self._cfg.model.num_of_trees: + self._feedback_eval_metrics() + self.eval_finish_flag = True + self._check_eval_finish(tree_num) + + def callback_func_for_weight_piece(self, message: Message): + tree_num, node_num, weight_piece = message.content + self.model[tree_num][node_num].weight_piece = weight_piece + self.ss_multiplicative(self.model[tree_num][node_num].weight_piece, + self.model[tree_num][node_num].indicator, + self.client_num, 'weight') + + def set_weight(self): + self.test_result_piece_list.append(self.res) + self.node_num += 1 + if self.own_label: + self._test_for_node(self.tree_num, self.node_num) + + def callback_func_for_split_request(self, message: Message): + if self.test_x is None: + self.test_x, self.test_y = self._fetch_test_data() + tree_num, node_num = message.content + feature_idx = self.model[tree_num][node_num].feature_idx + feature_value = self.model[tree_num][node_num].feature_value + left_child, right_child = self.model[tree_num].split_childern( + self.test_x[:, feature_idx], feature_value) + + left_child_piece_list = self.ss.secret_split(left_child, + cls='ss_piece') + right_child_piece_list = self.ss.secret_split(right_child, + cls='ss_piece') + self.left_child_piece = left_child_piece_list[self.ID - 1] + self.right_child_piece = right_child_piece_list[self.ID - 1] + for i in range(self.client_num): + if i + 1 != self.ID: + self.comm_manager.send( + Message(msg_type='split_result', + sender=self.ID, + state=self.state, + receiver=[i + 1], + content=(tree_num, node_num, + left_child_piece_list[i], + right_child_piece_list[i]))) + self.ss_multiplicative(self.model[tree_num][node_num].indicator, + self.left_child_piece, self.client_num, + 'left_child') + + @abstractmethod + def ss_multiplicative(self, + secret1, + secret2, + shared_party_num, + behavior=None): + pass + + def callback_func_for_split_result(self, message: Message): + tree_num, node_num, left_child_piece, right_child_piece \ + = message.content + self.left_child_piece = left_child_piece + self.right_child_piece = right_child_piece + self.ss_multiplicative(self.model[tree_num][node_num].indicator, + self.left_child_piece, self.client_num, + 'left_child') + + def set_left_child(self): + self.model[self.tree_num][2 * self.node_num + 1].indicator = self.res + + self.ss_multiplicative( + self.model[self.tree_num][self.node_num].indicator, + self.right_child_piece, self.client_num, 'right_child') + + def set_right_child(self): + self.model[self.tree_num][2 * self.node_num + 2].indicator = self.res + self.node_num += 1 + if self.own_label: + self._test_for_node(self.tree_num, self.node_num) + + def callback_func_for_feature_importance(self, message: Message): + state = message.state + self.comm_manager.send( + Message(msg_type='feature_importance', + sender=self.ID, + state=state, + receiver=[self.server_id], + content=self.feature_importance)) + + # Bind method to instance + client.eval = types.MethodType(eval, client) + client._fetch_test_data = types.MethodType(_fetch_test_data, client) + client._test_for_node = types.MethodType(_test_for_node, client) + client._feedback_eval_metrics = types.MethodType(_feedback_eval_metrics, + client) + + # client.ss_multiplicative = types.MethodType(ss_multiplicative, client) + client.set_left_child = types.MethodType(set_left_child, client) + client.set_right_child = types.MethodType(set_right_child, client) + client.set_weight = types.MethodType(set_weight, client) + + client.callback_func_for_indicator_piece = types.MethodType( + callback_func_for_indicator_piece, client) + client.callback_func_for_weight_piece = types.MethodType( + callback_func_for_weight_piece, client) + client.callback_func_for_ask_for_part_result = types.MethodType( + callback_func_for_ask_for_part_result, client) + client.callback_func_for_part_result = types.MethodType( + callback_func_for_part_result, client) + + client.callback_func_for_split_request = types.MethodType( + callback_func_for_split_request, client) + client.callback_func_for_split_result = types.MethodType( + callback_func_for_split_result, client) + client.callback_func_for_feature_importance = types.MethodType( + callback_func_for_feature_importance, client) + + # Register handler functions + client.register_handlers('split_request', + client.callback_func_for_split_request) + client.register_handlers('split_result', + client.callback_func_for_split_result) + client.register_handlers('ask_for_feature_importance', + client.callback_func_for_feature_importance) + + client.register_handlers('indicator_piece', + client.callback_func_for_indicator_piece) + client.register_handlers('weight_piece', + client.callback_func_for_weight_piece) + client.register_handlers('ask_for_part_result', + client.callback_func_for_ask_for_part_result) + client.register_handlers('part_result', + client.callback_func_for_part_result) + return client + + +def wrap_server_for_ss_evaluation(server): + return server diff --git a/federatedscope/vertical_fl/utils.py b/federatedscope/vertical_fl/utils.py index 47eb4df92..2080a1feb 100644 --- a/federatedscope/vertical_fl/utils.py +++ b/federatedscope/vertical_fl/utils.py @@ -3,12 +3,19 @@ wrap_client_for_evaluation, wrap_server_for_evaluation from federatedscope.vertical_fl.tree_based_models.worker.he_evaluation_wrapper\ import wrap_client_for_he_evaluation +from federatedscope.vertical_fl.tree_based_models.worker.ss_evaluation_wrapper\ + import wrap_client_for_ss_evaluation, wrap_server_for_ss_evaluation +from federatedscope.core.secret_sharing.ss_multiplicative_wrapper import \ + wrap_server_for_ss_multiplicative, wrap_client_for_ss_multiplicative def wrap_vertical_server(server, config): if config.vertical.algo in ['xgb', 'gbdt', 'rf']: server = wrap_server_for_train(server) server = wrap_server_for_evaluation(server) + if config.vertical.eval_protection == 'ss': + server = wrap_server_for_ss_evaluation(server) + server = wrap_server_for_ss_multiplicative(server) return server @@ -17,6 +24,9 @@ def wrap_vertical_client(client, config): if config.vertical.algo in ['xgb', 'gbdt', 'rf']: if config.vertical.eval_protection == 'he': client = wrap_client_for_he_evaluation(client) + elif config.vertical.eval_protection == 'ss': + client = wrap_client_for_ss_evaluation(client) + client = wrap_client_for_ss_multiplicative(client) else: client = wrap_client_for_evaluation(client) client = wrap_client_for_train(client) diff --git a/tests/test_tree_based_model_for_vfl.py b/tests/test_tree_based_model_for_vfl.py index 16276c579..afb7d3b6d 100644 --- a/tests/test_tree_based_model_for_vfl.py +++ b/tests/test_tree_based_model_for_vfl.py @@ -86,6 +86,43 @@ def set_config_for_he_eval(self, cfg): return backup_cfg + def set_config_for_he_eval(self, cfg): + backup_cfg = cfg.clone() + + import torch + cfg.use_gpu = torch.cuda.is_available() + + cfg.federate.mode = 'standalone' + cfg.federate.client_num = 2 + + cfg.model.type = 'xgb_tree' + cfg.model.lambda_ = 0.1 + cfg.model.gamma = 0 + cfg.model.num_of_trees = 10 + cfg.model.max_tree_depth = 3 + + cfg.train.optimizer.eta = 0.5 + + cfg.data.root = 'test_data/' + cfg.data.type = 'adult' + + cfg.dataloader.type = 'raw' + cfg.dataloader.batch_size = 2000 + + cfg.criterion.type = 'CrossEntropyLoss' + + cfg.vertical.use = True + cfg.vertical.dims = [7, 14] + cfg.vertical.algo = 'xgb' + cfg.vertical.data_size_for_debug = 2000 + cfg.vertical.eval = 'ss' + + cfg.trainer.type = 'verticaltrainer' + cfg.eval.freq = 5 + cfg.eval.best_res_update_round_wise_key = "test_loss" + + return backup_cfg + def set_config_for_gbdt_base(self, cfg): backup_cfg = cfg.clone() @@ -429,6 +466,27 @@ def test_XGB_Base_for_he_eval(self): self.assertGreater(test_results['server_global_eval']['test_acc'], 0.79) + def test_XGB_Base_for_ss_eval(self): + init_cfg = global_cfg.clone() + backup_cfg = self.set_config_for_xgb_base(init_cfg) + setup_seed(init_cfg.seed) + update_logger(init_cfg, True) + + data, modified_config = get_data(init_cfg.clone()) + init_cfg.merge_from_other_cfg(modified_config) + self.assertIsNotNone(data) + + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) + self.assertIsNotNone(Fed_runner) + test_results = Fed_runner.run() + init_cfg.merge_from_other_cfg(backup_cfg) + print(test_results) + self.assertGreater(test_results['server_global_eval']['test_acc'], + 0.79) + def test_GBDT_Base(self): init_cfg = global_cfg.clone() backup_cfg = self.set_config_for_gbdt_base(init_cfg)