-
Notifications
You must be signed in to change notification settings - Fork 94
/
Copy pathreplay_memory.py
40 lines (31 loc) · 1.35 KB
/
replay_memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
import random
import numpy as np
class ReplayMemory:
def __init__(self, entry_size):
self.entry_size = entry_size
self.memory_size = 200000
self.actions = np.empty(self.memory_size, dtype = np.uint8)
self.rewards = np.empty(self.memory_size, dtype = np.float64)
self.prestate = np.empty((self.memory_size, self.entry_size), dtype = np.float16)
self.poststate = np.empty((self.memory_size, self.entry_size), dtype = np.float16)
self.batch_size = 2000
self.count = 0
self.current = 0
def add(self, prestate, poststate, reward, action):
self.actions[self.current] = action
self.rewards[self.current] = reward
self.prestate[self.current] = prestate
self.poststate[self.current] = poststate
self.count = max(self.count, self.current + 1)
self.current = (self.current + 1) % self.memory_size
def sample(self):
if self.count < self.batch_size:
indexes = range(0, self.count)
else:
indexes = random.sample(range(0,self.count), self.batch_size)
prestate = self.prestate[indexes]
poststate = self.poststate[indexes]
actions = self.actions[indexes]
rewards = self.rewards[indexes]
return prestate, poststate, actions, rewards