-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathISL_utils.py
102 lines (86 loc) · 4.28 KB
/
ISL_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import mediapipe as mp
import cv2
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np
train_dir = './Data'
val_dir = './Val_Data'
test_dir = './Test_Data'
# Get all gestures from csv
df_itos = pd.read_csv('itos.csv')
labels_to_gestures = {}
gestures_to_labels = {}
gestures = []
for i in range(len(df_itos.loc[0])):
gesture_i = df_itos.iloc[0, i]
gestures.append(gesture_i)
labels_to_gestures[i] = gesture_i
gestures_to_labels[gesture_i] = i
mp_drawing = mp.solutions.drawing_utils
mp_holistic = mp.solutions.holistic
def mediapipe_detection(image, model):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # COLOR CONVERSION BGR 2 RGB
image.flags.writeable = False # Image is no longer writeable
results = model.process(image) # Make prediction
image.flags.writeable = True # Image is now writeable
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # COLOR COVERSION RGB 2 BGR
return image, results
def draw_styled_landmarks(image, results):
# Draw face connections
mp_drawing.draw_landmarks(image, results.face_landmarks, mp_holistic.FACEMESH_TESSELATION,
mp_drawing.DrawingSpec(color=(80, 110, 10), thickness=1, circle_radius=1),
mp_drawing.DrawingSpec(color=(80, 256, 121), thickness=1, circle_radius=1)
)
# Draw pose connections
mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS,
mp_drawing.DrawingSpec(color=(80, 22, 10), thickness=2, circle_radius=4),
mp_drawing.DrawingSpec(color=(80, 44, 121), thickness=2, circle_radius=2)
)
# Draw left hand connections
mp_drawing.draw_landmarks(image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS,
mp_drawing.DrawingSpec(color=(121, 22, 76), thickness=2, circle_radius=4),
mp_drawing.DrawingSpec(color=(121, 44, 250), thickness=2, circle_radius=2)
)
# Draw right hand connections
mp_drawing.draw_landmarks(image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS,
mp_drawing.DrawingSpec(color=(245, 117, 66), thickness=2, circle_radius=4),
mp_drawing.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=2)
)
def extract_keypoints(results):
pose = np.array([[res.x, res.y, res.z, res.visibility] for res in
results.pose_landmarks.landmark]).flatten() if results.pose_landmarks else np.zeros(33 * 4)
face = np.array([[res.x, res.y, res.z] for res in
results.face_landmarks.landmark]).flatten() if results.face_landmarks else np.zeros(468 * 3)
lh = np.array([[res.x, res.y, res.z] for res in
results.left_hand_landmarks.landmark]).flatten() if results.left_hand_landmarks else np.zeros(21 * 3)
rh = np.array([[res.x, res.y, res.z] for res in
results.right_hand_landmarks.landmark]).flatten() if results.right_hand_landmarks else np.zeros(
21 * 3)
return np.concatenate([pose, face, lh, rh])
def plot_loss_and_acc(history):
plt.figure()
plt.plot(history['train_loss'], label='train loss')
plt.plot(history['val_loss'], label='val loss')
plt.title('Training Loss vs Validation Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend()
plt.savefig(f'train_val_loss.png')
plt.figure()
plt.plot(history['train_acc'], label='train acc')
plt.plot(history['val_acc'], label='val acc')
plt.title('Training Acc vs Validation Acc')
plt.ylabel('Acc')
plt.xlabel('Epoch')
plt.legend()
plt.savefig(f'train_val_acc.png')
def plot_cm_heatmap(cm):
plt.figure(figsize=(20, 15))
df_cm = pd.DataFrame(cm, index=gestures, columns=gestures).astype(int)
heatmap = sns.heatmap(df_cm, annot=True, fmt="d", cmap='mako')
heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=15)
heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=15)
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.savefig('confusion_matrix.png')