-
Notifications
You must be signed in to change notification settings - Fork 0
/
reduce_checkpoint_size.py
executable file
·40 lines (30 loc) · 1.18 KB
/
reduce_checkpoint_size.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
훈련된 모델의 checkpoint 에는 Adam momentum 과 variance 가 함께 저장됩니다.
이 스크립트는 실제 model weights 만 저장해서 용량을 줄입니다.
"""
import argparse
import os
import tensorflow as tf
def main():
parser = argparse.ArgumentParser(description="Predict Shopping Categories")
## Optional arguments
parser.add_argument("--input_path", type=str, help="Path to the input checkpoint.")
parser.add_argument(
"--output_path", type=str, help="Path to the output checkpoint.")
args = parser.parse_args()
input_path = os.path.expanduser(args.input_path)
output_path = os.path.expanduser(args.output_path)
output_dir = os.path.dirname(output_path)
tf.gfile.MakeDirs(output_dir)
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
saver = tf.train.import_meta_graph(input_path + '.meta')
saver.restore(sess, input_path)
model_variables = [
var for var in tf.global_variables() if not var.name.startswith('training')
]
saver = tf.train.Saver(model_variables)
saver.save(sess, output_path)
if __name__ == "__main__":
main()