Skip to content

Commit

Permalink
things seem to work
Browse files Browse the repository at this point in the history
  • Loading branch information
beneisner committed May 2, 2024
1 parent 8956e92 commit 11a37c5
Show file tree
Hide file tree
Showing 10 changed files with 63 additions and 138 deletions.
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ module = [
"display.*",
"dgl.*",
"h5py.*",
"joblib.*",
"matplotlib.*",
"open3d.*",
"pandas.*",
"plotly.*",
"pybullet.*",
Expand All @@ -101,7 +103,8 @@ module = [
ignore_missing_imports = true

[tool.pytest.ini_options]
addopts = "--ignore=third_party/ -m 'not long'"
addopts = "--ignore=third_party/ -m 'not training and not pretraining and not ablations'"
testpaths = "tests"

[tool.pylint]
disable = [
Expand Down
7 changes: 4 additions & 3 deletions taxpose/datasets/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ def maybe_downsample(

# raise ValueError("Cannot downsample to more points than exist in the cloud.")

points, ids = sample_farthest_points(
points_pt, ids = sample_farthest_points(
torch.from_numpy(points), K=num_points, random_start_point=True
)

return points.numpy()
points = points_pt.numpy()
return points


@dataclass
Expand Down Expand Up @@ -70,6 +71,6 @@ def occlusion(points: npt.NDArray[np.float32], obj_class: int, min_num_points: i
# Ignore the occlusion if it's going to mess us up later...
if points_new.shape[0] > min_num_points:
points = points_new.unsqueeze(0)
return points if isinstance(points, np.ndarray) else points.numpy()
return points if isinstance(points, np.ndarray) else points.numpy() # type: ignore

return occlusion
6 changes: 5 additions & 1 deletion taxpose/datasets/env_mod_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# type: ignore

from typing import Dict, Tuple

import numpy as np
import torch
from pytorch3d.ops import sample_farthest_points
Expand Down Expand Up @@ -146,7 +150,7 @@ def get_random_distractor_demo(
transform_base=True,
return_debug=False,
rot_sample_method="axis_angle",
):
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Dict]:
debug = {}
if transform_base:
N = points_anchor_base.shape[0]
Expand Down
10 changes: 5 additions & 5 deletions taxpose/datasets/ndf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from taxpose.datasets.base import PlacementPointCloudData
from taxpose.datasets.enums import ObjectClass, Phase
from taxpose.datasets.env_mod_utils import get_random_distractor_demo
from taxpose.datasets.env_mod_utils import get_random_distractor_demo # type: ignore
from taxpose.datasets.symmetry_utils import (
compute_demo_symmetry_features as new_compute_demo_symmetry_features,
)
Expand Down Expand Up @@ -193,7 +193,7 @@ def __init__(self, cfg: NDFPointCloudDatasetConfig):

self.filenames = [
self.dataset_root / f"{idx}_{self.cloud_type}_obj_points.npz"
for idx in self.dataset_indices
for idx in self.dataset_indices # type: ignore
if idx not in self.bad_demo_id
]

Expand Down Expand Up @@ -303,7 +303,7 @@ def __getitem__(self, index: int) -> PlacementPointCloudData:
rot_sample_method=self.distractor_rot_sample_method,
)
points_action = points_action.numpy()
points_anchor = torch.cat([points_anchor1, points_anchor2], axis=1).numpy()
points_anchor = torch.cat([points_anchor1, points_anchor2], dim=1).numpy()

# Apply occlusions
if self.occlusion_cfg is not None:
Expand All @@ -327,8 +327,8 @@ def __getitem__(self, index: int) -> PlacementPointCloudData:
) = new_compute_demo_symmetry_features(
points_action[0],
points_anchor[0],
self.action_class,
self.anchor_class,
OBJECT_LABELS_TO_CLASS[(self.object_type, self.action_class)], # type: ignore
OBJECT_LABELS_TO_CLASS[(self.object_type, self.anchor_class)], # type: ignore
)

assert not isinstance(action_symmetry_features, torch.Tensor)
Expand Down
16 changes: 2 additions & 14 deletions taxpose/datasets/point_cloud_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,8 @@ def make_dataset(
import taxpose.datasets.ndf as ndf

return ndf.NDFPointCloudDataset(cast(ndf.NDFPointCloudDatasetConfig, cfg))
elif cfg.dataset_type == "rlbench":
import taxpose.datasets.rlbench as rlbench

return rlbench.RLBenchPointCloudDataset(
cast(rlbench.RLBenchPointCloudDatasetConfig, cfg)
)
elif cfg.dataset_type == "real_world_mug":
import taxpose.datasets.real_world_mug as real_world_mug

return real_world_mug.RealWorldMugPointCloudDataset(
cast(real_world_mug.RealWorldMugPointCloudDatasetConfig, cfg)
)
else:
raise NotImplementedError(f"Unknown dataset type: {cfg.dataset_type}")


class PointCloudDataset(Dataset):
Expand All @@ -84,8 +74,6 @@ def __init__(self, cfg: PointCloudDatasetConfig):
self.overfit = cfg.overfit
self.gripper_lr_label = cfg.gripper_lr_label
self.num_overfit_transforms = cfg.num_overfit_transforms
self.T0_list = []
self.T1_list = []
self.synthetic_occlusion = cfg.synthetic_occlusion
self.ball_radius = cfg.ball_radius
self.plane_standoff = cfg.plane_standoff
Expand Down
130 changes: 24 additions & 106 deletions taxpose/datasets/pretraining_point_cloud_data_module.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,40 @@
from dataclasses import dataclass
from typing import Optional, cast

import pytorch_lightning as pl
from torch.utils.data import DataLoader

from taxpose.datasets.base import PretrainingPointCloudDatasetConfig
from taxpose.datasets.ndf_pretraining import NDFPretrainingPointCloudDataset
from taxpose.datasets.shapenet_pretraining import ShapeNetPretrainingPointCloudDataset
from taxpose.datasets.base import (
PretrainingPointCloudDataset,
PretrainingPointCloudDatasetConfig,
)
from taxpose.datasets.ndf_pretraining import (
NDFPretrainingPointCloudDataset,
NDFPretrainingPointCloudDatasetConfig,
)
from taxpose.datasets.shapenet_pretraining import (
ShapeNetPretrainingPointCloudDataset,
ShapeNetPretrainingPointCloudDatasetConfig,
)


def make_dataset(cfg: PretrainingPointCloudDatasetConfig):
dataset_type = cfg.dataset_type
if dataset_type == "ndf_pretraining":
return NDFPretrainingPointCloudDataset(cfg)
ndf_cfg = cast(NDFPretrainingPointCloudDatasetConfig, cfg)
return NDFPretrainingPointCloudDataset(ndf_cfg)
elif dataset_type == "shapenet_pretraining":
return ShapeNetPretrainingPointCloudDataset(cfg)
shapenet_cfg = cast(ShapeNetPretrainingPointCloudDatasetConfig, cfg)
return ShapeNetPretrainingPointCloudDataset(shapenet_cfg)
else:
raise ValueError(f"Unknown dataset type: {dataset_type}")


@dataclass
class PretrainingMultiviewDMConfig:
train_config: PretrainingPointCloudDatasetConfig
val_config: PretrainingPointCloudDatasetConfig
test_config: PretrainingPointCloudDatasetConfig
train_dset: PretrainingPointCloudDatasetConfig
val_dset: PretrainingPointCloudDatasetConfig
test_dset: PretrainingPointCloudDatasetConfig


class PretrainingMultiviewDataModule(pl.LightningDataModule):
Expand All @@ -31,117 +43,23 @@ def __init__(
cfg: PretrainingMultiviewDMConfig,
batch_size=8,
num_workers=8,
# cloud_class=0,
# batch_size=8,
# num_workers=8,
# cloud_type="final",
# dataset_index=None,
# dataset_root=None,
# obj_class="mug",
# pretraining_data_path=None,
):
super().__init__()

self.batch_size = batch_size
self.num_workers = num_workers
self.cfg = cfg
# self.cloud_class = cloud_class
# self.cloud_type = cloud_type
# self.dataset_indices = dataset_index
# self.dataset_root = dataset_root
# self.pretraining_data_path = pretraining_data_path
# self.obj_class = obj_class

# 0 for mug, 1 for rack, 2 for gripper
# if self.cloud_class == 0:
# self.obj_class = obj_class
# else:
# self.obj_class = "non_mug"

# def pass_loss(self, loss):
# self.loss = loss.to(self.device)

# def prepare_data(self):
# """called only once and on 1 GPU"""

# def update_dataset(self):
# if self.obj_class != "non_mug":
# self.train_dataset = ShapeNetPretrainingPointCloudDataset(
# ShapeNetPretrainingPointCloudDatasetConfig(
# ndf_data_path=self.pretraining_data_path,
# obj_class=[self.obj_class],
# phase="train",
# )
# )
# else:
# self.train_dataset = NDFPretrainingPointCloudDataset(
# NDFPretrainingPointCloudDatasetConfig(
# dataset_root=self.dataset_root,
# dataset_indices=self.dataset_indices,
# cloud_type=self.cloud_type,
# action_class=self.cloud_class,
# )
# )
self.train_dataset: Optional[PretrainingPointCloudDataset] = None
self.val_dataset: Optional[PretrainingPointCloudDataset] = None
self.test_dataset: Optional[PretrainingPointCloudDataset] = None

def setup(self, stage: str):
if stage == "fit":
print("TRAIN Dataset")
self.train_dataset = make_dataset(self.cfg.train_dset)
# if self.obj_class != "non_mug":
# self.train_dataset = ShapeNetPretrainingPointCloudDataset(
# ShapeNetPretrainingPointCloudDatasetConfig(
# ndf_data_path=self.pretraining_data_path,
# obj_class=[self.obj_class],
# phase="train",
# )
# )
# else:
# self.train_dataset = NDFPretrainingPointCloudDataset(
# NDFPretrainingPointCloudDatasetConfig(
# dataset_root=self.dataset_root,
# dataset_indices=self.dataset_indices,
# cloud_type=self.cloud_type,
# action_class=self.cloud_class,
# )
# )
print("VAL Dataset")
self.val_dataset = make_dataset(self.cfg.val_dset)
# if self.obj_class != "non_mug":
# self.val_dataset = ShapeNetPretrainingPointCloudDataset(
# ShapeNetPretrainingPointCloudDatasetConfig(
# ndf_data_path=self.pretraining_data_path,
# obj_class=[self.obj_class],
# phase="val",
# )
# )
# else:
# self.val_dataset = NDFPretrainingPointCloudDataset(
# NDFPretrainingPointCloudDatasetConfig(
# dataset_root=self.dataset_root,
# dataset_indices=self.dataset_indices,
# cloud_type=self.cloud_type,
# action_class=self.cloud_class,
# )
# )

if stage == "test":
print("TEST Dataset")
self.test_dataset = make_dataset(self.cfg.test_dset)
# if self.obj_class != "non_mug":
# self.test_dataset = ShapeNetPretrainingPointCloudDataset(
# ShapeNetPretrainingPointCloudDatasetConfig(
# ndf_data_path=self.pretraining_data_path,
# obj_class=[self.obj_class],
# )
# )
# else:
# self.test_dataset = NDFPretrainingPointCloudDataset(
# NDFPretrainingPointCloudDatasetConfig(
# dataset_root=self.dataset_root,
# dataset_indices=self.dataset_indices,
# cloud_type=self.cloud_type,
# action_class=self.cloud_class,
# )
# )

def train_dataloader(self):
return DataLoader(
Expand Down
5 changes: 4 additions & 1 deletion taxpose/datasets/shapenet_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def __init__(
obj_class = cfg.obj_class
train_num = cfg.train_num

if ndf_data_path is None:
raise ValueError("ndf_data_path must be provided.")

self.ndf_data_path = Path(ndf_data_path)
# Path setup (change to folder where your training data is kept)
# these are the names of the full dataset folders
Expand Down Expand Up @@ -85,7 +88,7 @@ def __init__(
files = list(sorted(glob.glob(path + "/*.npz")))
n = len(files)

if train_num == None:
if train_num is None:
idx = int(0.9 * n)
else:
idx = train_num
Expand Down
17 changes: 12 additions & 5 deletions taxpose/datasets/symmetry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ def scalars_to_rgb(symmetry_labels: npt.NDArray[np.float32]) -> npt.NDArray[np.u

# Convert the color to RGB.
color = color * 255
color = color.round().astype(np.uint8)
icolor = color.round().astype(np.uint8)

return color
return icolor


def gripper_symmetry_labels(
gripper_pcd: npt.NDArray[np.float32],
) -> Tuple[np.float32, np.float32, np.float32]:
) -> Tuple[npt.NDArray[np.float32], npt.NDArray[np.float32], npt.NDArray[np.float32]]:
"""Compute gripper symmetry labels.
Args:
Expand Down Expand Up @@ -72,7 +72,12 @@ def rotational_symmetry_labels(
obj_class: ObjectClass,
look_at: Optional[npt.NDArray[np.float32]] = None,
seed: Optional[int] = None,
) -> Tuple[np.float32, np.float32, np.float32, np.float32]:
) -> Tuple[
npt.NDArray[np.float32],
npt.NDArray[np.float32],
npt.NDArray[np.float32],
npt.NDArray[np.float32],
]:
"""Computes object symmetry labels.
Args:
Expand Down Expand Up @@ -143,7 +148,9 @@ def rotational_symmetry_labels(
return l_obj[..., None], principal_axis, s_obj, centroid


def nonsymmetric_labels(obj_pcd) -> Tuple[np.float32, np.float32]:
def nonsymmetric_labels(
obj_pcd,
) -> Tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]:
"""Computes nonsymmetric labels. This should just be ones.
Args:
Expand Down
4 changes: 2 additions & 2 deletions taxpose/utils/symmetry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def get_sym_label_pca_grasp(
points_sym = anchor_cloud[0]
points_nonsym = action_cloud[0]

non_sym_center = points_nonsym.mean(axis=0)
non_sym_center = points_nonsym.mean(dim=0)
points_sym_np = to_np(points_sym)

pcd = o3d.geometry.PointCloud()
Expand Down Expand Up @@ -254,7 +254,7 @@ def get_sym_label_pca_place(
points_sym = anchor_cloud[0]
points_nonsym = action_cloud[0]

non_sym_center = points_nonsym.mean(axis=0)
non_sym_center = points_nonsym.mean(dim=0)
points_sym_np = to_np(points_sym)

pcd = o3d.geometry.PointCloud()
Expand Down
1 change: 1 addition & 0 deletions tests/pretrain_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def _get_pretraining_config_names():
DEFAULT_NDF_PATH = "/data"


@pytest.mark.pretraining
@pytest.mark.skipif(
("NDF_DATASET_ROOT" not in os.environ or not os.path.exists(DEFAULT_NDF_PATH))
and not torch.cuda.is_available(),
Expand Down

0 comments on commit 11a37c5

Please sign in to comment.