Skip to content

Commit

Permalink
Merge pull request #203 from mzouink/main
Browse files Browse the repository at this point in the history
Raise informative error if padding mode is not a valid option
  • Loading branch information
cmalinmayor authored Feb 14, 2024
2 parents 3e5f1f3 + b1d4412 commit ecbb63c
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 24 deletions.
6 changes: 3 additions & 3 deletions gunpowder/contrib/nodes/add_vector_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,9 @@ def __get_vector_map(self, batch, request, vector_map_array_key):
if num_src_vectors_per_trg_loc > 0:
dist_to_locs = {}
for phys_loc in relevant_partner_loc:
dist_to_locs[
np.linalg.norm(node.location - phys_loc)
] = phys_loc
dist_to_locs[np.linalg.norm(node.location - phys_loc)] = (
phys_loc
)
for nr, dist in enumerate(
reversed(np.sort(list(dist_to_locs.keys())))
):
Expand Down
16 changes: 10 additions & 6 deletions gunpowder/nodes/defect_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,11 @@ def prepare(self, request):
logger.debug("before growth: %s" % spec.roi)
growth = Coordinate(
tuple(
0
if d == self.axis
else raw_voxel_size[d] * self.deformation_strength
(
0
if d == self.axis
else raw_voxel_size[d] * self.deformation_strength
)
for d in range(spec.roi.dims)
)
)
Expand Down Expand Up @@ -267,9 +269,11 @@ def process(self, batch, request):
old_roi = request[self.intensities].roi
logger.debug("resetting roi to %s" % old_roi)
crop = tuple(
slice(None)
if d == self.axis
else slice(self.deformation_strength, -self.deformation_strength)
(
slice(None)
if d == self.axis
else slice(self.deformation_strength, -self.deformation_strength)
)
for d in range(raw.spec.roi.dims)
)
raw.data = raw.data[crop]
Expand Down
16 changes: 13 additions & 3 deletions gunpowder/nodes/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def __init__(self, key, size, mode="constant", value=None):
self.key = key
self.size = size
self.mode = mode
if self.mode not in ["constant", "reflect"]:
raise ValueError(
"Invalid padding mode %s provided. Must be 'constant' or 'reflect'."
% self.mode
)
self.value = value

def setup(self):
Expand Down Expand Up @@ -129,8 +134,13 @@ def __expand(self, a, from_roi, to_roi, value):
lower_pad = from_roi.begin - to_roi.begin
upper_pad = to_roi.end - from_roi.end
pad_width = [(0, 0)] * num_channels + list(zip(lower_pad, upper_pad))
if self.mode == "constant":
padded = np.pad(a, pad_width, "constant", constant_values=value)
elif self.mode == "reflect":
if self.mode == "reflect":
padded = np.pad(a, pad_width, "reflect")
elif self.mode == "constant":
padded = np.pad(a, pad_width, "constant", constant_values=value)
else:
raise ValueError(
"Invalid padding mode %s provided. Must be 'constant' or 'reflect'."
% self.mode
)
return padded
16 changes: 10 additions & 6 deletions gunpowder/nodes/simple_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,11 @@ def process(self, batch, request):
location_in_total_offset = np.asarray(node.location) - total_roi_offset
node.location = np.asarray(
[
total_roi_end[dim] - location_in_total_offset[dim]
if m
else node.location[dim]
(
total_roi_end[dim] - location_in_total_offset[dim]
if m
else node.location[dim]
)
for dim, m in enumerate(self.mirror)
],
dtype=graph.spec.dtype,
Expand Down Expand Up @@ -255,9 +257,11 @@ def __mirror_roi(self, roi, total_roi, mirror):
end_of_roi_in_total = roi_in_total_offset + roi_shape
roi_in_total_offset_mirrored = total_roi_shape - end_of_roi_in_total
roi_offset = Coordinate(
total_roi_offset[d] + roi_in_total_offset_mirrored[d]
if mirror[d]
else roi_offset[d]
(
total_roi_offset[d] + roi_in_total_offset_mirrored[d]
if mirror[d]
else roi_offset[d]
)
for d in range(self.dims)
)

Expand Down
6 changes: 3 additions & 3 deletions gunpowder/nodes/specified_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ def process(self, batch, request):
for array_key, spec in request.array_specs.items():
batch.arrays[array_key].spec.roi = spec.roi
if self.extra_data is not None:
batch.arrays[array_key].attrs[
"specified_location_extra_data"
] = self.extra_data[self.loc_i]
batch.arrays[array_key].attrs["specified_location_extra_data"] = (
self.extra_data[self.loc_i]
)

for graph_key, spec in request.graph_specs.items():
batch.points[graph_key].spec.roi = spec.roi
Expand Down
6 changes: 3 additions & 3 deletions tests/cases/simple_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,9 @@ def test_mismatched_voxel_multiples():

test_array = ArrayKey("TEST_ARRAY")
data = np.zeros([3, 3])
data[
2, 1
] = 1 # voxel has Roi((4, 2) (2, 2)). Contained in Roi((0, 0), (6, 4)). at 2, 1
data[2, 1] = (
1 # voxel has Roi((4, 2) (2, 2)). Contained in Roi((0, 0), (6, 4)). at 2, 1
)
source = ArraySource(
test_array,
Array(
Expand Down

0 comments on commit ecbb63c

Please sign in to comment.