Skip to content

Commit

Permalink
feat(Zettaset): allow sharing of a user-specified mask volume across …
Browse files Browse the repository at this point in the history
…different annotations
  • Loading branch information
torms3 committed Jan 18, 2024
1 parent 4cac781 commit 4c3ae03
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
15 changes: 14 additions & 1 deletion deepem/data/dataset/multi_zettaset.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def load_sample(
zettaset_lookup: dict[str, str] | None = None,
zettaset_resolution: tuple[int, int, int] | None = None,
requires_binarize: list[str] = [],
zettaset_share_mask: str | None = None,
**kwargs
) -> dict[str, np.ndarray]:
"""Load image and labels from a Sample."""
Expand Down Expand Up @@ -164,6 +165,16 @@ def convert_array(arr: ArrayLike) -> np.ndarray:
# Assumes that zettaset's annotation names follow DeepEM's convention.
zettaset_lookup = zettaset_lookup or {x: x for x in sample.annotation_names}

# Shared mask
shared_mask = None
if (not no_mask) and zettaset_share_mask:
key = zettaset_share_mask
mask_key = f"{zettaset_share_mask}_mask"
if key not in sample.masks:
raise KeyError(f"Mask '{mask_key}' not found.")
mask_vol = sample.read_mask(key)[key]
shared_mask = convert_array(mask_vol).astype("uint8")

# Process annotations
for name, key in zettaset_lookup.items():

Expand All @@ -178,7 +189,9 @@ def convert_array(arr: ArrayLike) -> np.ndarray:

# Mask
mask_key = f"{name}_mask"
if (not no_mask) and (key in sample.masks):
if shared_mask is not None:
dset[mask_key] = shared_mask
elif (not no_mask) and (key in sample.masks):
mask_vol = sample.read_mask(key)[key]
dset[mask_key] = convert_array(mask_vol).astype("uint8")
else:
Expand Down
2 changes: 2 additions & 0 deletions deepem/train/option.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def initialize(self):
self.parser.add_argument('--zettaset_padding_spec', type=json.loads, default={})
self.parser.add_argument('--zettaset_resolution', type=vec3f, default=None)
self.parser.add_argument('--zettaset_no_mask', action='store_true')
self.parser.add_argument('--zettaset_share_mask', type=str, default=None)

# file synchronization for spot/preemptible training
self.parser.add_argument('--samwise_map', nargs='*', default=None)
Expand Down Expand Up @@ -319,6 +320,7 @@ def parse(self):
zettaset_resolution=opt.zettaset_resolution,
zettaset_mask=not opt.zettaset_no_mask,
requires_binarize=requires_binarize,
zettaset_share_mask=opt.zettaset_share_mask,
)

# ONNX
Expand Down

0 comments on commit 4c3ae03

Please sign in to comment.