diff --git a/lightning_pose/__init__.py b/lightning_pose/__init__.py index e69de29b..67bc602a 100644 --- a/lightning_pose/__init__.py +++ b/lightning_pose/__init__.py @@ -0,0 +1 @@ +__version__ = "1.3.0" diff --git a/lightning_pose/apps/utils.py b/lightning_pose/apps/utils.py index d5f80dce..d272fb43 100644 --- a/lightning_pose/apps/utils.py +++ b/lightning_pose/apps/utils.py @@ -16,7 +16,7 @@ @st.cache_resource -def update_labeled_file_list(model_preds_folders: list, use_ood: bool = False): +def update_labeled_file_list(model_preds_folders: List[str], use_ood: bool = False) -> List[list]: per_model_preds = [] for model_pred_folder in model_preds_folders: # pull labeled results from each model folder @@ -40,15 +40,20 @@ def update_labeled_file_list(model_preds_folders: list, use_ood: bool = False): @st.cache_resource -def update_vid_metric_files_list(video: str, model_preds_folders: list): +def update_vid_metric_files_list( + video: str, + model_preds_folders: List[str], + video_subdir: str = "video_preds", +) -> List[list]: per_vid_preds = [] for model_preds_folder in model_preds_folders: # pull each prediction file associated with a particular video # wrap in Path so that it looks like an UploadedFile object + video_dir = os.path.join(model_preds_folder, video_subdir) + if not os.path.isdir(video_dir): + continue model_preds = [ - f - for f in os.listdir(os.path.join(model_preds_folder, "video_preds")) - if os.path.isfile(os.path.join(model_preds_folder, "video_preds", f)) + f for f in os.listdir(video_dir) if os.path.isfile(os.path.join(video_dir, f)) ] ret_files = [] for file in model_preds: @@ -59,16 +64,17 @@ def update_vid_metric_files_list(video: str, model_preds_folders: list): @st.cache_resource -def get_all_videos(model_preds_folders: list): +def get_all_videos(model_preds_folders: List[str], video_subdir: str = "video_preds") -> list: # find each video that is predicted on by the models # wrap in Path so that it looks like an UploadedFile object # returned by streamlit's file_uploader ret_videos = set() for model_preds_folder in model_preds_folders: + video_dir = os.path.join(model_preds_folder, video_subdir) + if not os.path.isdir(video_dir): + continue model_preds = [ - f - for f in os.listdir(os.path.join(model_preds_folder, "video_preds")) - if os.path.isfile(os.path.join(model_preds_folder, "video_preds", f)) + f for f in os.listdir(video_dir) if os.path.isfile(os.path.join(video_dir, f)) ] for file in model_preds: if "temporal" in file: @@ -97,7 +103,7 @@ def concat_dfs(dframes: Dict[str, pd.DataFrame]) -> Tuple[pd.DataFrame, List[str @st.cache_data -def get_df_box(df_orig, keypoint_names, model_names): +def get_df_box(df_orig: pd.DataFrame, keypoint_names: list, model_names: list) -> pd.DataFrame: df_boxes = [] for keypoint in keypoint_names: for model_curr in model_names: @@ -112,7 +118,13 @@ def get_df_box(df_orig, keypoint_names, model_names): @st.cache_data -def get_df_scatter(df_0, df_1, data_type, model_names, keypoint_names): +def get_df_scatter( + df_0: pd.DataFrame, + df_1: pd.DataFrame, + data_type: str, + model_names: list, + keypoint_names: list +) -> pd.DataFrame: df_scatters = [] for keypoint in keypoint_names: df_scatters.append( @@ -147,7 +159,7 @@ def get_full_name(keypoint: str, coordinate: str, model: str) -> str: # ---------------------------------------------- @st.cache_data def build_precomputed_metrics_df( - dframes: Dict[str, pd.DataFrame], keypoint_names: List[str], **kwargs + dframes: Dict[str, pd.DataFrame], keypoint_names: List[str], **kwargs, ) -> dict: concat_dfs = defaultdict(list) for model_name, df_dict in dframes.items(): @@ -179,7 +191,7 @@ def build_precomputed_metrics_df( @st.cache_data def get_precomputed_error( - df: pd.DataFrame, keypoint_names: List[str], model_name: str + df: pd.DataFrame, keypoint_names: List[str], model_name: str, ) -> pd.DataFrame: # collect results df_ = df @@ -192,17 +204,17 @@ def get_precomputed_error( @st.cache_data def compute_confidence( - df: pd.DataFrame, keypoint_names: List[str], model_name: str, **kwargs + df: pd.DataFrame, keypoint_names: List[str], model_name: str, **kwargs, ) -> pd.DataFrame: + if df.shape[1] % 3 == 1: - # get rid of "set" column if present - tmp = df.iloc[:, :-1].to_numpy().reshape(df.shape[0], -1, 3) + # collect "set" column if present set = df.iloc[:, -1].to_numpy() else: - tmp = df.to_numpy().reshape(df.shape[0], -1, 3) set = None - results = tmp[:, :, 2] + mask = df.columns.get_level_values("coords").isin(["likelihood"]) + results = df.loc[:, mask].to_numpy() # collect results df_ = pd.DataFrame(columns=keypoint_names) @@ -219,7 +231,7 @@ def compute_confidence( # ------------ utils related to model finding in dir --------- # write a function that finds all model folders in the model_dir -def get_model_folders(model_dir): +def get_model_folders(model_dir: str) -> List[str]: # strip trailing slash if present if model_dir[-1] == os.sep: model_dir = model_dir[:-1] @@ -232,7 +244,7 @@ def get_model_folders(model_dir): # just to get the last two levels of the path -def get_model_folders_vis(model_folders): +def get_model_folders_vis(model_folders: List[str]) -> List[str]: fs = [] for f in model_folders: fs.append(f.split("/")[-2:]) diff --git a/lightning_pose/apps/video_diagnostics.py b/lightning_pose/apps/video_diagnostics.py index 4bcee624..4ee0c88d 100644 --- a/lightning_pose/apps/video_diagnostics.py +++ b/lightning_pose/apps/video_diagnostics.py @@ -32,6 +32,7 @@ def run(): + args = parser.parse_args() st.title("Video Diagnostics") @@ -53,23 +54,19 @@ def run(): # get the last two levels of each path to be presented to user model_folders_vis = get_model_folders_vis(model_folders) - selected_models_vis = st.sidebar.multiselect( - "Select models", model_folders_vis, default=None - ) + selected_models_vis = st.sidebar.multiselect("Select models", model_folders_vis, default=None) # append this to full path - selected_models = [ - "/" + os.path.join(args.model_dir, f) for f in selected_models_vis - ] + selected_models = ["/" + os.path.join(args.model_dir, f) for f in selected_models_vis] # ----- selecting videos to analyze ----- - all_videos_: list = get_all_videos(selected_models) + all_videos_: list = get_all_videos(selected_models, video_subdir=args.video_subdir) # choose from the different videos that were predicted video_to_plot = st.sidebar.selectbox("Select a video:", [*all_videos_], key="video") prediction_files = update_vid_metric_files_list( - video=video_to_plot, model_preds_folders=selected_models + video=video_to_plot, model_preds_folders=selected_models, video_subdir=args.video_subdir, ) model_names = copy.copy(selected_models_vis) @@ -100,9 +97,7 @@ def run(): dframe = pd.read_csv(model_pred_file_path, index_col=None) dframes_metrics[model_name][str(model_pred_file)] = dframe else: - dframe = pd.read_csv( - model_pred_file_path, header=[1, 2], index_col=0 - ) + dframe = pd.read_csv(model_pred_file_path, header=[1, 2], index_col=0) dframes_traces[model_name] = dframe dframes_metrics[model_name]["confidence"] = dframe # data_types = dframe.iloc[:, -1].unique() @@ -221,6 +216,7 @@ def run(): parser = argparse.ArgumentParser() parser.add_argument("--model_dir", type=str, default=[]) + parser.add_argument("--video_subdir", type=str, default="video_preds") parser.add_argument("--make_dir", action="store_true", default=False) run() diff --git a/lightning_pose/train.py b/lightning_pose/train.py new file mode 100644 index 00000000..4e6ee35e --- /dev/null +++ b/lightning_pose/train.py @@ -0,0 +1,239 @@ +"""Example model training function.""" + +import os + +import lightning.pytorch as pl +from omegaconf import DictConfig, OmegaConf, open_dict +from typeguard import typechecked + +from lightning_pose.utils import pretty_print_cfg, pretty_print_str +from lightning_pose.utils.io import ( + check_video_paths, + return_absolute_data_paths, + return_absolute_path, +) +from lightning_pose.utils.predictions import predict_dataset +from lightning_pose.utils.scripts import ( + calculate_train_batches, + compute_metrics, + export_predictions_and_labeled_video, + get_callbacks, + get_data_module, + get_dataset, + get_imgaug_transform, + get_loss_factories, + get_model, +) + +# to ignore imports for sphix-autoapidoc +__all__ = ["train"] + + +@typechecked +def train(cfg: DictConfig) -> None: + + # record lightning-pose version + from lightning_pose import __version__ as lightning_pose_version + with open_dict(cfg): + cfg.model.lightning_pose_version = lightning_pose_version + + print("Our Hydra config file:") + pretty_print_cfg(cfg) + + # path handling for toy data + data_dir, video_dir = return_absolute_data_paths(data_cfg=cfg.data) + + # ---------------------------------------------------------------------------------- + # Set up data/model objects + # ---------------------------------------------------------------------------------- + + # imgaug transform + imgaug_transform = get_imgaug_transform(cfg=cfg) + + # dataset + dataset = get_dataset(cfg=cfg, data_dir=data_dir, imgaug_transform=imgaug_transform) + + # datamodule; breaks up dataset into train/val/test + data_module = get_data_module(cfg=cfg, dataset=dataset, video_dir=video_dir) + + # build loss factory which orchestrates different losses + loss_factories = get_loss_factories(cfg=cfg, data_module=data_module) + + # model + model = get_model(cfg=cfg, data_module=data_module, loss_factories=loss_factories) + + # ---------------------------------------------------------------------------------- + # Set up and run training + # ---------------------------------------------------------------------------------- + + # logger + logger = pl.loggers.TensorBoardLogger("tb_logs", name=cfg.model.model_name) + + # early stopping, learning rate monitoring, model checkpointing, backbone unfreezing + callbacks = get_callbacks(cfg, early_stopping=False) + + # calculate number of batches for both labeled and unlabeled data per epoch + limit_train_batches = calculate_train_batches(cfg, dataset) + + # set up trainer + trainer = pl.Trainer( # TODO: be careful with devices when scaling to multiple gpus + accelerator="gpu", # TODO: control from outside + devices=1, # TODO: control from outside + max_epochs=cfg.training.max_epochs, + min_epochs=cfg.training.min_epochs, + check_val_every_n_epoch=cfg.training.check_val_every_n_epoch, + log_every_n_steps=cfg.training.log_every_n_steps, + callbacks=callbacks, + logger=logger, + limit_train_batches=limit_train_batches, + accumulate_grad_batches=cfg.training.get("accumulate_grad_batches", 1), + profiler=cfg.training.get("profiler", None), + ) + + # train model! + trainer.fit(model=model, datamodule=data_module) + + # ---------------------------------------------------------------------------------- + # Post-training analysis + # ---------------------------------------------------------------------------------- + hydra_output_directory = os.getcwd() + print("Hydra output directory: {}".format(hydra_output_directory)) + # get best ckpt + best_ckpt = os.path.abspath(trainer.checkpoint_callback.best_model_path) + # check if best_ckpt is a file + if not os.path.isfile(best_ckpt): + raise FileNotFoundError("Cannot find checkpoint. Have you trained for too few epochs?") + # save config file + cfg_file_local = os.path.join(hydra_output_directory, "config.yaml") + with open(cfg_file_local, "w") as fp: + OmegaConf.save(config=cfg, f=fp.name) + + # make unaugmented data_loader if necessary + if cfg.training.imgaug != "default": + cfg_pred = cfg.copy() + cfg_pred.training.imgaug = "default" + imgaug_transform_pred = get_imgaug_transform(cfg=cfg_pred) + dataset_pred = get_dataset( + cfg=cfg_pred, data_dir=data_dir, imgaug_transform=imgaug_transform_pred + ) + data_module_pred = get_data_module(cfg=cfg_pred, dataset=dataset_pred, video_dir=video_dir) + data_module_pred.setup() + else: + data_module_pred = data_module + + # ---------------------------------------------------------------------------------- + # predict on all labeled frames (train/val/test) + # ---------------------------------------------------------------------------------- + pretty_print_str("Predicting train/val/test images...") + # compute and save frame-wise predictions + preds_file = os.path.join(hydra_output_directory, "predictions.csv") + predict_dataset( + cfg=cfg, + trainer=trainer, + model=model, + data_module=data_module_pred, + ckpt_file=best_ckpt, + preds_file=preds_file, + ) + # compute and save various metrics + try: + compute_metrics(cfg=cfg, preds_file=preds_file, data_module=data_module_pred) + except Exception as e: + print(f"Error computing metrics\n{e}") + + # ---------------------------------------------------------------------------------- + # predict folder of videos + # ---------------------------------------------------------------------------------- + if cfg.eval.predict_vids_after_training: + pretty_print_str("Predicting videos...") + if cfg.eval.test_videos_directory is None: + filenames = [] + else: + filenames = check_video_paths( + return_absolute_path(cfg.eval.test_videos_directory) + ) + vidstr = "video" if (len(filenames) == 1) else "videos" + pretty_print_str( + f"Found {len(filenames)} {vidstr} to predict on " + f"(in cfg.eval.test_videos_directory)" + ) + + for video_file in filenames: + assert os.path.isfile(video_file) + pretty_print_str(f"Predicting video: {video_file}...") + # get save name for prediction csv file + video_pred_dir = os.path.join(hydra_output_directory, "video_preds") + video_pred_name = os.path.splitext(os.path.basename(video_file))[0] + prediction_csv_file = os.path.join(video_pred_dir, video_pred_name + ".csv") + # get save name labeled video csv + if cfg.eval.save_vids_after_training: + labeled_vid_dir = os.path.join(video_pred_dir, "labeled_videos") + labeled_mp4_file = os.path.join( + labeled_vid_dir, video_pred_name + "_labeled.mp4" + ) + else: + labeled_mp4_file = None + # predict on video + export_predictions_and_labeled_video( + video_file=video_file, + cfg=cfg, + ckpt_file=best_ckpt, + prediction_csv_file=prediction_csv_file, + labeled_mp4_file=labeled_mp4_file, + trainer=trainer, + model=model, + data_module=data_module_pred, + save_heatmaps=cfg.eval.get( + "predict_vids_after_training_save_heatmaps", False + ), + ) + # compute and save various metrics + try: + compute_metrics( + cfg=cfg, + preds_file=prediction_csv_file, + data_module=data_module_pred, + ) + except Exception as e: + print(f"Error predicting on video {video_file}:\n{e}") + continue + + # ---------------------------------------------------------------------------------- + # predict on OOD frames + # ---------------------------------------------------------------------------------- + # update config file to point to OOD data + csv_file_ood = os.path.join(cfg.data.data_dir, cfg.data.csv_file).replace( + ".csv", "_new.csv" + ) + if os.path.exists(csv_file_ood): + cfg_ood = cfg.copy() + cfg_ood.data.csv_file = csv_file_ood + cfg_ood.training.imgaug = "default" + cfg_ood.training.train_prob = 1 + cfg_ood.training.val_prob = 0 + cfg_ood.training.train_frames = 1 + # build dataset/datamodule + imgaug_transform_ood = get_imgaug_transform(cfg=cfg_ood) + dataset_ood = get_dataset( + cfg=cfg_ood, data_dir=data_dir, imgaug_transform=imgaug_transform_ood + ) + data_module_ood = get_data_module(cfg=cfg_ood, dataset=dataset_ood, video_dir=video_dir) + data_module_ood.setup() + pretty_print_str("Predicting OOD images...") + # compute and save frame-wise predictions + preds_file_ood = os.path.join(hydra_output_directory, "predictions_new.csv") + predict_dataset( + cfg=cfg_ood, + trainer=trainer, + model=model, + data_module=data_module_ood, + ckpt_file=best_ckpt, + preds_file=preds_file_ood, + ) + # compute and save various metrics + try: + compute_metrics( + cfg=cfg_ood, preds_file=preds_file_ood, data_module=data_module_ood + ) + except Exception as e: + print(f"Error computing metrics\n{e}") diff --git a/scripts/predict_new_vids.py b/scripts/predict_new_vids.py index d4cb46fa..d4d22f3c 100755 --- a/scripts/predict_new_vids.py +++ b/scripts/predict_new_vids.py @@ -49,38 +49,11 @@ def __init__( def video_basename(self) -> str: return os.path.basename(self.video_file).split(".")[0] - @property - def loss_str(self) -> str: - semi_supervised = check_if_semi_supervised(self.model_cfg.model.losses_to_use) - loss_names = [] - loss_weights = [] - loss_str = "" - if semi_supervised: # add the loss names and weights - loss_str = "" - if len(self.model_cfg.model.losses_to_use) > 0: - loss_names = list(self.model_cfg.model.losses_to_use) - for loss in loss_names: - loss_weights.append(self.model_cfg.losses[loss]["log_weight"]) - - loss_str = "" - for loss, weight in zip(loss_names, loss_weights): - loss_str += "_" + loss + "_" + str(weight) - - else: # fully supervised, return empty string - loss_str = "" - return loss_str - def check_input_paths(self) -> None: assert os.path.isfile(self.video_file) assert os.path.isdir(self.save_preds_dir) def build_pred_file_basename(self, extra_str="") -> str: - # return "%s_%s%s%s.csv" % ( - # self.video_basename, - # self.model_cfg.model.model_type, - # self.loss_str, - # extra_str, - # ) return f"{self.video_basename}.csv" def __call__(self, extra_str="") -> str: @@ -117,13 +90,8 @@ def predict_videos_in_dir(cfg: DictConfig): # absolute_cfg_path will be the path of the trained model we're using for predictions absolute_cfg_path = return_absolute_path(hydra_relative_path, n_dirs_back=2) - # debug - print(f"\n\n{absolute_cfg_path = }\n\n") - # load model - model_cfg = OmegaConf.load( - os.path.join(absolute_cfg_path, ".hydra/config.yaml") - ) + model_cfg = OmegaConf.load(os.path.join(absolute_cfg_path, ".hydra/config.yaml")) ckpt_file = ckpt_path_from_base_path( base_path=absolute_cfg_path, model_name=model_cfg.model.model_name ) @@ -134,9 +102,7 @@ def predict_videos_in_dir(cfg: DictConfig): print("getting imgaug transform...") imgaug_transform = get_imgaug_transform(cfg=cfg) print("getting dataset...") - dataset = get_dataset( - cfg=cfg, data_dir=data_dir, imgaug_transform=imgaug_transform - ) + dataset = get_dataset(cfg=cfg, data_dir=data_dir, imgaug_transform=imgaug_transform) print("getting data module...") data_module = get_data_module(cfg=cfg, dataset=dataset, video_dir=video_dir) @@ -145,14 +111,10 @@ def predict_videos_in_dir(cfg: DictConfig): # save to where the videos are. may get an exception save_preds_dir = cfg.eval.test_videos_directory else: - save_preds_dir = return_absolute_path( - cfg.eval.saved_vid_preds_dir, n_dirs_back=3 - ) + save_preds_dir = return_absolute_path(cfg.eval.saved_vid_preds_dir, n_dirs_back=3) # loop over videos in a provided directory - video_files = get_videos_in_dir( - return_absolute_path(cfg.eval.test_videos_directory) - ) + video_files = get_videos_in_dir(return_absolute_path(cfg.eval.test_videos_directory)) for video_file in video_files: @@ -180,9 +142,7 @@ def predict_videos_in_dir(cfg: DictConfig): trainer=trainer, model=model, data_module=data_module, - save_heatmaps=cfg.eval.get( - "predict_vids_after_training_save_heatmaps", False - ), + save_heatmaps=cfg.eval.get("predict_vids_after_training_save_heatmaps", False), ) # compute and save various metrics diff --git a/scripts/train_hydra.py b/scripts/train_hydra.py index 13ae2fbf..8d4f5eca 100755 --- a/scripts/train_hydra.py +++ b/scripts/train_hydra.py @@ -1,231 +1,39 @@ """Example model training script.""" -import os - import hydra -import lightning.pytorch as pl from omegaconf import DictConfig -from lightning_pose.utils import pretty_print_cfg, pretty_print_str -from lightning_pose.utils.io import ( - check_video_paths, - return_absolute_data_paths, - return_absolute_path, -) -from lightning_pose.utils.predictions import predict_dataset -from lightning_pose.utils.scripts import ( - calculate_train_batches, - compute_metrics, - export_predictions_and_labeled_video, - get_callbacks, - get_data_module, - get_dataset, - get_imgaug_transform, - get_loss_factories, - get_model, -) +from lightning_pose.train import train @hydra.main(config_path="configs", config_name="config_mirror-mouse-example") -def train(cfg: DictConfig): - """Main fitting function, accessed from command line.""" - - print("Our Hydra config file:") - pretty_print_cfg(cfg) - - # path handling for toy data - data_dir, video_dir = return_absolute_data_paths(data_cfg=cfg.data) - - # ---------------------------------------------------------------------------------- - # Set up data/model objects - # ---------------------------------------------------------------------------------- - - # imgaug transform - imgaug_transform = get_imgaug_transform(cfg=cfg) - - # dataset - dataset = get_dataset(cfg=cfg, data_dir=data_dir, imgaug_transform=imgaug_transform) - - # datamodule; breaks up dataset into train/val/test - data_module = get_data_module(cfg=cfg, dataset=dataset, video_dir=video_dir) - - # build loss factory which orchestrates different losses - loss_factories = get_loss_factories(cfg=cfg, data_module=data_module) - - # model - model = get_model(cfg=cfg, data_module=data_module, loss_factories=loss_factories) - - # ---------------------------------------------------------------------------------- - # Set up and run training - # ---------------------------------------------------------------------------------- - - # logger - logger = pl.loggers.TensorBoardLogger("tb_logs", name=cfg.model.model_name) - - # early stopping, learning rate monitoring, model checkpointing, backbone unfreezing - callbacks = get_callbacks(cfg, early_stopping=False) +def train_model(cfg: DictConfig): + """Main fitting function, accessed from command line. - # calculate number of batches for both labeled and unlabeled data per epoch - limit_train_batches = calculate_train_batches(cfg, dataset) + To train a model on the example dataset provided with the Lightning Pose package with this + script, run the following command from inside the lightning-pose directory + (make sure you have activated your conda environment): - # set up trainer - trainer = pl.Trainer( # TODO: be careful with devices when scaling to multiple gpus - accelerator="gpu", # TODO: control from outside - devices=1, # TODO: control from outside - max_epochs=cfg.training.max_epochs, - min_epochs=cfg.training.min_epochs, - check_val_every_n_epoch=cfg.training.check_val_every_n_epoch, - log_every_n_steps=cfg.training.log_every_n_steps, - callbacks=callbacks, - logger=logger, - limit_train_batches=limit_train_batches, - accumulate_grad_batches=cfg.training.get("accumulate_grad_batches", 1), - profiler=cfg.training.get("profiler", None), - ) + ``` + python scripts/train_hydra.py + ``` - # train model! - trainer.fit(model=model, datamodule=data_module) + Note there are no arguments - this tells the script to default to the example data. - # ---------------------------------------------------------------------------------- - # Post-training analysis - # ---------------------------------------------------------------------------------- - hydra_output_directory = os.getcwd() - print("Hydra output directory: {}".format(hydra_output_directory)) - # get best ckpt - best_ckpt = os.path.abspath(trainer.checkpoint_callback.best_model_path) - # check if best_ckpt is a file - if not os.path.isfile(best_ckpt): - raise FileNotFoundError("Cannot find checkpoint. Have you trained for too few epochs?") - # make unaugmented data_loader if necessary - if cfg.training.imgaug != "default": - cfg_pred = cfg.copy() - cfg_pred.training.imgaug = "default" - imgaug_transform_pred = get_imgaug_transform(cfg=cfg_pred) - dataset_pred = get_dataset( - cfg=cfg_pred, data_dir=data_dir, imgaug_transform=imgaug_transform_pred - ) - data_module_pred = get_data_module(cfg=cfg_pred, dataset=dataset_pred, video_dir=video_dir) - data_module_pred.setup() - else: - data_module_pred = data_module + To train a model on your own dataset, overwrite the default config_path and config_name args: - # ---------------------------------------------------------------------------------- - # predict on all labeled frames (train/val/test) - # ---------------------------------------------------------------------------------- - pretty_print_str("Predicting train/val/test images...") - # compute and save frame-wise predictions - preds_file = os.path.join(hydra_output_directory, "predictions.csv") - predict_dataset( - cfg=cfg, - trainer=trainer, - model=model, - data_module=data_module_pred, - ckpt_file=best_ckpt, - preds_file=preds_file, - ) - # compute and save various metrics - try: - compute_metrics(cfg=cfg, preds_file=preds_file, data_module=data_module_pred) - except Exception as e: - print(f"Error computing metrics\n{e}") + ``` + python scripts/train_hydra.py --config-path= --config-name= # noqa + ``` - # ---------------------------------------------------------------------------------- - # predict folder of videos - # ---------------------------------------------------------------------------------- - if cfg.eval.predict_vids_after_training: - pretty_print_str("Predicting videos...") - if cfg.eval.test_videos_directory is None: - filenames = [] - else: - filenames = check_video_paths( - return_absolute_path(cfg.eval.test_videos_directory) - ) - vidstr = "video" if (len(filenames) == 1) else "videos" - pretty_print_str( - f"Found {len(filenames)} {vidstr} to predict on (in cfg.eval.test_videos_directory)" - ) + For more information on training models, see the docs at + https://lightning-pose.readthedocs.io/en/latest/source/user_guide/training.html - for video_file in filenames: - assert os.path.isfile(video_file) - pretty_print_str(f"Predicting video: {video_file}...") - # get save name for prediction csv file - video_pred_dir = os.path.join(hydra_output_directory, "video_preds") - video_pred_name = os.path.splitext(os.path.basename(video_file))[0] - prediction_csv_file = os.path.join(video_pred_dir, video_pred_name + ".csv") - # get save name labeled video csv - if cfg.eval.save_vids_after_training: - labeled_vid_dir = os.path.join(video_pred_dir, "labeled_videos") - labeled_mp4_file = os.path.join( - labeled_vid_dir, video_pred_name + "_labeled.mp4" - ) - else: - labeled_mp4_file = None - # predict on video - export_predictions_and_labeled_video( - video_file=video_file, - cfg=cfg, - ckpt_file=best_ckpt, - prediction_csv_file=prediction_csv_file, - labeled_mp4_file=labeled_mp4_file, - trainer=trainer, - model=model, - data_module=data_module_pred, - save_heatmaps=cfg.eval.get( - "predict_vids_after_training_save_heatmaps", False - ), - ) - # compute and save various metrics - try: - compute_metrics( - cfg=cfg, - preds_file=prediction_csv_file, - data_module=data_module_pred, - ) - except Exception as e: - print(f"Error predicting on video {video_file}:\n{e}") - continue + """ - # ---------------------------------------------------------------------------------- - # predict on OOD frames - # ---------------------------------------------------------------------------------- - # update config file to point to OOD data - csv_file_ood = os.path.join(cfg.data.data_dir, cfg.data.csv_file).replace( - ".csv", "_new.csv" - ) - if os.path.exists(csv_file_ood): - cfg_ood = cfg.copy() - cfg_ood.data.csv_file = csv_file_ood - cfg_ood.training.imgaug = "default" - cfg_ood.training.train_prob = 1 - cfg_ood.training.val_prob = 0 - cfg_ood.training.train_frames = 1 - # build dataset/datamodule - imgaug_transform_ood = get_imgaug_transform(cfg=cfg_ood) - dataset_ood = get_dataset( - cfg=cfg_ood, data_dir=data_dir, imgaug_transform=imgaug_transform_ood - ) - data_module_ood = get_data_module(cfg=cfg_ood, dataset=dataset_ood, video_dir=video_dir) - data_module_ood.setup() - pretty_print_str("Predicting OOD images...") - # compute and save frame-wise predictions - preds_file_ood = os.path.join(hydra_output_directory, "predictions_new.csv") - predict_dataset( - cfg=cfg_ood, - trainer=trainer, - model=model, - data_module=data_module_ood, - ckpt_file=best_ckpt, - preds_file=preds_file_ood, - ) - # compute and save various metrics - try: - compute_metrics( - cfg=cfg_ood, preds_file=preds_file_ood, data_module=data_module_ood - ) - except Exception as e: - print(f"Error computing metrics\n{e}") + train(cfg) if __name__ == "__main__": - train() + train_model() diff --git a/setup.py b/setup.py index 1abe2bc1..9531a4cc 100644 --- a/setup.py +++ b/setup.py @@ -1,15 +1,23 @@ - import re -import os import subprocess +from pathlib import Path from setuptools import find_packages, setup -VERSION = "1.2.3" -# add the README.md file to the long_description -with open("README.md", "r") as fh: - long_description = fh.read() +def read(rel_path): + here = Path(__file__).parent.absolute() + with open(here.joinpath(rel_path), "r") as fp: + return fp.read() + + +def get_version(rel_path): + for line in read(rel_path).splitlines(): + if line.startswith("__version__"): + delim = '"' if '"' in line else "'" + return line.split(delim)[1] + else: + raise RuntimeError("Unable to find version string.") def get_cuda_version(): @@ -28,7 +36,6 @@ def get_cuda_version(): cuda_version = get_cuda_version() - if cuda_version is not None: if 11.0 <= cuda_version < 12.0: dali = "nvidia-dali-cuda110" @@ -44,6 +51,10 @@ def get_cuda_version(): print(f"Found CUDA version: {cuda_version}, using DALI: {dali}") +# add the README.md file to the long_description +with open("README.md", "r") as fh: + long_description = fh.read() + # basic requirements install_requires = [ "fiftyone", @@ -94,7 +105,7 @@ def get_cuda_version(): setup( name="lightning-pose", packages=find_packages() + ["mirror_mouse_example"], # include data for wheel packaging - version=VERSION, + version=get_version(Path("lightning_pose").joinpath("__init__.py")), description="Semi-supervised pose estimation using pytorch lightning", long_description=long_description, long_description_content_type="text/markdown", diff --git a/tests/test_train.py b/tests/test_train.py new file mode 100644 index 00000000..77ad1640 --- /dev/null +++ b/tests/test_train.py @@ -0,0 +1,55 @@ +import copy +import os + + +def test_train(cfg, tmpdir): + + from lightning_pose.train import train + + pwd = os.getcwd() + + # copy config and update paths + cfg_tmp = copy.deepcopy(cfg) + cfg_tmp.data.data_dir = os.path.join(pwd, cfg_tmp.data.data_dir) + cfg_tmp.data.video_dir = os.path.join(cfg_tmp.data.data_dir, "videos") + cfg_tmp.eval.test_videos_directory = cfg_tmp.data.video_dir + + # don't train for long + cfg_tmp.training.min_epochs = 2 + cfg_tmp.training.max_epochs = 2 + cfg_tmp.training.check_val_every_n_epoch = 1 + cfg_tmp.training.log_every_n_steps = 1 + cfg_tmp.training.limit_train_batches = 2 + + # train simple model + cfg_tmp.model.model_type = "heatmap" + cfg_tmp.model.losses_to_use = [] + + # predict on vid + cfg_tmp.eval.predict_vids_after_training = True + cfg_tmp.eval.save_vids_after_training = True + + # change directory to save outputs elsewhere + os.chdir(tmpdir) + + # train model + train(cfg_tmp) + + # change directory back + os.chdir(pwd) + + # ensure labeled data was properly processed + assert os.path.isfile(os.path.join(tmpdir, "config.yaml")) + assert os.path.isfile(os.path.join(tmpdir, "predictions.csv")) + assert os.path.isfile(os.path.join(tmpdir, "predictions_pca_multiview_error.csv")) + assert os.path.isfile(os.path.join(tmpdir, "predictions_pca_singleview_error.csv")) + assert os.path.isfile(os.path.join(tmpdir, "predictions_pixel_error.csv")) + + # ensure video data was properly processed + assert os.path.isfile(os.path.join(tmpdir, "video_preds", "test_vid.csv")) + assert os.path.isfile(os.path.join(tmpdir, "video_preds", "test_vid_pca_multiview_error.csv")) + assert os.path.isfile(os.path.join(tmpdir, "video_preds", "test_vid_pca_singleview_error.csv")) + assert os.path.isfile(os.path.join(tmpdir, "video_preds", "test_vid_temporal_norm.csv")) + assert os.path.isfile(os.path.join( + tmpdir, "video_preds", "labeled_videos", "test_vid_labeled.mp4", + ))