Skip to content

Commit

Permalink
WIP(mean): fix a bug when mask_background is False
Browse files Browse the repository at this point in the history
  • Loading branch information
torms3 committed Apr 25, 2024
1 parent 63770e2 commit a2e3f2b
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions deepem/loss/mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def forward(

# Recompute external matrix
mext = self.compute_ext_matrix(ids, groups, self.recompute_ext, device)
vecs = self.generate_vecs(embd, trgt, ids)
vecs = self.generate_vecs(embd, trgt, mask, ids)
means = [torch.mean(vec, dim=0) for vec in vecs]
weights = [1.0] * len(vecs)

Expand Down Expand Up @@ -194,16 +194,21 @@ def generate_vecs(
self,
embd: torch.Tensor,
trgt: torch.Tensor,
mask: torch.Tensor,
ids: Sequence[int],
) -> list[torch.Tensor]:
"""
Generate a list of vectorized embeddings for each ground truth object.
"""
mask_bool = mask.bool() if not self.mask_background else None
result = []
for obj_id in ids:
obj = torch.nonzero(trgt == int(obj_id))
z, y, x = obj[:, -3], obj[:, -2], obj[:, -1]
vec = embd[0, :, z, y, x].transpose(0, 1) # Count x Dim
obj_mask = (trgt == int(obj_id)) & mask_bool if mask_bool is not None else (trgt == int(obj_id))
idx = torch.nonzero(obj_mask, as_tuple=True)
if idx[0].numel() == 0:
# If there are no indices for this ID, skip to the next one
continue
vec = embd[0, :, idx[-3], idx[-2], idx[-1]].transpose(0, 1) # Count x Dim
result.append(vec)
return result

Expand Down

0 comments on commit a2e3f2b

Please sign in to comment.