-
Notifications
You must be signed in to change notification settings - Fork 5
/
iknet_test.py
38 lines (29 loc) · 1.1 KB
/
iknet_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import argparse
import torch
from torch.utils.data import DataLoader
from iknet import IKDataset, IKNet
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--kinematics-pose-csv", type=str, default="./dataset/test/kinematics_pose.csv"
)
parser.add_argument(
"--joint-states-csv", type=str, default="./dataset/test/joint_states.csv"
)
parser.add_argument("--batch-size", type=int, default=10000)
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = IKNet()
model.load_state_dict(torch.load("iknet.pth"))
model.to(device)
model.eval()
dataset = IKDataset(args.kinematics_pose_csv, args.joint_states_csv)
test_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)
total_loss = 0.0
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
total_loss += (output - target).norm().item() / args.batch_size
print(f"Total loss = {total_loss}")
if __name__ == "__main__":
main()