Skip to content

Commit

Permalink
changes to reproducing
Browse files Browse the repository at this point in the history
  • Loading branch information
beneisner committed May 2, 2024
1 parent 11a37c5 commit 76224e1
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 30 deletions.
12 changes: 6 additions & 6 deletions REPRODUCING.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,39 +9,39 @@ All tasks require pretraining.
### Bottle

```
python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/bottle training.dataset.pretraining_data_path=/data/ndf_original/data data_root=/data/ndf
python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/bottle data_root=/data
```

### Bowl

```
python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/bowl training.dataset.pretraining_data_path=/data/ndf_original/data data_root=/data/ndf
python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/bowl data_root=/data
```

### Gripper

```
python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/gripper training.dataset.pretraining_data_path=/data/ndf_original/data data_root=/data/ndf
python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/gripper data_root=/data
```

### Mug

```
python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/mug training.dataset.pretraining_data_path=/data/ndf_original/data data_root=/data/ndf
python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/mug data_root=/data
```

### Rack

```
python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/rack training.dataset.pretraining_data_path=/data/ndf_original/data data_root=/data/ndf
python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/rack data_root=/data
```

### Slab

Note: this one appears broken.

```
python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/slab training.dataset.pretraining_data_path=/data/ndf_original/data data_root=/data/ndf
python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/slab data_root=/data
```

## Table 1
Expand Down
25 changes: 1 addition & 24 deletions scripts/pretrain_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
EquivariancePreTrainingModule,
)

# chuerp conda env: pytorch3d_38


@hydra.main(version_base="1.1", config_path="../configs", config_name="pretraining")
def main(cfg):
Expand All @@ -28,7 +26,6 @@ def main(cfg):
indent=4,
)
)
# breakpoint()

######################################################################
# Torch settings.
Expand Down Expand Up @@ -60,14 +57,9 @@ def main(cfg):
logger=logger if TRAINING else None,
accelerator="gpu",
devices=[0],
# reload_dataloaders_every_n_epochs=1,
# val_check_interval=0.2,
# val_check_interval=10,
log_every_n_steps=cfg.training.log_every_n_steps,
check_val_every_n_epoch=cfg.training.check_val_every_n_epoch,
# limit_train_batches=10,
max_epochs=cfg.training.epochs,
# callbacks=[SaverCallbackEmbnn()],
callbacks=(
[
# This checkpoint callback saves the latest model during training, i.e. so we can resume if it crashes.
Expand Down Expand Up @@ -100,18 +92,8 @@ def main(cfg):
cfg=cfg.dataset,
batch_size=cfg.training.batch_size,
num_workers=cfg.resources.num_workers,
# batch_size=cfg.training.batch_size,
# num_workers=cfg.resources.num_workers,
# cloud_class=cfg.training.dataset.cloud_class,
# dataset_root=cfg.training.dataset.root,
# dataset_index=cfg.training.dataset.dataset_index,
# cloud_type=cfg.training.dataset.cloud_type,
# # overfit=cfg.overfit,
# pretraining_data_path=cfg.training.dataset.pretraining_data_path,
# obj_class=cfg.training.dataset.obj_class,
)

# dm.setup()
network = EquivariantFeatureEmbeddingNetwork(
emb_dims=cfg.emb_dims, emb_nn=cfg.emb_nn
)
Expand All @@ -124,17 +106,12 @@ def main(cfg):
temperature=cfg.temperature,
con_weighting=cfg.con_weighting,
)
# model.cuda()
# model.train()
# logger.watch(network)

# if cfg.checkpoint_file is not None:
# model.load_state_dict(torch.load(cfg.checkpoint_file)["state_dict"])
trainer.fit(model, dm)


if __name__ == "__main__":
# torch.autograd.set_detect_anomaly(True)
# torch.cuda.empty_cache()
# torch.multiprocessing.set_sharing_strategy("file_system")
torch.multiprocessing.set_sharing_strategy("file_system")
main()

0 comments on commit 76224e1

Please sign in to comment.