forked from pancetta/sdc-gym
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrl_test.py
129 lines (105 loc) · 3.83 KB
/
rl_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import argparse
import datetime
import json
from pathlib import Path
from rl_playground import run_tests
import utils
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'path',
type=str,
help=(
'Arguments JSON file to load, filename with '
'timestamp right before extension, or timestamp.'
),
)
parser.add_argument(
'--model_path',
type=str,
default=None,
help='Model file to load. Has priority over `path`.',
)
parser.add_argument(
'--tests',
type=float,
default=None,
help=(
'Number of test runs for each preconditioning method. '
'Defaults to the number of tests for the loaded model.'
),
)
parser.add_argument(
'--store_stats',
type=utils.parse_bool,
default=False,
help=(
'Whether to store statistics from the reinforcement learning test.'
),
)
return parser.parse_args()
def _script_start_from_arg(arg):
file_extension_start_index = arg.rfind('.')
file_extension_start_rindex = len(arg) - file_extension_start_index
if file_extension_start_rindex == 7:
# We don't have a file name but only the script start time
return arg[-26:]
return arg[-file_extension_start_rindex-26:-file_extension_start_rindex]
def get_prev_args(test_args):
args_path = Path(test_args.path)
if not args_path.exists() or not args_path.suffix.lower() == '.json':
prev_script_start = _script_start_from_arg(test_args.path)
args_path = Path(f'args_{prev_script_start}.json')
with open(args_path, 'r') as f:
args = json.load(f)
args = argparse.Namespace(**args)
if not hasattr(args, 'script_start'):
args.script_start = prev_script_start
return args
def model_fname_from_args(test_args, args):
if test_args.model_path is not None:
fname = Path(test_args.model_path)
assert fname.exists(), f"checkpoint file {fname} does not exist"
return fname
fname = Path(f'sdc_model_{args.model_class.lower()}_'
f'{args.policy_class.lower()}_'
f'{args.script_start}.zip')
if not fname.exists():
learning_rate = utils.compute_learning_rate(args)
fname = Path(f'sdc_model_{args.model_class.lower()}_'
f'{args.policy_class.lower()}_{learning_rate}_'
f'{args.script_start}.zip')
assert fname.exists(), (
"checkpoint file could not be determined from arguments; "
"please use the `--model_path` argument.")
return fname
def main():
script_start = str(datetime.datetime.now()
).replace(':', '-').replace(' ', 'T')
test_args = parse_args()
test_args.script_start = script_start
args = get_prev_args(test_args)
# Only save when we were able to load the previous arguments
args_path = Path(f'test_args_{script_start}.json')
with open(args_path, 'w') as f:
json.dump(vars(test_args), f, indent=4)
utils.setup(args.use_sb3, args.debug_nans)
seed = args.seed
eval_seed = seed
if eval_seed is not None:
eval_seed += args.num_envs
policy_class = utils.get_policy_class(args.policy_class, args.model_class)
utils.check_num_envs(args, policy_class)
fname = model_fname_from_args(test_args, args)
# ---------------- TESTING STARTS HERE ----------------
fig_path = Path(f'test_results_{script_start}.pdf')
if test_args.store_stats:
stats_path = Path(f'test_stats_{script_start}.npz')
else:
stats_path = None
if test_args.tests is not None:
args.tests = test_args.tests
run_tests(fname, args, seed=eval_seed, fig_path=fig_path,
stats_path=stats_path)
if __name__ == '__main__':
main()