Skip to content

Commit

Permalink
add training tests
Browse files Browse the repository at this point in the history
  • Loading branch information
beneisner committed Apr 30, 2024
1 parent 8b3b701 commit 9190a8f
Show file tree
Hide file tree
Showing 23 changed files with 213 additions and 69 deletions.
15 changes: 15 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,18 @@ Run act
```
act -j develop
```


## Testing

To run the tests:

```
pytest
```

To run all the tests, including long ones

```
pytest -m "long"
```
8 changes: 0 additions & 8 deletions configs/commands/ndf/bottle/train_bottle_grasp.yaml

This file was deleted.

8 changes: 0 additions & 8 deletions configs/commands/ndf/bottle/train_bottle_place.yaml

This file was deleted.

1 change: 1 addition & 0 deletions configs/commands/ndf/bottle/train_grasp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
defaults:
- _self_
- /train_ndf
- override /model: taxpose
- override /task: bottle_grasp

mode: train
1 change: 1 addition & 0 deletions configs/commands/ndf/bottle/train_place.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
defaults:
- _self_
- /train_ndf
- override /model: taxpose
- override /task: bottle_place

mode: train
8 changes: 0 additions & 8 deletions configs/commands/ndf/bowl/train_bowl_grasp.yaml

This file was deleted.

8 changes: 0 additions & 8 deletions configs/commands/ndf/bowl/train_bowl_place.yaml

This file was deleted.

1 change: 1 addition & 0 deletions configs/commands/ndf/bowl/train_grasp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
defaults:
- _self_
- /train_ndf
- override /model: taxpose
- override /task: bowl_grasp

mode: train
1 change: 1 addition & 0 deletions configs/commands/ndf/bowl/train_place.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
defaults:
- _self_
- /train_ndf
- override /model: taxpose
- override /task: bowl_place

mode: train
8 changes: 0 additions & 8 deletions configs/commands/ndf/mug/train_mug_grasp.yaml

This file was deleted.

8 changes: 0 additions & 8 deletions configs/commands/ndf/mug/train_mug_place.yaml

This file was deleted.

6 changes: 4 additions & 2 deletions configs/task/bottle_grasp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@ anchor_class: 0
cloud_type: pre_grasp
softmax_temperature: 0.1
weight_normalize: softmax
checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_gripper_embnn_weights.ckpt
checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_bottle_embnn_weights.ckpt
# checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_gripper_embnn_weights.ckpt
# checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_bottle_embnn_weights.ckpt
checkpoint_file_action: null
checkpoint_file_anchor: null
6 changes: 4 additions & 2 deletions configs/task/bottle_place.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@ anchor_class: 1
cloud_type: teleport
softmax_temperature: 1
weight_normalize: l1
checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_bottle_embnn_weights.ckpt
checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_slab_embnn_weights.ckpt
# checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_bottle_embnn_weights.ckpt
# checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_slab_embnn_weights.ckpt
checkpoint_file_action: null
checkpoint_file_anchor: null
6 changes: 4 additions & 2 deletions configs/task/bowl_grasp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@ anchor_class: 0
cloud_type: pre_grasp
softmax_temperature: 0.1
weight_normalize: softmax
checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_gripper_embnn_weights.ckpt
checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_bowl_embnn_weights.ckpt
# checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_gripper_embnn_weights.ckpt
# checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_bowl_embnn_weights.ckpt
checkpoint_file_action: null
checkpoint_file_anchor: null
6 changes: 4 additions & 2 deletions configs/task/bowl_place.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@ anchor_name: slab
cloud_type: teleport
softmax_temperature: 1
weight_normalize: l1
checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_bowl_embnn_weights.ckpt
checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_slab_embnn_weights.ckpt
# checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_bowl_embnn_weights.ckpt
# checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_slab_embnn_weights.ckpt
checkpoint_file_action: null
checkpoint_file_anchor: null
6 changes: 4 additions & 2 deletions configs/task/mug_grasp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@ anchor_class: 0
cloud_type: pre_grasp
softmax_temperature: 0.1
weight_normalize: softmax
checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_gripper_embnn_weights.ckpt
checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt
# checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_gripper_embnn_weights.ckpt
# checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt
checkpoint_file_action: null
checkpoint_file_anchor: null
6 changes: 4 additions & 2 deletions configs/task/mug_place.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@ softmax_temperature: 1
weight_normalize: l1
action_name: mug
anchor_name: rack
checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt
checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_rack_embnn_weights.ckpt
# checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt
# checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_rack_embnn_weights.ckpt
checkpoint_file_action: null
checkpoint_file_anchor: null
6 changes: 4 additions & 2 deletions configs/train_mug_residual_ablation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ flow_supervision: both

