-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathpoint_pillars_training_run.py
68 lines (53 loc) · 2.84 KB
/
point_pillars_training_run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import time
import numpy as np
import tensorflow as tf
from glob import glob
from config import Parameters
from loss import PointPillarNetworkLoss
from network import build_point_pillar_graph
from processors import SimpleDataGenerator
from readers import KittiDataReader
tf.get_logger().setLevel("ERROR")
DATA_ROOT = "../training" # TODO make main arg
MODEL_ROOT = "./logs"
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
if __name__ == "__main__":
params = Parameters()
pillar_net = build_point_pillar_graph(params)
pillar_net.load_weights(os.path.join(MODEL_ROOT, "model.h5"))
loss = PointPillarNetworkLoss(params)
optimizer = tf.keras.optimizers.Adam(lr=params.learning_rate, decay=params.decay_rate)
pillar_net.compile(optimizer, loss=loss.losses())
data_reader = KittiDataReader()
lidar_files = sorted(glob(os.path.join(DATA_ROOT, "velodyne", "*.bin")))
label_files = sorted(glob(os.path.join(DATA_ROOT, "label_2", "*.txt")))
calibration_files = sorted(glob(os.path.join(DATA_ROOT, "calib", "*.txt")))
assert len(lidar_files) == len(label_files) == len(calibration_files), "Input dirs require equal number of files."
validation_len = int(0.3*len(label_files))
training_gen = SimpleDataGenerator(data_reader, params.batch_size, lidar_files[:-validation_len], label_files[:-validation_len], calibration_files[:-validation_len])
validation_gen = SimpleDataGenerator(data_reader, params.batch_size, lidar_files[-validation_len:], label_files[-validation_len:], calibration_files[-validation_len:])
log_dir = MODEL_ROOT
epoch_to_decay = int(
np.round(params.iters_to_decay / params.batch_size * int(np.ceil(float(len(label_files)) / params.batch_size))))
callbacks = [
tf.keras.callbacks.TensorBoard(log_dir=log_dir),
tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(log_dir, "model.h5"),
monitor='val_loss', save_best_only=True),
tf.keras.callbacks.LearningRateScheduler(
lambda epoch, lr: lr * 0.8 if ((epoch % epoch_to_decay == 0) and (epoch != 0)) else lr, verbose=True),
tf.keras.callbacks.EarlyStopping(patience=20, monitor='val_loss'),
]
try:
pillar_net.fit(training_gen,
validation_data = validation_gen,
steps_per_epoch=len(training_gen),
callbacks=callbacks,
use_multiprocessing=True,
epochs=int(params.total_training_epochs),
workers=6)
except KeyboardInterrupt:
model_str = "interrupted_%s.h5" % time.strftime("%Y%m%d-%H%M%S")
pillar_net.save(os.path.join(log_dir, model_str))
print("Interrupt. Saving output to %s" % os.path.join(os.getcwd(), log_dir[1:], model_str))