diff --git a/REPRODUCING.md b/REPRODUCING.md index 7b856ec..b9199db 100644 --- a/REPRODUCING.md +++ b/REPRODUCING.md @@ -9,31 +9,31 @@ All tasks require pretraining. ### Bottle ``` -python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/bottle training.dataset.pretraining_data_path=/data/ndf_original/data data_root=/data/ndf +python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/bottle data_root=/data ``` ### Bowl ``` -python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/bowl training.dataset.pretraining_data_path=/data/ndf_original/data data_root=/data/ndf +python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/bowl data_root=/data ``` ### Gripper ``` -python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/gripper training.dataset.pretraining_data_path=/data/ndf_original/data data_root=/data/ndf +python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/gripper data_root=/data ``` ### Mug ``` -python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/mug training.dataset.pretraining_data_path=/data/ndf_original/data data_root=/data/ndf +python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/mug data_root=/data ``` ### Rack ``` -python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/rack training.dataset.pretraining_data_path=/data/ndf_original/data data_root=/data/ndf +python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/rack data_root=/data ``` ### Slab @@ -41,7 +41,7 @@ python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/rack Note: this one appears broken. ``` -python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/slab training.dataset.pretraining_data_path=/data/ndf_original/data data_root=/data/ndf +python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/slab data_root=/data ``` ## Table 1 diff --git a/scripts/pretrain_embedding.py b/scripts/pretrain_embedding.py index 7d29245..9910705 100644 --- a/scripts/pretrain_embedding.py +++ b/scripts/pretrain_embedding.py @@ -16,8 +16,6 @@ EquivariancePreTrainingModule, ) -# chuerp conda env: pytorch3d_38 - @hydra.main(version_base="1.1", config_path="../configs", config_name="pretraining") def main(cfg): @@ -28,7 +26,6 @@ def main(cfg): indent=4, ) ) - # breakpoint() ###################################################################### # Torch settings. @@ -60,14 +57,9 @@ def main(cfg): logger=logger if TRAINING else None, accelerator="gpu", devices=[0], - # reload_dataloaders_every_n_epochs=1, - # val_check_interval=0.2, - # val_check_interval=10, log_every_n_steps=cfg.training.log_every_n_steps, check_val_every_n_epoch=cfg.training.check_val_every_n_epoch, - # limit_train_batches=10, max_epochs=cfg.training.epochs, - # callbacks=[SaverCallbackEmbnn()], callbacks=( [ # This checkpoint callback saves the latest model during training, i.e. so we can resume if it crashes. @@ -100,18 +92,8 @@ def main(cfg): cfg=cfg.dataset, batch_size=cfg.training.batch_size, num_workers=cfg.resources.num_workers, - # batch_size=cfg.training.batch_size, - # num_workers=cfg.resources.num_workers, - # cloud_class=cfg.training.dataset.cloud_class, - # dataset_root=cfg.training.dataset.root, - # dataset_index=cfg.training.dataset.dataset_index, - # cloud_type=cfg.training.dataset.cloud_type, - # # overfit=cfg.overfit, - # pretraining_data_path=cfg.training.dataset.pretraining_data_path, - # obj_class=cfg.training.dataset.obj_class, ) - # dm.setup() network = EquivariantFeatureEmbeddingNetwork( emb_dims=cfg.emb_dims, emb_nn=cfg.emb_nn ) @@ -124,17 +106,12 @@ def main(cfg): temperature=cfg.temperature, con_weighting=cfg.con_weighting, ) - # model.cuda() - # model.train() - # logger.watch(network) - # if cfg.checkpoint_file is not None: - # model.load_state_dict(torch.load(cfg.checkpoint_file)["state_dict"]) trainer.fit(model, dm) if __name__ == "__main__": # torch.autograd.set_detect_anomaly(True) # torch.cuda.empty_cache() - # torch.multiprocessing.set_sharing_strategy("file_system") + torch.multiprocessing.set_sharing_strategy("file_system") main()