-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathagent.py
129 lines (102 loc) · 5.5 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import numpy as np
from model import PolicyNetwork, QvalueNetwork, ValueNetwork
import torch
from replay_memory import Memory, Transition
from torch import from_numpy
from torch.optim.adam import Adam
class SAC:
def __init__(self, env_name, n_states, n_actions, memory_size, batch_size, gamma, alpha, lr, action_bounds,
reward_scale):
self.env_name = env_name
self.n_states = n_states
self.n_actions = n_actions
self.memory_size = memory_size
self.batch_size = batch_size
self.gamma = gamma
self.alpha = alpha
self.lr = lr
self.action_bounds = action_bounds
self.reward_scale = reward_scale
self.memory = Memory(memory_size=self.memory_size)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.policy_network = PolicyNetwork(n_states=self.n_states, n_actions=self.n_actions,
action_bounds=self.action_bounds).to(self.device)
self.q_value_network1 = QvalueNetwork(n_states=self.n_states, n_actions=self.n_actions).to(self.device)
self.q_value_network2 = QvalueNetwork(n_states=self.n_states, n_actions=self.n_actions).to(self.device)
self.value_network = ValueNetwork(n_states=self.n_states).to(self.device)
self.value_target_network = ValueNetwork(n_states=self.n_states).to(self.device)
self.value_target_network.load_state_dict(self.value_network.state_dict())
self.value_target_network.eval()
self.value_loss = torch.nn.MSELoss()
self.q_value_loss = torch.nn.MSELoss()
self.value_opt = Adam(self.value_network.parameters(), lr=self.lr)
self.q_value1_opt = Adam(self.q_value_network1.parameters(), lr=self.lr)
self.q_value2_opt = Adam(self.q_value_network2.parameters(), lr=self.lr)
self.policy_opt = Adam(self.policy_network.parameters(), lr=self.lr)
def store(self, state, reward, done, action, next_state):
state = from_numpy(state).float().to("cpu")
reward = torch.Tensor([reward]).to("cpu")
done = torch.Tensor([done]).to("cpu")
action = torch.Tensor([action]).to("cpu")
next_state = from_numpy(next_state).float().to("cpu")
self.memory.add(state, reward, done, action, next_state)
def unpack(self, batch):
batch = Transition(*zip(*batch))
states = torch.cat(batch.state).view(self.batch_size, self.n_states).to(self.device)
rewards = torch.cat(batch.reward).view(self.batch_size, 1).to(self.device)
dones = torch.cat(batch.done).view(self.batch_size, 1).to(self.device)
actions = torch.cat(batch.action).view(-1, self.n_actions).to(self.device)
next_states = torch.cat(batch.next_state).view(self.batch_size, self.n_states).to(self.device)
return states, rewards, dones, actions, next_states
def train(self):
if len(self.memory) < self.batch_size:
return 0, 0, 0
else:
batch = self.memory.sample(self.batch_size)
states, rewards, dones, actions, next_states = self.unpack(batch)
# Calculating the value target
reparam_actions, log_probs = self.policy_network.sample_or_likelihood(states)
q1 = self.q_value_network1(states, reparam_actions)
q2 = self.q_value_network2(states, reparam_actions)
q = torch.min(q1, q2)
target_value = q.detach() - self.alpha * log_probs.detach()
value = self.value_network(states)
value_loss = self.value_loss(value, target_value)
# Calculating the Q-Value target
with torch.no_grad():
target_q = self.reward_scale * rewards + \
self.gamma * self.value_target_network(next_states) * (1 - dones)
q1 = self.q_value_network1(states, actions)
q2 = self.q_value_network2(states, actions)
q1_loss = self.q_value_loss(q1, target_q)
q2_loss = self.q_value_loss(q2, target_q)
policy_loss = (self.alpha * log_probs - q).mean()
self.policy_opt.zero_grad()
policy_loss.backward()
self.policy_opt.step()
self.value_opt.zero_grad()
value_loss.backward()
self.value_opt.step()
self.q_value1_opt.zero_grad()
q1_loss.backward()
self.q_value1_opt.step()
self.q_value2_opt.zero_grad()
q2_loss.backward()
self.q_value2_opt.step()
self.soft_update_target_network(self.value_network, self.value_target_network)
return value_loss.item(), 0.5 * (q1_loss + q2_loss).item(), policy_loss.item()
def choose_action(self, states):
states = np.expand_dims(states, axis=0)
states = from_numpy(states).float().to(self.device)
action, _ = self.policy_network.sample_or_likelihood(states)
return action.detach().cpu().numpy()[0]
@staticmethod
def soft_update_target_network(local_network, target_network, tau=0.005):
for target_param, local_param in zip(target_network.parameters(), local_network.parameters()):
target_param.data.copy_(tau * local_param.data + (1 - tau) * target_param.data)
def save_weights(self):
torch.save(self.policy_network.state_dict(), self.env_name + "_weights.pth")
def load_weights(self):
self.policy_network.load_state_dict(torch.load(self.env_name + "_weights.pth"))
def set_to_eval_mode(self):
self.policy_network.eval()