From 9190a8f5ac6a5245a455356e187ac244871c90b3 Mon Sep 17 00:00:00 2001 From: Ben Eisner Date: Tue, 30 Apr 2024 18:20:51 -0400 Subject: [PATCH] add training tests --- CONTRIBUTING.md | 15 ++ .../ndf/bottle/train_bottle_grasp.yaml | 8 - .../ndf/bottle/train_bottle_place.yaml | 8 - configs/commands/ndf/bottle/train_grasp.yaml | 1 + configs/commands/ndf/bottle/train_place.yaml | 1 + .../commands/ndf/bowl/train_bowl_grasp.yaml | 8 - .../commands/ndf/bowl/train_bowl_place.yaml | 8 - configs/commands/ndf/bowl/train_grasp.yaml | 1 + configs/commands/ndf/bowl/train_place.yaml | 1 + configs/commands/ndf/mug/train_mug_grasp.yaml | 8 - configs/commands/ndf/mug/train_mug_place.yaml | 8 - configs/task/bottle_grasp.yaml | 6 +- configs/task/bottle_place.yaml | 6 +- configs/task/bowl_grasp.yaml | 6 +- configs/task/bowl_place.yaml | 6 +- configs/task/mug_grasp.yaml | 6 +- configs/task/mug_place.yaml | 6 +- configs/train_mug_residual_ablation.yaml | 6 +- pyproject.toml | 2 +- scripts/train_residual_flow.py | 12 +- scripts/train_residual_flow_ablation.py | 13 +- taxpose/nets/transformer_flow.py | 9 +- tests/train_test.py | 137 ++++++++++++++++++ 23 files changed, 213 insertions(+), 69 deletions(-) delete mode 100644 configs/commands/ndf/bottle/train_bottle_grasp.yaml delete mode 100644 configs/commands/ndf/bottle/train_bottle_place.yaml delete mode 100644 configs/commands/ndf/bowl/train_bowl_grasp.yaml delete mode 100644 configs/commands/ndf/bowl/train_bowl_place.yaml delete mode 100644 configs/commands/ndf/mug/train_mug_grasp.yaml delete mode 100644 configs/commands/ndf/mug/train_mug_place.yaml create mode 100644 tests/train_test.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4df8057..e1e4871 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -35,3 +35,18 @@ Run act ``` act -j develop ``` + + +## Testing + +To run the tests: + +``` +pytest +``` + +To run all the tests, including long ones + +``` +pytest -m "long" +``` diff --git a/configs/commands/ndf/bottle/train_bottle_grasp.yaml b/configs/commands/ndf/bottle/train_bottle_grasp.yaml deleted file mode 100644 index 9c190af..0000000 --- a/configs/commands/ndf/bottle/train_bottle_grasp.yaml +++ /dev/null @@ -1,8 +0,0 @@ -# @package _global_ - -defaults: -- _self_ -- /train_ndf -- override /task: bottle_grasp - -mode: train diff --git a/configs/commands/ndf/bottle/train_bottle_place.yaml b/configs/commands/ndf/bottle/train_bottle_place.yaml deleted file mode 100644 index 85d5ddb..0000000 --- a/configs/commands/ndf/bottle/train_bottle_place.yaml +++ /dev/null @@ -1,8 +0,0 @@ -# @package _global_ - -defaults: -- _self_ -- /train_ndf -- override /task: bottle_place - -mode: train diff --git a/configs/commands/ndf/bottle/train_grasp.yaml b/configs/commands/ndf/bottle/train_grasp.yaml index 9c190af..2e04db7 100644 --- a/configs/commands/ndf/bottle/train_grasp.yaml +++ b/configs/commands/ndf/bottle/train_grasp.yaml @@ -3,6 +3,7 @@ defaults: - _self_ - /train_ndf +- override /model: taxpose - override /task: bottle_grasp mode: train diff --git a/configs/commands/ndf/bottle/train_place.yaml b/configs/commands/ndf/bottle/train_place.yaml index 85d5ddb..764667c 100644 --- a/configs/commands/ndf/bottle/train_place.yaml +++ b/configs/commands/ndf/bottle/train_place.yaml @@ -3,6 +3,7 @@ defaults: - _self_ - /train_ndf +- override /model: taxpose - override /task: bottle_place mode: train diff --git a/configs/commands/ndf/bowl/train_bowl_grasp.yaml b/configs/commands/ndf/bowl/train_bowl_grasp.yaml deleted file mode 100644 index 18f9448..0000000 --- a/configs/commands/ndf/bowl/train_bowl_grasp.yaml +++ /dev/null @@ -1,8 +0,0 @@ -# @package _global_ - -defaults: -- _self_ -- /train_ndf -- override /task: bowl_grasp - -mode: train diff --git a/configs/commands/ndf/bowl/train_bowl_place.yaml b/configs/commands/ndf/bowl/train_bowl_place.yaml deleted file mode 100644 index b51185c..0000000 --- a/configs/commands/ndf/bowl/train_bowl_place.yaml +++ /dev/null @@ -1,8 +0,0 @@ -# @package _global_ - -defaults: -- _self_ -- /train_ndf -- override /task: bowl_place - -mode: train diff --git a/configs/commands/ndf/bowl/train_grasp.yaml b/configs/commands/ndf/bowl/train_grasp.yaml index 18f9448..e3b9422 100644 --- a/configs/commands/ndf/bowl/train_grasp.yaml +++ b/configs/commands/ndf/bowl/train_grasp.yaml @@ -3,6 +3,7 @@ defaults: - _self_ - /train_ndf +- override /model: taxpose - override /task: bowl_grasp mode: train diff --git a/configs/commands/ndf/bowl/train_place.yaml b/configs/commands/ndf/bowl/train_place.yaml index b51185c..7c7649b 100644 --- a/configs/commands/ndf/bowl/train_place.yaml +++ b/configs/commands/ndf/bowl/train_place.yaml @@ -3,6 +3,7 @@ defaults: - _self_ - /train_ndf +- override /model: taxpose - override /task: bowl_place mode: train diff --git a/configs/commands/ndf/mug/train_mug_grasp.yaml b/configs/commands/ndf/mug/train_mug_grasp.yaml deleted file mode 100644 index 4586d50..0000000 --- a/configs/commands/ndf/mug/train_mug_grasp.yaml +++ /dev/null @@ -1,8 +0,0 @@ -# @package _global_ - -defaults: -- _self_ -- /train_ndf -- override /task: mug_grasp - -mode: train diff --git a/configs/commands/ndf/mug/train_mug_place.yaml b/configs/commands/ndf/mug/train_mug_place.yaml deleted file mode 100644 index b4e34e5..0000000 --- a/configs/commands/ndf/mug/train_mug_place.yaml +++ /dev/null @@ -1,8 +0,0 @@ -# @package _global_ - -defaults: -- _self_ -- /train_ndf -- override /task: mug_place - -mode: train diff --git a/configs/task/bottle_grasp.yaml b/configs/task/bottle_grasp.yaml index 08c45cb..aaf7c09 100644 --- a/configs/task/bottle_grasp.yaml +++ b/configs/task/bottle_grasp.yaml @@ -7,5 +7,7 @@ anchor_class: 0 cloud_type: pre_grasp softmax_temperature: 0.1 weight_normalize: softmax -checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_gripper_embnn_weights.ckpt -checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_bottle_embnn_weights.ckpt +# checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_gripper_embnn_weights.ckpt +# checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_bottle_embnn_weights.ckpt +checkpoint_file_action: null +checkpoint_file_anchor: null diff --git a/configs/task/bottle_place.yaml b/configs/task/bottle_place.yaml index 94698cb..040547e 100644 --- a/configs/task/bottle_place.yaml +++ b/configs/task/bottle_place.yaml @@ -7,5 +7,7 @@ anchor_class: 1 cloud_type: teleport softmax_temperature: 1 weight_normalize: l1 -checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_bottle_embnn_weights.ckpt -checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_slab_embnn_weights.ckpt +# checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_bottle_embnn_weights.ckpt +# checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_slab_embnn_weights.ckpt +checkpoint_file_action: null +checkpoint_file_anchor: null diff --git a/configs/task/bowl_grasp.yaml b/configs/task/bowl_grasp.yaml index 352a451..7f4a159 100644 --- a/configs/task/bowl_grasp.yaml +++ b/configs/task/bowl_grasp.yaml @@ -7,5 +7,7 @@ anchor_class: 0 cloud_type: pre_grasp softmax_temperature: 0.1 weight_normalize: softmax -checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_gripper_embnn_weights.ckpt -checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_bowl_embnn_weights.ckpt +# checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_gripper_embnn_weights.ckpt +# checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_bowl_embnn_weights.ckpt +checkpoint_file_action: null +checkpoint_file_anchor: null diff --git a/configs/task/bowl_place.yaml b/configs/task/bowl_place.yaml index 64a13c0..5ab5162 100644 --- a/configs/task/bowl_place.yaml +++ b/configs/task/bowl_place.yaml @@ -7,5 +7,7 @@ anchor_name: slab cloud_type: teleport softmax_temperature: 1 weight_normalize: l1 -checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_bowl_embnn_weights.ckpt -checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_slab_embnn_weights.ckpt +# checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_bowl_embnn_weights.ckpt +# checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_slab_embnn_weights.ckpt +checkpoint_file_action: null +checkpoint_file_anchor: null diff --git a/configs/task/mug_grasp.yaml b/configs/task/mug_grasp.yaml index 495f575..5a7a847 100644 --- a/configs/task/mug_grasp.yaml +++ b/configs/task/mug_grasp.yaml @@ -7,5 +7,7 @@ anchor_class: 0 cloud_type: pre_grasp softmax_temperature: 0.1 weight_normalize: softmax -checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_gripper_embnn_weights.ckpt -checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt +# checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_gripper_embnn_weights.ckpt +# checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt +checkpoint_file_action: null +checkpoint_file_anchor: null diff --git a/configs/task/mug_place.yaml b/configs/task/mug_place.yaml index ad42edf..fc4a550 100644 --- a/configs/task/mug_place.yaml +++ b/configs/task/mug_place.yaml @@ -7,5 +7,7 @@ softmax_temperature: 1 weight_normalize: l1 action_name: mug anchor_name: rack -checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt -checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_rack_embnn_weights.ckpt +# checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt +# checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_rack_embnn_weights.ckpt +checkpoint_file_action: null +checkpoint_file_anchor: null diff --git a/configs/train_mug_residual_ablation.yaml b/configs/train_mug_residual_ablation.yaml index 5c52290..e67c84c 100644 --- a/configs/train_mug_residual_ablation.yaml +++ b/configs/train_mug_residual_ablation.yaml @@ -65,7 +65,9 @@ flow_supervision: both # Training Settings checkpoint_file: Null -checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt -checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_rack_embnn_weights.ckpt +# checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt +# checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_rack_embnn_weights.ckpt +checkpoint_file_action: null +checkpoint_file_anchor: null lr: 1e-4 max_epochs: 1000 diff --git a/pyproject.toml b/pyproject.toml index b51de55..b377bcb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,7 +101,7 @@ module = [ ignore_missing_imports = true [tool.pytest.ini_options] -addopts = "--ignore=third_party/" +addopts = "--ignore=third_party/ -m 'not long'" [tool.pylint] disable = [ diff --git a/scripts/train_residual_flow.py b/scripts/train_residual_flow.py index c236ef4..8cbceb4 100644 --- a/scripts/train_residual_flow.py +++ b/scripts/train_residual_flow.py @@ -26,6 +26,8 @@ def write_to_file(file_name, string): def main(cfg): print(OmegaConf.to_yaml(cfg, resolve=True)) + TESTING = "PYTEST_CURRENT_TEST" in os.environ + # breakpoint() # torch.set_float32_matmul_precision("medium") pl.seed_everything(cfg.seed) @@ -33,12 +35,18 @@ def main(cfg): logger.log_hyperparams(cfg) logger.log_hyperparams({"working_dir": os.getcwd()}) trainer = pl.Trainer( - logger=logger, + logger=logger if not TESTING else False, accelerator="gpu", devices=[0], reload_dataloaders_every_n_epochs=1, - callbacks=[SaverCallbackModel(), SaverCallbackEmbnnActionAnchor()], + callbacks=( + [SaverCallbackModel(), SaverCallbackEmbnnActionAnchor()] + if not TESTING + else [] + ), max_epochs=cfg.max_epochs, + # Check if PYTEST is running, and run for 5 steps if it is. + fast_dev_run=5 if "PYTEST_CURRENT_TEST" in os.environ else False, ) log_txt_file = cfg.log_txt_file if cfg.mode == "train": diff --git a/scripts/train_residual_flow_ablation.py b/scripts/train_residual_flow_ablation.py index 56012ee..8b5f016 100644 --- a/scripts/train_residual_flow_ablation.py +++ b/scripts/train_residual_flow_ablation.py @@ -27,16 +27,25 @@ def write_to_file(file_name, string): @hydra.main(config_path="../configs", config_name="train_mug_residual_ablation") def main(cfg): pl.seed_everything(cfg.seed) + + TESTING = "PYTEST_CURRENT_TEST" in os.environ + logger = WandbLogger(project=cfg.experiment) logger.log_hyperparams(cfg) logger.log_hyperparams({"working_dir": os.getcwd()}) trainer = pl.Trainer( - logger=logger, + logger=logger if not TESTING else False, accelerator="gpu", devices=[0], reload_dataloaders_every_n_epochs=1, - callbacks=[SaverCallbackModel(), SaverCallbackEmbnnActionAnchor()], + callbacks=( + [SaverCallbackModel(), SaverCallbackEmbnnActionAnchor()] + if not TESTING + else [] + ), max_epochs=cfg.max_epochs, + # Check if PYTEST is running, and run for 5 steps if it is. + fast_dev_run=5 if "PYTEST_CURRENT_TEST" in os.environ else False, ) log_txt_file = cfg.log_txt_file diff --git a/taxpose/nets/transformer_flow.py b/taxpose/nets/transformer_flow.py index 5f1fee8..21504a9 100644 --- a/taxpose/nets/transformer_flow.py +++ b/taxpose/nets/transformer_flow.py @@ -82,6 +82,10 @@ def forward(self, *input): scores=None, ).permute(0, 2, 1) + outputs = { + "flow_action": flow_action, + } + if self.cycle: flow_anchor = self.head_anchor( anchor_embedding_tf, @@ -90,8 +94,9 @@ def forward(self, *input): action_points, scores=None, ).permute(0, 2, 1) - return flow_action, flow_anchor - return flow_action + outputs["flow_anchor"] = flow_anchor + + return outputs class CorrespondenceMLPHead(nn.Module): diff --git a/tests/train_test.py b/tests/train_test.py new file mode 100644 index 0000000..665663f --- /dev/null +++ b/tests/train_test.py @@ -0,0 +1,137 @@ +# Much around with the path to make the import work +import os +import sys +from pathlib import Path + +import pytest +import torch +from hydra import compose, initialize +from hydra.core.hydra_config import HydraConfig + +# Add the parent directory to the path to import the script. Hacky, but it works. +THIS_DIR = Path(__file__).resolve().parent +sys.path.append(str(THIS_DIR.parent)) + +from scripts.train_residual_flow import main +from scripts.train_residual_flow_ablation import main as main_ablation + + +def _get_training_config_names(bmark, ablation=False): + # Get config paths from the configs/commands directory, relative to the commands directory. + configs = [path for path in Path(f"configs/commands/{bmark}").rglob("*.yaml")] + + # Strip the "configs/" prefix. + configs = [str(path)[8:] for path in configs] + + # Filter out paths with basenames that have a leading underscore. + configs = [config for config in configs if not Path(config).name.startswith("_")] + + # Filter out paths that don't include the word "train" in the path. + configs = [config for config in configs if "train_" in config] + + if ablation: + # Filter out paths that don't include the word "ablation" in the path except for + configs = [ + config + for config in configs + if "ablation" in config and "n_demo" not in config + ] + else: + # Filter out paths that include the word "ablation" in the path. + configs = [ + config + for config in configs + if "ablation" not in config or "n_demo" in config + ] + + # Filter out paths with any folder that have a leading underscore. + configs = [ + config + for config in configs + if not any(folder.startswith("_") for folder in Path(config).parts) + ] + + return configs + + +DEFAULT_NDF_PATH = "/data/ndf" + + +# Skip this if the environment variable is not set or the path does not exist. +@pytest.mark.training +@pytest.mark.skipif( + ("NDF_DATASET_ROOT" not in os.environ or not os.path.exists(DEFAULT_NDF_PATH)) + and not torch.cuda.is_available(), + reason="NDF_DATASET_ROOT environment variable is not set or the path does not exist.", +) +@pytest.mark.parametrize("config_name", _get_training_config_names("ndf")) +def test_training_commands_run(config_name): + dataset_root = ( + os.environ["NDF_DATASET_ROOT"] + if "NDF_DATASET_ROOT" in os.environ + else DEFAULT_NDF_PATH + ) + + torch.multiprocessing.set_sharing_strategy("file_system") + + with initialize(version_base=None, config_path="../configs"): + cfg = compose( + config_name=config_name, + overrides=[ + "hydra.verbose=true", + "hydra.job.num=1", + "hydra.runtime.output_dir=.", + "seed=1234", + f"dataset_root={dataset_root}", + "batch_size=2", + ], + return_hydra_config=True, + ) + # Resolve the config + HydraConfig.instance().set_config(cfg) + + # Just for this function call, set the environment variable to WANDB_MODE=disabled + os.environ["WANDB_MODE"] = "disabled" + # Run the training script. + main(cfg) + + +# Do the same for the ablation configs. +@pytest.mark.ablations +@pytest.mark.skipif( + ("NDF_DATASET_ROOT" not in os.environ or not os.path.exists(DEFAULT_NDF_PATH)) + and not torch.cuda.is_available(), + reason="NDF_DATASET_ROOT environment variable is not set or the path does not exist.", +) +@pytest.mark.parametrize( + "config_name", _get_training_config_names("ndf", ablation=True) +) +def test_training_ablation_commands_run(config_name): + dataset_root = ( + os.environ["NDF_DATASET_ROOT"] + if "NDF_DATASET_ROOT" in os.environ + else DEFAULT_NDF_PATH + ) + + torch.multiprocessing.set_sharing_strategy("file_system") + + with initialize(version_base=None, config_path="../configs"): + cfg = compose( + config_name=config_name, + overrides=[ + "hydra.verbose=true", + "hydra.job.num=1", + "hydra.runtime.output_dir=.", + "seed=1234", + f"dataset_root={dataset_root}", + "batch_size=2", + ], + return_hydra_config=True, + ) + # Resolve the config + HydraConfig.instance().set_config(cfg) + + # Just for this function call, set the environment variable to WANDB_MODE=disabled + os.environ["WANDB_MODE"] = "disabled" + # Run the training script. + main_ablation(cfg)