Skip to content

Commit

Permalink
Merge pull request #1030 from MouseLand/new3d
Browse files Browse the repository at this point in the history
new gpu-accelerated mask creation step
  • Loading branch information
carsen-stringer authored Oct 21, 2024
2 parents 52f75f9 + 7dfe3b7 commit 729b701
Show file tree
Hide file tree
Showing 11 changed files with 878 additions and 706 deletions.
10 changes: 6 additions & 4 deletions cellpose/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def main():
% (nimg, cstr0[channels[0]], cstr1[channels[1]]))

# handle built-in model exceptions
if builtin_size and restore_type is None:
if builtin_size and restore_type is None and not args.pretrained_model_ortho:
model = models.Cellpose(gpu=gpu, device=device, model_type=model_type,
backbone=backbone)
else:
Expand All @@ -166,11 +166,13 @@ def main():

pretrained_model = None if model_type is not None else pretrained_model
if restore_type is None:
pretrained_model_ortho = None if args.pretrained_model_ortho is None else args.pretrained_model_ortho
model = models.CellposeModel(gpu=gpu, device=device,
pretrained_model=pretrained_model,
model_type=model_type,
nchan=nchan,
backbone=backbone)
backbone=backbone,
pretrained_model_ortho=pretrained_model_ortho)
else:
model = denoise.CellposeDenoiseModel(
gpu=gpu, device=device, pretrained_model=pretrained_model,
Expand Down Expand Up @@ -209,7 +211,8 @@ def main():
invert=args.invert, batch_size=args.batch_size,
interp=(not args.no_interp), normalize=(not args.no_norm),
channel_axis=args.channel_axis, z_axis=args.z_axis,
anisotropy=args.anisotropy, niter=args.niter)
anisotropy=args.anisotropy, niter=args.niter,
dP_smooth=args.dP_smooth)
masks, flows = out[:2]
if len(out) > 3 and restore_type is None:
diams = out[-1]
Expand Down Expand Up @@ -240,7 +243,6 @@ def main():
io.save_rois(masks, image_name)
logger.info(">>>> completed in %0.3f sec" % (time.time() - tic))
else:

test_dir = None if len(args.test_dir) == 0 else args.test_dir
images, labels, image_names, train_probs = None, None, None, None
test_images, test_labels, image_names_test, test_probs = None, None, None, None
Expand Down
10 changes: 8 additions & 2 deletions cellpose/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def get_arg_parser():

