Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Minqi824 committed Jun 13, 2022
1 parent ca2a975 commit 716cd9b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 20 deletions.
34 changes: 15 additions & 19 deletions myutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,32 @@
# statistical analysis
from scipy.stats import wilcoxon

'''
整合一些常用的函数
'''
class Utils():
def __init__(self):
pass

# remove randomness,固定结果
# remove randomness
def set_seed(self, seed):
# os.environ['PYTHONHASHSEED'] = str(seed)
# os.environ['TF_CUDNN_DETERMINISTIC'] = 'true'
# os.environ['TF_DETERMINISTIC_OPS'] = 'true'

#basic seed
# basic seed
np.random.seed(seed)
random.seed(seed)

#tensorflow seed
# tensorflow seed
try:
tf.random.set_seed(seed) # for tf >= 2.0
except:
tf.set_random_seed(seed)
tf.random.set_random_seed(seed)

#pytorch seed
# pytorch seed
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# 检测是否有GPU
def get_device(self):
# if torch.cuda.is_available():
# n_gpu = torch.cuda.device_count()
Expand All @@ -58,7 +54,7 @@ def get_device(self):
device = torch.device("cpu")
return device

#根据两个实数生成唯一的实数
# generate unique value
def unique(self, a, b):
u = 0.5 * (a + b) * (a + b + 1) + b
return int(u)
Expand All @@ -72,14 +68,14 @@ def data_description(self, X, y):

print(des_dict)

#返回异常检测中常用的两个metric结果: AUC-ROC以及AUC-PR
# metric
def metric(self, y_true, y_score, pos_label=1):
aucroc = roc_auc_score(y_true=y_true, y_score=y_score)
aucpr = average_precision_score(y_true=y_true, y_score=y_score, pos_label=1)

return {'aucroc':aucroc, 'aucpr':aucpr}

#重采样函数
# resampling function
def sampler(self, X_train, y_train, batch_size):
index_u = np.where(y_train == 0)[0]
index_a = np.where(y_train == 1)[0]
Expand Down Expand Up @@ -147,23 +143,23 @@ def sampler_pairs(self, X_train_tensor, y_train, epoch, batch_num, batch_size, s
for i in range(batch_num): # i.e., drop_last = True
index = []

# 分别是(a,a); (a,u); (u,u)共6部分样本
# pairs of (a,a); (a,u); (u,u)
for j in range(6):
# generate unique seed and set seed
# seed = self.unique(epoch, i)
# seed = self.unique(seed, j)
# self.set_seed(seed)

if j < 3: # 其中batch size // 4与原论文中一致
if j < 3:
index_sub = np.random.choice(index_a, batch_size // 4, replace=True)
index.append(list(index_sub))

if j == 3:
index_sub = np.random.choice(index_u, batch_size // 4, replace=True) # unlabel部分可以变为False
index_sub = np.random.choice(index_u, batch_size // 4, replace=True)
index.append(list(index_sub))

if j > 3:
index_sub = np.random.choice(index_u, batch_size // 2, replace=True) # unlabel部分可以变为False
index_sub = np.random.choice(index_u, batch_size // 2, replace=True)
index.append(list(index_sub))

# index[0] + index[1] = (a,a), batch / 4
Expand All @@ -189,12 +185,12 @@ def sampler_pairs(self, X_train_tensor, y_train, epoch, batch_num, batch_size, s
y_train_new = y_train_new[index_shuffle]

# save
data_loader_X.append([X_train_tensor_left, X_train_tensor_right]) # 注意left和right顺序
data_loader_X.append([X_train_tensor_left, X_train_tensor_right])
data_loader_y.append(y_train_new)

return data_loader_X, data_loader_y

#返回梯度
# gradient norm
def grad_norm(self, grad_tuple):

grad = torch.tensor([0.0])
Expand Down Expand Up @@ -224,7 +220,7 @@ def plot_grad_flow(self, named_parameters):
# # Compute the first Wasserstein distance between two 1D distributions.
# return (torch_cdf_loss(tensor_a, tensor_b, p=1))

#Calculate the First Wasserstein Distance
# Calculate the First Wasserstein Distance
def torch_cdf_loss(self, tensor_a, tensor_b, p=1):
# last-dimension is weight distribution
# p is the norm of the distance, p=1 --> First Wasserstein Distance
Expand All @@ -249,7 +245,7 @@ def torch_cdf_loss(self, tensor_a, tensor_b, p=1):
cdf_loss = cdf_distance.mean()
return cdf_loss

#Calculate the loss like devnet in PyTorch
# Calculate the loss like devnet in PyTorch
def cal_loss(self, y, y_pred, mode='devnet'):
if mode == 'devnet':
y_pred.squeeze_()
Expand Down
1 change: 0 additions & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def dataset_filter(self):
dataset_list_org = [os.path.splitext(_)[0] for _ in os.listdir(os.path.join(os.getcwd(), 'datasets'))
if os.path.splitext(_)[1] != '']

# 将不符合标准的数据集筛除
dataset_list = []
dataset_size = []

Expand Down

0 comments on commit 716cd9b

Please sign in to comment.