-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathai_image.py
90 lines (70 loc) · 2.61 KB
/
ai_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from config import IMAGE_MODEL_IMAGE_SIZE, IMAGE_MODEL_URL, LABELS_MAP_PATH
import tensorflow_hub as hub
import tensorflow as tf
from typing import List
#initializing tensorflow model (singleton for all the server for now)
tf.keras.backend.clear_session()
m = tf.keras.Sequential([
hub.KerasLayer(IMAGE_MODEL_URL) #TODO: add to config
])
m.build([None, IMAGE_MODEL_IMAGE_SIZE, IMAGE_MODEL_IMAGE_SIZE, 3])
#main functions
def getDescriptionsOfImage(image_file_path: str, max_desc_number: int = 5):
image = tf.keras.preprocessing.image.load_img(image_file_path, target_size=(IMAGE_MODEL_IMAGE_SIZE, IMAGE_MODEL_IMAGE_SIZE))
print(image)
image = tf.keras.preprocessing.image.img_to_array(image)
image = (image - 128.) / 128.
logits = m(tf.expand_dims(image, 0), False)
labels_map = LABELS_MAP_PATH #TODO: add to config
pred = tf.keras.activations.sigmoid(logits)
idx = tf.argsort(logits[0])[::-1][:20].numpy()
classes = get_imagenet_labels(labels_map)
results: List[str] = []
for i, id in enumerate(idx):
if(len(results) < max_desc_number):
results.append(classes[id])
else:
break
return results
#Helper functions
def get_imagenet_labels(filename):
labels = []
with open(filename, 'r') as f:
for line in f:
labels.append(line.split('\t')[1][:-1])
return labels
from config import IMAGE_MODEL_IMAGE_SIZE, IMAGE_MODEL_URL, LABELS_MAP_PATH
import tensorflow_hub as hub
import tensorflow as tf
from typing import List
#initializing tensorflow model (singleton for all the server for now)
tf.keras.backend.clear_session()
m = tf.keras.Sequential([
hub.KerasLayer(IMAGE_MODEL_URL) #TODO: add to config
])
m.build([None, IMAGE_MODEL_IMAGE_SIZE, IMAGE_MODEL_IMAGE_SIZE, 3])
#main functions
def getDescriptionsOfImage(image_file_path: str, max_desc_number: int = 5):
image = tf.keras.preprocessing.image.load_img(image_file_path, target_size=(IMAGE_MODEL_IMAGE_SIZE, IMAGE_MODEL_IMAGE_SIZE))
print(image)
image = tf.keras.preprocessing.image.img_to_array(image)
image = (image - 128.) / 128.
logits = m(tf.expand_dims(image, 0), False)
labels_map = LABELS_MAP_PATH #TODO: add to config
pred = tf.keras.activations.sigmoid(logits)
idx = tf.argsort(logits[0])[::-1][:20].numpy()
classes = get_imagenet_labels(labels_map)
results: List[str] = []
for i, id in enumerate(idx):
if(len(results) < max_desc_number):
results.append(classes[id])
else:
break
return results
#Helper functions
def get_imagenet_labels(filename):
labels = []
with open(filename, 'r') as f:
for line in f:
labels.append(line.split('\t')[1][:-1])
return labels