# Training Settings
checkpoint_file: Null
checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt
checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_rack_embnn_weights.ckpt
# checkpoint_file_action: ${hydra:runtime.cwd}/trained_models/pretraining_mug_embnn_weights.ckpt
# checkpoint_file_anchor: ${hydra:runtime.cwd}/trained_models/pretraining_rack_embnn_weights.ckpt
checkpoint_file_action: null
checkpoint_file_anchor: null
lr: 1e-4
max_epochs: 1000
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ module = [
ignore_missing_imports = true

[tool.pytest.ini_options]
addopts = "--ignore=third_party/"
addopts = "--ignore=third_party/ -m 'not long'"

[tool.pylint]
disable = [
Expand Down
12 changes: 10 additions & 2 deletions scripts/train_residual_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,27 @@ def write_to_file(file_name, string):
def main(cfg):
print(OmegaConf.to_yaml(cfg, resolve=True))

TESTING = "PYTEST_CURRENT_TEST" in os.environ

# breakpoint()
# torch.set_float32_matmul_precision("medium")
pl.seed_everything(cfg.seed)
logger = WandbLogger(project="taxpose", job_type=cfg.job_name)
logger.log_hyperparams(cfg)
logger.log_hyperparams({"working_dir": os.getcwd()})
trainer = pl.Trainer(
logger=logger,
logger=logger if not TESTING else False,
accelerator="gpu",
devices=[0],
reload_dataloaders_every_n_epochs=1,
callbacks=[SaverCallbackModel(), SaverCallbackEmbnnActionAnchor()],
callbacks=(
[SaverCallbackModel(), SaverCallbackEmbnnActionAnchor()]
if not TESTING
else []
),
max_epochs=cfg.max_epochs,
# Check if PYTEST is running, and run for 5 steps if it is.
fast_dev_run=5 if "PYTEST_CURRENT_TEST" in os.environ else False,
)
log_txt_file = cfg.log_txt_file
if cfg.mode == "train":
Expand Down
13 changes: 11 additions & 2 deletions scripts/train_residual_flow_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,25 @@ def write_to_file(file_name, string):
@hydra.main(config_path="../configs", config_name="train_mug_residual_ablation")
def main(cfg):
pl.seed_everything(cfg.seed)

TESTING = "PYTEST_CURRENT_TEST" in os.environ

logger = WandbLogger(project=cfg.experiment)
logger.log_hyperparams(cfg)
logger.log_hyperparams({"working_dir": os.getcwd()})
trainer = pl.Trainer(
logger=logger,
logger=logger if not TESTING else False,
accelerator="gpu",
devices=[0],
reload_dataloaders_every_n_epochs=1,
callbacks=[SaverCallbackModel(), SaverCallbackEmbnnActionAnchor()],
callbacks=(
[SaverCallbackModel(), SaverCallbackEmbnnActionAnchor()]
if not TESTING
else []
),
max_epochs=cfg.max_epochs,
# Check if PYTEST is running, and run for 5 steps if it is.
fast_dev_run=5 if "PYTEST_CURRENT_TEST" in os.environ else False,
)
log_txt_file = cfg.log_txt_file

Expand Down
9 changes: 7 additions & 2 deletions taxpose/nets/transformer_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def forward(self, *input):
scores=None,
).permute(0, 2, 1)

outputs = {
"flow_action": flow_action,
}

if self.cycle:
flow_anchor = self.head_anchor(
anchor_embedding_tf,
Expand All @@ -90,8 +94,9 @@ def forward(self, *input):
action_points,
scores=None,
).permute(0, 2, 1)
return flow_action, flow_anchor
return flow_action
outputs["flow_anchor"] = flow_anchor

return outputs


class CorrespondenceMLPHead(nn.Module):
Expand Down
Loading

0 comments on commit 9190a8f

Please sign in to comment.