-
Notifications
You must be signed in to change notification settings - Fork 13
/
ray_train_dqn.py
100 lines (84 loc) · 2.72 KB
/
ray_train_dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import argparse
import ray
from ray import tune, air, train
from ray.tune.registry import register_env
from env_creator import qsimpy_env_creator
from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.utils.framework import try_import_tf
from ray.tune.analysis import ExperimentAnalysis
import os
tf1, tf, tfv = try_import_tf()
parser = argparse.ArgumentParser()
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument(
"--framework",
choices=["tf", "tf2", "torch"],
default="torch",
help="The DL framework specifier.",
)
parser.add_argument(
"--stop-iters", type=int, default=100, help="Number of iterations to train."
)
parser.add_argument(
"--stop-timesteps", type=int, default=100000, help="Number of timesteps to train."
)
if __name__ == "__main__":
args = parser.parse_args()
# ray.init(num_cpus=args.num_cpus or None)
register_env("QSimPyEnv", qsimpy_env_creator)
replay_config = {
"type": "MultiAgentPrioritizedReplayBuffer",
"capacity": 60000,
"prioritized_replay_alpha": 0.5,
"prioritized_replay_beta": 0.5,
"prioritized_replay_eps": 3e-6,
}
config = (
DQNConfig()
.framework(framework=args.framework)
.environment(
env="QSimPyEnv",
env_config={
"obs_filter": "rescale_-1_1",
"reward_filter": None,
"dataset": "qdataset/qsimpyds_1000_sub_26.csv",
},
)
.training(
lr=tune.grid_search([0.01]),
train_batch_size=tune.grid_search([78]),
replay_buffer_config=replay_config,
num_atoms=tune.grid_search(
[
10
]
),
n_step=tune.grid_search([5]),
noisy=True,
v_min=-10.0,
v_max=10.0,
)
.rollouts(num_rollout_workers=8)
)
stop_config = {
"timesteps_total": args.stop_timesteps,
"training_iteration": args.stop_iters,
}
# Get the absolute path of the current directory
current_directory = os.getcwd()
# Append the "result" folder to the current directory path
result_directory = os.path.join(current_directory, "results")
# Create the storage_path with the "file://" scheme
storage_path = f"file://{result_directory}"
results = tune.Tuner(
"DQN",
run_config=air.RunConfig(
stop=stop_config,
# Save checkpoints every 10 iterations.
checkpoint_config=train.CheckpointConfig(checkpoint_frequency=10),
storage_path=storage_path,
name="DQN_QCE_1000"
),
param_space=config.to_dict(),
).fit()
ray.shutdown()