diff --git a/.github/workflows/build-site.yaml b/.github/workflows/build-site.yaml index c0866e8..f8bdaa8 100644 --- a/.github/workflows/build-site.yaml +++ b/.github/workflows/build-site.yaml @@ -28,24 +28,8 @@ jobs: - name: Install specific pip. run: pip install pip==23.0.0 - - name: Install CPU version of torch. - run: pip install torch==1.11.0+cpu torchvision==0.12.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu - - - name: Install the torch-geometric. - run: pip install pyg_lib==0.1.0 torch_scatter==2.0.9 torch_sparse==0.6.15 torch_cluster==1.6.0 torch_spline_conv==1.2.1 -f https://data.pyg.org/whl/torch-1.11.0+cpu.html - - - name: Install pytorch3d dependencies. - run: pip install fvcore iopath - - - name: Install pytorch3d - run: pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py39_cu113_pyt1110/download.html - # run: pip install pytorch3d - - - name: Install Dependencies - run: pip install -e ".[build_docs]" - - - name: Install ndf_robot - run: pip install -e third_party/ndf_robot + - name: Install doc requirements. + run: pip install mkdocs-material mkdocstrings[python] - name: Build mkdocs site working-directory: docs diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4df8057..e1e4871 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -35,3 +35,18 @@ Run act ``` act -j develop ``` + + +## Testing + +To run the tests: + +``` +pytest +``` + +To run all the tests, including long ones + +``` +pytest -m "long" +``` diff --git a/REPRODUCING.md b/REPRODUCING.md new file mode 100644 index 0000000..7b856ec --- /dev/null +++ b/REPRODUCING.md @@ -0,0 +1,387 @@ +# Reproducing the paper. + +## High-level TODOS + +## Pretraining + +All tasks require pretraining. + +### Bottle + +``` +python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/bottle training.dataset.pretraining_data_path=/data/ndf_original/data data_root=/data/ndf +``` + +### Bowl + +``` +python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/bowl training.dataset.pretraining_data_path=/data/ndf_original/data data_root=/data/ndf +``` + +### Gripper + +``` +python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/gripper training.dataset.pretraining_data_path=/data/ndf_original/data data_root=/data/ndf +``` + +### Mug + +``` +python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/mug training.dataset.pretraining_data_path=/data/ndf_original/data data_root=/data/ndf +``` + +### Rack + +``` +python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/rack training.dataset.pretraining_data_path=/data/ndf_original/data data_root=/data/ndf +``` + +### Slab + +Note: this one appears broken. + +``` +python scripts/pretrain_embedding.py --config-name commands/ndf/pretraining/slab training.dataset.pretraining_data_path=/data/ndf_original/data data_root=/data/ndf +``` + +## Table 1 + +This table trains the mug on the grasp and place tasks, and evaluates the model on the upright and arbitrary settings of mug-hanging. Reported results are success rates. + +### Train Mug Grasp + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/mug/train_grasp +``` + +### Train Mug Place + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/mug/train_place +``` + +### Evaluate + +Upright: + +``` +python scripts/evaluate_ndf_mug_standalone.py --config-name commands/ndf/mug/eval_ndf_upright checkpoint_file_grasp=??? checkpoint_file_place=??? seed=??? pybullet_viz=False +``` + +Arbitrary: + +``` +python scripts/evaluate_ndf_mug_standalone.py --config-name commands/ndf/mug/eval_ndf_arbitrary checkpoint_file_grasp=??? checkpoint_file_place=??? seed=??? pybullet_viz=False +``` + +## Table 2 + +This table compares sample-efficiency for {1, 5, 10} demos on the mug-hanging task, and evaluates only on upright setting. Reported results are Overall success rates. + +### Train Mug Grasp 1 Demo + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/ablations/n_demos/train_mug_grasp_1 +``` + +### Train Mug Place 1 Demo + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/ablations/n_demos/train_mug_place_1 +``` + +### Train Mug Grasp 5 Demos + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/ablations/n_demos/train_mug_grasp_5 +``` + +### Train Mug Place 5 Demos + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/ablations/n_demos/train_mug_place_5 +``` +### Evaluate + +``` +python scripts/evaluate_ndf_mug_standalone.py --config-name commands/ndf/mug/eval_ndf_upright checkpoint_file_grasp=??? checkpoint_file_place=??? seed=??? pybullet_viz=False +``` + + +## Table 3 + +This table contains ablations. All are trained on 10 demos of mug-hanging, and evaluated in the upright setting. + +### No residual + +#### Train Mug Grasp + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/ablations/4_no_residuals/train_mug_grasp +``` + +#### Train Mug Place + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/ablations/4_no_residuals/train_mug_place +``` + +#### Evaluate + +TODO. + +### Unweighted SVD + +#### Train Mug Grasp + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/ablations/5_unweighted_svd/train_mug_grasp +``` + +#### Train Mug Place + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/ablations/5_unweighted_svd/train_mug_place +``` + +#### Evaluate + +TODO. + +### No Cross-Attention + +#### Train Mug Grasp + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/ablations/8_mlp/train_mug_grasp +``` + +#### Train Mug Place + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/ablations/8_mlp/train_mug_place +``` + +#### Evaluate + +TODO. + +## Table 4 + +TODO: fill in this table. + +## Table 5 - Attention weight ablation + +Mug hanging, upright. + +TODO: Not sure what this was... + +### Train Mug Grasp + +### Train Mug Place + +### Evaluate + +## Table 6 + +### No L_disp + +#### Train Mug Grasp + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/ablations/0_no_disp_loss/train_mug_grasp +``` + +#### Train Mug Place + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/ablations/0_no_disp_loss/train_mug_place +``` + +#### Evaluate + +TODO. + +### No L_corr + +#### Train Mug Grasp + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/ablations/1_no_corr_loss/train_mug_grasp +``` + +#### Train Mug Place + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/ablations/1_no_corr_loss/train_mug_place +``` + +#### Evaluate + +TODO. + +### No L_cons + +#### Train Mug Grasp + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/ablations/2_no_cons_loss/train_mug_grasp +``` + +#### Train Mug Place + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/ablations/2_no_cons_loss/train_mug_place +``` + +#### Evaluate + +TODO + +### Scaled loss combo 1.1 * L_cons + L_corr + +#### Train Mug Grasp + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/ablations/3_no_disp_loss_combined/train_mug_grasp +``` + +#### Train Mug Place + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/ablations/3_no_disp_loss_combined/train_mug_place +``` + +#### Evaluate + +TODO. + +### No correspondence residuals. + +See above. + +### Unweighted SVD + +See above. + +### No finetuning of embedding network + +#### Train Mug Grasp + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/ablations/6_no_finetuning/train_mug_grasp +``` + +#### Train Mug Place + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/ablations/6_no_finetuning/train_mug_place +``` + +#### Evaluate + +TODO. + +### No pretraining of embedding network + +#### Train Mug Grasp + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/ablations/7_no_pretraining/train_mug_grasp +``` + +#### Train Mug Place + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/ablations/7_no_pretraining/train_mug_place +``` + +#### Evaluate + +### 3-layer MLP instead of Transformer + +See above. + +### Embedding network feature dim = 16 + +#### Train Mug Grasp + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/ablations/9_low_dim_embedding/train_mug_grasp +``` + +#### Train Mug Place + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/ablations/9_low_dim_embedding/train_mug_place +``` + +#### Evaluate + +TODO. + +## Table 7 - Pretraining + +TODO: fill in this table. + +## Table 8 - Bottle & Bowl + +### Train Bottle Grasp + +Broken: sampling is not working correctly. + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/bottle/train_grasp +``` + +### Train Bottle Place + +Broken: sampling is not working correctly. + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/bottle/train_place +``` + +### Evaluate Bottle + +### Train Bowl Grasp + +Broken: sampling is not working correctly. + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/bowl/train_grasp +``` + +### Train Bowl Place + +Broken: sampling is not working correctly. + +``` +python scripts/train_residual_flow.py --config-name commands/ndf/bowl/train_place +``` + +### Evaluate Bowl + +## Table 9 - PM Placement + +TODO: fill in this table. + +## Table 10 - PM Placement + +TODO: fill in this table. + +## Table 11 - PM Placement + +TODO: fill in this table. + +## Table 12 - PM Placement + +TODO: fill in this table. + +## Table 13 - PM Placement + +TODO: fill in this table. + +## Table 14 - PM Placement + +TODO: fill in this table. diff --git a/configs/ablation/0_no_disp_loss.yaml b/configs/ablation/0_no_disp_loss.yaml deleted file mode 100644 index b0a7d45..0000000 --- a/configs/ablation/0_no_disp_loss.yaml +++ /dev/null @@ -1,12 +0,0 @@ -name: 0_no_disp_loss -displace_loss_weight: 0 -direct_correspondence_loss_weight: 1 -consistency_loss_weight: 0.1 - -residual_on: True -pred_weight: True -freeze_embnn: False -checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt -checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_rack_embnn_weights.ckpt -mlp: False -emb_dims: 512 diff --git a/configs/ablation/1_no_corr_loss.yaml b/configs/ablation/1_no_corr_loss.yaml deleted file mode 100644 index 8e91631..0000000 --- a/configs/ablation/1_no_corr_loss.yaml +++ /dev/null @@ -1,12 +0,0 @@ -name: 1_no_corr_loss -displace_loss_weight: 1 -direct_correspondence_loss_weight: 0 -consistency_loss_weight: 0.1 - -residual_on: True -pred_weight: True -freeze_embnn: False -checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt -checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_rack_embnn_weights.ckpt -mlp: False -emb_dims: 512 diff --git a/configs/ablation/2_no_cons_loss.yaml b/configs/ablation/2_no_cons_loss.yaml deleted file mode 100644 index 5e9d86b..0000000 --- a/configs/ablation/2_no_cons_loss.yaml +++ /dev/null @@ -1,12 +0,0 @@ -name: 2_no_cons_loss -displace_loss_weight: 1 -direct_correspondence_loss_weight: 1 -consistency_loss_weight: 0 - -residual_on: True -pred_weight: True -freeze_embnn: False -checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt -checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_rack_embnn_weights.ckpt -mlp: False -emb_dims: 512 diff --git a/configs/ablation/3_no_disp_loss_combined.yaml b/configs/ablation/3_no_disp_loss_combined.yaml deleted file mode 100644 index 460b2eb..0000000 --- a/configs/ablation/3_no_disp_loss_combined.yaml +++ /dev/null @@ -1,12 +0,0 @@ -name: 3_no_disp_loss_combined -displace_loss_weight: 0 -direct_correspondence_loss_weight: 1 -consistency_loss_weight: 1.1 - -residual_on: True -pred_weight: True -freeze_embnn: False -checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt -checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_rack_embnn_weights.ckpt -mlp: False -emb_dims: 512 diff --git a/configs/ablation/4_no_residuals.yaml b/configs/ablation/4_no_residuals.yaml deleted file mode 100644 index e629be0..0000000 --- a/configs/ablation/4_no_residuals.yaml +++ /dev/null @@ -1,12 +0,0 @@ -name: 4_no_residuals -residual_on: False - -displace_loss_weight: 1 -direct_correspondence_loss_weight: 1 -consistency_loss_weight: 0.1 -pred_weight: True -freeze_embnn: False -checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt -checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_rack_embnn_weights.ckpt -mlp: False -emb_dims: 512 diff --git a/configs/ablation/5_unweighted_svd.yaml b/configs/ablation/5_unweighted_svd.yaml deleted file mode 100644 index b830f39..0000000 --- a/configs/ablation/5_unweighted_svd.yaml +++ /dev/null @@ -1,12 +0,0 @@ -name: 5_unweighted_svd -pred_weight: False - -displace_loss_weight: 1 -direct_correspondence_loss_weight: 1 -consistency_loss_weight: 0.1 -residual_on: True -freeze_embnn: False -checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt -checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_rack_embnn_weights.ckpt -mlp: False -emb_dims: 512 diff --git a/configs/ablation/6_no_finetuning.yaml b/configs/ablation/6_no_finetuning.yaml deleted file mode 100644 index 04768c8..0000000 --- a/configs/ablation/6_no_finetuning.yaml +++ /dev/null @@ -1,12 +0,0 @@ -name: 6_no_finetuning -freeze_embnn: True - -residual_on: True -displace_loss_weight: 1 -direct_correspondence_loss_weight: 1 -consistency_loss_weight: 0.1 -pred_weight: True -checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt -checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_rack_embnn_weights.ckpt -mlp: False -emb_dims: 512 diff --git a/configs/ablation/7_no_pretraining.yaml b/configs/ablation/7_no_pretraining.yaml deleted file mode 100644 index 3bb6fb9..0000000 --- a/configs/ablation/7_no_pretraining.yaml +++ /dev/null @@ -1,12 +0,0 @@ -name: 7_no_pretraining -checkpoint_file_action: Null -checkpoint_file_anchor: Null - -residual_on: True -displace_loss_weight: 1 -direct_correspondence_loss_weight: 1 -consistency_loss_weight: 0.1 -pred_weight: True -freeze_embnn: False -mlp: False -emb_dims: 512 diff --git a/configs/ablation/8_mlp.yaml b/configs/ablation/8_mlp.yaml deleted file mode 100644 index e09bbad..0000000 --- a/configs/ablation/8_mlp.yaml +++ /dev/null @@ -1,12 +0,0 @@ -name: 8_mlp -mlp: True - -residual_on: True -displace_loss_weight: 1 -direct_correspondence_loss_weight: 1 -consistency_loss_weight: 0.1 -pred_weight: True -freeze_embnn: False -checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt -checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_rack_embnn_weights.ckpt -emb_dims: 512 diff --git a/configs/ablation/9_low_dim_embedding.yaml b/configs/ablation/9_low_dim_embedding.yaml deleted file mode 100644 index 21bdffc..0000000 --- a/configs/ablation/9_low_dim_embedding.yaml +++ /dev/null @@ -1,12 +0,0 @@ -name: 9_low_dim_embedding -emb_dims: 16 - -residual_on: True -displace_loss_weight: 1 -direct_correspondence_loss_weight: 1 -consistency_loss_weight: 0.1 -pred_weight: True -freeze_embnn: False -checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt -checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_rack_embnn_weights.ckpt -mlp: False diff --git a/configs/ablation/taxpose.yaml b/configs/ablation/taxpose.yaml deleted file mode 100644 index f32c40b..0000000 --- a/configs/ablation/taxpose.yaml +++ /dev/null @@ -1,11 +0,0 @@ -name: taxpose -displace_loss_weight: 1 -direct_correspondence_loss_weight: 1 -consistency_loss_weight: 0.1 -residual_on: True -pred_weight: True -freeze_embnn: False -checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt -checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_rack_embnn_weights.ckpt -mlp: False -emb_dims: 512 diff --git a/configs/commands/ndf/ablations/0_no_disp_loss/_ablation.yaml b/configs/commands/ndf/ablations/0_no_disp_loss/_ablation.yaml new file mode 100644 index 0000000..2b50d96 --- /dev/null +++ b/configs/commands/ndf/ablations/0_no_disp_loss/_ablation.yaml @@ -0,0 +1,7 @@ +# @package _global_ + +ablation: + name: 0_no_disp_loss + + +displace_loss_weight: 0 diff --git a/configs/commands/ndf/ablations/0_no_disp_loss/train_mug_grasp.yaml b/configs/commands/ndf/ablations/0_no_disp_loss/train_mug_grasp.yaml new file mode 100644 index 0000000..95bbbe6 --- /dev/null +++ b/configs/commands/ndf/ablations/0_no_disp_loss/train_mug_grasp.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: +- /train_ndf +- _ablation +- override /object_class: mug +- override /relationship: grasp +- _self_ + +mode: train diff --git a/configs/commands/ndf/ablations/0_no_disp_loss/train_mug_place.yaml b/configs/commands/ndf/ablations/0_no_disp_loss/train_mug_place.yaml new file mode 100644 index 0000000..b57fea8 --- /dev/null +++ b/configs/commands/ndf/ablations/0_no_disp_loss/train_mug_place.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: +- /train_ndf +- _ablation +- override /object_class: mug +- override /relationship: place +- _self_ + +mode: train diff --git a/configs/commands/ndf/ablations/1_no_corr_loss/_ablation.yaml b/configs/commands/ndf/ablations/1_no_corr_loss/_ablation.yaml new file mode 100644 index 0000000..e33d81f --- /dev/null +++ b/configs/commands/ndf/ablations/1_no_corr_loss/_ablation.yaml @@ -0,0 +1,6 @@ +# @package _global_ + +ablation: + name: 1_no_corr_loss + +direct_correspondence_loss_weight: 0 diff --git a/configs/commands/ndf/ablations/1_no_corr_loss/train_mug_grasp.yaml b/configs/commands/ndf/ablations/1_no_corr_loss/train_mug_grasp.yaml new file mode 100644 index 0000000..95bbbe6 --- /dev/null +++ b/configs/commands/ndf/ablations/1_no_corr_loss/train_mug_grasp.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: +- /train_ndf +- _ablation +- override /object_class: mug +- override /relationship: grasp +- _self_ + +mode: train diff --git a/configs/commands/ndf/ablations/1_no_corr_loss/train_mug_place.yaml b/configs/commands/ndf/ablations/1_no_corr_loss/train_mug_place.yaml new file mode 100644 index 0000000..b57fea8 --- /dev/null +++ b/configs/commands/ndf/ablations/1_no_corr_loss/train_mug_place.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: +- /train_ndf +- _ablation +- override /object_class: mug +- override /relationship: place +- _self_ + +mode: train diff --git a/configs/commands/ndf/ablations/2_no_cons_loss/_ablation.yaml b/configs/commands/ndf/ablations/2_no_cons_loss/_ablation.yaml new file mode 100644 index 0000000..485aa45 --- /dev/null +++ b/configs/commands/ndf/ablations/2_no_cons_loss/_ablation.yaml @@ -0,0 +1,6 @@ +# @package _global_ + +ablation: + name: 2_no_cons_loss + +consistency_loss_weight: 0 diff --git a/configs/commands/ndf/ablations/2_no_cons_loss/train_mug_grasp.yaml b/configs/commands/ndf/ablations/2_no_cons_loss/train_mug_grasp.yaml new file mode 100644 index 0000000..95bbbe6 --- /dev/null +++ b/configs/commands/ndf/ablations/2_no_cons_loss/train_mug_grasp.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: +- /train_ndf +- _ablation +- override /object_class: mug +- override /relationship: grasp +- _self_ + +mode: train diff --git a/configs/commands/ndf/ablations/2_no_cons_loss/train_mug_place.yaml b/configs/commands/ndf/ablations/2_no_cons_loss/train_mug_place.yaml new file mode 100644 index 0000000..b57fea8 --- /dev/null +++ b/configs/commands/ndf/ablations/2_no_cons_loss/train_mug_place.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: +- /train_ndf +- _ablation +- override /object_class: mug +- override /relationship: place +- _self_ + +mode: train diff --git a/configs/commands/ndf/ablations/3_no_disp_loss_combined/_ablation.yaml b/configs/commands/ndf/ablations/3_no_disp_loss_combined/_ablation.yaml new file mode 100644 index 0000000..7fdc2c5 --- /dev/null +++ b/configs/commands/ndf/ablations/3_no_disp_loss_combined/_ablation.yaml @@ -0,0 +1,7 @@ +# @package _global_ + +ablation: + name: 3_no_disp_loss_combined + +displace_loss_weight: 0 +consistency_loss_weight: 0.1 diff --git a/configs/commands/ndf/ablations/3_no_disp_loss_combined/train_mug_grasp.yaml b/configs/commands/ndf/ablations/3_no_disp_loss_combined/train_mug_grasp.yaml new file mode 100644 index 0000000..95bbbe6 --- /dev/null +++ b/configs/commands/ndf/ablations/3_no_disp_loss_combined/train_mug_grasp.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: +- /train_ndf +- _ablation +- override /object_class: mug +- override /relationship: grasp +- _self_ + +mode: train diff --git a/configs/commands/ndf/ablations/3_no_disp_loss_combined/train_mug_place.yaml b/configs/commands/ndf/ablations/3_no_disp_loss_combined/train_mug_place.yaml new file mode 100644 index 0000000..b57fea8 --- /dev/null +++ b/configs/commands/ndf/ablations/3_no_disp_loss_combined/train_mug_place.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: +- /train_ndf +- _ablation +- override /object_class: mug +- override /relationship: place +- _self_ + +mode: train diff --git a/configs/commands/ndf/ablations/4_no_residuals/_ablation.yaml b/configs/commands/ndf/ablations/4_no_residuals/_ablation.yaml new file mode 100644 index 0000000..4e2ae5e --- /dev/null +++ b/configs/commands/ndf/ablations/4_no_residuals/_ablation.yaml @@ -0,0 +1,7 @@ +# @package _global_ + +ablation: + name: 4_no_residuals + +model: + residual_on: False diff --git a/configs/commands/ndf/ablations/4_no_residuals/train_mug_grasp.yaml b/configs/commands/ndf/ablations/4_no_residuals/train_mug_grasp.yaml new file mode 100644 index 0000000..95bbbe6 --- /dev/null +++ b/configs/commands/ndf/ablations/4_no_residuals/train_mug_grasp.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: +- /train_ndf +- _ablation +- override /object_class: mug +- override /relationship: grasp +- _self_ + +mode: train diff --git a/configs/commands/ndf/ablations/4_no_residuals/train_mug_place.yaml b/configs/commands/ndf/ablations/4_no_residuals/train_mug_place.yaml new file mode 100644 index 0000000..b57fea8 --- /dev/null +++ b/configs/commands/ndf/ablations/4_no_residuals/train_mug_place.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: +- /train_ndf +- _ablation +- override /object_class: mug +- override /relationship: place +- _self_ + +mode: train diff --git a/configs/commands/ndf/ablations/5_unweighted_svd/_ablation.yaml b/configs/commands/ndf/ablations/5_unweighted_svd/_ablation.yaml new file mode 100644 index 0000000..4024f47 --- /dev/null +++ b/configs/commands/ndf/ablations/5_unweighted_svd/_ablation.yaml @@ -0,0 +1,7 @@ +# @package _global_ + +ablation: + name: 5_unweighted_svd + +model: + pred_weight: False diff --git a/configs/commands/ndf/ablations/5_unweighted_svd/train_mug_grasp.yaml b/configs/commands/ndf/ablations/5_unweighted_svd/train_mug_grasp.yaml new file mode 100644 index 0000000..95bbbe6 --- /dev/null +++ b/configs/commands/ndf/ablations/5_unweighted_svd/train_mug_grasp.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: +- /train_ndf +- _ablation +- override /object_class: mug +- override /relationship: grasp +- _self_ + +mode: train diff --git a/configs/commands/ndf/ablations/5_unweighted_svd/train_mug_place.yaml b/configs/commands/ndf/ablations/5_unweighted_svd/train_mug_place.yaml new file mode 100644 index 0000000..b57fea8 --- /dev/null +++ b/configs/commands/ndf/ablations/5_unweighted_svd/train_mug_place.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: +- /train_ndf +- _ablation +- override /object_class: mug +- override /relationship: place +- _self_ + +mode: train diff --git a/configs/commands/ndf/ablations/6_no_finetuning/_ablation.yaml b/configs/commands/ndf/ablations/6_no_finetuning/_ablation.yaml new file mode 100644 index 0000000..b6385a0 --- /dev/null +++ b/configs/commands/ndf/ablations/6_no_finetuning/_ablation.yaml @@ -0,0 +1,6 @@ +# @package _global_ + +ablation: + name: 6_no_finetuning + +freeze_embnn: True diff --git a/configs/commands/ndf/ablations/6_no_finetuning/train_mug_grasp.yaml b/configs/commands/ndf/ablations/6_no_finetuning/train_mug_grasp.yaml new file mode 100644 index 0000000..95bbbe6 --- /dev/null +++ b/configs/commands/ndf/ablations/6_no_finetuning/train_mug_grasp.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: +- /train_ndf +- _ablation +- override /object_class: mug +- override /relationship: grasp +- _self_ + +mode: train diff --git a/configs/commands/ndf/ablations/6_no_finetuning/train_mug_place.yaml b/configs/commands/ndf/ablations/6_no_finetuning/train_mug_place.yaml new file mode 100644 index 0000000..b57fea8 --- /dev/null +++ b/configs/commands/ndf/ablations/6_no_finetuning/train_mug_place.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: +- /train_ndf +- _ablation +- override /object_class: mug +- override /relationship: place +- _self_ + +mode: train diff --git a/configs/commands/ndf/ablations/7_no_pretraining/_ablation.yaml b/configs/commands/ndf/ablations/7_no_pretraining/_ablation.yaml new file mode 100644 index 0000000..e901c1d --- /dev/null +++ b/configs/commands/ndf/ablations/7_no_pretraining/_ablation.yaml @@ -0,0 +1,8 @@ +# @package _global_ + +ablation: + name: 7_no_pretraining + +pretraining: + checkpoint_file_action: Null + checkpoint_file_anchor: Null diff --git a/configs/commands/ndf/ablations/7_no_pretraining/train_mug_grasp.yaml b/configs/commands/ndf/ablations/7_no_pretraining/train_mug_grasp.yaml new file mode 100644 index 0000000..95bbbe6 --- /dev/null +++ b/configs/commands/ndf/ablations/7_no_pretraining/train_mug_grasp.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: +- /train_ndf +- _ablation +- override /object_class: mug +- override /relationship: grasp +- _self_ + +mode: train diff --git a/configs/commands/ndf/ablations/7_no_pretraining/train_mug_place.yaml b/configs/commands/ndf/ablations/7_no_pretraining/train_mug_place.yaml new file mode 100644 index 0000000..b57fea8 --- /dev/null +++ b/configs/commands/ndf/ablations/7_no_pretraining/train_mug_place.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: +- /train_ndf +- _ablation +- override /object_class: mug +- override /relationship: place +- _self_ + +mode: train diff --git a/configs/commands/ndf/ablations/8_mlp/_ablation.yaml b/configs/commands/ndf/ablations/8_mlp/_ablation.yaml new file mode 100644 index 0000000..66e8890 --- /dev/null +++ b/configs/commands/ndf/ablations/8_mlp/_ablation.yaml @@ -0,0 +1,6 @@ +# @package _global_ + +ablation: + name: 8_mlp + +mlp: True diff --git a/configs/commands/ndf/ablations/8_mlp/train_mug_grasp.yaml b/configs/commands/ndf/ablations/8_mlp/train_mug_grasp.yaml new file mode 100644 index 0000000..95bbbe6 --- /dev/null +++ b/configs/commands/ndf/ablations/8_mlp/train_mug_grasp.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: +- /train_ndf +- _ablation +- override /object_class: mug +- override /relationship: grasp +- _self_ + +mode: train diff --git a/configs/commands/ndf/ablations/8_mlp/train_mug_place.yaml b/configs/commands/ndf/ablations/8_mlp/train_mug_place.yaml new file mode 100644 index 0000000..b57fea8 --- /dev/null +++ b/configs/commands/ndf/ablations/8_mlp/train_mug_place.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: +- /train_ndf +- _ablation +- override /object_class: mug +- override /relationship: place +- _self_ + +mode: train diff --git a/configs/commands/ndf/ablations/9_low_dimensional_embeddings/_ablation.yaml b/configs/commands/ndf/ablations/9_low_dimensional_embeddings/_ablation.yaml new file mode 100644 index 0000000..ea8bd15 --- /dev/null +++ b/configs/commands/ndf/ablations/9_low_dimensional_embeddings/_ablation.yaml @@ -0,0 +1,7 @@ +# @package _global_ + +ablation: + name: 9_low_dim_embedding + +model: + emb_dims: 16 diff --git a/configs/commands/ndf/ablations/9_low_dimensional_embeddings/train_mug_grasp.yaml b/configs/commands/ndf/ablations/9_low_dimensional_embeddings/train_mug_grasp.yaml new file mode 100644 index 0000000..95bbbe6 --- /dev/null +++ b/configs/commands/ndf/ablations/9_low_dimensional_embeddings/train_mug_grasp.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: +- /train_ndf +- _ablation +- override /object_class: mug +- override /relationship: grasp +- _self_ + +mode: train diff --git a/configs/commands/ndf/ablations/9_low_dimensional_embeddings/train_mug_place.yaml b/configs/commands/ndf/ablations/9_low_dimensional_embeddings/train_mug_place.yaml new file mode 100644 index 0000000..b57fea8 --- /dev/null +++ b/configs/commands/ndf/ablations/9_low_dimensional_embeddings/train_mug_place.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: +- /train_ndf +- _ablation +- override /object_class: mug +- override /relationship: place +- _self_ + +mode: train diff --git a/configs/commands/ndf/ablations/n_demos/train_mug_grasp_1.yaml b/configs/commands/ndf/ablations/n_demos/train_mug_grasp_1.yaml new file mode 100644 index 0000000..6ae57fc --- /dev/null +++ b/configs/commands/ndf/ablations/n_demos/train_mug_grasp_1.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: +- /train_ndf +- override /object_class: mug +- override /relationship: grasp +- _self_ + +mode: train +num_demo: 1 diff --git a/configs/commands/ndf/ablations/n_demos/train_mug_grasp_5.yaml b/configs/commands/ndf/ablations/n_demos/train_mug_grasp_5.yaml new file mode 100644 index 0000000..f8c08f9 --- /dev/null +++ b/configs/commands/ndf/ablations/n_demos/train_mug_grasp_5.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: +- /train_ndf +- override /object_class: mug +- override /relationship: grasp +- _self_ + +mode: train +num_demo: 5 diff --git a/configs/commands/ndf/ablations/n_demos/train_mug_place_1.yaml b/configs/commands/ndf/ablations/n_demos/train_mug_place_1.yaml new file mode 100644 index 0000000..6ae57fc --- /dev/null +++ b/configs/commands/ndf/ablations/n_demos/train_mug_place_1.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: +- /train_ndf +- override /object_class: mug +- override /relationship: grasp +- _self_ + +mode: train +num_demo: 1 diff --git a/configs/commands/ndf/ablations/n_demos/train_mug_place_5.yaml b/configs/commands/ndf/ablations/n_demos/train_mug_place_5.yaml new file mode 100644 index 0000000..f8c08f9 --- /dev/null +++ b/configs/commands/ndf/ablations/n_demos/train_mug_place_5.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: +- /train_ndf +- override /object_class: mug +- override /relationship: grasp +- _self_ + +mode: train +num_demo: 5 diff --git a/configs/commands/ndf/bottle/train_grasp.yaml b/configs/commands/ndf/bottle/train_grasp.yaml new file mode 100644 index 0000000..2e04db7 --- /dev/null +++ b/configs/commands/ndf/bottle/train_grasp.yaml @@ -0,0 +1,9 @@ +# @package _global_ + +defaults: +- _self_ +- /train_ndf +- override /model: taxpose +- override /task: bottle_grasp + +mode: train diff --git a/configs/commands/ndf/bottle/train_place.yaml b/configs/commands/ndf/bottle/train_place.yaml new file mode 100644 index 0000000..764667c --- /dev/null +++ b/configs/commands/ndf/bottle/train_place.yaml @@ -0,0 +1,9 @@ +# @package _global_ + +defaults: +- _self_ +- /train_ndf +- override /model: taxpose +- override /task: bottle_place + +mode: train diff --git a/configs/commands/ndf/bowl/train_grasp.yaml b/configs/commands/ndf/bowl/train_grasp.yaml new file mode 100644 index 0000000..e3b9422 --- /dev/null +++ b/configs/commands/ndf/bowl/train_grasp.yaml @@ -0,0 +1,9 @@ +# @package _global_ + +defaults: +- _self_ +- /train_ndf +- override /model: taxpose +- override /task: bowl_grasp + +mode: train diff --git a/configs/commands/ndf/bowl/train_place.yaml b/configs/commands/ndf/bowl/train_place.yaml new file mode 100644 index 0000000..7c7649b --- /dev/null +++ b/configs/commands/ndf/bowl/train_place.yaml @@ -0,0 +1,9 @@ +# @package _global_ + +defaults: +- _self_ +- /train_ndf +- override /model: taxpose +- override /task: bowl_place + +mode: train diff --git a/configs/commands/ndf/mug/eval_ndf_arbitrary.yaml b/configs/commands/ndf/mug/eval_ndf_arbitrary.yaml new file mode 100644 index 0000000..9e48f07 --- /dev/null +++ b/configs/commands/ndf/mug/eval_ndf_arbitrary.yaml @@ -0,0 +1,8 @@ +# @package _global_ + +defaults: +- /eval_full_mug_standalone +- override /object_class: mug +- override /model: taxpose +- override /pose_dist: arbitrary +- _self_ diff --git a/configs/commands/ndf/mug/eval_ndf_upright.yaml b/configs/commands/ndf/mug/eval_ndf_upright.yaml new file mode 100644 index 0000000..8ed3fb9 --- /dev/null +++ b/configs/commands/ndf/mug/eval_ndf_upright.yaml @@ -0,0 +1,8 @@ +# @package _global_ + +defaults: +- /eval_full_mug_standalone +- override /object_class: mug +- override /model: taxpose +- override /pose_dist: upright +- _self_ diff --git a/configs/commands/ndf/mug/train_grasp.yaml b/configs/commands/ndf/mug/train_grasp.yaml new file mode 100644 index 0000000..e942865 --- /dev/null +++ b/configs/commands/ndf/mug/train_grasp.yaml @@ -0,0 +1,9 @@ +# @package _global_ + +defaults: +- _self_ +- /train_ndf +- override /model: taxpose +- override /task: mug_grasp + +mode: train diff --git a/configs/commands/ndf/mug/train_place.yaml b/configs/commands/ndf/mug/train_place.yaml new file mode 100644 index 0000000..57e831b --- /dev/null +++ b/configs/commands/ndf/mug/train_place.yaml @@ -0,0 +1,9 @@ +# @package _global_ + +defaults: +- _self_ +- /train_ndf +- override /model: taxpose +- override /task: mug_place + +mode: train diff --git a/configs/commands/ndf/pretraining/bottle.yaml b/configs/commands/ndf/pretraining/bottle.yaml new file mode 100644 index 0000000..d461045 --- /dev/null +++ b/configs/commands/ndf/pretraining/bottle.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: +- /pretraining +- _self_ + +data_root: ${hydra:runtime.cwd}/data/ + +training: + dataset: + root: ${data_root}/bottle_place/train_data/renders + cloud_class: 0 diff --git a/configs/commands/ndf/pretraining/bowl.yaml b/configs/commands/ndf/pretraining/bowl.yaml new file mode 100644 index 0000000..24850ab --- /dev/null +++ b/configs/commands/ndf/pretraining/bowl.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: +- /pretraining +- _self_ + +data_root: ${hydra:runtime.cwd}/data/ + +training: + dataset: + root: ${data_root}/bowl_place/train_data/renders + cloud_class: 0 diff --git a/configs/commands/ndf/pretraining/gripper.yaml b/configs/commands/ndf/pretraining/gripper.yaml new file mode 100644 index 0000000..5b77a6d --- /dev/null +++ b/configs/commands/ndf/pretraining/gripper.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: +- /pretraining +- _self_ + +data_root: ${hydra:runtime.cwd}/data/ + +training: + dataset: + root: ${data_root}/mug_place/train_data/renders + cloud_class: 2 diff --git a/configs/commands/ndf/pretraining/mug.yaml b/configs/commands/ndf/pretraining/mug.yaml new file mode 100644 index 0000000..62cb56d --- /dev/null +++ b/configs/commands/ndf/pretraining/mug.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: +- /pretraining +- _self_ + +data_root: ${hydra:runtime.cwd}/data/ + +training: + dataset: + root: ${data_root}/mug_place/train_data/renders + cloud_class: 0 diff --git a/configs/commands/ndf/pretraining/rack.yaml b/configs/commands/ndf/pretraining/rack.yaml new file mode 100644 index 0000000..487702e --- /dev/null +++ b/configs/commands/ndf/pretraining/rack.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: +- /pretraining +- _self_ + +data_root: ${hydra:runtime.cwd}/data/ + +training: + dataset: + root: ${data_root}/mug_place/train_data/renders + cloud_class: 1 diff --git a/configs/commands/ndf/pretraining/slab.yaml b/configs/commands/ndf/pretraining/slab.yaml new file mode 100644 index 0000000..eee8b47 --- /dev/null +++ b/configs/commands/ndf/pretraining/slab.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: +- /pretraining +- _self_ + +data_root: ${hydra:runtime.cwd}/data/ + +training: + dataset: + root: ${data_root}/bottle_place/train_data/renders + cloud_class: 1 diff --git a/configs/eval_full_mug_ablation.yaml b/configs/eval_full_mug_ablation.yaml deleted file mode 100644 index e633729..0000000 --- a/configs/eval_full_mug_ablation.yaml +++ /dev/null @@ -1,99 +0,0 @@ -hydra: - run: - dir: ${log_dir}/${experiment}/${now:%Y-%m-%d_%H%M%S} - sweep: - dir: ${log_dir}/${experiment}/sweep/${now:%Y-%m-%d_%H%M%S} - subdir: ${hydra.job.num} - job: - chdir: True - -experiment: residual_flow_test_partial_cloud_initpose - -#### TO BE CHANGED #### -defaults: - - _self_ - - pose_dist: upright # {upright, arbitrary} - - ablation: 0_no_disp_loss - - # $ablation \in { - # 0_no_disp_loss, - # 1_no_corr_loss, - # 2_no_cons_loss, - # 3_no_disp_loss_combined, - # 4_no_residuals, - # 5_unweighted_svd, - # 6_no_finetuning, - # 7_no_pretraining, - # 8_mlp, - # 9_low_dim_embedding - #} -checkpoint_file_grasp: Null # to be filled with trained model -checkpoint_file_place: Null # to be filled with trained model -# checkpoint_file_grasp: ${hydra:runtime.cwd}/trained_models/ndf/mug/${pose_dist.name}/grasp.ckpt -# checkpoint_file_place: ${hydra:runtime.cwd}/trained_models/ndf/mug/${pose_dist.name}/place.ckpt -log_txt_file: ${hydra:runtime.cwd}/test_results_ablation.txt # abs path of file to log results -# log saving dir -log_dir: ./results/ndf/mug_place -num_iterations: 100 # number of trails -object_class: mug -data_dir: mug_place # directory name for data saved -log_every_trial: False # True - log success rate at every trial; False - log only at the end of the # of trials -#### TO BE CHANGED #### - -# Model Settings -flow_compute_type: 0 -emb_nn: dgcnn -num_points: 1024 - -# Dataset Settings -dataset_index: None -action_class: 0 -anchor_class: 1 -dataset_size: 300 -action_rotation_variance: 180 -translation_variance: 0.5 -batch_size: 1 -num_workers: 4 -cloud_type: init -no_transform_applied: True -point_loss_type: 0 -gripper_lr_label: False -return_flow_component: False -center_feature: True -overfit: False -diff_emb: True -diff_transformer: True - -#### TO BE CHANGED #### -weight_normalize_place: l1 -sigmoid_on: True -rotation_weight: 0 -consistency_weight: 1 -smoothness_weight: 0.1 -softmax_temperature: 1 -weight_normalize_grasp: softmax -softmax_temperature_grasp: 0.1 -#### TO BE CHANGED #### - -return_attn: True -rand_mesh_scale: True -loop: 1 -init_distribution_tranform_file: Null - -# Loss Settings -lr: 1e-4 - -# Logging Settings -image_logging_period: 100 -debug: False -seed: 10 -pybullet_viz: False -only_test_ids: True -demo_exp: grasp_rim_hang_handle_gaussian_precise_w_shelf -exp: debug_eval -num_demo: 12 -config: eval_mug_gen -model_path: multi_category_weights -n_demos: 0 -single_instance: False -start_iteration: 0 diff --git a/configs/eval_full_mug_place.yaml b/configs/eval_full_mug_place.yaml index 466a1b3..c5218fd 100644 --- a/configs/eval_full_mug_place.yaml +++ b/configs/eval_full_mug_place.yaml @@ -17,11 +17,13 @@ defaults: # Machinery to load the correct checkpoints. - model: release - - checkpoints/ndf@checkpoints.grasp: ${object_class}/${pose_dist}/${model}/grasp - - checkpoints/ndf@checkpoints.place: ${object_class}/${pose_dist}/${model}/place + - optional checkpoints/ndf@checkpoints.grasp: ${object_class}/${pose_dist}/${model}/grasp + - optional checkpoints/ndf@checkpoints.place: ${object_class}/${pose_dist}/${model}/place -checkpoint_file_grasp: ${checkpoints.grasp.ckpt_file} -checkpoint_file_place: ${checkpoints.place.ckpt_file} +# checkpoint_file_grasp: ${checkpoints.grasp.ckpt_file} +# checkpoint_file_place: ${checkpoints.place.ckpt_file} +checkpoint_file_grasp: ??? +checkpoint_file_place: ??? log_txt_file: ${hydra:runtime.cwd}/test_results.txt # abs path of file to log results # log saving dir @@ -35,7 +37,7 @@ log_every_trial: False # True - log success rate at every trial; False - log onl # flow_compute_type: 0 emb_nn: dgcnn num_points: 1024 -emb_dims: 512 +emb_dims: ${model.emb_dims} multilaterate: False # Dataset Settings @@ -63,11 +65,11 @@ sigmoid_on: True rotation_weight: 0 consistency_weight: 1 smoothness_weight: 0.1 -pred_weight: True +pred_weight: ${model.pred_weight} softmax_temperature: 1 weight_normalize_grasp: softmax softmax_temperature_grasp: 0.1 -residual_on: True +residual_on: ${model.residual_on} #### TO BE CHANGED #### mlp: False diff --git a/configs/eval_full_mug_standalone.yaml b/configs/eval_full_mug_standalone.yaml index e651f0d..264666b 100644 --- a/configs/eval_full_mug_standalone.yaml +++ b/configs/eval_full_mug_standalone.yaml @@ -22,12 +22,19 @@ defaults: # Load the checkpoints appropriately. - checkpoints/ndf@checkpoints: ${task}/${model} + # - checkpoints/ndf@checkpoints: ${object_class}/${model} + + +# checkpoint_file_grasp: ${checkpoints.${pose_dist.name}.grasp} +# checkpoint_file_place: ${checkpoints.${pose_dist.name}.place} +# checkpoint_file_place_refinement: ${checkpoints.${pose_dist.name}.place_refinement} +# checkpoint_file_grasp_refinement: ${checkpoints.${pose_dist.name}.grasp_refinement} +checkpoint_file_grasp: null +checkpoint_file_place: null +checkpoint_file_place_refinement: null +checkpoint_file_grasp_refinement: null -checkpoint_file_grasp: ${checkpoints.${pose_dist.name}.grasp} -checkpoint_file_place: ${checkpoints.${pose_dist.name}.place} -checkpoint_file_place_refinement: ${checkpoints.${pose_dist.name}.place_refinement} -checkpoint_file_grasp_refinement: ${checkpoints.${pose_dist.name}.grasp_refinement} # Extra random irrelevant model settings, legacy. loop: 1 diff --git a/configs/task/ndf/bottle/phase/grasp.yaml b/configs/task/ndf/bottle/phase/grasp.yaml index 08c45cb..aaf7c09 100644 --- a/configs/task/ndf/bottle/phase/grasp.yaml +++ b/configs/task/ndf/bottle/phase/grasp.yaml @@ -7,5 +7,7 @@ anchor_class: 0 cloud_type: pre_grasp softmax_temperature: 0.1 weight_normalize: softmax -checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_gripper_embnn_weights.ckpt -checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_bottle_embnn_weights.ckpt +# checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_gripper_embnn_weights.ckpt +# checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_bottle_embnn_weights.ckpt +checkpoint_file_action: null +checkpoint_file_anchor: null diff --git a/configs/task/ndf/bottle/phase/place.yaml b/configs/task/ndf/bottle/phase/place.yaml index 94698cb..040547e 100644 --- a/configs/task/ndf/bottle/phase/place.yaml +++ b/configs/task/ndf/bottle/phase/place.yaml @@ -7,5 +7,7 @@ anchor_class: 1 cloud_type: teleport softmax_temperature: 1 weight_normalize: l1 -checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_bottle_embnn_weights.ckpt -checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_slab_embnn_weights.ckpt +# checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_bottle_embnn_weights.ckpt +# checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_slab_embnn_weights.ckpt +checkpoint_file_action: null +checkpoint_file_anchor: null diff --git a/configs/task/ndf/bowl/phase/grasp.yaml b/configs/task/ndf/bowl/phase/grasp.yaml index 352a451..7f4a159 100644 --- a/configs/task/ndf/bowl/phase/grasp.yaml +++ b/configs/task/ndf/bowl/phase/grasp.yaml @@ -7,5 +7,7 @@ anchor_class: 0 cloud_type: pre_grasp softmax_temperature: 0.1 weight_normalize: softmax -checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_gripper_embnn_weights.ckpt -checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_bowl_embnn_weights.ckpt +# checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_gripper_embnn_weights.ckpt +# checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_bowl_embnn_weights.ckpt +checkpoint_file_action: null +checkpoint_file_anchor: null diff --git a/configs/task/ndf/bowl/phase/place.yaml b/configs/task/ndf/bowl/phase/place.yaml index 5231695..2667a69 100644 --- a/configs/task/ndf/bowl/phase/place.yaml +++ b/configs/task/ndf/bowl/phase/place.yaml @@ -7,5 +7,7 @@ anchor_class: 1 cloud_type: teleport softmax_temperature: 1 weight_normalize: l1 -checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_bowl_embnn_weights.ckpt -checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_slab_embnn_weights.ckpt +# checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_bowl_embnn_weights.ckpt +# checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_slab_embnn_weights.ckpt +checkpoint_file_action: null +checkpoint_file_anchor: null diff --git a/configs/task/ndf/mug/phase/grasp.yaml b/configs/task/ndf/mug/phase/grasp.yaml index 495f575..5a7a847 100644 --- a/configs/task/ndf/mug/phase/grasp.yaml +++ b/configs/task/ndf/mug/phase/grasp.yaml @@ -7,5 +7,7 @@ anchor_class: 0 cloud_type: pre_grasp softmax_temperature: 0.1 weight_normalize: softmax -checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_gripper_embnn_weights.ckpt -checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt +# checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_gripper_embnn_weights.ckpt +# checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt +checkpoint_file_action: null +checkpoint_file_anchor: null diff --git a/configs/task/ndf/mug/phase/place.yaml b/configs/task/ndf/mug/phase/place.yaml index ad42edf..fc4a550 100644 --- a/configs/task/ndf/mug/phase/place.yaml +++ b/configs/task/ndf/mug/phase/place.yaml @@ -7,5 +7,7 @@ softmax_temperature: 1 weight_normalize: l1 action_name: mug anchor_name: rack -checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt -checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_rack_embnn_weights.ckpt +# checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt +# checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_rack_embnn_weights.ckpt +checkpoint_file_action: null +checkpoint_file_anchor: null diff --git a/configs/train_mug_residual_ablation.yaml b/configs/train_mug_residual_ablation.yaml deleted file mode 100644 index 6086d93..0000000 --- a/configs/train_mug_residual_ablation.yaml +++ /dev/null @@ -1,71 +0,0 @@ -# Logging Settings -hydra: - run: - dir: ${log_dir}/${experiment}/${now:%Y-%m-%d_%H%M%S} - sweep: - dir: ${log_dir}/${experiment}/sweep/${now:%Y-%m-%d_%H%M%S} - subdir: ${hydra.job.num} - job: - chdir: True - -log_dir: logs -experiment: residual_flow_occlusion -image_logging_period: 100 -log_txt_file: ${hydra:runtime.cwd}/train_ablation.txt -# $task \in {mug_place, mug_grasp} -defaults: - - _self_ - - task: mug_place - - ablation: 0_no_disp_loss - - # $ablation \in { - # 0_no_disp_loss, - # 1_no_corr_loss, - # 2_no_cons_loss, - # 3_no_disp_loss_combined, - # 4_no_residuals, - # 5_unweighted_svd, - # 6_no_finetuning, - # 7_no_pretraining, - # 8_mlp, - # 9_low_dim_embedding - #} - -# Dataset Settings -dataset_root: ${hydra:runtime.cwd}/data -train_data_dir: ${dataset_root}/${task.name}/train_data/renders -test_data_dir: ${dataset_root}/${task.name}/test_data/renders -num_workers: 52 -batch_size: 8 -num_points: 1024 -num_demo: 10 -dataset_index: None -object_type: mug # -dataset_size: 1000 -action_rotation_variance: 180 -translation_variance: 0.5 -synthetic_occlusion: True -plane_occlusion: True -plane_standoff: 0.04 -ball_occlusion: True -ball_radius: 0.1 -gripper_lr_label: False # -occlusion_class: 0 -overfit: False -seed: 0 - -# Network Settings -center_feature: True -emb_nn: dgcnn -sigmoid_on: True -return_flow_component: False - -# Loss Settings -flow_supervision: both - -# Training Settings -checkpoint_file: Null -checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt -checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_rack_embnn_weights.ckpt -lr: 1e-4 -max_epochs: 1000 diff --git a/pyproject.toml b/pyproject.toml index b8b322f..5ae9b16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,7 +106,7 @@ module = [ ignore_missing_imports = true [tool.pytest.ini_options] -addopts = "--ignore=third_party/ -n auto" +addopts = "--ignore=third_party/ -m 'not long' -n auto" testpaths = "tests" [tool.pylint] diff --git a/scripts/evaluate_ndf_mug.py b/scripts/evaluate_ndf_mug.py index a20993e..85e9d39 100644 --- a/scripts/evaluate_ndf_mug.py +++ b/scripts/evaluate_ndf_mug.py @@ -47,7 +47,10 @@ from ndf_robot.utils.util import np2img from pytorch3d.ops import sample_farthest_points -from taxpose.nets.transformer_flow import ResidualFlow_DiffEmbTransformer +from taxpose.nets.transformer_flow import ( + CorrespondenceFlow_DiffEmbMLP, + ResidualFlow_DiffEmbTransformer, +) from taxpose.training.flow_equivariance_training_module_nocentering_eval_init import ( EquivarianceTestingModule, ) @@ -480,17 +483,24 @@ def show_link(obj_id, link_id, color): pl.seed_everything(hydra_cfg.seed) - place_network = ResidualFlow_DiffEmbTransformer( - emb_dims=hydra_cfg.emb_dims, - emb_nn=hydra_cfg.emb_nn, - center_feature=hydra_cfg.center_feature, - pred_weight=hydra_cfg.pred_weight, - residual_on=hydra_cfg.residual_on, - return_flow_component=hydra_cfg.return_flow_component, - freeze_embnn=hydra_cfg.freeze_embnn, - return_attn=hydra_cfg.return_attn, - multilaterate=hydra_cfg.multilaterate, - ) + if hydra_cfg.mlp: + place_network = CorrespondenceFlow_DiffEmbMLP( + emb_dims=hydra_cfg.emb_dims, + emb_nn=hydra_cfg.emb_nn, + center_feature=hydra_cfg.center_feature, + ) + else: + place_network = ResidualFlow_DiffEmbTransformer( + emb_dims=hydra_cfg.emb_dims, + emb_nn=hydra_cfg.emb_nn, + center_feature=hydra_cfg.center_feature, + pred_weight=hydra_cfg.pred_weight, + residual_on=hydra_cfg.residual_on, + return_flow_component=hydra_cfg.return_flow_component, + freeze_embnn=hydra_cfg.freeze_embnn, + return_attn=hydra_cfg.return_attn, + multilaterate=hydra_cfg.multilaterate, + ) place_model = EquivarianceTestingModule( place_network, @@ -508,17 +518,24 @@ def show_link(obj_id, link_id, color): ) log_info("Model Loaded from " + str(hydra_cfg.checkpoint_file_place)) - grasp_network = ResidualFlow_DiffEmbTransformer( - emb_dims=hydra_cfg.emb_dims, - emb_nn=hydra_cfg.emb_nn, - center_feature=hydra_cfg.center_feature, - pred_weight=hydra_cfg.pred_weight, - residual_on=hydra_cfg.residual_on, - return_flow_component=hydra_cfg.return_flow_component, - freeze_embnn=hydra_cfg.freeze_embnn, - return_attn=hydra_cfg.return_attn, - multilaterate=hydra_cfg.multilaterate, - ) + if hydra_cfg.mlp: + grasp_network = CorrespondenceFlow_DiffEmbMLP( + emb_dims=hydra_cfg.emb_dims, + emb_nn=hydra_cfg.emb_nn, + center_feature=hydra_cfg.center_feature, + ) + else: + grasp_network = ResidualFlow_DiffEmbTransformer( + emb_dims=hydra_cfg.emb_dims, + emb_nn=hydra_cfg.emb_nn, + center_feature=hydra_cfg.center_feature, + pred_weight=hydra_cfg.pred_weight, + residual_on=hydra_cfg.residual_on, + return_flow_component=hydra_cfg.return_flow_component, + freeze_embnn=hydra_cfg.freeze_embnn, + return_attn=hydra_cfg.return_attn, + multilaterate=hydra_cfg.multilaterate, + ) grasp_model = EquivarianceTestingModule( grasp_network, diff --git a/scripts/evaluate_ndf_mug_ablation.py b/scripts/evaluate_ndf_mug_ablation.py deleted file mode 100644 index 0ccf6c1..0000000 --- a/scripts/evaluate_ndf_mug_ablation.py +++ /dev/null @@ -1,1300 +0,0 @@ -"""This is a fork of https://github.com/anthonysimeonov/ndf_robot/blob/master/src/ndf_robot/eval/evaluate_ndf.py""" - -import os -import os.path as osp -import random -import signal -import time -from pathlib import Path - -import hydra -import numpy as np -import pybullet as p -import pytorch_lightning as pl -import torch -from airobot import Robot, log_info, log_warn, set_log_level -from airobot.utils import common -from airobot.utils.common import euler2quat -from ndf_robot.config.default_eval_cfg import get_eval_cfg_defaults -from ndf_robot.config.default_obj_cfg import get_obj_cfg_defaults -from ndf_robot.robot.multicam import MultiCams -from ndf_robot.share.globals import ( - bad_shapenet_bottles_ids_list, - bad_shapenet_bowls_ids_list, - bad_shapenet_mug_ids_list, -) -from ndf_robot.utils import path_util, util -from ndf_robot.utils.eval_gen_utils import ( - constraint_grasp_close, - constraint_grasp_open, - constraint_obj_world, - get_ee_offset, - object_is_still_grasped, - process_demo_data_rack, - process_demo_data_shelf, - process_xq_data, - process_xq_rs_data, - safeCollisionFilterPair, - safeRemoveConstraint, - soft_grasp_close, -) -from ndf_robot.utils.franka_ik import FrankaIK -from ndf_robot.utils.util import np2img -from pytorch3d.ops import sample_farthest_points - -from taxpose.nets.transformer_flow import ( - CorrespondenceFlow_DiffEmbMLP, - ResidualFlow_DiffEmbTransformer, -) -from taxpose.training.flow_equivariance_training_module_nocentering_eval_init import ( - EquivarianceTestingModule, -) -from taxpose.utils.ndf_sim_utils import get_clouds, get_object_clouds - -# Gotta do some path hacking to convince ndf_robot to work. -NDF_ROOT = Path(__file__).parent.parent / "third_party" / "ndf_robot" -os.environ["NDF_SOURCE_DIR"] = str(NDF_ROOT / "src" / "ndf_robot") -os.environ["PB_PLANNING_SOURCE_DIR"] = str(NDF_ROOT / "pybullet-planning") - - -def get_world_transform(pred_T_action_mat, obj_start_pose, point_cloud, invert=False): - """ - pred_T_action_mat: normal SE(3) [R|t] - [0|1] in object frame - obj_start_pose: stamped_pose of obj in world frame - """ - point_cloud_mean = point_cloud.squeeze(0).mean(axis=0).tolist() - obj_start_pose_list = util.pose_stamped2list(obj_start_pose) - # import pdb - # pdb.set_trace() - pose = util.pose_from_matrix(pred_T_action_mat) - centering_mat = np.eye(4) - # centering_mat[:3, 3] = -np.array(obj_start_pose_list[:3]) - centering_mat[:3, 3] = -np.array(point_cloud_mean) - centering_pose = util.pose_from_matrix(centering_mat) - uncentering_pose = util.pose_from_matrix(np.linalg.inv(centering_mat)) - - centered_pose = util.transform_pose( - pose_source=obj_start_pose, pose_transform=centering_pose - ) # obj_start_pose: stamped_pose - trans_pose = util.transform_pose(pose_source=centered_pose, pose_transform=pose) - final_pose = util.transform_pose( - pose_source=trans_pose, pose_transform=uncentering_pose - ) - if invert: - final_pose = util.pose_from_matrix( - np.linalg.inv(util.matrix_from_pose(final_pose)) - ) - - return final_pose - - -def load_data(num_points, clouds, classes, action_class, anchor_class): - points_raw_np = clouds - classes_raw_np = classes - - points_action_np = points_raw_np[classes_raw_np == action_class].copy() - points_action_mean_np = points_action_np.mean(axis=0) - points_action_np = points_action_np - points_action_mean_np - - points_anchor_np = points_raw_np[classes_raw_np == anchor_class].copy() - points_anchor_np = points_anchor_np - points_action_mean_np - points_anchor_mean_np = points_anchor_np.mean(axis=0) - - points_action = torch.from_numpy(points_action_np).float().unsqueeze(0) - points_anchor = torch.from_numpy(points_anchor_np).float().unsqueeze(0) - - points_action, points_anchor = subsample(num_points, points_action, points_anchor) - - return points_action.cuda(), points_anchor.cuda() - - -def load_data_raw(num_points, clouds, classes, action_class, anchor_class): - points_raw_np = clouds - classes_raw_np = classes - - points_action_np = points_raw_np[classes_raw_np == action_class].copy() - - points_anchor_np = points_raw_np[classes_raw_np == anchor_class].copy() - - points_action = torch.from_numpy(points_action_np).float().unsqueeze(0) - points_anchor = torch.from_numpy(points_anchor_np).float().unsqueeze(0) - - points_action, points_anchor = subsample(num_points, points_action, points_anchor) - if points_action is None: - return None, None - - return points_action.cuda(), points_anchor.cuda() - - -def subsample(num_points, points_action, points_anchor): - if points_action.shape[1] > num_points: - points_action, _ = sample_farthest_points( - points_action, K=num_points, random_start_point=True - ) - elif points_action.shape[1] < num_points: - log_info( - f"Action point cloud is smaller than cloud size ({points_action.shape[1]} < {num_points})" - ) - return None, None - # raise NotImplementedError( - # f'Action point cloud is smaller than cloud size ({points_action.shape[1]} < {num_points})') - - if points_anchor.shape[1] > num_points: - points_anchor, _ = sample_farthest_points( - points_anchor, K=num_points, random_start_point=True - ) - elif points_anchor.shape[1] < num_points: - log_info( - f"Anchor point cloud is smaller than cloud size ({points_anchor.shape[1]} < {num_points})" - ) - return None, None - # raise NotImplementedError( - # f'Anchor point cloud is smaller than cloud size ({points_anchor.shape[1]} < {num_points})') - - return points_action, points_anchor - - -def write_to_file(file_name, string): - with open(file_name, "a") as f: - f.writelines(string) - f.write("\n") - f.close() - log_info("file dir: {}".format(os.getcwd())) - - -################################################################### -# WHAT TO CHANGE FOR LOCAL USE -# or Search "#### TO BE CHANGED ####" -# -# 1. $config_path$ - absolute path of where the eval config files are stroed -# 2. $config_name$ - eval config file name -# 3. $save_dir$ - the data directory where the scene captures during eval are saved -################################################################### - -#### TO BE CHANGED #### - - -#### TO BE CHANGED #### -@hydra.main(config_path="../configs", config_name="eval_full_mug_ablation") -def main(hydra_cfg): - data_dir = hydra_cfg.data_dir - # '/home/exx/Documents/taxpose/search_existing_models.txt' - log_txt_file = hydra_cfg.log_txt_file - save_dir = os.path.join(hydra.utils.get_original_cwd(), data_dir) - - if not os.path.isdir(save_dir): - os.makedirs(save_dir) - - eval_data_dir = hydra_cfg.data_dir - obj_class = hydra_cfg.object_class - shapenet_obj_dir = osp.join( - path_util.get_ndf_obj_descriptions(), obj_class + "_centered_obj_normalized" - ) - - demo_load_dir = osp.join( - path_util.get_ndf_data(), "demos", obj_class, hydra_cfg.demo_exp - ) - - expstr = "exp--" + str(hydra_cfg.exp) - modelstr = "model--" + str(hydra_cfg.model_path) - seedstr = "seed--" + str(hydra_cfg.seed) - full_experiment_name = "_".join([expstr, modelstr, seedstr]) - - eval_save_dir = osp.join( - path_util.get_ndf_eval_data(), eval_data_dir, full_experiment_name - ) - util.safe_makedirs(eval_save_dir) - - vnn_model_path = osp.join( - path_util.get_ndf_model_weights(), hydra_cfg.model_path + ".pth" - ) - - global_dict = dict( - shapenet_obj_dir=shapenet_obj_dir, - demo_load_dir=demo_load_dir, - eval_save_dir=eval_save_dir, - object_class=obj_class, - vnn_checkpoint_path=vnn_model_path, - ) - - if hydra_cfg.debug: - set_log_level("debug") - else: - set_log_level("info") - - robot = Robot( - "franka", - pb_cfg={"gui": hydra_cfg.pybullet_viz}, - arm_cfg={"self_collision": False, "seed": hydra_cfg.seed}, - ) - ik_helper = FrankaIK(gui=False) - torch.manual_seed(hydra_cfg.seed) - random.seed(hydra_cfg.seed) - np.random.seed(hydra_cfg.seed) - - # general experiment + environment setup/scene generation configs - cfg = get_eval_cfg_defaults() - config_fname = osp.join( - path_util.get_ndf_config(), "eval_cfgs", hydra_cfg.config + ".yaml" - ) - if osp.exists(config_fname): - cfg.merge_from_file(config_fname) - else: - pass - # log_info("Config file %s does not exist, using defaults" % - # config_fname) - cfg.freeze() - - # object specific configs - obj_cfg = get_obj_cfg_defaults() - obj_config_name = osp.join( - path_util.get_ndf_config(), hydra_cfg.object_class + "_obj_cfg.yaml" - ) - obj_cfg.merge_from_file(obj_config_name) - obj_cfg.freeze() - - shapenet_obj_dir = global_dict["shapenet_obj_dir"] - obj_class = global_dict["object_class"] - eval_save_dir = global_dict["eval_save_dir"] - - eval_grasp_imgs_dir = osp.join(eval_save_dir, "grasp_imgs") - eval_teleport_imgs_dir = osp.join(eval_save_dir, "teleport_imgs") - util.safe_makedirs(eval_grasp_imgs_dir) - util.safe_makedirs(eval_teleport_imgs_dir) - - test_shapenet_ids = np.loadtxt( - osp.join(path_util.get_ndf_share(), "%s_test_object_split.txt" % obj_class), - dtype=str, - ).tolist() - if obj_class == "mug": - avoid_shapenet_ids = bad_shapenet_mug_ids_list + cfg.MUG.AVOID_SHAPENET_IDS - elif obj_class == "bowl": - avoid_shapenet_ids = bad_shapenet_bowls_ids_list + cfg.BOWL.AVOID_SHAPENET_IDS - elif obj_class == "bottle": - avoid_shapenet_ids = ( - bad_shapenet_bottles_ids_list + cfg.BOTTLE.AVOID_SHAPENET_IDS - ) - else: - test_shapenet_ids = [] - - finger_joint_id = 9 - left_pad_id = 9 - right_pad_id = 10 - p.changeDynamics(robot.arm.robot_id, left_pad_id, lateralFriction=1.0) - p.changeDynamics(robot.arm.robot_id, right_pad_id, lateralFriction=1.0) - - x_low, x_high = cfg.OBJ_SAMPLE_X_HIGH_LOW - y_low, y_high = cfg.OBJ_SAMPLE_Y_HIGH_LOW - table_z = cfg.TABLE_Z - - preplace_horizontal_tf_list = cfg.PREPLACE_HORIZONTAL_OFFSET_TF - preplace_horizontal_tf = util.list2pose_stamped(cfg.PREPLACE_HORIZONTAL_OFFSET_TF) - preplace_offset_tf = util.list2pose_stamped(cfg.PREPLACE_OFFSET_TF) - - if cfg.DEMOS.PLACEMENT_SURFACE == "shelf": - load_shelf = True - else: - load_shelf = False - - # get filenames of all the demo files - demo_filenames = os.listdir(global_dict["demo_load_dir"]) - assert len(demo_filenames), ( - "No demonstrations found in path: %s!" % global_dict["demo_load_dir"] - ) - - # strip the filenames to properly pair up each demo file - grasp_demo_filenames_orig = [ - osp.join(global_dict["demo_load_dir"], fn) - for fn in demo_filenames - if "grasp_demo" in fn - ] # use the grasp names as a reference - - place_demo_filenames = [] - grasp_demo_filenames = [] - for i, fname in enumerate(grasp_demo_filenames_orig): - shapenet_id_npz = fname.split("/")[-1].split("grasp_demo_")[-1] - place_fname = osp.join( - "/".join(fname.split("/")[:-1]), "place_demo_" + shapenet_id_npz - ) - if osp.exists(place_fname): - grasp_demo_filenames.append(fname) - place_demo_filenames.append(place_fname) - else: - log_warn( - "Could not find corresponding placement demo: %s, skipping " - % place_fname - ) - - success_list = [] - place_success_list = [] - place_success_teleport_list = [] - grasp_success_list = [] - - place_fail_list = [] - place_fail_teleport_list = [] - grasp_fail_list = [] - - demo_shapenet_ids = [] - - # get info from all demonstrations - demo_target_info_list = [] - demo_rack_target_info_list = [] - - if hydra_cfg.n_demos > 0: - gp_fns = list(zip(grasp_demo_filenames, place_demo_filenames)) - gp_fns = random.sample(gp_fns, hydra_cfg.n_demos) - grasp_demo_filenames, place_demo_filenames = zip(*gp_fns) - grasp_demo_filenames, place_demo_filenames = list(grasp_demo_filenames), list( - place_demo_filenames - ) - log_warn("USING ONLY %d DEMONSTRATIONS" % len(grasp_demo_filenames)) - print(grasp_demo_filenames, place_demo_filenames) - else: - log_warn("USING ALL %d DEMONSTRATIONS" % len(grasp_demo_filenames)) - - grasp_demo_filenames = grasp_demo_filenames[: hydra_cfg.num_demo] - place_demo_filenames = place_demo_filenames[: hydra_cfg.num_demo] - - max_bb_volume = 0 - place_xq_demo_idx = 0 - grasp_data_list = [] - place_data_list = [] - demo_rel_mat_list = [] - - # load all the demo data and look at objects to help decide on query points - for i, fname in enumerate(grasp_demo_filenames): - print("Loading demo from fname: %s" % fname) - grasp_demo_fn = grasp_demo_filenames[i] - place_demo_fn = place_demo_filenames[i] - grasp_data = np.load(grasp_demo_fn, allow_pickle=True) - place_data = np.load(place_demo_fn, allow_pickle=True) - - grasp_data_list.append(grasp_data) - place_data_list.append(place_data) - - start_ee_pose = grasp_data["ee_pose_world"].tolist() - end_ee_pose = place_data["ee_pose_world"].tolist() - place_rel_mat = util.get_transform( - pose_frame_target=util.list2pose_stamped(end_ee_pose), - pose_frame_source=util.list2pose_stamped(start_ee_pose), - ) - place_rel_mat = util.matrix_from_pose(place_rel_mat) - demo_rel_mat_list.append(place_rel_mat) - - if i == 0: - ( - optimizer_gripper_pts, - rack_optimizer_gripper_pts, - shelf_optimizer_gripper_pts, - ) = process_xq_data(grasp_data, place_data, shelf=load_shelf) - ( - optimizer_gripper_pts_rs, - rack_optimizer_gripper_pts_rs, - shelf_optimizer_gripper_pts_rs, - ) = process_xq_rs_data(grasp_data, place_data, shelf=load_shelf) - - if cfg.DEMOS.PLACEMENT_SURFACE == "shelf": - print("Using shelf points") - place_optimizer_pts = shelf_optimizer_gripper_pts - place_optimizer_pts_rs = shelf_optimizer_gripper_pts_rs - else: - print("Using rack points") - place_optimizer_pts = rack_optimizer_gripper_pts - place_optimizer_pts_rs = rack_optimizer_gripper_pts_rs - - if cfg.DEMOS.PLACEMENT_SURFACE == "shelf": - target_info, rack_target_info, shapenet_id = process_demo_data_shelf( - grasp_data, place_data, cfg=None - ) - else: - target_info, rack_target_info, shapenet_id = process_demo_data_rack( - grasp_data, place_data, cfg=None - ) - - if cfg.DEMOS.PLACEMENT_SURFACE == "shelf": - rack_target_info["demo_query_pts"] = place_optimizer_pts - demo_target_info_list.append(target_info) - demo_rack_target_info_list.append(rack_target_info) - demo_shapenet_ids.append(shapenet_id) - - # get objects that we can use for testing - test_object_ids = [] - shapenet_id_list = ( - [fn.split("_")[0] for fn in os.listdir(shapenet_obj_dir)] - if obj_class == "mug" - else os.listdir(shapenet_obj_dir) - ) - for s_id in shapenet_id_list: - valid = s_id not in demo_shapenet_ids and s_id not in avoid_shapenet_ids - if hydra_cfg.only_test_ids: - valid = valid and (s_id in test_shapenet_ids) - - if valid: - test_object_ids.append(s_id) - - if hydra_cfg.single_instance: - test_object_ids = [demo_shapenet_ids[0]] - - # reset - robot.arm.reset(force_reset=True) - robot.cam.setup_camera( - focus_pt=[0.4, 0.0, table_z], dist=0.9, yaw=45, pitch=-25, roll=0 - ) - - cams = MultiCams(cfg.CAMERA, robot.pb_client, n_cams=cfg.N_CAMERAS) - cam_info = {} - cam_info["pose_world"] = [] - for cam in cams.cams: - cam_info["pose_world"].append(util.pose_from_matrix(cam.cam_ext_mat)) - - # put table at right spot - table_ori = euler2quat([0, 0, np.pi / 2]) - - # this is the URDF that was used in the demos -- make sure we load an identical one - tmp_urdf_fname = osp.join( - path_util.get_ndf_descriptions(), "hanging/table/table_rack_tmp.urdf" - ) - open(tmp_urdf_fname, "w").write(grasp_data["table_urdf"].item()) - table_id = robot.pb_client.load_urdf( - tmp_urdf_fname, cfg.TABLE_POS, table_ori, scaling=cfg.TABLE_SCALING - ) - - if obj_class == "mug": - rack_link_id = 0 - shelf_link_id = 1 - elif obj_class in ["bowl", "bottle"]: - rack_link_id = None - shelf_link_id = 0 - - if cfg.DEMOS.PLACEMENT_SURFACE == "shelf": - placement_link_id = shelf_link_id - else: - placement_link_id = rack_link_id - - def hide_link(obj_id, link_id): - if link_id is not None: - p.changeVisualShape(obj_id, link_id, rgbaColor=[0, 0, 0, 0]) - - def show_link(obj_id, link_id, color): - if link_id is not None: - p.changeVisualShape(obj_id, link_id, rgbaColor=color) - - pl.seed_everything(hydra_cfg.seed) - - if hydra_cfg.ablation.mlp: - network = CorrespondenceFlow_DiffEmbMLP( - emb_dims=hydra_cfg.ablation.emb_dims, - emb_nn=hydra_cfg.emb_nn, - center_feature=hydra_cfg.center_feature, - ) - - network = ResidualFlow_DiffEmbTransformer( - emb_dims=hydra_cfg.ablation.emb_dims, - emb_nn=hydra_cfg.emb_nn, - center_feature=hydra_cfg.center_feature, - pred_weight=hydra_cfg.ablation.pred_weight, - residual_on=hydra_cfg.ablation.residual_on, - return_flow_component=hydra_cfg.return_flow_component, - freeze_embnn=hydra_cfg.ablation.freeze_embnn, - return_attn=hydra_cfg.return_attn, - ) - - place_model = EquivarianceTestingModule( - network, - lr=hydra_cfg.lr, - image_log_period=hydra_cfg.image_logging_period, - weight_normalize=hydra_cfg.weight_normalize_place, - loop=hydra_cfg.loop, - ) - - place_model.cuda() - - if hydra_cfg.checkpoint_file_place is not None: - place_model.load_state_dict( - torch.load(hydra_cfg.checkpoint_file_place)["state_dict"] - ) - log_info("Model Loaded from " + str(hydra_cfg.checkpoint_file_place)) - - grasp_model = EquivarianceTestingModule( - network, - lr=hydra_cfg.lr, - image_log_period=hydra_cfg.image_logging_period, - weight_normalize=hydra_cfg.weight_normalize_grasp, - softmax_temperature=hydra_cfg.softmax_temperature_grasp, - loop=hydra_cfg.loop, - ) - - grasp_model.cuda() - - if hydra_cfg.checkpoint_file_grasp is not None: - grasp_model.load_state_dict( - torch.load(hydra_cfg.checkpoint_file_grasp)["state_dict"] - ) - log_info("Model Loaded from " + str(hydra_cfg.checkpoint_file_grasp)) - - for iteration in range(hydra_cfg.start_iteration, hydra_cfg.num_iterations): - # load a test object - obj_shapenet_id = random.sample(test_object_ids, 1)[0] - id_str = "Shapenet ID: %s" % obj_shapenet_id - log_info(id_str) - - viz_dict = {} # will hold information that's useful for post-run visualizations - eval_iter_dir = osp.join(eval_save_dir, "trial_%d" % iteration) - util.safe_makedirs(eval_iter_dir) - - if obj_class in ["bottle", "jar", "bowl", "mug"]: - upright_orientation = common.euler2quat([np.pi / 2, 0, 0]).tolist() - else: - upright_orientation = common.euler2quat([0, 0, 0]).tolist() - - # for testing, use the "normalized" object - obj_obj_file = osp.join( - shapenet_obj_dir, obj_shapenet_id, "models/model_normalized.obj" - ) - obj_obj_file_dec = obj_obj_file.split(".obj")[0] + "_dec.obj" - - scale_high, scale_low = cfg.MESH_SCALE_HIGH, cfg.MESH_SCALE_LOW - scale_default = cfg.MESH_SCALE_DEFAULT - if hydra_cfg.rand_mesh_scale: - mesh_scale = [np.random.random() * (scale_high - scale_low) + scale_low] * 3 - else: - mesh_scale = [scale_default] * 3 - - if hydra_cfg.pose_dist.any_pose: - if obj_class in ["bowl", "bottle"]: - rp = np.random.rand(2) * (2 * np.pi / 3) - (np.pi / 3) - ori = common.euler2quat([rp[0], rp[1], 0]).tolist() - else: - rpy = np.random.rand(3) * (2 * np.pi / 3) - (np.pi / 3) - ori = common.euler2quat([rpy[0], rpy[1], rpy[2]]).tolist() - - pos = [ - np.random.random() * (x_high - x_low) + x_low, - np.random.random() * (y_high - y_low) + y_low, - table_z, - ] - pose = pos + ori - rand_yaw_T = util.rand_body_yaw_transform( - pos, min_theta=-np.pi, max_theta=np.pi - ) - pose_w_yaw = util.transform_pose( - util.list2pose_stamped(pose), util.pose_from_matrix(rand_yaw_T) - ) - pos, ori = ( - util.pose_stamped2list(pose_w_yaw)[:3], - util.pose_stamped2list(pose_w_yaw)[3:], - ) - else: - pos = [ - np.random.random() * (x_high - x_low) + x_low, - np.random.random() * (y_high - y_low) + y_low, - table_z, - ] - pose = util.list2pose_stamped(pos + upright_orientation) - rand_yaw_T = util.rand_body_yaw_transform( - pos, min_theta=-np.pi, max_theta=np.pi - ) - pose_w_yaw = util.transform_pose(pose, util.pose_from_matrix(rand_yaw_T)) - pos, ori = ( - util.pose_stamped2list(pose_w_yaw)[:3], - util.pose_stamped2list(pose_w_yaw)[3:], - ) - - viz_dict["shapenet_id"] = obj_shapenet_id - viz_dict["obj_obj_file"] = obj_obj_file - if "normalized" not in shapenet_obj_dir: - viz_dict["obj_obj_norm_file"] = osp.join( - shapenet_obj_dir + "_normalized", - obj_shapenet_id, - "models/model_normalized.obj", - ) - else: - viz_dict["obj_obj_norm_file"] = osp.join( - shapenet_obj_dir, obj_shapenet_id, "models/model_normalized.obj" - ) - viz_dict["obj_obj_file_dec"] = obj_obj_file_dec - viz_dict["mesh_scale"] = mesh_scale - - # convert mesh with vhacd - if not osp.exists(obj_obj_file_dec): - p.vhacd( - obj_obj_file, - obj_obj_file_dec, - "log.txt", - concavity=0.0025, - alpha=0.04, - beta=0.05, - gamma=0.00125, - minVolumePerCH=0.0001, - resolution=1000000, - depth=20, - planeDownsampling=4, - convexhullDownsampling=4, - pca=0, - mode=0, - convexhullApproximation=1, - ) - - robot.arm.go_home(ignore_physics=True) - robot.arm.move_ee_xyz([0, 0, 0.2]) - - if hydra_cfg.pose_dist.any_pose: - robot.pb_client.set_step_sim(True) - if obj_class in ["bowl"]: - robot.pb_client.set_step_sim(True) - - obj_id = robot.pb_client.load_geom( - "mesh", - mass=0.01, - mesh_scale=mesh_scale, - visualfile=obj_obj_file_dec, - collifile=obj_obj_file_dec, - base_pos=pos, - base_ori=ori, - ) - p.changeDynamics(obj_id, -1, lateralFriction=0.5) - log_info("any_pose:{}".format(hydra_cfg.pose_dist.any_pose)) - if obj_class == "bowl": - safeCollisionFilterPair( - bodyUniqueIdA=obj_id, - bodyUniqueIdB=table_id, - linkIndexA=-1, - linkIndexB=rack_link_id, - enableCollision=False, - ) - safeCollisionFilterPair( - bodyUniqueIdA=obj_id, - bodyUniqueIdB=table_id, - linkIndexA=-1, - linkIndexB=shelf_link_id, - enableCollision=False, - ) - robot.pb_client.set_step_sim(False) - - o_cid = None - if hydra_cfg.pose_dist.any_pose: - o_cid = constraint_obj_world(obj_id, pos, ori) - robot.pb_client.set_step_sim(False) - safeCollisionFilterPair(obj_id, table_id, -1, -1, enableCollision=True) - p.changeDynamics(obj_id, -1, linearDamping=5, angularDamping=5) - time.sleep(1.5) - - hide_link(table_id, rack_link_id) - - obj_pose_world = p.getBasePositionAndOrientation(obj_id) - obj_pose_world = util.list2pose_stamped( - list(obj_pose_world[0]) + list(obj_pose_world[1]) - ) - viz_dict["start_obj_pose"] = util.pose_stamped2list(obj_pose_world) - - if obj_class == "mug": - rack_color = p.getVisualShapeData(table_id)[rack_link_id][7] - show_link(table_id, rack_link_id, rack_color) - - time.sleep(1.5) - teleport_rgb = robot.cam.get_images(get_rgb=True)[0] - teleport_img_fname = osp.join(eval_teleport_imgs_dir, "%d_init.png" % iteration) - np2img(teleport_rgb.astype(np.uint8), teleport_img_fname) - cloud_points, cloud_colors, cloud_classes = get_clouds(cams) - obj_points, obj_colors, obj_classes = get_object_clouds(cams) - - points_mug_raw, points_rack_raw = load_data_raw( - num_points=1024, - clouds=obj_points, - classes=obj_classes, - action_class=0, - anchor_class=1, - ) - if points_mug_raw is None: - continue - points_gripper_raw, points_mug_raw = load_data_raw( - num_points=1024, - clouds=obj_points, - classes=obj_classes, - action_class=2, - anchor_class=0, - ) - points_mug, points_rack = load_data( - num_points=1024, - clouds=obj_points, - classes=obj_classes, - action_class=0, - anchor_class=1, - ) - - ans = place_model.get_transform(points_mug, points_rack) # 1, 4, 4 - - pred_T_action_init = ans["pred_T_action"] - pred_T_action_mat = pred_T_action_init.get_matrix()[0].T.detach().cpu().numpy() - obj_pose_world = p.getBasePositionAndOrientation(obj_id) # list - obj_pose_world = util.list2pose_stamped( - list(obj_pose_world[0]) + list(obj_pose_world[1]) - ) # stamped_pose - obj_start_pose = obj_pose_world - rack_relative_pose = get_world_transform( - pred_T_action_mat, obj_start_pose, points_mug_raw - ) # pose_stamped - obj_end_pose_list = util.pose_stamped2list(rack_relative_pose) - transform_rack_relative_pose = util.get_transform( - rack_relative_pose, obj_start_pose - ) - pose_tuple = robot.arm.get_ee_pose() - ee_pose_world = util.list2pose_stamped( - list(pose_tuple[0]) + list(pose_tuple[1]) - ) - # Get Grasp Pose - points_gripper, points_mug = load_data( - num_points=1024, - clouds=obj_points, - classes=obj_classes, - action_class=2, - anchor_class=0, - ) - ans_grasp = grasp_model.get_transform(points_gripper, points_mug) # 1, 4, 4 - pred_T_action_init_gripper2mug = ans_grasp["pred_T_action"] - pred_T_action_mat_gripper2mug = ( - pred_T_action_init_gripper2mug.get_matrix()[0].T.detach().cpu().numpy() - ) - pred_T_action_mat_gripper2mug[2, -1] -= 0.001 - - gripper_relative_pose = get_world_transform( - pred_T_action_mat_gripper2mug, ee_pose_world, points_gripper_raw - ) # transform from gripper to mug in world frame - pre_grasp_ee_pose = util.pose_stamped2list(gripper_relative_pose) - - np.savez( - f"{save_dir}/{iteration}_init_all_points.npz", - clouds=cloud_points, - colors=cloud_colors, - classes=cloud_classes, - shapenet_id=obj_shapenet_id, - ) - - np.savez( - f"{save_dir}/{iteration}_init_obj_points.npz", - clouds=obj_points, - colors=obj_colors, - classes=obj_classes, - shapenet_id=obj_shapenet_id, - points_mug_raw=points_mug_raw.detach().cpu(), - points_gripper_raw=points_gripper_raw.detach().cpu(), - points_rack_raw=points_rack_raw.detach().cpu(), - pred_T_action_mat=pred_T_action_mat, - pred_T_action_mat_gripper2mug=pred_T_action_mat_gripper2mug, - ) - # log_info("Saved point cloud data to:") - # log_info(f"{save_dir}/{iteration}_init_obj_points.npz") - - # optimize grasp pose - viz_dict["start_ee_pose"] = pre_grasp_ee_pose - - ########################### grasp post-process ############################# - - pregrasp_offset_tf = get_ee_offset(ee_pose=pre_grasp_ee_pose) - # pre_pre_grasp_ee_pose = pre_grasp_ee_pose - pre_pre_grasp_ee_pose = util.pose_stamped2list( - util.transform_pose( - pose_source=util.list2pose_stamped(pre_grasp_ee_pose), - pose_transform=util.list2pose_stamped(pregrasp_offset_tf), - ) - ) - - # reset object to placement pose to detect placement success - safeCollisionFilterPair(obj_id, table_id, -1, -1, enableCollision=False) - safeCollisionFilterPair( - obj_id, table_id, -1, placement_link_id, enableCollision=False - ) - robot.pb_client.set_step_sim(True) - safeRemoveConstraint(o_cid) - robot.pb_client.reset_body(obj_id, obj_end_pose_list[:3], obj_end_pose_list[3:]) - - cloud_points, cloud_colors, cloud_classes = get_clouds(cams) - obj_points, obj_colors, obj_classes = get_object_clouds(cams) - - np.savez( - f"{save_dir}/{iteration}_teleport_all_points.npz", - clouds=cloud_points, - colors=cloud_colors, - classes=cloud_classes, - shapenet_id=obj_shapenet_id, - ) - - np.savez( - f"{save_dir}/{iteration}_teleport_obj_points.npz", - clouds=obj_points, - colors=obj_colors, - classes=obj_classes, - shapenet_id=obj_shapenet_id, - ) - - time.sleep(1.0) - teleport_rgb = robot.cam.get_images(get_rgb=True)[0] - teleport_img_fname = osp.join( - eval_teleport_imgs_dir, "teleport_%d.png" % iteration - ) - np2img(teleport_rgb.astype(np.uint8), teleport_img_fname) - safeCollisionFilterPair( - obj_id, table_id, -1, placement_link_id, enableCollision=True - ) - robot.pb_client.set_step_sim(False) - time.sleep(1.0) - - cloud_points, cloud_colors, cloud_classes = get_clouds(cams) - obj_points, obj_colors, obj_classes = get_object_clouds(cams) - - np.savez( - f"{save_dir}/{iteration}_post_teleport_all_points.npz", - clouds=cloud_points, - colors=cloud_colors, - classes=cloud_classes, - shapenet_id=obj_shapenet_id, - ) - - np.savez( - f"{save_dir}/{iteration}_post_teleport_obj_points.npz", - clouds=obj_points, - colors=obj_colors, - classes=obj_classes, - shapenet_id=obj_shapenet_id, - ) - - time.sleep(1.0) - teleport_rgb = robot.cam.get_images(get_rgb=True)[0] - teleport_img_fname = osp.join( - eval_teleport_imgs_dir, "post_teleport_%d.png" % iteration - ) - np2img(teleport_rgb.astype(np.uint8), teleport_img_fname) - - obj_surf_contacts = p.getContactPoints(obj_id, table_id, -1, placement_link_id) - touching_surf = len(obj_surf_contacts) > 0 - place_success_teleport = touching_surf - place_success_teleport_list.append(place_success_teleport) - if not place_success_teleport: - place_fail_teleport_list.append(iteration) - - time.sleep(1.0) - safeCollisionFilterPair(obj_id, table_id, -1, -1, enableCollision=True) - robot.pb_client.reset_body(obj_id, pos, ori) - - # attempt grasp and solve for plan to execute placement with arm - jnt_pos = grasp_jnt_pos = grasp_plan = None - place_success = grasp_success = False - for g_idx in range(2): - # reset everything - robot.pb_client.set_step_sim(False) - safeCollisionFilterPair(obj_id, table_id, -1, -1, enableCollision=True) - if hydra_cfg.pose_dist.any_pose: - robot.pb_client.set_step_sim(True) - safeRemoveConstraint(o_cid) - p.resetBasePositionAndOrientation(obj_id, pos, ori) - print(p.getBasePositionAndOrientation(obj_id)) - time.sleep(0.5) - - if hydra_cfg.pose_dist.any_pose: - o_cid = constraint_obj_world(obj_id, pos, ori) - robot.pb_client.set_step_sim(False) - robot.arm.go_home(ignore_physics=True) - - # turn OFF collisions between robot and object / table, and move to pre-grasp pose - for i in range(p.getNumJoints(robot.arm.robot_id)): - safeCollisionFilterPair( - bodyUniqueIdA=robot.arm.robot_id, - bodyUniqueIdB=table_id, - linkIndexA=i, - linkIndexB=-1, - enableCollision=False, - physicsClientId=robot.pb_client.get_client_id(), - ) - safeCollisionFilterPair( - bodyUniqueIdA=robot.arm.robot_id, - bodyUniqueIdB=obj_id, - linkIndexA=i, - linkIndexB=-1, - enableCollision=False, - physicsClientId=robot.pb_client.get_client_id(), - ) - robot.arm.eetool.open() - - if jnt_pos is None or grasp_jnt_pos is None: - jnt_pos = ik_helper.get_feasible_ik(pre_pre_grasp_ee_pose) - grasp_jnt_pos = ik_helper.get_feasible_ik(pre_grasp_ee_pose) - - if jnt_pos is None or grasp_jnt_pos is None: - jnt_pos = ik_helper.get_ik(pre_pre_grasp_ee_pose) - grasp_jnt_pos = ik_helper.get_ik(pre_grasp_ee_pose) - - if jnt_pos is None or grasp_jnt_pos is None: - jnt_pos = robot.arm.compute_ik( - pre_pre_grasp_ee_pose[:3], pre_pre_grasp_ee_pose[3:] - ) - # this is the pose that's at the grasp, where we just need to close the fingers - grasp_jnt_pos = robot.arm.compute_ik( - pre_grasp_ee_pose[:3], pre_grasp_ee_pose[3:] - ) - - if grasp_jnt_pos is not None and jnt_pos is not None: - if g_idx == 0: - robot.pb_client.set_step_sim(True) - robot.arm.set_jpos(grasp_jnt_pos, ignore_physics=True) - robot.arm.eetool.close(ignore_physics=True) - time.sleep(0.2) - grasp_rgb = robot.cam.get_images(get_rgb=True)[0] - grasp_img_fname = osp.join( - eval_grasp_imgs_dir, "pre_grasp_%d.png" % iteration - ) - np2img(grasp_rgb.astype(np.uint8), grasp_img_fname) - cloud_points, cloud_colors, cloud_classes = get_clouds(cams) - obj_points, obj_colors, obj_classes = get_object_clouds(cams) - - np.savez( - f"{save_dir}/{iteration}_pre_grasp_all_points.npz", - clouds=cloud_points, - colors=cloud_colors, - classes=cloud_classes, - shapenet_id=obj_shapenet_id, - ) - - np.savez( - f"{save_dir}/{iteration}_pre_grasp_obj_points.npz", - clouds=obj_points, - colors=obj_colors, - classes=obj_classes, - shapenet_id=obj_shapenet_id, - ) - - continue - - ########################### planning to pre_pre_grasp and pre_grasp ########################## - if grasp_plan is None: - plan1 = ik_helper.plan_joint_motion(robot.arm.get_jpos(), jnt_pos) - plan2 = ik_helper.plan_joint_motion(jnt_pos, grasp_jnt_pos) - - if plan1 is not None and plan2 is not None: - grasp_plan = plan1 + plan2 - - robot.arm.eetool.open() - for jnt in plan1: - robot.arm.set_jpos(jnt, wait=False) - time.sleep(0.025) - robot.arm.set_jpos(plan1[-1], wait=True) - for jnt in plan2: - robot.arm.set_jpos(jnt, wait=False) - time.sleep(0.04) - robot.arm.set_jpos(grasp_plan[-1], wait=True) - - # get pose that's straight up - offset_pose = util.transform_pose( - pose_source=util.list2pose_stamped( - np.concatenate(robot.arm.get_ee_pose()[:2]).tolist() - ), - pose_transform=util.list2pose_stamped( - [0, 0, 0.15, 0, 0, 0, 1] - ), - ) - offset_pose_list = util.pose_stamped2list(offset_pose) - offset_jnts = ik_helper.get_feasible_ik(offset_pose_list) - - # turn ON collisions between robot and object, and close fingers - for i in range(p.getNumJoints(robot.arm.robot_id)): - safeCollisionFilterPair( - bodyUniqueIdA=robot.arm.robot_id, - bodyUniqueIdB=obj_id, - linkIndexA=i, - linkIndexB=-1, - enableCollision=True, - physicsClientId=robot.pb_client.get_client_id(), - ) - safeCollisionFilterPair( - bodyUniqueIdA=robot.arm.robot_id, - bodyUniqueIdB=table_id, - linkIndexA=i, - linkIndexB=rack_link_id, - enableCollision=False, - physicsClientId=robot.pb_client.get_client_id(), - ) - - time.sleep(0.8) - obj_pos_before_grasp = p.getBasePositionAndOrientation(obj_id)[ - 0 - ] - jnt_pos_before_grasp = robot.arm.get_jpos() - soft_grasp_close(robot, finger_joint_id, force=50) - safeRemoveConstraint(o_cid) - time.sleep(0.8) - safeCollisionFilterPair( - obj_id, table_id, -1, -1, enableCollision=False - ) - time.sleep(0.8) - grasp_rgb = robot.cam.get_images(get_rgb=True)[0] - grasp_img_fname = osp.join( - eval_grasp_imgs_dir, "post_grasp_%d.png" % iteration - ) - np2img(grasp_rgb.astype(np.uint8), grasp_img_fname) - cloud_points, cloud_colors, cloud_classes = get_clouds(cams) - obj_points, obj_colors, obj_classes = get_object_clouds(cams) - - np.savez( - f"{save_dir}/{iteration}_post_grasp_all_points.npz", - clouds=cloud_points, - colors=cloud_colors, - classes=cloud_classes, - shapenet_id=obj_shapenet_id, - ) - - np.savez( - f"{save_dir}/{iteration}_post_grasp_obj_points.npz", - clouds=obj_points, - colors=obj_colors, - classes=obj_classes, - shapenet_id=obj_shapenet_id, - ) - - if g_idx == 1: - grasp_success = object_is_still_grasped( - robot, obj_id, right_pad_id, left_pad_id - ) - - if grasp_success: - # turn OFF collisions between object / table and object / rack, and move to pre-place pose - safeCollisionFilterPair( - obj_id, table_id, -1, -1, enableCollision=True - ) - robot.arm.eetool.open() - p.resetBasePositionAndOrientation( - obj_id, obj_pos_before_grasp, ori - ) - soft_grasp_close(robot, finger_joint_id, force=40) - robot.arm.set_jpos( - jnt_pos_before_grasp, ignore_physics=True - ) - cid = constraint_grasp_close(robot, obj_id) - # grasp_rgb = robot.cam.get_images(get_rgb=True)[ - # 0] - # grasp_img_fname = osp.join( - # eval_grasp_imgs_dir, 'after_grasp_success_%d.png' % iteration) - # np2img(grasp_rgb.astype( - # np.uint8), grasp_img_fname) - ######################################################################################################### - - if offset_jnts is not None: - offset_plan = ik_helper.plan_joint_motion( - robot.arm.get_jpos(), offset_jnts - ) - - if offset_plan is not None: - for jnt in offset_plan: - robot.arm.set_jpos(jnt, wait=False) - time.sleep(0.04) - robot.arm.set_jpos(offset_plan[-1], wait=True) - - # turn OFF collisions between object / table and object / rack, and move to pre-place pose - safeCollisionFilterPair( - obj_id, table_id, -1, -1, enableCollision=False - ) - safeCollisionFilterPair( - obj_id, table_id, -1, rack_link_id, enableCollision=False - ) - time.sleep(1.0) - - if grasp_success: - # optimize placement pose - ee_end_pose = util.transform_pose( - pose_source=util.list2pose_stamped(pre_grasp_ee_pose), - pose_transform=transform_rack_relative_pose, - ) - pre_ee_end_pose2 = util.transform_pose( - pose_source=ee_end_pose, pose_transform=preplace_offset_tf - ) - pre_ee_end_pose1 = util.transform_pose( - pose_source=pre_ee_end_pose2, pose_transform=preplace_horizontal_tf - ) - - ee_end_pose_list = util.pose_stamped2list(ee_end_pose) - pre_ee_end_pose1_list = util.pose_stamped2list(pre_ee_end_pose1) - pre_ee_end_pose2_list = util.pose_stamped2list(pre_ee_end_pose2) - - ####################################### get place pose ########################################### - - pre_place_jnt_pos1 = ik_helper.get_feasible_ik(pre_ee_end_pose1_list) - pre_place_jnt_pos2 = ik_helper.get_feasible_ik(pre_ee_end_pose2_list) - place_jnt_pos = ik_helper.get_feasible_ik(ee_end_pose_list) - - if ( - place_jnt_pos is not None - and pre_place_jnt_pos2 is not None - and pre_place_jnt_pos1 is not None - ): - plan1 = ik_helper.plan_joint_motion( - robot.arm.get_jpos(), pre_place_jnt_pos1 - ) - plan2 = ik_helper.plan_joint_motion( - pre_place_jnt_pos1, pre_place_jnt_pos2 - ) - plan3 = ik_helper.plan_joint_motion(pre_place_jnt_pos2, place_jnt_pos) - - if plan1 is not None and plan2 is not None and plan3 is not None: - place_plan = plan1 + plan2 - - for jnt in place_plan: - robot.arm.set_jpos(jnt, wait=False) - time.sleep(0.035) - robot.arm.set_jpos(place_plan[-1], wait=True) - - ################################################################################################################ - - # turn ON collisions between object and rack, and open fingers - safeCollisionFilterPair( - obj_id, table_id, -1, -1, enableCollision=True - ) - safeCollisionFilterPair( - obj_id, table_id, -1, rack_link_id, enableCollision=True - ) - - for jnt in plan3: - robot.arm.set_jpos(jnt, wait=False) - time.sleep(0.075) - robot.arm.set_jpos(plan3[-1], wait=True) - - p.changeDynamics(obj_id, -1, linearDamping=5, angularDamping=5) - constraint_grasp_open(cid) - robot.arm.eetool.open() - - time.sleep(0.2) - for i in range(p.getNumJoints(robot.arm.robot_id)): - safeCollisionFilterPair( - bodyUniqueIdA=robot.arm.robot_id, - bodyUniqueIdB=obj_id, - linkIndexA=i, - linkIndexB=-1, - enableCollision=False, - physicsClientId=robot.pb_client.get_client_id(), - ) - robot.arm.move_ee_xyz([0, 0.075, 0.075]) - safeCollisionFilterPair( - obj_id, table_id, -1, -1, enableCollision=False - ) - time.sleep(4.0) - - # observe and record outcome - obj_surf_contacts = p.getContactPoints( - obj_id, table_id, -1, placement_link_id - ) - touching_surf = len(obj_surf_contacts) > 0 - obj_floor_contacts = p.getContactPoints( - obj_id, robot.arm.floor_id, -1, -1 - ) - touching_floor = len(obj_floor_contacts) > 0 - place_success = touching_surf and not touching_floor - - robot.arm.go_home(ignore_physics=True) - - place_success_list.append(place_success) - grasp_success_list.append(grasp_success) - if not place_success: - place_fail_list.append(iteration) - if not grasp_success: - grasp_fail_list.append(iteration) - log_str = "Iteration: %d, " % iteration - kvs = {} - # kvs["Place [teleport] Success"] = place_success_teleport_list[-1] - # kvs["Grasp Success"] = grasp_success_list[-1] - - overall_success_num = 0 - for i in range(len(grasp_success_list)): - if place_success_teleport_list[i] == 1 and grasp_success_list[i] == 1: - overall_success_num += 1 - kvs["Grasp Success Rate"] = sum(grasp_success_list) / float( - len(grasp_success_list) - ) - kvs["Place [teleport] Success Rate"] = sum(place_success_teleport_list) / float( - len(place_success_teleport_list) - ) - kvs["overall success Rate"] = overall_success_num / float( - len(grasp_success_list) - ) - if iteration == 0: - write_to_file(log_txt_file, "\n") - write_to_file(log_txt_file, "cwd:" + os.getcwd()) - write_to_file( - log_txt_file, - "pose_distribution: {}".format( - "arbitrary" if hydra_cfg.pose_dist.any_pose else "upright" - ), - ) - write_to_file(log_txt_file, "seed: {}".format(hydra_cfg.seed)) - log_info("checkpoint_file_grasp: " + hydra_cfg.checkpoint_file_grasp) - write_to_file( - log_txt_file, - "checkpoint_file_grasp: " + hydra_cfg.checkpoint_file_grasp, - ) - log_info("checkpoint_file_place: " + hydra_cfg.checkpoint_file_place) - write_to_file( - log_txt_file, - "checkpoint_file_place: " + hydra_cfg.checkpoint_file_place, - ) - if hydra_cfg.log_every_trial: - for k, v in kvs.items(): - log_str += "%s: %.3f, " % (k, v) - # id_str = ", shapenet_id: %s" % obj_shapenet_id - log_info(log_str) - write_to_file(log_txt_file, log_str) - - else: - if iteration == hydra_cfg.num_iterations - 1: - write_to_file(log_txt_file, "cwd:" + os.getcwd()) - write_to_file( - log_txt_file, - "pose_distribution: {}".format( - "arbitrary" if hydra_cfg.pose_dist.any_pose else "upright" - ), - ) - write_to_file(log_txt_file, "seed: {}".format(hydra_cfg.seed)) - for k, v in kvs.items(): - log_str += "%s: %.3f, " % (k, v) - # id_str = ", shapenet_id: %s" % obj_shapenet_id - log_info(log_str) - write_to_file(log_txt_file, log_str) - log_info("checkpoint_file_grasp: " + hydra_cfg.checkpoint_file_grasp) - write_to_file( - log_txt_file, - "checkpoint_file_grasp: " + hydra_cfg.checkpoint_file_grasp, - ) - log_info("checkpoint_file_place: " + hydra_cfg.checkpoint_file_place) - write_to_file( - log_txt_file, - "checkpoint_file_place: " + hydra_cfg.checkpoint_file_place, - ) - write_to_file(log_txt_file, "\n") - - eval_iter_dir = osp.join(eval_save_dir, "trial_%d" % iteration) - if not osp.exists(eval_iter_dir): - os.makedirs(eval_iter_dir) - sample_fname = osp.join(eval_iter_dir, "success_rate_eval_implicit.npz") - np.savez( - sample_fname, - obj_shapenet_id=obj_shapenet_id, - success=success_list, - grasp_success=grasp_success, - place_success=place_success, - place_success_teleport=place_success_teleport, - grasp_success_list=grasp_success_list, - place_success_list=place_success_list, - place_success_teleport_list=place_success_teleport_list, - start_obj_pose=util.pose_stamped2list(obj_start_pose), - best_place_obj_pose=obj_end_pose_list, - mesh_file=obj_obj_file, - distractor_info=None, - args=hydra_cfg.__dict__, - global_dict=global_dict, - cfg=util.cn2dict(cfg), - obj_cfg=util.cn2dict(obj_cfg), - ) - - robot.pb_client.remove_body(obj_id) - - -if __name__ == "__main__": - signal.signal(signal.SIGINT, util.signal_handler) - - main() diff --git a/scripts/train_residual_flow.py b/scripts/train_residual_flow.py index ddfef97..1f7204b 100644 --- a/scripts/train_residual_flow.py +++ b/scripts/train_residual_flow.py @@ -10,7 +10,10 @@ from pytorch_lightning.loggers import WandbLogger from taxpose.datasets.point_cloud_data_module import MultiviewDataModule -from taxpose.nets.transformer_flow import ResidualFlow_DiffEmbTransformer +from taxpose.nets.transformer_flow import ( + CorrespondenceFlow_DiffEmbMLP, + ResidualFlow_DiffEmbTransformer, +) from taxpose.training.flow_equivariance_training_module_nocentering import ( EquivarianceTrainingModule, ) @@ -58,6 +61,8 @@ def main(cfg): print(OmegaConf.to_yaml(cfg, resolve=True)) # torch.set_float32_matmul_precision("medium") + TESTING = "PYTEST_CURRENT_TEST" in os.environ + pl.seed_everything(cfg.seed) logger = WandbLogger( entity=cfg.wandb.entity, @@ -74,35 +79,40 @@ def main(cfg): # logger.log_hyperparams(cfg) # logger.log_hyperparams({"working_dir": os.getcwd()}) trainer = pl.Trainer( - logger=logger, + logger=logger if not TESTING else False, accelerator="gpu", devices=[0], log_every_n_steps=cfg.training.log_every_n_steps, check_val_every_n_epoch=cfg.training.check_val_every_n_epoch, # reload_dataloaders_every_n_epochs=1, # callbacks=[SaverCallbackModel(), SaverCallbackEmbnnActionAnchor()], - callbacks=[ - # This checkpoint callback saves the latest model during training, i.e. so we can resume if it crashes. - # It saves everything, and you can load by referencing last.ckpt. - ModelCheckpoint( - dirpath=cfg.lightning.checkpoint_dir, - filename="{epoch}-{step}", - monitor="step", - mode="max", - save_weights_only=False, - save_last=True, - every_n_epochs=1, - ), - # This checkpoint will get saved to WandB. The Callback mechanism in lightning is poorly designed, so we have to put it last. - ModelCheckpoint( - dirpath=cfg.lightning.checkpoint_dir, - filename="{epoch}-{step}-{train_loss:.2f}-weights-only", - monitor="val_loss", - mode="min", - save_weights_only=True, - ), - ], + callbacks=( + [ + # This checkpoint callback saves the latest model during training, i.e. so we can resume if it crashes. + # It saves everything, and you can load by referencing last.ckpt. + ModelCheckpoint( + dirpath=cfg.lightning.checkpoint_dir, + filename="{epoch}-{step}", + monitor="step", + mode="max", + save_weights_only=False, + save_last=True, + every_n_epochs=1, + ), + # This checkpoint will get saved to WandB. The Callback mechanism in lightning is poorly designed, so we have to put it last. + ModelCheckpoint( + dirpath=cfg.lightning.checkpoint_dir, + filename="{epoch}-{step}-{train_loss:.2f}-weights-only", + monitor="val_loss", + mode="min", + save_weights_only=True, + ), + ] + if not TESTING + else [] + ), max_epochs=cfg.training.max_epochs, + fast_dev_run=5 if "PYTEST_CURRENT_TEST" in os.environ else False, ) dm = MultiviewDataModule( @@ -113,18 +123,25 @@ def main(cfg): dm.setup() - network = ResidualFlow_DiffEmbTransformer( - emb_dims=cfg.model.emb_dims, - emb_nn=cfg.model.emb_nn, - return_flow_component=cfg.model.return_flow_component, - center_feature=cfg.model.center_feature, - pred_weight=cfg.model.pred_weight, - multilaterate=cfg.model.multilaterate, - sample=cfg.model.mlat_sample, - mlat_nkps=cfg.model.mlat_nkps, - break_symmetry=cfg.break_symmetry, - conditional=cfg.model.conditional if "conditional" in cfg.model else False, - ) + if cfg.mlp: + network = CorrespondenceFlow_DiffEmbMLP( + emb_dims=cfg.emb_dims, + emb_nn=cfg.emb_nn, + center_feature=cfg.center_feature, + ) + else: + network = ResidualFlow_DiffEmbTransformer( + emb_dims=cfg.model.emb_dims, + emb_nn=cfg.model.emb_nn, + return_flow_component=cfg.model.return_flow_component, + center_feature=cfg.model.center_feature, + pred_weight=cfg.model.pred_weight, + multilaterate=cfg.model.multilaterate, + sample=cfg.model.mlat_sample, + mlat_nkps=cfg.model.mlat_nkps, + break_symmetry=cfg.break_symmetry, + conditional=cfg.model.conditional if "conditional" in cfg.model else False, + ) model = EquivarianceTrainingModule( network, diff --git a/scripts/train_residual_flow_ablation.py b/scripts/train_residual_flow_ablation.py deleted file mode 100644 index 6f779c1..0000000 --- a/scripts/train_residual_flow_ablation.py +++ /dev/null @@ -1,174 +0,0 @@ -import os - -import hydra -import numpy as np -import pytorch_lightning as pl -import torch -from pytorch_lightning.loggers import WandbLogger - -from taxpose.datasets.point_cloud_data_module import MultiviewDataModule -from taxpose.nets.transformer_flow import ( - CorrespondenceFlow_DiffEmbMLP, - ResidualFlow_DiffEmbTransformer, -) -from taxpose.training.flow_equivariance_training_module_nocentering import ( - EquivarianceTrainingModule, -) -from taxpose.utils.callbacks import SaverCallbackEmbnnActionAnchor, SaverCallbackModel - - -def write_to_file(file_name, string): - with open(file_name, "a") as f: - f.writelines(string) - f.write("\n") - f.close() - - -@hydra.main(config_path="../configs", config_name="train_mug_residual_ablation") -def main(cfg): - pl.seed_everything(cfg.seed) - logger = WandbLogger(project=cfg.experiment) - logger.log_hyperparams(cfg) - logger.log_hyperparams({"working_dir": os.getcwd()}) - trainer = pl.Trainer( - logger=logger, - gpus=1, - reload_dataloaders_every_n_epochs=1, - callbacks=[SaverCallbackModel(), SaverCallbackEmbnnActionAnchor()], - max_epochs=cfg.max_epochs, - ) - log_txt_file = cfg.log_txt_file - - if cfg.ablation.name == "7_no_pretraining": - cfg.checkpoint_file_action = cfg.ablation.checkpoint_file_action - cfg.checkpoint_file_anchor = cfg.ablation.checkpoint_file_anchor - else: - cfg.checkpoint_file_action = cfg.task.checkpoint_file_action - cfg.checkpoint_file_anchor = cfg.task.checkpoint_file_anchor - - write_to_file(log_txt_file, "working_dir: {}".format(os.getcwd())) - write_to_file(log_txt_file, "ablation: {}".format(cfg.ablation.name)) - write_to_file( - log_txt_file, - "consistency_loss_weight: {}".format(cfg.ablation.consistency_loss_weight), - ) - write_to_file( - log_txt_file, - "direct_correspondence_loss_weight: {}".format( - cfg.ablation.direct_correspondence_loss_weight - ), - ) - write_to_file( - log_txt_file, - "displace_loss_weight: {}".format(cfg.ablation.displace_loss_weight), - ) - write_to_file(log_txt_file, "residual_on: {}".format(cfg.ablation.residual_on)) - write_to_file(log_txt_file, "pred_weight: {}".format(cfg.ablation.pred_weight)) - write_to_file(log_txt_file, "freeze_embnn: {}".format(cfg.ablation.freeze_embnn)) - write_to_file( - log_txt_file, "checkpoint_file_action: {}".format(cfg.checkpoint_file_action) - ) - write_to_file( - log_txt_file, "checkpoint_file_anchor: {}".format(cfg.checkpoint_file_anchor) - ) - write_to_file(log_txt_file, "mlp: {}".format(cfg.ablation.mlp)) - - write_to_file(log_txt_file, "") - dm = MultiviewDataModule( - dataset_root=hydra.utils.to_absolute_path(cfg.train_data_dir), - test_dataset_root=hydra.utils.to_absolute_path(cfg.test_data_dir), - dataset_index=cfg.dataset_index, - action_class=cfg.task.action_class, - anchor_class=cfg.task.anchor_class, - dataset_size=cfg.dataset_size, - rotation_variance=np.pi / 180 * cfg.rotation_variance, - translation_variance=cfg.translation_variance, - batch_size=cfg.batch_size, - num_workers=cfg.num_workers, - cloud_type=cfg.task.cloud_type, - num_points=cfg.num_points, - overfit=cfg.overfit, - synthetic_occlusion=cfg.synthetic_occlusion, - ball_radius=cfg.ball_radius, - ball_occlusion=cfg.ball_occlusion, - plane_standoff=cfg.plane_standoff, - plane_occlusion=cfg.plane_occlusion, - num_demo=cfg.num_demo, - occlusion_class=cfg.occlusion_class, - ) - - dm.setup() - - if cfg.ablation.mlp: - network = CorrespondenceFlow_DiffEmbMLP( - emb_dims=cfg.ablation.emb_dims, - emb_nn=cfg.emb_nn, - center_feature=cfg.center_feature, - ) - else: - network = ResidualFlow_DiffEmbTransformer( - emb_dims=cfg.ablation.emb_dims, - emb_nn=cfg.emb_nn, - return_flow_component=cfg.return_flow_component, - center_feature=cfg.center_feature, - pred_weight=cfg.ablation.pred_weight, - residual_on=cfg.ablation.residual_on, - freeze_embnn=cfg.ablation.freeze_embnn, - ) - - model = EquivarianceTrainingModule( - network, - lr=cfg.lr, - image_log_period=cfg.image_logging_period, - displace_loss_weight=cfg.ablation.displace_loss_weight, - consistency_loss_weight=cfg.ablation.consistency_loss_weight, - direct_correspondence_loss_weight=cfg.ablation.direct_correspondence_loss_weight, - weight_normalize=cfg.task.weight_normalize, - sigmoid_on=cfg.sigmoid_on, - softmax_temperature=cfg.task.softmax_temperature, - flow_supervision=cfg.flow_supervision, - ) - - model.cuda() - model.train() - if cfg.checkpoint_file is not None: - print("loaded checkpoint from") - print(cfg.checkpoint_file) - model.load_state_dict( - torch.load(hydra.utils.to_absolute_path(cfg.checkpoint_file))["state_dict"] - ) - - else: - if cfg.checkpoint_file_action is not None: - model.model.emb_nn_action.load_state_dict( - torch.load(hydra.utils.to_absolute_path(cfg.checkpoint_file_action))[ - "embnn_state_dict" - ] - ) - print( - "-----------------------Pretrained EmbNN Action Model Loaded!-----------------------" - ) - print( - "Loaded Pretrained EmbNN Action: {}".format(cfg.checkpoint_file_action) - ) - if cfg.checkpoint_file_anchor is not None: - model.model.emb_nn_anchor.load_state_dict( - torch.load(hydra.utils.to_absolute_path(cfg.checkpoint_file_anchor))[ - "embnn_state_dict" - ] - ) - print( - "-----------------------Pretrained EmbNN Anchor Model Loaded!-----------------------" - ) - print( - "Loaded Pretrained EmbNN Anchor: {}".format(cfg.checkpoint_file_anchor) - ) - - trainer.fit(model, dm) - - -if __name__ == "__main__": - torch.autograd.set_detect_anomaly(True) - torch.cuda.empty_cache() - torch.multiprocessing.set_sharing_strategy("file_system") - main() diff --git a/taxpose/nets/transformer_flow.py b/taxpose/nets/transformer_flow.py index d3a539b..a4282ba 100644 --- a/taxpose/nets/transformer_flow.py +++ b/taxpose/nets/transformer_flow.py @@ -85,6 +85,10 @@ def forward(self, *input): scores=None, ).permute(0, 2, 1) + outputs = { + "flow_action": flow_action, + } + if self.cycle: flow_anchor = self.head_anchor( anchor_embedding_tf, @@ -93,8 +97,9 @@ def forward(self, *input): action_points, scores=None, ).permute(0, 2, 1) - return flow_action, flow_anchor - return flow_action + outputs["flow_anchor"] = flow_anchor + + return outputs class CorrespondenceMLPHead(nn.Module): diff --git a/tests/config_test.py b/tests/config_test.py index 195089c..cb5ffda 100644 --- a/tests/config_test.py +++ b/tests/config_test.py @@ -52,7 +52,7 @@ def test_rlbench_commands_compile(config_name): assert cfg.wandb.save_dir is not None -@pytest.mark.parametrize("config_name", _get_config_names("rlbench")) +@pytest.mark.parametrize("config_name", _get_config_names("ndf")) def test_ndf_commands_compile(config_name): with initialize(version_base=None, config_path="../configs"): cfg = compose( @@ -70,6 +70,3 @@ def test_ndf_commands_compile(config_name): # Resolve to yaml. yaml_cfg = OmegaConf.to_yaml(cfg, resolve=True) - - assert cfg.job_type is not None - assert cfg.wandb.save_dir is not None diff --git a/tests/train_test.py b/tests/train_test.py new file mode 100644 index 0000000..ff57bd2 --- /dev/null +++ b/tests/train_test.py @@ -0,0 +1,112 @@ +# Much around with the path to make the import work +import os +import sys +from pathlib import Path + +import pytest +import torch +from hydra import compose, initialize +from hydra.core.hydra_config import HydraConfig + +# Add the parent directory to the path to import the script. Hacky, but it works. +THIS_DIR = Path(__file__).resolve().parent +sys.path.append(str(THIS_DIR.parent)) + +from scripts.train_residual_flow import main + + +def _get_training_config_names(bmark, ablation=False): + # Get config paths from the configs/commands directory, relative to the commands directory. + configs = [path for path in Path(f"configs/commands/{bmark}").rglob("*.yaml")] + + # Strip the "configs/" prefix. + configs = [str(path)[8:] for path in configs] + + # Filter out paths with basenames that have a leading underscore. + configs = [config for config in configs if not Path(config).name.startswith("_")] + + # Filter out paths that don't include the word "train" in the path. + configs = [config for config in configs if "train_" in config] + + if ablation: + # Filter out paths that don't include the word "ablation" in the path except for + configs = [ + config + for config in configs + if "ablation" in config and "n_demo" not in config + ] + else: + # Filter out paths that include the word "ablation" in the path. + configs = [ + config + for config in configs + if "ablation" not in config or "n_demo" in config + ] + + # Filter out paths with any folder that have a leading underscore. + configs = [ + config + for config in configs + if not any(folder.startswith("_") for folder in Path(config).parts) + ] + + return configs + + +DEFAULT_NDF_PATH = "/data/ndf" + + +def _test_commands_run(config_name): + dataset_root = ( + os.environ["NDF_DATASET_ROOT"] + if "NDF_DATASET_ROOT" in os.environ + else DEFAULT_NDF_PATH + ) + torch.multiprocessing.set_sharing_strategy("file_system") + + with initialize(version_base=None, config_path="../configs"): + cfg = compose( + config_name=config_name, + overrides=[ + "hydra.verbose=true", + "hydra.job.num=1", + "hydra.runtime.output_dir=.", + "seed=1234", + f"dataset_root={dataset_root}", + "batch_size=2", + ], + return_hydra_config=True, + ) + # Resolve the config + HydraConfig.instance().set_config(cfg) + + # Just for this function call, set the environment variable to WANDB_MODE=disabled + os.environ["WANDB_MODE"] = "disabled" + # Run the training script. + main(cfg) + + +# Skip this if the environment variable is not set or the path does not exist. +@pytest.mark.training +@pytest.mark.skipif( + ("NDF_DATASET_ROOT" not in os.environ or not os.path.exists(DEFAULT_NDF_PATH)) + and not torch.cuda.is_available(), + reason="NDF_DATASET_ROOT environment variable is not set or the path does not exist.", +) +@pytest.mark.parametrize("config_name", _get_training_config_names("ndf")) +def test_training_commands_run(config_name): + _test_commands_run(config_name) + + +# Do the same for the ablation configs. +@pytest.mark.ablations +@pytest.mark.skipif( + ("NDF_DATASET_ROOT" not in os.environ or not os.path.exists(DEFAULT_NDF_PATH)) + and not torch.cuda.is_available(), + reason="NDF_DATASET_ROOT environment variable is not set or the path does not exist.", +) +@pytest.mark.parametrize( + "config_name", _get_training_config_names("ndf", ablation=True) +) +def test_training_ablation_commands_run(config_name): + _test_commands_run(config_name)