# model settings
model_args = parser.add_argument_group("Model Arguments")
model_args.add_argument("--pretrained_model", required=False, default="cyto",
model_args.add_argument("--pretrained_model", required=False, default="cyto3",
type=str,
help="model to use for running or starting training")
model_args.add_argument("--restore_type", required=False, default=None, type=str,
Expand All @@ -79,7 +79,10 @@ def get_arg_parser():
model_args.add_argument(
"--transformer", action="store_true", help=
"use transformer backbone (pretrained_model from Cellpose3 is transformer_cp3)")

model_args.add_argument("--pretrained_model_ortho", required=False, default=None,
type=str,
help="model to use for running 3D ortho views (ZY and ZX)")

# algorithm settings
algorithm_args = parser.add_argument_group("Algorithm Arguments")
algorithm_args.add_argument(
Expand All @@ -105,6 +108,9 @@ def get_arg_parser():
algorithm_args.add_argument(
"--min_size", required=False, default=15, type=int,
help="minimum number of pixels per mask, can turn off with -1")
algorithm_args.add_argument(
"--dP_smooth", required=False, default=0, type=float,
help="stddev of gaussian for smoothing of dP for dynamics in 3D, default of 0 means no smoothing")

algorithm_args.add_argument(
"--flow_threshold", default=0.4, type=float, help=
Expand Down
251 changes: 88 additions & 163 deletions cellpose/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ def _use_gpu_torch(gpu_number=0):
"""
try:
device = torch.device("cuda:" + str(gpu_number))
_ = torch.zeros([1, 2, 3]).to(device)
_ = torch.zeros((1,1)).to(device)
core_logger.info("** TORCH CUDA version installed and working. **")
return True
except:
pass
try:
device = torch.device('mps:' + str(gpu_number))
_ = torch.zeros([1, 2, 3]).to(device)
_ = torch.zeros((1,1)).to(device)
core_logger.info('** TORCH MPS version installed and working. **')
return True
except:
Expand Down Expand Up @@ -153,7 +153,7 @@ def _to_device(x, device):
torch.Tensor: The converted tensor on the specified device.
"""
if not isinstance(x, torch.Tensor):
X = torch.from_numpy(x).float().to(device)
X = torch.from_numpy(x).to(device, dtype=torch.float32)
return X
else:
return x
Expand Down Expand Up @@ -195,20 +195,19 @@ def _forward(net, x):
return y, style


def run_net(net, imgs, batch_size=8, augment=False, tile=True, tile_overlap=0.1,
bsize=224):
def run_net(net, imgi, batch_size=8, augment=False, tile_overlap=0.1, bsize=224,
rsz=None):
"""
Run network on image or stack of images.
Run network on stack of images.
(faster if augment is False)
Args:
net (class): cellpose network (model.net)
imgs (np.ndarray): The input image or stack of images of size [Ly x Lx x nchan] or [Lz x Ly x Lx x nchan].
imgi (np.ndarray): The input image or stack of images of size [Lz x Ly x Lx x nchan].
batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8.
rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0.
augment (bool, optional): Tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False.
tile (bool, optional): Tiles image to ensure GPU/CPU memory usage limited (recommended); cannot be turned off for 3D segmentation. Defaults to True.
tile_overlap (float, optional): Fraction of overlap of tiles when computing flows. Defaults to 0.1.
bsize (int, optional): Size of tiles to use in pixels [bsize x bsize]. Defaults to 224.
Expand All @@ -217,149 +216,73 @@ def run_net(net, imgs, batch_size=8, augment=False, tile=True, tile_overlap=0.1,
y[...,0] is Y flow; y[...,1] is X flow; y[...,2] is cell probability.
style (np.ndarray): 1D array of size 256 summarizing the style of the image, if tiled it is averaged over tiles.
"""
if imgs.ndim == 4:
# make image Lz x nchan x Ly x Lx for net
imgs = np.transpose(imgs, (0, 3, 1, 2))
detranspose = (0, 2, 3, 1)
elif imgs.ndim == 3:
# make image nchan x Ly x Lx for net
imgs = np.transpose(imgs, (2, 0, 1))
detranspose = (1, 2, 0)
elif imgs.ndim == 2:
imgs = imgs[np.newaxis, :, :]
detranspose = (1, 2, 0)

# pad image for net so Ly and Lx are divisible by 4
imgs, ysub, xsub = transforms.pad_image_ND(imgs)
# slices from padding
# slc = [slice(0, self.nclasses) for n in range(imgs.ndim)] # changed from imgs.shape[n]+1 for first slice size
slc = [slice(0, imgs.shape[n] + 1) for n in range(imgs.ndim)]
slc[-3] = slice(0, net.nout)
slc[-2] = slice(ysub[0], ysub[-1] + 1)
slc[-1] = slice(xsub[0], xsub[-1] + 1)
slc = tuple(slc)

# run network
if tile or augment or imgs.ndim == 4:
y, style = _run_tiled(net, imgs, augment=augment, bsize=bsize,
batch_size=batch_size, tile_overlap=tile_overlap)
nout = net.nout
Lz, Ly0, Lx0, nchan = imgi.shape
if rsz is not None:
if not isinstance(rsz, list) and not isinstance(rsz, np.ndarray):
rsz = [rsz, rsz]
Lyr, Lxr = int(Ly0 * rsz[0]), int(Lx0 * rsz[1])
else:
imgs = np.expand_dims(imgs, axis=0)
y, style = _forward(net, imgs)
y, style = y[0], style[0]
style /= (style**2).sum()**0.5

# slice out padding
y = y[slc]
# transpose so channels axis is last again
y = np.transpose(y, detranspose)

return y, style


def _run_tiled(net, imgi, batch_size=8, augment=False, bsize=224, tile_overlap=0.1):
"""
Run network on tiles of size [bsize x bsize]
Lyr, Lxr = Ly0, Lx0
ypad1, ypad2, xpad1, xpad2 = transforms.get_pad_yx(Lyr, Lxr)
pads = np.array([[0, 0], [ypad1, ypad2], [xpad1, xpad2]])
Ly, Lx = Lyr + ypad1 + ypad2, Lxr + xpad1 + xpad2
if augment:
ny = max(2, int(np.ceil(2. * Ly / bsize)))
nx = max(2, int(np.ceil(2. * Lx / bsize)))
ly, lx = bsize, bsize
else:
ny = 1 if Ly <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Ly / bsize))
nx = 1 if Lx <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Lx / bsize))
ly, lx = min(bsize, Ly), min(bsize, Lx)
yf = np.zeros((Lz, nout, Ly, Lx), "float32")
styles = np.zeros((Lz, 256), "float32")

(faster if augment is False)
# run multiple slices at the same time
ntiles = ny * nx
nimgs = max(1, batch_size // ntiles) # number of imgs to run in the same batch
niter = int(np.ceil(Lz / nimgs))
ziterator = (trange(niter, file=tqdm_out, mininterval=30)
if niter > 10 or Lz > 1 else range(niter))
for k in ziterator:
inds = np.arange(k * nimgs, min(Lz, (k + 1) * nimgs))
IMGa = np.zeros((ntiles * len(inds), nchan, ly, lx), "float32")
for i, b in enumerate(inds):
# pad image for net so Ly and Lx are divisible by 4
imgb = transforms.resize_image(imgi[b], rsz=rsz) if rsz is not None else imgi[b].copy()
imgb = np.pad(imgb.transpose(2,0,1), pads, mode="constant")
IMG, ysub, xsub, Ly, Lx = transforms.make_tiles(
imgb, bsize=bsize, augment=augment,
tile_overlap=tile_overlap)
IMGa[i * ntiles : (i+1) * ntiles] = np.reshape(IMG,
(ny * nx, nchan, ly, lx))

ya = np.zeros((IMGa.shape[0], nout, ly, lx), "float32")
stylea = np.zeros((IMGa.shape[0], 256), "float32")
for j in range(0, IMGa.shape[0], batch_size):
bslc = slice(j, min(j + batch_size, IMGa.shape[0]))
ya[bslc], stylea[bslc] = _forward(net, IMGa[bslc])
for i, b in enumerate(inds):
y = ya[i * ntiles : (i + 1) * ntiles]
if augment:
y = np.reshape(y, (ny, nx, 3, ly, lx))
y = transforms.unaugment_tiles(y)
y = np.reshape(y, (-1, 3, ly, lx))
yfi = transforms.average_tiles(y, ysub, xsub, Ly, Lx)
yf[b] = yfi[:, :imgb.shape[-2], :imgb.shape[-1]]
stylei = stylea[i * ntiles:(i + 1) * ntiles].sum(axis=0)
stylei /= (stylei**2).sum()**0.5
styles[b] = stylei
# slices from padding
yf = yf[:, :, ypad1 : Ly-ypad2, xpad1 : Lx-xpad2]
yf = yf.transpose(0,2,3,1)
return yf, np.array(styles)

Args:
imgs (np.ndarray): The input image or stack of images of size [Ly x Lx x nchan] or [Lz x Ly x Lx x nchan].
batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8.
augment (bool, optional): Tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False.
tile_overlap (float, optional): Fraction of overlap of tiles when computing flows. Defaults to 0.1.
bsize (int, optional): Size of tiles to use in pixels [bsize x bsize]. Defaults to 224.

Returns:
y (np.ndarray): output of network, if tiled it is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3].
y[...,0] is Y flow; y[...,1] is X flow; y[...,2] is cell probability.
style (np.ndarray): 1D array of size 256 summarizing the style of the image, if tiled it is averaged over tiles.
"""
nout = net.nout
if imgi.ndim == 4:
Lz, nchan, Ly, Lx = imgi.shape
if augment:
ny = max(2, int(np.ceil(2. * Ly / bsize)))
nx = max(2, int(np.ceil(2. * Lx / bsize)))
ly, lx = bsize, bsize
else:
ny = 1 if Ly <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Ly / bsize))
nx = 1 if Lx <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Lx / bsize))
ly, lx = min(bsize, Ly), min(bsize, Lx)
yf = np.zeros((Lz, nout, imgi.shape[-2], imgi.shape[-1]), np.float32)
styles = []
if ny * nx > batch_size:
ziterator = (trange(Lz, file=tqdm_out, mininterval=30)
if Lz > 1 else range(Lz))
for i in ziterator:
yfi, stylei = _run_tiled(net, imgi[i], augment=augment, bsize=bsize,
batch_size=batch_size, tile_overlap=tile_overlap)
yf[i] = yfi
styles.append(stylei)
else:
# run multiple slices at the same time
ntiles = ny * nx
nimgs = batch_size // ntiles # number of z-slices to run at the same time
niter = int(np.ceil(Lz / nimgs))
ziterator = (trange(niter, file=tqdm_out, mininterval=30)
if Lz > 1 else range(niter))
for k in ziterator:
inds = np.arange(k * nimgs, min(Lz, (k + 1) * nimgs))
IMGa = np.zeros((ntiles * len(inds), nchan, ly, lx), "float32")
for i, b in enumerate(inds):
IMG, ysub, xsub, Ly, Lx = transforms.make_tiles(
imgi[b], bsize=bsize, augment=augment,
tile_overlap=tile_overlap)
IMGa[i * ntiles : (i+1) * ntiles] = np.reshape(IMG,
(ny * nx, nchan, ly, lx))
ya, stylea = _forward(net, IMGa)
for i, b in enumerate(inds):
y = ya[i * ntiles : (i + 1) * ntiles]
if augment:
y = np.reshape(y, (ny, nx, 3, ly, lx))
y = transforms.unaugment_tiles(y)
y = np.reshape(y, (-1, 3, ly, lx))
yfi = transforms.average_tiles(y, ysub, xsub, Ly, Lx)
yfi = yfi[:, :imgi.shape[2], :imgi.shape[3]]
yf[b] = yfi
stylei = stylea[i * ntiles:(i + 1) * ntiles].sum(axis=0)
stylei /= (stylei**2).sum()**0.5
styles.append(stylei)
return yf, np.array(styles)
else:
IMG, ysub, xsub, Ly, Lx = transforms.make_tiles(imgi, bsize=bsize,
augment=augment,
tile_overlap=tile_overlap)
ny, nx, nchan, ly, lx = IMG.shape
IMG = np.reshape(IMG, (ny * nx, nchan, ly, lx))
niter = int(np.ceil(IMG.shape[0] / batch_size))
y = np.zeros((IMG.shape[0], nout, ly, lx))
iterator = (trange(niter, file=tqdm_out, mininterval=30)
if niter > 25 else range(niter))
for k in iterator:
irange = slice(batch_size * k, min(IMG.shape[0],
batch_size * k + batch_size))
y0, style = _forward(net, IMG[irange])
y[irange] = y0.reshape(irange.stop - irange.start, y0.shape[-3],
y0.shape[-2], y0.shape[-1])
if k == 0:
styles = style.sum(axis=0)
else:
styles += style.sum(axis=0)
styles /= IMG.shape[0]
if augment:
y = np.reshape(y, (ny, nx, nout, bsize, bsize))
y = transforms.unaugment_tiles(y)
y = np.reshape(y, (-1, nout, bsize, bsize))

