-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
82 lines (61 loc) · 1.88 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import warnings
from dataclasses import dataclass
import sys
from enum import Enum
import torch
import numpy as np
@dataclass
class Args: # used as abstract type for sharing args
pass
@dataclass
class TemplateArgs:
"""
DataClass used to store general project parameters for the dataset.
"""
facesQs_path: str = "path-to-source" # root to faceQs source (for devel)
rootdir: str = "" # root dir of the project
data_path: str = "" # path to data
device: str = "cuda" # gpu accelerator if available
data_fname: str = "" # filename of the data description
def add_root_path(self): # add root dir to sys.path
sys.path.append(self.rootdir)
def verbatimT(verbose, true, text, deep=0):
if verbose >= 1 and true == 1:
if deep == 0:
pre = "-> "
if deep == 1:
pre = " "
print(pre + text)
elif verbose >= 2 and true == 2:
pre = " "
print(pre + text + "\n")
def verbatimO(verbose, obj, level=0):
if verbose == 2:
print(obj)
def warn(text):
warnings.warn(text)
class Task(Enum):
CLASSIFICATION = 0
REGRESSION = 1
class Split(Enum):
PERCENTAGE_SPLIT = 0
CROSS_VALIDATION = 1
RAVDESS_SPLIT = 2
RAVDESS_SPLIT_5F = 3
class TrainEvaluationMetrics(Enum):
ACCURACY = 0
LOSS = 1
def frame_resampling(x, max_frame = 60, method="pad"):
if len(x) > max_frame:
return x[0:max_frame] # cut
elif method == "pad":
diff = max_frame - len(x)
x = torch.vstack((x, torch.zeros(diff, x.shape[1], x.shape[2], x.shape[3])))
return x
def frame_resampling_np(x, max_frame = 60, method="pad"):
if len(x) > max_frame:
return x[0:max_frame] # cut
elif method == "pad":
diff = max_frame - len(x)
x = np.vstack((x, np.zeros((diff, x.shape[1], x.shape[2], x.shape[3]))))
return x