diff --git a/cellpose/gui/gui.py b/cellpose/gui/gui.py index a3ec3891..556ad0d2 100644 --- a/cellpose/gui/gui.py +++ b/cellpose/gui/gui.py @@ -946,6 +946,8 @@ def dropEvent(self, event): files = [u.toLocalFile() for u in event.mimeData().urls()] if os.path.splitext(files[0])[-1] == '.npy': io._load_seg(self, filename=files[0]) + elif os.path.isdir(files[0]) and os.path.basename(files[0]) == 'suite2p': + io._load_suite2p(self, pathname=files[0]) else: io._load_image(self, filename=files[0]) diff --git a/cellpose/gui/io.py b/cellpose/gui/io.py index be4bb658..302b8993 100644 --- a/cellpose/gui/io.py +++ b/cellpose/gui/io.py @@ -97,6 +97,61 @@ def _get_train_set(image_names): train_labels.append(masks) return train_data, train_labels, train_files + +def _load_suite2p_image(parent, pathname=None, image_name='meanImg_chan2'): + # Identify plane folders in suite2p folder + plane_folders = [os.path.basename(p) for p in glob.glob(f"{pathname}/plane*")] + if len(plane_folders)==0: + print(f'GUI_INFO: suite2p folder: {pathname} has no plane folders in it.') + return None + + # Identify ops.npy files + ops_files = [glob.glob(f"{pathname}/{plane}/ops.npy") for plane in plane_folders] + if not all([len(ofile)==1 for ofile in ops_files]): + print(f'GUI_INFO: plane folders of {pathname} do not all contain ops.npy files!') + return None + + # Get images from ops files + ops = [np.load(ofile[0], allow_pickle=True).item() for ofile in ops_files] + ops_image_shape = [(op['Ly'], op['Lx']) for op in ops] + if not all([image_name in op for op in ops]): + print(f'GUI_INFO: requested image ({image_name}) is not a key of every ops.npy file!') + return None + + ops_image = [op[image_name] for op in ops] + if not all([oimage.shape == oimage_shape for oimage, oimage_shape in zip(ops_image, ops_image_shape)]): + print(f'GUI_INFO: requested images ({image_name}) do not all have the expected shape...') + for plane, oimage, oimage_shape in zip(plane_folders, ops_image, ops_image_shape): + print(f' - in {plane}, ops["{image_name}"].shape={oimage.shape} but (ops["Ly"], ops["Lx"])={oimage_shape}') + return None + + return ops_image + +def _load_suite2p(parent, pathname=None, image_name='meanImg_chan2'): + # attempt to load image from ops.npy files in suite2p folder + ops_image = _load_suite2p_image(parent, pathname=pathname, image_name=image_name) + + # Load images into parent + try: + print(f'GUI_INFO: loading {image_name} from every plane in suite2p folder: {pathname}') + image = np.stack(ops_image) + parent.loaded = True + except Exception as e: + print('ERROR: failed to load images, they might have incompatible shapes') + print(f'ERROR: {e}') + + # Update GUI + if parent.loaded: + parent.reset() + parent.filename = pathname + parent.filetype = ['suite2p', image_name] + _initialize_images(parent, image, resize=parent.resize, X2=0) + parent.clear_all() + parent.loaded = True + parent.enable_buttons() + # TO DO: Consider integrating loadmasks into this load pipeline if they are already stored? + #if load_mask: _load_masks(parent, filename=mask_file) + def _load_image(parent, filename=None, load_seg=True): """ load image with filename; if None, open QFileDialog """ if filename is None: @@ -129,6 +184,7 @@ def _load_image(parent, filename=None, load_seg=True): if parent.loaded: parent.reset() parent.filename = filename + parent.filetype = 'image' filename = os.path.split(parent.filename)[-1] _initialize_images(parent, image, resize=parent.resize, X2=0) parent.clear_all() @@ -239,6 +295,11 @@ def _load_seg(parent, filename=None, image=None, image_file=None): parent.filename = dat['filename'] if os.path.isfile(parent.filename): parent.filename = dat['filename'] + parent.filetype = dat.get('filetype', 'image') + found_image = True + if os.path.isdir(parent.filename) and dat.get('filetype', 'image')[0] == 'suite2p': + parent.filename = dat['filename'] + parent.filetype = dat['filetype'] found_image = True else: imgname = os.path.split(parent.filename)[1] @@ -247,12 +308,22 @@ def _load_seg(parent, filename=None, image=None, image_file=None): if os.path.isfile(parent.filename): found_image = True if found_image: - try: - image = imread(parent.filename) - except: - parent.loaded = False - found_image = False - print('ERROR: cannot find image file, loading from npy') + if dat.get('filetype', 'image')[0] == 'suite2p': + try: + image_name = dat.get('filetype', ['suite2p', 'meanImg_chan2'])[1] + image = _load_suite2p_image(parent, pathname=parent.filename, image_name=image_name) + image = np.stack(image) + except: + parent.loaded = False + found_image = False + print('ERROR: cannot generate image from suite2p ops.npy files, loading from npy') + else: + try: + image = imread(parent.filename) + except: + parent.loaded = False + found_image = False + print('ERROR: cannot find image file, loading from npy') if not found_image: parent.filename = filename[:-11] if 'img' in dat: @@ -488,7 +559,10 @@ def _save_sets(parent): is forced, e.g. when clicking the save button. Otherwise, use _save_sets_with_check """ filename = parent.filename - base = os.path.splitext(filename)[0] + if getattr(parent, 'filetype', 'image')[0] == 'suite2p': + base = os.path.join(filename, 'suite2p') + else: + base = os.path.splitext(filename)[0] flow_threshold, cellprob_threshold = parent.get_thresholds() if parent.NZ > 1 and parent.is_stack: np.save(base + '_seg.npy', @@ -497,6 +571,7 @@ def _save_sets(parent): 'masks': parent.cellpix, 'current_channel': (parent.color-2)%5, 'filename': parent.filename, + 'filetype': parent.filetype, 'flows': parent.flows, 'zdraw': parent.zdraw, 'model_path': parent.current_model_path if hasattr(parent, 'current_model_path') else 0, @@ -511,6 +586,7 @@ def _save_sets(parent): 'chan_choose': [parent.ChannelChoose[0].currentIndex(), parent.ChannelChoose[1].currentIndex()], 'filename': parent.filename, + 'filetype': parent.filetype, 'flows': parent.flows, 'ismanual': parent.ismanual, 'manual_changes': parent.track_changes,