forked from facebookresearch/swav
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
180 lines (145 loc) · 4.86 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import argparse
from logging import getLogger
import pickle
import os
import numpy as np
import torch
import torch.nn as nn
from .logger import create_logger, PD_Stats
import torch.distributed as dist
FALSY_STRINGS = {"off", "false", "0"}
TRUTHY_STRINGS = {"on", "true", "1"}
logger = getLogger()
def bool_flag(s):
"""
Parse boolean arguments from the command line.
"""
if s.lower() in FALSY_STRINGS:
return False
elif s.lower() in TRUTHY_STRINGS:
return True
else:
raise argparse.ArgumentTypeError("invalid value for a boolean flag")
def init_distributed_mode(args):
"""
Initialize the following variables:
- world_size
- rank
"""
args.is_slurm_job = "SLURM_JOB_ID" in os.environ
if args.is_slurm_job:
args.rank = int(os.environ["SLURM_PROCID"])
args.world_size = int(os.environ["SLURM_NNODES"]) * int(
os.environ["SLURM_TASKS_PER_NODE"][0]
)
else:
# multi-GPU job (local or multi-node) - jobs started with torch.distributed.launch
# read environment variables
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
# prepare distributed
dist.init_process_group(
backend="nccl",
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
)
# set cuda device
args.gpu_to_work_on = args.rank % torch.cuda.device_count()
torch.cuda.set_device(args.gpu_to_work_on)
return
def initialize_exp(params, *args, dump_params=True):
"""
Initialize the experience:
- dump parameters
- create checkpoint repo
- create a logger
- create a panda object to keep track of the training statistics
"""
# dump parameters
if dump_params:
pickle.dump(params, open(os.path.join(params.dump_path, "params.pkl"), "wb"))
# create repo to store checkpoints
params.dump_checkpoints = os.path.join(params.dump_path, "checkpoints")
if not params.rank and not os.path.isdir(params.dump_checkpoints):
os.mkdir(params.dump_checkpoints)
# create a panda object to log loss and acc
training_stats = PD_Stats(
os.path.join(params.dump_path, "stats" + str(params.rank) + ".pkl"), args
)
# create a logger
logger = create_logger(
os.path.join(params.dump_path, "train.log"), rank=params.rank
)
logger.info("============ Initialized logger ============")
logger.info(
"\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(params)).items()))
)
logger.info("The experiment will be stored in %s\n" % params.dump_path)
logger.info("")
return logger, training_stats
def restart_from_checkpoint(ckp_paths, run_variables=None, **kwargs):
"""
Re-start from checkpoint
"""
# look for a checkpoint in exp repository
if isinstance(ckp_paths, list):
for ckp_path in ckp_paths:
if os.path.isfile(ckp_path):
break
else:
ckp_path = ckp_paths
if not os.path.isfile(ckp_path):
return
logger.info("Found checkpoint at {}".format(ckp_path))
# open checkpoint file
checkpoint = torch.load(
ckp_path, map_location="cuda:" + str(torch.distributed.get_rank() % torch.cuda.device_count())
)
# key is what to look for in the checkpoint file
# value is the object to load
# example: {'state_dict': model}
for key, value in kwargs.items():
if key in checkpoint and value is not None:
try:
msg = value.load_state_dict(checkpoint[key], strict=False)
print(msg)
except TypeError:
msg = value.load_state_dict(checkpoint[key])
logger.info("=> loaded {} from checkpoint '{}'".format(key, ckp_path))
else:
logger.warning(
"=> failed to load {} from checkpoint '{}'".format(key, ckp_path)
)
# re load variable important for the run
if run_variables is not None:
for var_name in run_variables:
if var_name in checkpoint:
run_variables[var_name] = checkpoint[var_name]
def fix_random_seeds(seed=31):
"""
Fix random seeds.
"""
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
class AverageMeter(object):
"""computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count