From 81c2ea397695790c370669ce57988c403f7cf7ea Mon Sep 17 00:00:00 2001 From: DocGarbanzo <47540921+DocGarbanzo@users.noreply.github.com> Date: Thu, 10 Jun 2021 21:36:17 +0100 Subject: [PATCH] * Bump version -> 4.2.1 (#881) * Revert to passing full path to model in training call which got accidentally broken in 4.2 master. --- donkeycar/__init__.py | 2 +- donkeycar/pipeline/database.py | 4 ++-- donkeycar/pipeline/training.py | 29 +++++++++++++++++++---------- setup.py | 2 +- 4 files changed, 23 insertions(+), 14 deletions(-) diff --git a/donkeycar/__init__.py b/donkeycar/__init__.py index 2714c29a8..63475dbac 100644 --- a/donkeycar/__init__.py +++ b/donkeycar/__init__.py @@ -1,7 +1,7 @@ import sys from pyfiglet import Figlet -__version__ = '4.2.0' +__version__ = '4.2.1' f = Figlet(font='speed') print(f.renderText('Donkey Car')) diff --git a/donkeycar/pipeline/database.py b/donkeycar/pipeline/database.py index d6502dbee..9e829192b 100644 --- a/donkeycar/pipeline/database.py +++ b/donkeycar/pipeline/database.py @@ -33,8 +33,8 @@ def generate_model_name(self) -> Tuple[str, int]: else: this_num = 0 date = time.strftime('%y-%m-%d') - name = 'pilot_' + date + '_' + str(this_num) - return name, this_num + name = f'pilot_{date}_{this_num}.h5' + return os.path.join(self.cfg.MODELS_PATH, name), this_num def to_df(self) -> pd.DataFrame: if self.entries: diff --git a/donkeycar/pipeline/training.py b/donkeycar/pipeline/training.py index 56bd640fe..bec78a1aa 100644 --- a/donkeycar/pipeline/training.py +++ b/donkeycar/pipeline/training.py @@ -81,6 +81,15 @@ def create_tf_data(self) -> tf.data.Dataset: def get_model_train_details(cfg: Config, database: PilotDatabase, model: str = None, model_type: str = None) \ -> Tuple[str, int, str, bool]: + """ + Returns automatic model name if none is given + :param cfg: donkey config + :param database: model database with existing training data + :param model: model path + :param model_type: type of model, like 'linear', 'tflite_linear', etc + :return: tuple of model path, number, training type, and if + tflite is requested + """ if not model_type: model_type = cfg.DEFAULT_MODEL_TYPE train_type = model_type @@ -90,12 +99,13 @@ def get_model_train_details(cfg: Config, database: PilotDatabase, is_tflite = True model_num = 0 if not model: - model_name, model_num = database.generate_model_name() + model_path, model_num = database.generate_model_name() else: - model_name, model_ext = os.path.splitext(model) + _, model_ext = os.path.splitext(model) + model_path = model is_tflite = model_ext == '.tflite' - return model_name, model_num, train_type, is_tflite + return model_path, model_num, train_type, is_tflite def train(cfg: Config, tub_paths: str, model: str = None, @@ -105,10 +115,9 @@ def train(cfg: Config, tub_paths: str, model: str = None, Train the model """ database = PilotDatabase(cfg) - model_name, model_num, train_type, is_tflite = \ + model_path, model_num, train_type, is_tflite = \ get_model_train_details(cfg, database, model, model_type) - output_path = os.path.join(cfg.MODELS_PATH, model_name + '.h5') kl = get_model_by_type(train_type, cfg) if transfer: kl.load(transfer) @@ -135,7 +144,7 @@ def train(cfg: Config, tub_paths: str, model: str = None, assert val_size > 0, "Not enough validation data, decrease the batch " \ "size or add more data." - history = kl.train(model_path=output_path, + history = kl.train(model_path=model_path, train_data=dataset_train, train_steps=train_size, batch_size=cfg.BATCH_SIZE, @@ -146,14 +155,14 @@ def train(cfg: Config, tub_paths: str, model: str = None, min_delta=cfg.MIN_DELTA, patience=cfg.EARLY_STOP_PATIENCE, show_plot=cfg.SHOW_PLOT) - + base_path = os.path.splitext(model_path)[0] if is_tflite: - tf_lite_model_path = f'{os.path.splitext(output_path)[0]}.tflite' - keras_model_to_tflite(output_path, tf_lite_model_path) + tf_lite_model_path = f'{base_path}.tflite' + keras_model_to_tflite(model_path, tf_lite_model_path) database_entry = { 'Number': model_num, - 'Name': model_name, + 'Name': os.path.basename(base_path), 'Type': str(kl), 'Tubs': tub_paths, 'Time': time(), diff --git a/setup.py b/setup.py index 3eee96b75..c279a1c3a 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ def package_files(directory, strip_leading): long_description = fh.read() setup(name='donkeycar', - version='4.2.0', + version='4.2.1', long_description=long_description, description='Self driving library for python.', url='https://github.com/autorope/donkeycar',