forked from haoyye/ResourceAllocationReinforcementLearning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbase.py
62 lines (48 loc) · 1.95 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from __future__ import print_function
import os
class BaseModel(object):
def __init__(self, config):
self._saver = None
self.config = config
try:
self._attr = config.__dict__['__flags']
except:
self._attr = class_var(config)
self.config = config
for attr in self._attrs:
name = attr if not attr.startswith('_') else attr[1:]
setattr(self, name, getattr(self.config, attr))
def save_model(self, step=None):
print(" [*] Saving checkpoints...")
model_name = type(self).__name__
if not os.path.exists(self.checkpoint_dir):
os.makedirs(self.checkpoint_dir)
self.saver.save(self.sess, self.checkpoint_dir, global_step=step)
def load_model(self):
print(" [*] Loading checkpoints...")
ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
fname = os.path.join(self.checkpoint_dir, ckpt_name)
self.saver.restore(self.sess, fname)
print(" [*] Load SUCCESS: %s" % fname)
return True
else:
print(" [!] Load FAILED: %s" % self.checkpoint_dir)
return False
@property
def checkpoint_dir(self):
return os.path.join('checkpoints', self.model_dir)
@property
def model_dir(self):
model_dir = self.config.env_name
for k, v in self._attrs.items():
if not k.startswith('_') and k not in ['display']:
model_dir += "/%s-%s" % (k, ",".join([str(i) for i in v])
if type(v) == list else v)
return model_dir + '/'
@property
def saver(self):
if self._saver == None:
self._saver = tf.train.Saver(max_to_keep=10)
return self._saver