Skip to content

Commit

Permalink
[update] update dict config parse
Browse files Browse the repository at this point in the history
  • Loading branch information
yzqin committed Mar 21, 2024
1 parent 5e1e65e commit 947b66e
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 8 deletions.
21 changes: 14 additions & 7 deletions dex_retargeting/retargeting_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sapien.core as sapien
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Dict
from typing import List, Optional, Dict, Any
from typing import Union

import numpy as np
Expand Down Expand Up @@ -45,6 +45,9 @@ class RetargetingConfig:
normal_delta: float = 4e-3
huber_delta: float = 2e-2

# Constraint parameters
constraint_map: Optional[Dict[str, np.ndarray]] = None

# Joint limit tag
has_joint_limits: bool = True

Expand Down Expand Up @@ -110,12 +113,16 @@ def load_from_file(cls, config_path: Union[str, Path], override: Optional[Dict]
with path.open("r") as f:
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
cfg = yaml_config["retargeting"]
if "target_link_human_indices" in cfg:
cfg["target_link_human_indices"] = np.array(cfg["target_link_human_indices"])
if override is not None:
for key, value in override.items():
cfg[key] = value
config = RetargetingConfig(**cfg)
return cls.from_dict(cfg, override)

@classmethod
def from_dict(cls, cfg: Dict[str, Any], override: Optional[Dict] = None):
if "target_link_human_indices" in cfg:
cfg["target_link_human_indices"] = np.array(cfg["target_link_human_indices"])
if override is not None:
for key, value in override.items():
cfg[key] = value
config = RetargetingConfig(**cfg)
return config

def build(self, scene: Optional[sapien.Scene] = None) -> SeqRetargeting:
Expand Down
66 changes: 65 additions & 1 deletion tests/test_retargeting_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path

import pytest
import yaml

from dex_retargeting.retargeting_config import RetargetingConfig
from dex_retargeting.seq_retarget import SeqRetargeting
Expand Down Expand Up @@ -36,8 +37,71 @@ class TestRetargetingConfig:
)

@pytest.mark.parametrize("config_path", config_paths)
def test_config_parsing(self, config_path):
def test_path_config_parsing(self, config_path):
config_path = self.config_dir / config_path
config = RetargetingConfig.load_from_file(config_path)
retargeting = config.build()
assert isinstance(retargeting, SeqRetargeting)

def test_dict_config_parsing(self):
cfg_str = """
type: vector
urdf_path: allegro_hand/allegro_hand_right.urdf
wrist_link_name: "wrist"
# Target refers to the retargeting target, which is the robot hand
target_joint_names: null
target_origin_link_names: [ "wrist", "wrist", "wrist", "wrist" ]
target_task_link_names: [ "link_15.0_tip", "link_3.0_tip", "link_7.0_tip", "link_11.0_tip" ]
scaling_factor: 1.6
# Source refers to the retargeting input, which usually corresponds to the human hand
# The joint indices of human hand joint which corresponds to each link in the target_link_names
target_link_human_indices: [ [ 0, 0, 0, 0 ], [ 4, 8, 12, 16 ] ]
# A smaller alpha means stronger filtering, i.e. more smooth but also larger latency
low_pass_alpha: 0.2
"""
cfg_dict = yaml.safe_load(cfg_str)
config = RetargetingConfig.from_dict(cfg_dict)
retargeting = config.build()
assert type(retargeting) == SeqRetargeting

def test_multi_dict_config_parsing(self):
cfg_str = """
- type: vector
urdf_path: allegro_hand/allegro_hand_right.urdf
wrist_link_name: "wrist"
# Target refers to the retargeting target, which is the robot hand
target_joint_names: null
target_origin_link_names: [ "wrist", "wrist", "wrist", "wrist" ]
target_task_link_names: [ "link_15.0_tip", "link_3.0_tip", "link_7.0_tip", "link_11.0_tip" ]
scaling_factor: 1.6
# Source refers to the retargeting input, which usually corresponds to the human hand
# The joint indices of human hand joint which corresponds to each link in the target_link_names
target_link_human_indices: [ [ 0, 0, 0, 0 ], [ 4, 8, 12, 16 ] ]
# A smaller alpha means stronger filtering, i.e. more smooth but also larger latency
low_pass_alpha: 0.2
- type: DexPilot
urdf_path: leap_hand/leap_hand_right.urdf
wrist_link_name: "base"
# Target refers to the retargeting target, which is the robot hand
target_joint_names: null
finger_tip_link_names: [ "thumb_tip_head", "index_tip_head", "middle_tip_head", "ring_tip_head" ]
scaling_factor: 1.6
# A smaller alpha means stronger filtering, i.e. more smooth but also larger latency
low_pass_alpha: 0.2
"""
cfg_dict_list = yaml.safe_load(cfg_str)
retargetings = []
for cfg_dict in cfg_dict_list:
config = RetargetingConfig.from_dict(cfg_dict)
retargeting = config.build()
retargetings.append(retargeting)
assert isinstance(retargeting, SeqRetargeting)

0 comments on commit 947b66e

Please sign in to comment.