forked from StanfordVL/mini_behavior
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_APT.py
61 lines (45 loc) · 1.52 KB
/
train_APT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gym
import os
from mini_behavior.register import register
from algorithms.APT_PPO import APT_PPO
from env_wrapper import CustomObservationWrapper
TASK = 'MultiToy'
PARTIAL_OBS = True
ROOM_SIZE = 16
MAX_STEPS = 1000
TOTAL_TIMESTEPS = 1e6
DENSE_REWARD = False
POLICY_TYPE = 'CnnPolicy'
NUM_ENVS = 8
NUM_STEPS = 125
SAVE_FREQUENCY = 100
env_name = f"MiniGrid-{TASK}-{ROOM_SIZE}x{ROOM_SIZE}-N2-v0"
env_kwargs = {"room_size": ROOM_SIZE, "max_steps": MAX_STEPS}
def make_env(env_id, seed, idx, env_kwargs):
def thunk():
env = gym.make(env_id, **env_kwargs)
env = CustomObservationWrapper(env)
env.seed(seed + idx)
return env
return thunk
def init_env(num_envs: int, seed):
return gym.vector.SyncVectorEnv(
[make_env(env_name, seed, i, env_kwargs) for i in range(num_envs)]
)
if __name__ == "__main__":
register(
id=env_name,
entry_point=f'mini_behavior.envs:{TASK}Env',
kwargs=env_kwargs
)
env = init_env(NUM_ENVS, seed = 1)
print('begin training')
# Policy training
model = APT_PPO(env, env_id = env_name, num_envs=NUM_ENVS, total_timesteps = TOTAL_TIMESTEPS, num_steps=NUM_STEPS, save_freq = SAVE_FREQUENCY)
model.train()
# Define the directory path
save_dir = "models/APT_PPO_MultiToy"
# Check if the directory exists, and if not, create it
if not os.path.exists(save_dir):
os.makedirs(save_dir)
model.save(f"{save_dir}/{env_name}", env_kwargs = env_kwargs)