-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathinspect_checkpoint.py
44 lines (35 loc) · 1.28 KB
/
inspect_checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""
Simple script that checks if a checkpoint is corrupted with any inf/NaN values. Run like this:
python inspect_checkpoint.py model.12345
"""
import tensorflow as tf
import sys
import numpy as np
if __name__ == '__main__':
if len(sys.argv) != 2:
raise Exception("Usage: python inspect_checkpoint.py <file_name>\nNote: Do not include the .data .index or .meta part of the model checkpoint in file_name.")
file_name = sys.argv[1]
reader = tf.train.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()
finite = []
all_infnan = []
some_infnan = []
for key in sorted(var_to_shape_map.keys()):
tensor = reader.get_tensor(key)
if np.all(np.isfinite(tensor)):
finite.append(key)
else:
if not np.any(np.isfinite(tensor)):
all_infnan.append(key)
else:
some_infnan.append(key)
print("\nFINITE VARIABLES:")
for key in finite: print(key)
print("\nVARIABLES THAT ARE ALL INF/NAN:")
for key in all_infnan: print(key)
print("\nVARIABLES THAT CONTAIN SOME FINITE, SOME INF/NAN VALUES:")
for key in some_infnan: print(key)
if not all_infnan and not some_infnan:
print("CHECK PASSED: checkpoint contains no inf/NaN values")
else:
print("CHECK FAILED: checkpoint contains some inf/NaN values")