forked from pmuens/alphago
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_pg.py
37 lines (31 loc) · 1.19 KB
/
train_pg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import argparse
import h5py
from dlgo import agent
from dlgo import rl
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--learning-agent', required=True)
parser.add_argument('--agent-out', required=True)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--clipnorm', type=float, default=1.0)
parser.add_argument('--bs', type=int, default=512)
parser.add_argument('experience', nargs='+')
args = parser.parse_args()
learning_agent_filename = args.learning_agent
experience_files = args.experience
updated_agent_filename = args.agent_out
learning_rate = args.lr
clipnorm = args.clipnorm
batch_size = args.bs
learning_agent = agent.load_policy_agent(h5py.File(learning_agent_filename))
for exp_filename in experience_files:
exp_buffer = rl.load_experience(h5py.File(exp_filename))
learning_agent.train(
exp_buffer,
lr=learning_rate,
clipnorm=clipnorm,
batch_size=batch_size)
with h5py.File(updated_agent_filename, 'w') as updated_agent_outf:
learning_agent.serialize(updated_agent_outf)
if __name__ == '__main__':
main()