-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils_hydra.py
86 lines (69 loc) · 2.53 KB
/
utils_hydra.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from torch.nn import functional as F
import torch
import transformers
import numpy as np
from hydra import compose, initialize
from main import make_eval_model, make_task_sampler, wandb_init
import omegaconf
import hydra
import time
# File containing imports to be used in hydra configs/for instantiating hydra
# objects outside main
def initialize_cfg(
config_path="cfgs",
hydra_overrides: dict = {},
log_yaml: bool = False,
):
with initialize(version_base=None, config_path=config_path,
job_name="test_app"):
cfg = compose(config_name="config",
overrides=hydra_overrides,
)
if log_yaml:
print('Loading the following configurations:')
print(omegaconf.OmegaConf.to_yaml(cfg))
return cfg
def load_run_cfgs_trainer(
run_file_path_location: str,
hydra_overrides_from_dict: dict = dict(
wandb_log="false",
wandb_project="scratch",
),
task_sampler_kwargs: dict = dict(
tasks=["lb/passage_retrieval_en"],
),
config_path="cfgs",
batch_size: int = 1,
):
print(f"Loading model specified in {run_file_path_location}")
start_time = time.time()
hydra_overrides = [
f"run@_global_={run_file_path_location}",
]
hydra_overrides_from_dict = [f"{k}={v}" for k, v in
hydra_overrides_from_dict.items()]
hydra_overrides = hydra_overrides + hydra_overrides_from_dict
cfg = initialize_cfg(
config_path=config_path,
hydra_overrides=hydra_overrides,
)
(memory_policy, memory_model, memory_evaluator, evolution_algorithm,
auxiliary_loss) = make_eval_model(cfg=cfg)
task_sampler = make_task_sampler(cfg=cfg, **task_sampler_kwargs)
trainer = hydra.utils.instantiate(
cfg.trainer,
evaluation_model=memory_evaluator,
task_sampler=task_sampler,
evolution_algorithm=evolution_algorithm,
auxiliary_loss=auxiliary_loss,
)
params, buffers = trainer.sample_and_synchronize_params(best=True)
memory_model = memory_model
memory_model.set_memory_params(params=params)
if memory_model.memory_policy_has_buffers_to_merge():
memory_model.load_buffers_dict(buffers_dict=buffers)
batch_idxs = np.zeros([batch_size])
memory_policy.set_params_batch_idxs(batch_idxs)
print("Time taken:", round(time.time() - start_time))
return trainer # contains all other models