Skip to content

Commit

Permalink
Merge branch 'mrariden-dynamics_reshape_fix' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Nov 8, 2023
2 parents c5ccc32 + 0743187 commit 5792997
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 23 deletions.
54 changes: 34 additions & 20 deletions cellpose/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,11 +718,29 @@ def get_masks(p, iscell=None, rpad=20):
M0 = np.reshape(M0, shape0)
return M0

def resize_and_compute_masks(dP, cellprob, p=None, niter=200,
cellprob_threshold=0.0,
flow_threshold=0.4, interp=True, do_3D=False,
min_size=15, resize=None,
use_gpu=False, device=None):
""" compute masks using dynamics from dP, cellprob, and boundary """
mask, p = compute_masks(dP, cellprob, p=p, niter=niter,
cellprob_threshold=cellprob_threshold,
flow_threshold=flow_threshold, interp=interp,
do_3D=do_3D, min_size=min_size,
use_gpu=use_gpu, device=device)

if resize is not None:
mask = transforms.resize_image(mask, resize[0], resize[1], interpolation=cv2.INTER_NEAREST)
p = np.array([transforms.resize_image(pi, resize[0], resize[1], interpolation=cv2.INTER_NEAREST) for pi in p])

return mask, p


def compute_masks(dP, cellprob, p=None, niter=200,
cellprob_threshold=0.0,
flow_threshold=0.4, interp=True, do_3D=False,
min_size=15, resize=None,
use_gpu=False,device=None):
min_size=15, use_gpu=False,device=None):
""" compute masks using dynamics from dP, cellprob, and boundary """

cp_mask = cellprob > cellprob_threshold
Expand All @@ -734,7 +752,7 @@ def compute_masks(dP, cellprob, p=None, niter=200,
use_gpu=use_gpu, device=device)
if inds is None:
dynamics_logger.info('No cell pixels found.')
shape = resize if resize is not None else cellprob.shape
shape = cellprob.shape
mask = np.zeros(shape, np.uint16)
p = np.zeros((len(shape), *shape), np.uint16)
return mask, p
Expand All @@ -744,31 +762,27 @@ def compute_masks(dP, cellprob, p=None, niter=200,

# flow thresholding factored out of get_masks
if not do_3D:
shape0 = p.shape[1:]
if mask.max()>0 and flow_threshold is not None and flow_threshold > 0:
# make sure labels are unique at output of get_masks
mask = remove_bad_flow_masks(mask, dP, threshold=flow_threshold, use_gpu=use_gpu, device=device)

if resize is not None:
#if verbose:
# dynamics_logger.info(f'resizing output with resize = {resize}')
if mask.max() > 2**16-1:
recast = True
mask = mask.astype(np.float32)
else:
recast = False
mask = mask.astype(np.uint16)
mask = transforms.resize_image(mask, resize[0], resize[1], interpolation=cv2.INTER_NEAREST)
if recast:
mask = mask.astype(np.uint32)
Ly,Lx = mask.shape
elif mask.max() < 2**16:
if mask.max() > 2**16-1:
recast = True
mask = mask.astype(np.float32)
else:
recast = False
mask = mask.astype(np.uint16)

if recast:
mask = mask.astype(np.uint32)

if mask.max() < 2**16:
mask = mask.astype(np.uint16)

else: # nothing to compute, just make it compatible
dynamics_logger.info('No cell pixels found.')
shape = resize if resize is not None else cellprob.shape
mask = np.zeros(shape, np.uint16)
shape = cellprob.shape
mask = np.zeros(cellprob.shape, np.uint16)
p = np.zeros((len(shape), *shape), np.uint16)
return mask, p

Expand Down
2 changes: 1 addition & 1 deletion cellpose/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1897,7 +1897,7 @@ def compute_cprob(self):
else:
logger.info('computing masks with cell prob=%0.3f, flow error threshold=%0.3f'%
(cellprob_threshold, flow_threshold))
maski = dynamics.compute_masks(self.flows[4][:-1],
maski = dynamics.resize_and_compute_masks(self.flows[4][:-1],
self.flows[4][-1],
p=self.flows[3].copy(),
cellprob_threshold=cellprob_threshold,
Expand Down
4 changes: 2 additions & 2 deletions cellpose/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False,
tic=time.time()
niter = 200 if (do_3D and not resample) else (1 / rescale * 200)
if do_3D:
masks, p = dynamics.compute_masks(dP, cellprob, niter=niter,
masks, p = dynamics.resize_and_compute_masks(dP, cellprob, niter=niter,
cellprob_threshold=cellprob_threshold,
flow_threshold=flow_threshold,
interp=interp, do_3D=do_3D, min_size=min_size,
Expand All @@ -653,7 +653,7 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False,
masks, p = [], []
resize = [shape[1], shape[2]] if not resample else None
for i in iterator:
outputs = dynamics.compute_masks(dP[:,i], cellprob[i], niter=niter, cellprob_threshold=cellprob_threshold,
outputs = dynamics.resize_and_compute_masks(dP[:,i], cellprob[i], niter=niter, cellprob_threshold=cellprob_threshold,
flow_threshold=flow_threshold, interp=interp, resize=resize,
min_size=min_size if stitch_threshold==0 or nimg==1 else -1, # turn off for 3D stitching
use_gpu=self.gpu, device=self.device)
Expand Down

0 comments on commit 5792997

Please sign in to comment.