-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlunar_lander_dqn.py
45 lines (36 loc) · 1.24 KB
/
lunar_lander_dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import gym
from stable_baselines import DQN
from stable_baselines.common.vec_env import VecVideoRecorder, DummyVecEnv
import os.path
from os import path
#env_id = 'LunarLander-v2'
#video_folder = 'videos/'
env = gym.make('LunarLander-v2')
if path.isfile("l_lander_dqn.zip"):
#env = DummyVecEnv([lambda: gym.make(env_id)])
#env = VecVideoRecorder(env, video_folder,
# record_video_trigger=lambda x: x == 0, video_length=1000,
# name_prefix="lunar_lander_testing_agent")
# Evaluation stage
model = DQN.load("l_lander_dqn")
obs = env.reset()
for i in range(1000):
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
#env.close()
else:
# Learning stage
model = DQN('MlpPolicy', env, learning_rate=1e-3, prioritized_replay=True, verbose=1)
# Train the agent
model.learn(total_timesteps=int(5e5))
model.save("l_lander_dqn")
del model
# Load the trained agent
model = DQN.load("l_lander_dqn")
# Evaluation environment
obs = env.reset()
for i in range(1000):
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()