yf = transforms.average_tiles(y, ysub, xsub, Ly, Lx)
yf = yf[:, :imgi.shape[1], :imgi.shape[2]]
styles /= (styles**2).sum()**0.5
return yf, styles


def run_3D(net, imgs, batch_size=8, rsz=1.0, anisotropy=None, augment=False, tile=True,
tile_overlap=0.1, bsize=224, progress=None):
def run_3D(net, imgs, batch_size=8, augment=False,
tile_overlap=0.1, bsize=224, net_ortho=None,
progress=None):
"""
Run network on image z-stack.
Expand All @@ -371,9 +294,9 @@ def run_3D(net, imgs, batch_size=8, rsz=1.0, anisotropy=None, augment=False, til
rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0.
anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None.
augment (bool, optional): Tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False.
tile (bool, optional): Tiles image to ensure GPU/CPU memory usage limited (recommended); cannot be turned off for 3D segmentation. Defaults to True.
tile_overlap (float, optional): Fraction of overlap of tiles when computing flows. Defaults to 0.1.
bsize (int, optional): Size of tiles to use in pixels [bsize x bsize]. Defaults to 224.
net_ortho (class, optional): cellpose network for orthogonal ZY and ZX planes. Defaults to None.
progress (QProgressBar, optional): pyqt progress bar. Defaults to None.
Returns:
Expand All @@ -382,26 +305,28 @@ def run_3D(net, imgs, batch_size=8, rsz=1.0, anisotropy=None, augment=False, til
style (np.ndarray): 1D array of size 256 summarizing the style of the image, if tiled it is averaged over tiles.
"""
sstr = ["YX", "ZY", "ZX"]
if anisotropy is not None:
rescaling = [[rsz, rsz], [rsz * anisotropy, rsz], [rsz * anisotropy, rsz]]
else:
rescaling = [rsz] * 3
pm = [(0, 1, 2, 3), (1, 0, 2, 3), (2, 0, 1, 3)]
ipm = [(3, 0, 1, 2), (3, 1, 0, 2), (3, 1, 2, 0)]
nout = net.nout
yf = np.zeros((3, nout, imgs.shape[0], imgs.shape[1], imgs.shape[2]), np.float32)
ipm = [(0, 1, 2), (1, 0, 2), (1, 2, 0)]
cp = [(1, 2), (0, 2), (0, 1)]
cpy = [(0, 1), (0, 1), (0, 1)]
shape = imgs.shape[:-1]
#cellprob = np.zeros(shape, "float32")
yf = np.zeros((*shape, 4), "float32")
for p in range(3):
xsl = imgs.copy().transpose(pm[p])
# rescale image for flow computation
shape = xsl.shape
xsl = transforms.resize_image(xsl, rsz=rescaling[p])
xsl = imgs.transpose(pm[p])
# per image
core_logger.info("running %s: %d planes of size (%d, %d)" %
(sstr[p], shape[0], shape[1], shape[2]))
y, style = run_net(net, xsl, batch_size=batch_size, augment=augment, tile=tile,
bsize=bsize, tile_overlap=tile_overlap)
y = transforms.resize_image(y, shape[1], shape[2])
yf[p] = y.transpose(ipm[p])
(sstr[p], shape[pm[p][0]], shape[pm[p][1]], shape[pm[p][2]]))
y, style = run_net(net if p==0 or net_ortho is None else net_ortho,
xsl, batch_size=batch_size, augment=augment,
bsize=bsize, tile_overlap=tile_overlap,
rsz=None)
yf[..., -1] += y[..., -1].transpose(ipm[p])
for j in range(2):
yf[..., cp[p][j]] += y[..., cpy[p][j]].transpose(ipm[p])
y = None; del y

if progress is not None:
progress.setValue(25 + 15 * p)

return yf, style
Loading

0 comments on commit 729b701

Please sign in to comment.