-
Notifications
You must be signed in to change notification settings - Fork 0
/
mcts_hallway_test.py
154 lines (135 loc) · 4.91 KB
/
mcts_hallway_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import numpy as np
import matplotlib.pyplot as plt
import multiprocessing as mp
from utils import *
import config
from mdp import *
from mdp.policies import *
import mcts
import linalg
import time
writetodisk = True
root = 'data/mcts_hallway'
nodes = 50
action_n = 3
batch_size = 2
batch = True
if batch:
num_start_states = 5*(mp.cpu_count()-1)*batch_size
else:
num_start_states = batch_size
horizon = 50
rebuild_all = True
build_problem = False
build_mdp = False
run_solver = False
# Build problem
prob_build_fn = lambda: config.mdp.make_hallway_problem(nodes)
prob_file = root + '.prob.pickle'
problem = do_or_load(prob_build_fn,
prob_file,
rebuild_all or build_problem,
'problem build')
# Generate MDP and discretizer
actions = np.array([-1,0,1]).reshape((3,1))
mdp_build_fn = lambda: config.mdp.make_trivial_mdp(problem,
nodes,
actions)
mdp_file = root + '.mdp.pickle'
(mdp_obj,disc) = do_or_load(mdp_build_fn,
mdp_file,
rebuild_all or build_mdp,
'mdp build')
# Solve with Kojima
solve_fn = lambda: config.solver.solve_with_kojima(mdp_obj,
1e-8,1000)
soln_file = root + 'soln.pickle'
(p,d) = do_or_load(solve_fn,
soln_file,
rebuild_all or run_solver,
'solver')
# Build value function
print 'Building value function'
(v,flow) = split_solution(mdp_obj,p)
v_fn = InterpolatedFunction(disc,v)
dump(v_fn,root + '.vfn.pickle')
#######################
# Build policies
print 'Building policies'
policy_dict = {}
v_pert = v + 0.3*np.random.randn(v.size)
v_pert_fn = InterpolatedFunction(disc,v_pert)
q = q_vectors(mdp_obj,v_pert)
q_fns = build_functions(mdp_obj,disc,q)
policy_dict['q pert'] = IndexPolicyWrapper(MinFunPolicy(q_fns),
mdp_obj.actions)
print 'Building policies'
q = q_vectors(mdp_obj,v)
q_fns = build_functions(mdp_obj,disc,q)
policy_dict['q'] = IndexPolicyWrapper(MinFunPolicy(q_fns),
mdp_obj.actions)
policy_dict['hand'] = HallwayPolicy(nodes)
rollout_policy = EpsilonFuzzedPolicy(3,
0.2,
HallwayPolicy(nodes))
policy_dict['hand_pert'] = IndexPolicyWrapper(rollout_policy,
mdp_obj.actions)
pert_flow = np.maximum(0,flow + 5*np.random.randn(*flow.shape))
pert_flow_fns = build_functions(mdp_obj,disc,pert_flow)
initial_prob = probs.FunctionProbability(pert_flow_fns)
pert_flow_policy = MaxFunPolicy(pert_flow_fns)
policy_dict['flow_pert'] = IndexPolicyWrapper(pert_flow_policy,
mdp_obj.actions)
for budget in xrange(10,250,10):
for rollout in [5]:
for prob_scale in [2]:
name = 'mcts_{0}_{1}_{2}'.format(budget,
rollout,
prob_scale)
policy_dict[name] = mcts.MCTSPolicy(problem,
mdp_obj.actions,
rollout_policy,
initial_prob,
v_pert_fn,
rollout,
prob_scale,
budget)
dump(policy_dict,root + '.policies.pickle')
"""
start_states = np.random.randint(nodes,
size=(num_start_states,1))
"""
start_states = 10*np.ones((num_start_states,1))
dump(start_states,root + '.starts.pickle')
# Simulate
print 'Simulating'
results = {}
start = time.time()
for (name,policy) in policy_dict.items():
print '\tRunning {0} jobs'.format(name)
if batch:
result = batch_simulate(problem,
policy,
start_states,
horizon,
batch_size)
else:
result = simulate(problem,
policy,
start_states,
horizon)
assert((num_start_states,horizon) == result.costs.shape)
results[name] = result
dump({name:result},root+'.'+name+'.result.pickle')
print '**Multithread total', time.time() - start
dump(results,root + '.results.pickle')
# Return
returns = {}
for (name,result) in results.items():
returns[name] = discounted_return(result.costs,
problem.discount)
assert((num_start_states,) == returns[name].shape)
dump(returns,root + '.return.pickle')
# V
vals = v_fn.evaluate(start_states)
dump(vals,root + '.vals.pickle')