-
Notifications
You must be signed in to change notification settings - Fork 72
/
config.py
83 lines (65 loc) · 2.45 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# coding=utf-8
import os
import platform
os_name = platform.system().lower()
def is_mac():
return os_name.startswith('darwin')
def is_windows():
return os_name.startswith('windows')
def is_linux():
return os_name.startswith('linux')
def parse_weigths(weights):
if not weights \
or not weights.endswith('.h5') \
or not weights.__contains__('/') \
or not weights.__contains__('-'):
return None
try:
weights_info = weights.split(os.path.sep)[-1].replace('.h5', '').split('-')
if len(weights_info) != 3:
return None
epoch = int(weights_info[0])
val_loss = float(weights_info[1])
val_acc = float(weights_info[2])
return epoch, val_loss, val_acc
except Exception as e:
raise Exception('Parse weights failure: %s', str(e))
def CONTEXT(name, **kwargs):
return {
'weights': 'params/%s/{epoch:05d}-{val_loss:.4f}-{val_acc:.4f}.h5' % name,
'summary': 'log/%s' % name,
'predictor_cache_dir': 'cache/%s' % name,
'load_imagenet_weights': is_windows(),
'path_json_dump': 'eval_json/%s/result%s.json' % (
name, ('_' + kwargs['policy']) if kwargs.__contains__('policy') else ''),
}
# image path
if is_windows():
PATH_TRAIN_BASE = 'G:/Dataset/SceneClassify/ai_challenger_scene_train_20170904'
PATH_VAL_BASE = 'G:/Dataset/SceneClassify/ai_challenger_scene_validation_20170908'
PATH_TEST_B = 'G:/Dataset/SceneClassify/ai_challenger_scene_test_b_20170922/scene_test_b_images_20170922'
elif is_mac():
PATH_TRAIN_BASE = '/Users/zijiao/Desktop/ai_challenger_scene_train_20170904'
PATH_VAL_BASE = '/Users/zijiao/Desktop/ai_challenger_scene_validation_20170908'
PATH_TEST_B = ''
elif is_linux():
# 皮皮酱
PATH_TRAIN_BASE = ''
PATH_VAL_BASE = ''
PATH_TEST_B = ''
else:
raise Exception('No images configured on %s' % os_name)
PATH_TRAIN_IMAGES = os.path.join(PATH_TRAIN_BASE, 'classes')
PATH_TRAIN_JSON = os.path.join(PATH_TRAIN_BASE, 'scene_train_annotations_20170904.json')
PATH_VAL_IMAGES = os.path.join(PATH_VAL_BASE, 'classes')
PATH_VAL_JSON = os.path.join(PATH_VAL_BASE, 'scene_validation_annotations_20170908.json')
PATH_JSON_DUMP = 'eval_json/resnet.json'
# train info
IM_SIZE_299 = 299
IM_SIZE_224 = 224
BATCH_SIZE = 32
CLASSES = len(os.listdir(PATH_TRAIN_IMAGES))
EPOCH = 100
if __name__ == '__main__':
print(PATH_TRAIN_IMAGES)
print(CONTEXT('test').values())