diff --git a/.github/workflows/publish-docs.yaml b/.github/workflows/publish-docs.yaml index 0f9b78bb..efeb8e85 100644 --- a/.github/workflows/publish-docs.yaml +++ b/.github/workflows/publish-docs.yaml @@ -3,8 +3,7 @@ name: Deploy Docs to GitHub Pages on: push: branches: [main] - pull_request: - branches: [main] + tags: "*" workflow_dispatch: # Allow this job to clone the repo and create a page deployment diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 706882c7..2a345506 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.7", "3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] platform: [ubuntu-latest] steps: diff --git a/docs/source/conf.py b/docs/source/conf.py index 529dbe8b..b0da21cf 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -54,7 +54,7 @@ ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output diff --git a/examples/cremi/mknet.py b/examples/cremi/mknet.py index aac8c0df..fe13a7a5 100644 --- a/examples/cremi/mknet.py +++ b/examples/cremi/mknet.py @@ -2,8 +2,8 @@ import tensorflow as tf import json -def create_network(input_shape, name): +def create_network(input_shape, name): tf.reset_default_graph() # create a placeholder for the 3D raw input tensor @@ -11,20 +11,17 @@ def create_network(input_shape, name): # create a U-Net raw_batched = tf.reshape(raw, (1, 1) + input_shape) - unet_output = unet(raw_batched, 6, 4, [[1,3,3],[1,3,3],[1,3,3]]) + unet_output = unet(raw_batched, 6, 4, [[1, 3, 3], [1, 3, 3], [1, 3, 3]]) # add a convolution layer to create 3 output maps representing affinities # in z, y, and x pred_affs_batched = conv_pass( - unet_output, - kernel_size=1, - num_fmaps=3, - num_repetitions=1, - activation='sigmoid') + unet_output, kernel_size=1, num_fmaps=3, num_repetitions=1, activation="sigmoid" + ) # get the shape of the output output_shape_batched = pred_affs_batched.get_shape().as_list() - output_shape = output_shape_batched[1:] # strip the batch dimension + output_shape = output_shape_batched[1:] # strip the batch dimension # the 4D output tensor (3, depth, height, width) pred_affs = tf.reshape(pred_affs_batched, output_shape) @@ -33,46 +30,39 @@ def create_network(input_shape, name): gt_affs = tf.placeholder(tf.float32, shape=output_shape) # create a placeholder for per-voxel loss weights - loss_weights = tf.placeholder( - tf.float32, - shape=output_shape) + loss_weights = tf.placeholder(tf.float32, shape=output_shape) # compute the loss as the weighted mean squared error between the # predicted and the ground-truth affinities - loss = tf.losses.mean_squared_error( - gt_affs, - pred_affs, - loss_weights) + loss = tf.losses.mean_squared_error(gt_affs, pred_affs, loss_weights) # use the Adam optimizer to minimize the loss opt = tf.train.AdamOptimizer( - learning_rate=0.5e-4, - beta1=0.95, - beta2=0.999, - epsilon=1e-8) + learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8 + ) optimizer = opt.minimize(loss) # store the network in a meta-graph file - tf.train.export_meta_graph(filename=name + '.meta') + tf.train.export_meta_graph(filename=name + ".meta") # store network configuration for use in train and predict scripts config = { - 'raw': raw.name, - 'pred_affs': pred_affs.name, - 'gt_affs': gt_affs.name, - 'loss_weights': loss_weights.name, - 'loss': loss.name, - 'optimizer': optimizer.name, - 'input_shape': input_shape, - 'output_shape': output_shape[1:] + "raw": raw.name, + "pred_affs": pred_affs.name, + "gt_affs": gt_affs.name, + "loss_weights": loss_weights.name, + "loss": loss.name, + "optimizer": optimizer.name, + "input_shape": input_shape, + "output_shape": output_shape[1:], } - with open(name + '_config.json', 'w') as f: + with open(name + "_config.json", "w") as f: json.dump(config, f) -if __name__ == "__main__": +if __name__ == "__main__": # create a network for training - create_network((84, 268, 268), 'train_net') + create_network((84, 268, 268), "train_net") # create a larger network for faster prediction - create_network((120, 322, 322), 'test_net') + create_network((120, 322, 322), "test_net") diff --git a/examples/cremi/predict.py b/examples/cremi/predict.py index 8693786f..4f229b14 100644 --- a/examples/cremi/predict.py +++ b/examples/cremi/predict.py @@ -2,29 +2,29 @@ import gunpowder as gp import json -def predict(iteration): +def predict(iteration): ################## # DECLARE ARRAYS # ################## # raw intensities - raw = gp.ArrayKey('RAW') + raw = gp.ArrayKey("RAW") # the predicted affinities - pred_affs = gp.ArrayKey('PRED_AFFS') + pred_affs = gp.ArrayKey("PRED_AFFS") #################### # DECLARE REQUESTS # #################### - with open('test_net_config.json', 'r') as f: + with open("test_net_config.json", "r") as f: net_config = json.load(f) # get the input and output size in world units (nm, in this case) voxel_size = gp.Coordinate((40, 4, 4)) - input_size = gp.Coordinate(net_config['input_shape'])*voxel_size - output_size = gp.Coordinate(net_config['output_shape'])*voxel_size + input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size + output_size = gp.Coordinate(net_config["output_shape"]) * voxel_size context = input_size - output_size # formulate the request for what a batch should contain @@ -37,10 +37,8 @@ def predict(iteration): ############################# source = gp.Hdf5Source( - 'sample_A_padded_20160501.hdf', - datasets = { - raw: 'volumes/raw' - }) + "sample_A_padded_20160501.hdf", datasets={raw: "volumes/raw"} + ) # get the ROI provided for raw (we need it later to calculate the ROI in # which we can make predictions) @@ -48,41 +46,35 @@ def predict(iteration): raw_roi = source.spec[raw].roi pipeline = ( - # read from HDF5 file - source + - + source + + # convert raw to float in [0, 1] - gp.Normalize(raw) + - + gp.Normalize(raw) + + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Predict( - graph='test_net.meta', - checkpoint='train_net_checkpoint_%d'%iteration, - inputs={ - net_config['raw']: raw - }, - outputs={ - net_config['pred_affs']: pred_affs - }, - array_specs={ - pred_affs: gp.ArraySpec(roi=raw_roi.grow(-context, -context)) - }) + - + graph="test_net.meta", + checkpoint="train_net_checkpoint_%d" % iteration, + inputs={net_config["raw"]: raw}, + outputs={net_config["pred_affs"]: pred_affs}, + array_specs={pred_affs: gp.ArraySpec(roi=raw_roi.grow(-context, -context))}, + ) + + # store all passing batches in the same HDF5 file gp.Hdf5Write( { - raw: '/volumes/raw', - pred_affs: '/volumes/pred_affs', + raw: "/volumes/raw", + pred_affs: "/volumes/pred_affs", }, - output_filename='predictions_sample_A.hdf', - compression_type='gzip' - ) + - + output_filename="predictions_sample_A.hdf", + compression_type="gzip", + ) + + # show a summary of time spend in each node every 10 iterations - gp.PrintProfilingStats(every=10) + - + gp.PrintProfilingStats(every=10) + + # iterate over the whole dataset in a scanning fashion, emitting # requests that match the size of the network gp.Scan(reference=request) @@ -93,5 +85,6 @@ def predict(iteration): # without keeping the complete dataset in memory pipeline.request_batch(gp.BatchRequest()) + if __name__ == "__main__": predict(200000) diff --git a/examples/cremi/train.py b/examples/cremi/train.py index 8edd12f7..6faf7e50 100644 --- a/examples/cremi/train.py +++ b/examples/cremi/train.py @@ -6,41 +6,41 @@ logging.basicConfig(level=logging.INFO) -def train(iterations): +def train(iterations): ################## # DECLARE ARRAYS # ################## # raw intensities - raw = gp.ArrayKey('RAW') + raw = gp.ArrayKey("RAW") # objects labelled with unique IDs - gt_labels = gp.ArrayKey('LABELS') + gt_labels = gp.ArrayKey("LABELS") # array of per-voxel affinities to direct neighbors - gt_affs= gp.ArrayKey('AFFINITIES') + gt_affs = gp.ArrayKey("AFFINITIES") # weights to use to balance the loss - loss_weights = gp.ArrayKey('LOSS_WEIGHTS') + loss_weights = gp.ArrayKey("LOSS_WEIGHTS") # the predicted affinities - pred_affs = gp.ArrayKey('PRED_AFFS') + pred_affs = gp.ArrayKey("PRED_AFFS") # the gredient of the loss wrt to the predicted affinities - pred_affs_gradients = gp.ArrayKey('PRED_AFFS_GRADIENTS') + pred_affs_gradients = gp.ArrayKey("PRED_AFFS_GRADIENTS") #################### # DECLARE REQUESTS # #################### - with open('train_net_config.json', 'r') as f: + with open("train_net_config.json", "r") as f: net_config = json.load(f) # get the input and output size in world units (nm, in this case) voxel_size = gp.Coordinate((40, 4, 4)) - input_size = gp.Coordinate(net_config['input_shape'])*voxel_size - output_size = gp.Coordinate(net_config['output_shape'])*voxel_size + input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size + output_size = gp.Coordinate(net_config["output_shape"]) * voxel_size # formulate the request for what a batch should (at least) contain request = gp.BatchRequest() @@ -60,44 +60,38 @@ def train(iterations): ############################## pipeline = ( - # a tuple of sources, one for each sample (A, B, and C) provided by the # CREMI challenge tuple( - # read batches from the HDF5 file gp.Hdf5Source( - 'sample_'+s+'_padded_20160501.hdf', - datasets = { - raw: 'volumes/raw', - gt_labels: 'volumes/labels/neuron_ids' - } - ) + - + "sample_" + s + "_padded_20160501.hdf", + datasets={raw: "volumes/raw", gt_labels: "volumes/labels/neuron_ids"}, + ) + + # convert raw to float in [0, 1] gp.Normalize(raw) + - # chose a random location for each requested batch gp.RandomLocation() - - for s in ['A', 'B', 'C'] - ) + - + for s in ["A", "B", "C"] + ) + + # chose a random source (i.e., sample) from the above - gp.RandomProvider() + - + gp.RandomProvider() + + # elastically deform the batch gp.ElasticAugment( - [4,40,40], - [0,2,2], - [0,math.pi/2.0], + [4, 40, 40], + [0, 2, 2], + [0, math.pi / 2.0], prob_slip=0.05, prob_shift=0.05, - max_misalign=25) + - + max_misalign=25, + ) + + # apply transpose and mirror augmentations - gp.SimpleAugment(transpose_only=[1, 2]) + - + gp.SimpleAugment(transpose_only=[1, 2]) + + # scale and shift the intensity of the raw array gp.IntensityAugment( raw, @@ -105,65 +99,54 @@ def train(iterations): scale_max=1.1, shift_min=-0.1, shift_max=0.1, - z_section_wise=True) + - + z_section_wise=True, + ) + + # grow a boundary between labels - gp.GrowBoundary( - gt_labels, - steps=3, - only_xy=True) + - + gp.GrowBoundary(gt_labels, steps=3, only_xy=True) + + # convert labels into affinities between voxels - gp.AddAffinities( - [[-1, 0, 0], [0, -1, 0], [0, 0, -1]], - gt_labels, - gt_affs) + - + gp.AddAffinities([[-1, 0, 0], [0, -1, 0], [0, 0, -1]], gt_labels, gt_affs) + + # create a weight array that balances positive and negative samples in # the affinity array - gp.BalanceLabels( - gt_affs, - loss_weights) + - + gp.BalanceLabels(gt_affs, loss_weights) + + # pre-cache batches from the point upstream - gp.PreCache( - cache_size=10, - num_workers=5) + - + gp.PreCache(cache_size=10, num_workers=5) + + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Train( - 'train_net', - net_config['optimizer'], - net_config['loss'], + "train_net", + net_config["optimizer"], + net_config["loss"], inputs={ - net_config['raw']: raw, - net_config['gt_affs']: gt_affs, - net_config['loss_weights']: loss_weights + net_config["raw"]: raw, + net_config["gt_affs"]: gt_affs, + net_config["loss_weights"]: loss_weights, }, - outputs={ - net_config['pred_affs']: pred_affs - }, - gradients={ - net_config['pred_affs']: pred_affs_gradients - }, - save_every=1) + - + outputs={net_config["pred_affs"]: pred_affs}, + gradients={net_config["pred_affs"]: pred_affs_gradients}, + save_every=1, + ) + + # save the passing batch as an HDF5 file for inspection gp.Snapshot( { - raw: '/volumes/raw', - gt_labels: '/volumes/labels/neuron_ids', - gt_affs: '/volumes/labels/affs', - pred_affs: '/volumes/pred_affs', - pred_affs_gradients: '/volumes/pred_affs_gradients' + raw: "/volumes/raw", + gt_labels: "/volumes/labels/neuron_ids", + gt_affs: "/volumes/labels/affs", + pred_affs: "/volumes/pred_affs", + pred_affs_gradients: "/volumes/pred_affs_gradients", }, - output_dir='snapshots', - output_filename='batch_{iteration}.hdf', + output_dir="snapshots", + output_filename="batch_{iteration}.hdf", every=100, additional_request=snapshot_request, - compression_type='gzip') + - + compression_type="gzip", + ) + + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=10) ) @@ -180,6 +163,6 @@ def train(iterations): print("Finished") + if __name__ == "__main__": train(200000) - \ No newline at end of file diff --git a/gunpowder/array_spec.py b/gunpowder/array_spec.py index ec271488..9002ae4f 100644 --- a/gunpowder/array_spec.py +++ b/gunpowder/array_spec.py @@ -14,13 +14,12 @@ class ArraySpec(Freezable): roi (:class:`Roi`): The region of interested represented by this array spec. Can be - ``None`` for :class:`BatchProviders` that allow - requests for arrays everywhere, but will always be set for array - specs that are part of a :class:`Array`. + ``None`` for nonspatial arrays or to indicate the true value is unknown. voxel_size (:class:`Coordinate`): - The size of the spatial axises in world units. + The size of the spatial axises in world units. Can be ``None`` for + nonspatial arrays or to indicate the true value is unknown. interpolatable (``bool``): @@ -55,7 +54,7 @@ def __init__( if nonspatial: assert roi is None, "Non-spatial arrays can not have a ROI" - assert voxel_size is None, "Non-spatial arrays can not " "have a voxel size" + assert voxel_size is None, "Non-spatial arrays can not have a voxel size" self.freeze() diff --git a/gunpowder/ext/__init__.py b/gunpowder/ext/__init__.py index 7aec50c9..fdfcfa02 100644 --- a/gunpowder/ext/__init__.py +++ b/gunpowder/ext/__init__.py @@ -3,6 +3,7 @@ import traceback import sys +from typing import Optional, Any logger = logging.getLogger(__name__) @@ -58,6 +59,7 @@ def __getattr__(self, item): except ImportError as e: augment = NoSuchModule("augment") +ZarrFile: Optional[Any] = None try: import zarr from .zarr_file import ZarrFile diff --git a/gunpowder/jax/nodes/predict.py b/gunpowder/jax/nodes/predict.py index 4c46f233..496d0fd0 100644 --- a/gunpowder/jax/nodes/predict.py +++ b/gunpowder/jax/nodes/predict.py @@ -6,7 +6,7 @@ import pickle import logging -from typing import Dict, Union +from typing import Dict, Union, Optional logger = logging.getLogger(__name__) @@ -52,8 +52,8 @@ def __init__( model: GenericJaxModel, inputs: Dict[str, ArrayKey], outputs: Dict[Union[str, int], ArrayKey], - array_specs: Dict[ArrayKey, ArraySpec] = None, - checkpoint: str = None, + array_specs: Optional[Dict[ArrayKey, ArraySpec]] = None, + checkpoint: Optional[str] = None, spawn_subprocess=False, ): self.array_specs = array_specs if array_specs is not None else {} diff --git a/gunpowder/jax/nodes/train.py b/gunpowder/jax/nodes/train.py index 4d1f17a3..9621b129 100644 --- a/gunpowder/jax/nodes/train.py +++ b/gunpowder/jax/nodes/train.py @@ -11,7 +11,7 @@ from gunpowder.nodes.generic_train import GenericTrain from gunpowder.jax import GenericJaxModel -from typing import Dict, Union, Optional +from typing import Dict, Union, Optional, Any logger = logging.getLogger(__name__) @@ -108,7 +108,7 @@ def __init__( checkpoint_basename: str = "model", save_every: int = 2000, keep_n_checkpoints: Optional[int] = None, - log_dir: str = None, + log_dir: Optional[str] = None, log_every: int = 1, spawn_subprocess: bool = False, n_devices: Optional[int] = None, @@ -141,7 +141,7 @@ def __init__( if log_dir is not None: logger.warning("log_dir given, but tensorboardX is not installed") - self.intermediate_layers = {} + self.intermediate_layers: dict[ArrayKey, Any] = {} self.validate_fn = validate_fn self.validate_every = validate_every diff --git a/gunpowder/nodes/batch_provider.py b/gunpowder/nodes/batch_provider.py index dc641c8e..304e1e3a 100644 --- a/gunpowder/nodes/batch_provider.py +++ b/gunpowder/nodes/batch_provider.py @@ -3,6 +3,8 @@ import copy import logging import random +import traceback +from typing import Optional from gunpowder.coordinate import Coordinate from gunpowder.provider_spec import ProviderSpec @@ -15,17 +17,22 @@ class BatchRequestError(Exception): - def __init__(self, provider, request, batch): + def __init__( + self, provider, request, batch, original_traceback: Optional[list[str]] = None + ): self.provider = provider self.request = request self.batch = batch + self.original_traceback = original_traceback def __str__(self): return ( f"Exception in {self.provider.name()} while processing request" - f"{self.request} \n" + f"{self.request}" "Batch returned so far:\n" - f"{self.batch}" + f"{self.batch}" + ("\n\n" + "".join(self.original_traceback)) + if self.original_traceback is not None + else "" ) @@ -174,7 +181,6 @@ def request_batch(self, request): batch = None try: - self.set_seeds(request) logger.debug("%s got request %s", self.name(), request) @@ -195,7 +201,12 @@ def request_batch(self, request): logger.debug("%s provides %s", self.name(), batch) except Exception as e: - raise BatchRequestError(self, request, batch) from e + tb = traceback.format_exception(type(e), e, e.__traceback__) + if isinstance(e, BatchRequestError): + tb = tb[-1:] + raise BatchRequestError( + self, request, batch, original_traceback=tb + ) from None return batch diff --git a/gunpowder/nodes/deform_augment.py b/gunpowder/nodes/deform_augment.py index cdf5eeff..6d7e23af 100644 --- a/gunpowder/nodes/deform_augment.py +++ b/gunpowder/nodes/deform_augment.py @@ -21,6 +21,7 @@ import logging import math import random +from typing import Optional logger = logging.getLogger(__name__) @@ -93,8 +94,8 @@ def __init__( spatial_dims=3, use_fast_points_transform=False, recompute_missing_points=True, - transform_key: ArrayKey = None, - graph_raster_voxel_size: Coordinate = None, + transform_key: Optional[ArrayKey] = None, + graph_raster_voxel_size: Optional[Coordinate] = None, ): self.control_point_spacing = Coordinate(control_point_spacing) self.jitter_sigma = Coordinate(jitter_sigma) @@ -129,7 +130,6 @@ def setup(self): self.provides(self.transform_key, spec) def prepare(self, request): - # get the total ROI of all requests total_roi = request.get_total_roi() logger.debug("total ROI is %s" % total_roi) diff --git a/gunpowder/nodes/elastic_augment.py b/gunpowder/nodes/elastic_augment.py index a70f7866..d999d6fe 100644 --- a/gunpowder/nodes/elastic_augment.py +++ b/gunpowder/nodes/elastic_augment.py @@ -124,7 +124,6 @@ def __init__( self.recompute_missing_points = recompute_missing_points def prepare(self, request): - # get the voxel size self.voxel_size = self.__get_common_voxel_size(request) diff --git a/gunpowder/nodes/generic_predict.py b/gunpowder/nodes/generic_predict.py index 524967b8..e3f4ec5b 100644 --- a/gunpowder/nodes/generic_predict.py +++ b/gunpowder/nodes/generic_predict.py @@ -89,7 +89,7 @@ def setup(self): if self.spawn_subprocess: # start prediction as a producer pool, so that we can gracefully # exit if anything goes wrong - self.worker = ProducerPool([self.__produce_predict_batch], queue_size=1) + self.worker = ProducerPool([self._produce_predict_batch], queue_size=1) self.batch_in = multiprocessing.Queue(maxsize=1) self.batch_in_lock = multiprocessing.Lock() self.batch_out_lock = multiprocessing.Lock() @@ -177,7 +177,7 @@ def stop(self): """ pass - def __produce_predict_batch(self): + def _produce_predict_batch(self): """Process one batch.""" if not self.initialized: diff --git a/gunpowder/nodes/generic_train.py b/gunpowder/nodes/generic_train.py index ae93b7de..a26a285f 100644 --- a/gunpowder/nodes/generic_train.py +++ b/gunpowder/nodes/generic_train.py @@ -104,7 +104,7 @@ def setup(self): if self.spawn_subprocess: # start training as a producer pool, so that we can gracefully exit if # anything goes wrong - self.worker = ProducerPool([self.__produce_train_batch], queue_size=1) + self.worker = ProducerPool([self._produce_train_batch], queue_size=1) self.batch_in = multiprocessing.Queue(maxsize=1) self.worker.start() else: @@ -208,7 +208,7 @@ def natural_keys(text): return None, 0 - def __produce_train_batch(self): + def _produce_train_batch(self): """Process one train batch.""" if not self.initialized: diff --git a/gunpowder/nodes/noise_augment.py b/gunpowder/nodes/noise_augment.py index f4bfb5ba..5275a2c0 100644 --- a/gunpowder/nodes/noise_augment.py +++ b/gunpowder/nodes/noise_augment.py @@ -57,13 +57,11 @@ def process(self, batch, request): seed = request.random_seed try: - raw.data = skimage.util.random_noise( raw.data, mode=self.mode, rng=seed, clip=self.clip, **self.kwargs ).astype(raw.data.dtype) except ValueError: - # legacy version of skimage random_noise raw.data = skimage.util.random_noise( raw.data, mode=self.mode, seed=seed, clip=self.clip, **self.kwargs diff --git a/gunpowder/nodes/pad.py b/gunpowder/nodes/pad.py index 6bbfdc58..81150835 100644 --- a/gunpowder/nodes/pad.py +++ b/gunpowder/nodes/pad.py @@ -7,6 +7,8 @@ from gunpowder.coordinate import Coordinate from gunpowder.batch_request import BatchRequest +from itertools import product + logger = logging.getLogger(__name__) @@ -27,15 +29,22 @@ class Pad(BatchFilter): a coordinate, this amount will be added to the ROI in the positive and negative direction. + mode (string): + + One of 'constant' or 'reflect'. + Default is 'constant' + value (scalar or ``None``): The value to report inside the padding. If not given, 0 is used. + Only used in case of 'constant' mode. Only used for :class:`Array`. """ - def __init__(self, key, size, value=None): + def __init__(self, key, size, mode="constant", value=None): self.key = key self.size = size + self.mode = mode self.value = value def setup(self): @@ -119,18 +128,42 @@ def __expand(self, a, from_roi, to_roi, value): num_channels = len(a.shape) - from_roi.dims channel_shapes = a.shape[:num_channels] - b = np.zeros(channel_shapes + to_roi.shape, dtype=a.dtype) - if value != 0: - b[:] = value - shift = -to_roi.offset + if self.mode == "constant": + if value != 0: + b[:] = value + elif self.mode == "reflect": + if a.shape == b.shape: + pass # handled later + else: + diff = Coordinate(b.shape) - Coordinate(a.shape) + slices = [ + ( + (slice(None),) * num_channels + + tuple( + slice(diff[i], None) if d == 1 else slice(None, diff[i]) + for i, d in enumerate(selected_dims) + ), + (slice(None),) * num_channels + + tuple( + slice(None, diff[i]) if d == 1 else slice(None) + for i, d in enumerate(selected_dims) + ), + (slice(None),) * num_channels + + tuple( + slice(None, None, -1) if d == 1 else slice(None) + for i, d in enumerate(selected_dims) + ), + ) + for selected_dims in product((0, 1), repeat=from_roi.dims) + ] + for output_slices, input_slices, rev_slices in slices: + b[output_slices] = a[input_slices][rev_slices] + logger.debug("shifting 'from' by " + str(shift)) a_in_b = from_roi.shift(shift).to_slices() - logger.debug("target shape is " + str(b.shape)) logger.debug("target slice is " + str(a_in_b)) - b[(slice(None),) * num_channels + a_in_b] = a - return b diff --git a/gunpowder/nodes/random_location.py b/gunpowder/nodes/random_location.py index fccbd6cb..d5b6c1e2 100644 --- a/gunpowder/nodes/random_location.py +++ b/gunpowder/nodes/random_location.py @@ -172,7 +172,6 @@ def setup(self): self.provides(self.random_shift_key, ArraySpec(nonspatial=True)) def prepare(self, request): - logger.debug("request: %s", request.array_specs) logger.debug("my spec: %s", self.spec) @@ -383,9 +382,7 @@ def __select_random_location_with_points( logger.debug("belongs to lcm voxel %s", lcm_location) # align the point request ROI with lcm voxel grid - lcm_roi = request_points_roi.snap_to_grid( - lcm_voxel_size, - mode="shrink") + lcm_roi = request_points_roi.snap_to_grid(lcm_voxel_size, mode="shrink") lcm_roi = lcm_roi / lcm_voxel_size logger.debug("Point request ROI: %s", request_points_roi) logger.debug("Point request lcm ROI shape: %s", lcm_roi.shape) diff --git a/gunpowder/nodes/random_provider.py b/gunpowder/nodes/random_provider.py index dfb086f8..a9ae1081 100644 --- a/gunpowder/nodes/random_provider.py +++ b/gunpowder/nodes/random_provider.py @@ -69,7 +69,6 @@ def setup(self): self.provides(self.random_provider_key, ArraySpec(nonspatial=True)) def provide(self, request): - if self.random_provider_key is not None: del request[self.random_provider_key] diff --git a/gunpowder/nodes/reject.py b/gunpowder/nodes/reject.py index b6a47436..87bb83aa 100644 --- a/gunpowder/nodes/reject.py +++ b/gunpowder/nodes/reject.py @@ -55,7 +55,6 @@ def setup(self): self.upstream_provider = self.get_upstream_provider() def provide(self, request): - report_next_timeout = 10 num_rejected = 0 diff --git a/gunpowder/nodes/shift_augment.py b/gunpowder/nodes/shift_augment.py index 8fe6524b..8761a563 100644 --- a/gunpowder/nodes/shift_augment.py +++ b/gunpowder/nodes/shift_augment.py @@ -24,7 +24,6 @@ def __init__(self, prob_slip=0, prob_shift=0, sigma=0, shift_axis=0): self.lcm_voxel_size = None def prepare(self, request): - self.ndim = request.get_total_roi().dims assert self.shift_axis in range(self.ndim) diff --git a/gunpowder/nodes/simple_augment.py b/gunpowder/nodes/simple_augment.py index dc756e3a..f5a97333 100644 --- a/gunpowder/nodes/simple_augment.py +++ b/gunpowder/nodes/simple_augment.py @@ -106,7 +106,6 @@ def setup(self): self.permutation_dict[k] = v def prepare(self, request): - self.mirror = [ random.random() < self.mirror_probs[d] if self.mirror_mask[d] else 0 for d in range(self.dims) diff --git a/gunpowder/nodes/stack.py b/gunpowder/nodes/stack.py index 21f53acc..5d7feabd 100644 --- a/gunpowder/nodes/stack.py +++ b/gunpowder/nodes/stack.py @@ -25,7 +25,6 @@ def __init__(self, num_repetitions): self.num_repetitions = num_repetitions def provide(self, request): - batches = [] for _ in range(self.num_repetitions): upstream_request = request.copy() diff --git a/gunpowder/nodes/zarr_source.py b/gunpowder/nodes/zarr_source.py index 812769f3..b7133580 100644 --- a/gunpowder/nodes/zarr_source.py +++ b/gunpowder/nodes/zarr_source.py @@ -107,9 +107,8 @@ def _get_offset(self, dataset): def _rev_metadata(self): with ZarrFile(self.store, mode="a") as store: - return ( - isinstance(store.chunk_store, N5Store) or - isinstance(store.chunk_store, N5FSStore) + return isinstance(store.chunk_store, N5Store) or isinstance( + store.chunk_store, N5FSStore ) def _open_file(self, store): diff --git a/gunpowder/nodes/zarr_write.py b/gunpowder/nodes/zarr_write.py index 35965b6d..3beba3ae 100644 --- a/gunpowder/nodes/zarr_write.py +++ b/gunpowder/nodes/zarr_write.py @@ -5,10 +5,10 @@ from zarr import N5FSStore, N5Store from .batch_filter import BatchFilter +from gunpowder.array import ArrayKey from gunpowder.batch_request import BatchRequest from gunpowder.coordinate import Coordinate from gunpowder.roi import Roi -from gunpowder.coordinate import Coordinate from gunpowder.ext import ZarrFile import logging @@ -71,7 +71,7 @@ def __init__( else: self.dataset_dtypes = dataset_dtypes - self.dataset_offsets = {} + self.dataset_offsets: dict[ArrayKey, Coordinate] = {} def _get_voxel_size(self, dataset): if "resolution" not in dataset.attrs: diff --git a/gunpowder/producer_pool.py b/gunpowder/producer_pool.py index 035f0d74..0f6c2888 100644 --- a/gunpowder/producer_pool.py +++ b/gunpowder/producer_pool.py @@ -143,9 +143,7 @@ def _run_worker(self, target): try: result = target() except Exception as e: - logger.error(e, exc_info=True) result = e - traceback.print_exc() # don't stop on normal exceptions -- place them in result queue # and let them be handled by caller except: diff --git a/gunpowder/torch/nodes/predict.py b/gunpowder/torch/nodes/predict.py index 3e5ba8f1..89c9ac0c 100644 --- a/gunpowder/torch/nodes/predict.py +++ b/gunpowder/torch/nodes/predict.py @@ -4,7 +4,7 @@ from gunpowder.nodes.generic_predict import GenericPredict import logging -from typing import Dict, Union +from typing import Dict, Union, Optional, Any logger = logging.getLogger(__name__) @@ -60,8 +60,8 @@ def __init__( model, inputs: Dict[str, ArrayKey], outputs: Dict[Union[str, int], ArrayKey], - array_specs: Dict[ArrayKey, ArraySpec] = None, - checkpoint: str = None, + array_specs: Optional[Dict[ArrayKey, ArraySpec]] = None, + checkpoint: Optional[str] = None, device="cuda", spawn_subprocess=False, ): @@ -82,14 +82,16 @@ def __init__( self.model = model self.checkpoint = checkpoint - self.intermediate_layers = {} - self.register_hooks() + self.intermediate_layers: dict[ArrayKey, Any] = {} def start(self): - self.use_cuda = torch.cuda.is_available() and self.device_string == "cuda" - logger.info(f"Predicting on {'gpu' if self.use_cuda else 'cpu'}") - self.device = torch.device("cuda" if self.use_cuda else "cpu") + # Issue #188 + self.use_cuda = torch.cuda.is_available() and self.device_string.__contains__( + "cuda" + ) + logger.info(f"Predicting on {'gpu' if self.use_cuda else 'cpu'}") + self.device = torch.device(self.device_string if self.use_cuda else "cpu") try: self.model = self.model.to(self.device) except RuntimeError as e: @@ -106,6 +108,8 @@ def start(self): else: self.model.load_state_dict(checkpoint) + self.register_hooks() + def predict(self, batch, request): inputs = self.get_inputs(batch) with torch.no_grad(): diff --git a/gunpowder/torch/nodes/train.py b/gunpowder/torch/nodes/train.py index ed2df002..e913d8ad 100644 --- a/gunpowder/torch/nodes/train.py +++ b/gunpowder/torch/nodes/train.py @@ -6,7 +6,8 @@ from gunpowder.ext import torch, tensorboardX, NoSuchModule from gunpowder.nodes.generic_train import GenericTrain -from typing import Dict, Union, Optional +from typing import Dict, Union, Optional, Any +import itertools logger = logging.getLogger(__name__) @@ -78,6 +79,12 @@ class Train(GenericTrain): spawn_subprocess (``bool``, optional): Whether to run the ``train_step`` in a separate process. Default is false. + + device (``str``, optional): + + Accepts a cuda gpu specifically to train on (e.g. `cuda:1`, `cuda:2`), helps in multi-card systems. + defaults to ``cuda`` + """ def __init__( @@ -95,6 +102,7 @@ def __init__( log_dir: str = None, log_every: int = 1, spawn_subprocess: bool = False, + device: str = "cuda", ): if not model.training: logger.warning( @@ -104,12 +112,18 @@ def __init__( # not yet implemented gradients = gradients - inputs.update( - {k: v for k, v in loss_inputs.items() if v not in outputs.values()} - ) + all_inputs = { + k: v + for k, v in itertools.chain(inputs.items(), loss_inputs.items()) + if v not in outputs.values() + } super(Train, self).__init__( - inputs, outputs, gradients, array_specs, spawn_subprocess=spawn_subprocess + all_inputs, + outputs, + gradients, + array_specs, + spawn_subprocess=spawn_subprocess, ) self.model = model @@ -118,6 +132,7 @@ def __init__( self.loss_inputs = loss_inputs self.checkpoint_basename = checkpoint_basename self.save_every = save_every + self.dev = device self.iteration = 0 @@ -129,7 +144,7 @@ def __init__( if log_dir is not None: logger.warning("log_dir given, but tensorboardX is not installed") - self.intermediate_layers = {} + self.intermediate_layers: dict[ArrayKey, Any] = {} self.register_hooks() def register_hooks(self): @@ -160,7 +175,8 @@ def retain_gradients(self, request, outputs): def start(self): self.use_cuda = torch.cuda.is_available() - self.device = torch.device("cuda" if self.use_cuda else "cpu") + # Issue: #188 + self.device = torch.device(self.dev if self.use_cuda else "cpu") try: self.model = self.model.to(self.device) @@ -278,13 +294,6 @@ def train_step(self, batch, request): spec.roi = request[array_key].roi batch.arrays[array_key] = Array(tensor.grad.cpu().detach().numpy(), spec) - for array_key, array_name in requested_outputs.items(): - spec = self.spec[array_key].copy() - spec.roi = request[array_key].roi - batch.arrays[array_key] = Array( - outputs[array_name].cpu().detach().numpy(), spec - ) - batch.loss = loss.cpu().detach().numpy() self.iteration += 1 batch.iteration = self.iteration diff --git a/gunpowder/version_info.py b/gunpowder/version_info.py index 01724d07..e45efbfd 100644 --- a/gunpowder/version_info.py +++ b/gunpowder/version_info.py @@ -1,6 +1,6 @@ __major__ = 1 __minor__ = 3 -__patch__ = 1 +__patch__ = 2 __tag__ = "" __version__ = "{}.{}.{}{}".format(__major__, __minor__, __patch__, __tag__).strip(".") diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..6daa39e0 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,37 @@ +[mypy] + +# ext +[mypy-dvision.*] +ignore_missing_imports = True +[mypy-pyklb.*] +ignore_missing_imports = True +[mypy-malis.*] +ignore_missing_imports = True +[mypy-haiku.*] +ignore_missing_imports = True +[mypy-optax.*] +ignore_missing_imports = True + +# dependencies +[mypy-tensorflow.*] +ignore_missing_imports = True +[mypy-tensorboardX.*] +ignore_missing_imports = True +[mypy-torch.*] +ignore_missing_imports = True +[mypy-jax.*] +ignore_missing_imports = True +[mypy-daisy.*] +ignore_missing_imports = True +[mypy-scipy.*] +ignore_missing_imports = True +[mypy-h5py.*] +ignore_missing_imports = True +[mypy-augment.*] +ignore_missing_imports = True +[mypy-zarr.*] +ignore_missing_imports = True +[mypy-networkx.*] +ignore_missing_imports = True +[mypy-Queue.*] +ignore_missing_imports = True \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 11ab82bc..a389a33c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,77 +6,71 @@ requires = ["setuptools", "wheel"] name = "gunpowder" description = "A library to facilitate machine learning on large, multi-dimensional images." authors = [ - {name = "Jan Funke", email = "funkej@hhmi.org"}, - {name = "William Patton", email = "pattonw@hhmi.org"}, - {name = "Renate Krause"}, - {name = "Julia Buhmann"}, - {name = "Rodrigo Ceballos Lentini"}, - {name = "William Grisaitis"}, - {name = "Chris Barnes"}, - {name = "Caroline Malin-Mayor"}, - {name = "Larissa Heinrich"}, - {name = "Philipp Hanslovsky"}, - {name = "Sherry Ding"}, - {name = "Andrew Champion"}, - {name = "Arlo Sheridan"}, - {name = "Constantin Pape"}, + { name = "Jan Funke", email = "funkej@hhmi.org" }, + { name = "William Patton", email = "pattonw@hhmi.org" }, + { name = "Renate Krause" }, + { name = "Julia Buhmann" }, + { name = "Rodrigo Ceballos Lentini" }, + { name = "William Grisaitis" }, + { name = "Chris Barnes" }, + { name = "Caroline Malin-Mayor" }, + { name = "Larissa Heinrich" }, + { name = "Philipp Hanslovsky" }, + { name = "Sherry Ding" }, + { name = "Andrew Champion" }, + { name = "Arlo Sheridan" }, + { name = "Constantin Pape" }, ] -license = {text = "MIT"} +license = { text = "MIT" } readme = "README.md" dynamic = ["version"] -classifiers = [ - "Programming Language :: Python :: 3", -] +classifiers = ["Programming Language :: Python :: 3"] keywords = [] requires-python = ">=3.7" dependencies = [ - "numpy", - "scipy", - "h5py", - "scikit-image", - "requests", - "augment-nd>=0.1.3", - "tqdm", - "funlib.geometry", - "zarr", - "networkx", + "numpy>=1.24", + "scipy>=1.6", + "h5py>=3.10", + "scikit-image", + "requests", + "augment-nd>=0.1.3", + "tqdm", + "funlib.geometry>=0.2", + "zarr", + "networkx>=3.1", ] [project.optional-dependencies] -dev = [ - "pytest", - "pytest-cov", - "flake8", -] +dev = ["pytest", "pytest-cov", "flake8", "mypy", "types-requests", "types-tqdm"] docs = [ - "sphinx", - "sphinx_rtd_theme", - "sphinx_togglebutton", - "tomli", - "jupyter_sphinx", - "ipykernel", - "matplotlib", - "torch", + "sphinx", + "sphinx_rtd_theme", + "sphinx_togglebutton", + "tomli", + "jupyter_sphinx", + "ipykernel", + "matplotlib", + "torch", ] pytorch = ['torch'] tensorflow = [ - # TF doesn't provide <2.0 wheels for py>=3.8 on pypi - 'tensorflow<2.0; python_version<"3.8"', # https://stackoverflow.com/a/72493690 - 'protobuf==3.20.*; python_version=="3.7"', + # TF doesn't provide <2.0 wheels for py>=3.8 on pypi + 'tensorflow<2.0; python_version<"3.8"', # https://stackoverflow.com/a/72493690 + 'protobuf==3.20.*; python_version=="3.7"', ] full = [ - 'torch', - 'tensorflow<2.0; python_version<"3.8"', - 'protobuf==3.20.*; python_version=="3.7"', + 'torch', + 'tensorflow<2.0; python_version<"3.8"', + 'protobuf==3.20.*; python_version=="3.7"', ] [tool.setuptools.dynamic] -version = {attr = "gunpowder.version_info.__version__"} +version = { attr = "gunpowder.version_info.__version__" } [tool.black] -target_version = ['py36', 'py37', 'py38', 'py39', 'py310'] +target_version = ['py38', 'py39', 'py310'] [tool.setuptools.packages.find] include = ["gunpowder*"] diff --git a/tests/cases/deform_augment.py b/tests/cases/deform_augment.py index 2134a708..f722b0bb 100644 --- a/tests/cases/deform_augment.py +++ b/tests/cases/deform_augment.py @@ -160,6 +160,9 @@ def test_3d_basics(rotate, spatial_dims, fast_points): loc = (loc - labels.spec.roi.begin) / labels.spec.voxel_size loc = np.array(loc) com = center_of_mass(labels.data == node.id) + if any(np.isnan(com)): + # cannot assume that the rasterized data will exist after defomation + continue assert ( np.linalg.norm(com - loc) < np.linalg.norm(labels.spec.voxel_size) * 2 diff --git a/tests/cases/helper_sources.py b/tests/cases/helper_sources.py index 219044b1..630333d6 100644 --- a/tests/cases/helper_sources.py +++ b/tests/cases/helper_sources.py @@ -13,7 +13,10 @@ def setup(self): def provide(self, request): outputs = Batch() - outputs[self.key] = copy.deepcopy(self.array.crop(request[self.key].roi)) + if self.array.spec.nonspatial: + outputs[self.key] = copy.deepcopy(self.array) + else: + outputs[self.key] = copy.deepcopy(self.array.crop(request[self.key].roi)) return outputs diff --git a/tests/cases/pad.py b/tests/cases/pad.py index 8b7ab179..c66332fe 100644 --- a/tests/cases/pad.py +++ b/tests/cases/pad.py @@ -1,71 +1,75 @@ -from .provider_test import ProviderTest +from .helper_sources import ArraySource, GraphSource from gunpowder import ( - BatchProvider, BatchRequest, - Batch, - ArrayKeys, ArraySpec, Roi, Coordinate, + Graph, GraphKey, - GraphKeys, GraphSpec, Array, ArrayKey, Pad, build, + MergeProvider, ) -import numpy as np - -class ExampleSourcePad(BatchProvider): - def setup(self): - self.provides( - ArrayKeys.TEST_LABELS, - ArraySpec(roi=Roi((200, 20, 20), (1800, 180, 180)), voxel_size=(20, 2, 2)), - ) +import pytest +import numpy as np - self.provides( - GraphKeys.TEST_GRAPH, GraphSpec(roi=Roi((200, 20, 20), (1800, 180, 180))) - ) +from itertools import product - def provide(self, request): - batch = Batch() - roi_array = request[ArrayKeys.TEST_LABELS].roi - roi_voxel = roi_array // self.spec[ArrayKeys.TEST_LABELS].voxel_size +@pytest.mark.parametrize("mode", ["constant", "reflect"]) +def test_output(mode): + array_key = ArrayKey("TEST_ARRAY") + graph_key = GraphKey("TEST_GRAPH") - data = np.zeros(roi_voxel.shape, dtype=np.uint32) - data[:, ::2] = 100 + array_spec = ArraySpec( + roi=Roi((200, 20, 20), (1800, 180, 180)), voxel_size=(20, 2, 2) + ) + roi_voxel = array_spec.roi / array_spec.voxel_size + data = np.zeros(roi_voxel.shape, dtype=np.uint32) + data[:, ::2] = 100 + array = Array(data, spec=array_spec) - spec = self.spec[ArrayKeys.TEST_LABELS].copy() - spec.roi = roi_array - batch.arrays[ArrayKeys.TEST_LABELS] = Array(data, spec=spec) + graph_spec = GraphSpec(roi=Roi((200, 20, 20), (1800, 180, 180))) + graph = Graph([], [], graph_spec) - return batch + source = ( + ArraySource(array_key, array), + GraphSource(graph_key, graph), + ) + MergeProvider() + pipeline = ( + source + + Pad(array_key, Coordinate((20, 20, 20)), value=1, mode=mode) + + Pad(graph_key, Coordinate((10, 10, 10)), mode=mode) + ) -class TestPad(ProviderTest): - def test_output(self): - graph = GraphKey("TEST_GRAPH") - labels = ArrayKey("TEST_LABELS") + with build(pipeline): + assert pipeline.spec[array_key].roi == Roi((180, 0, 0), (1840, 220, 220)) + assert pipeline.spec[graph_key].roi == Roi((190, 10, 10), (1820, 200, 200)) - pipeline = ( - ExampleSourcePad() - + Pad(labels, Coordinate((20, 20, 20)), value=1) - + Pad(graph, Coordinate((10, 10, 10))) + batch = pipeline.request_batch( + BatchRequest({array_key: ArraySpec(Roi((180, 0, 0), (40, 40, 40)))}) ) - with build(pipeline): - self.assertTrue( - pipeline.spec[labels].roi == Roi((180, 0, 0), (1840, 220, 220)) - ) - self.assertTrue( - pipeline.spec[graph].roi == Roi((190, 10, 10), (1820, 200, 200)) + data = batch.arrays[array_key].data + if mode == "constant": + octants = [ + (1 * 10 * 10) if zi + yi + xi < 3 else 100 * 1 * 5 * 10 + for zi, yi, xi in product(range(2), range(2), range(2)) + ] + assert np.sum(data) == np.sum(octants), ( + np.sum(data), + np.sum(octants), + np.unique(data), ) - - batch = pipeline.request_batch( - BatchRequest({labels: ArraySpec(Roi((180, 0, 0), (20, 20, 20)))}) + elif mode == "reflect": + octants = [100 * 1 * 5 * 10 for _ in range(8)] + assert np.sum(data) == np.sum(octants), ( + np.sum(data), + np.sum(octants), + data, ) - - self.assertEqual(np.sum(batch.arrays[labels].data), 1 * 10 * 10) diff --git a/tests/cases/torch_train.py b/tests/cases/torch_train.py index c213a1c9..b368f9d6 100644 --- a/tests/cases/torch_train.py +++ b/tests/cases/torch_train.py @@ -1,230 +1,217 @@ -from .provider_test import ProviderTest +from .helper_sources import ArraySource from gunpowder import ( - BatchProvider, BatchRequest, ArraySpec, Roi, - Coordinate, - ArrayKeys, ArrayKey, Array, - Batch, Scan, PreCache, + MergeProvider, build, ) from gunpowder.ext import torch, NoSuchModule from gunpowder.torch import Train, Predict -from unittest import skipIf, expectedFailure +from unittest import skipIf import numpy as np +import pytest import logging -class ExampleTorchTrain2DSource(BatchProvider): - def __init__(self): - pass +# Example 2D source +def example_2d_source(array_key: ArrayKey): + array_spec = ArraySpec( + roi=Roi((0, 0), (17, 17)), + dtype=np.float32, + interpolatable=True, + voxel_size=(1, 1), + ) + data = np.array(list(range(17)), dtype=np.float32).reshape([17, 1]) + data = data + data.T + array = Array(data, array_spec) + return ArraySource(array_key, array) - def setup(self): - spec = ArraySpec( - roi=Roi((0, 0), (17, 17)), - dtype=np.float32, - interpolatable=True, - voxel_size=(1, 1), - ) - self.provides(ArrayKeys.A, spec) - def provide(self, request): - batch = Batch() +def example_train_source(a_key, b_key, c_key): + spec1 = ArraySpec( + roi=Roi((0, 0), (2, 2)), + dtype=np.float32, + interpolatable=True, + voxel_size=(1, 1), + ) + spec2 = ArraySpec(nonspatial=True) - spec = self.spec[ArrayKeys.A] + data1 = np.array([[0, 1], [2, 3]], dtype=np.float32) + data2 = np.array([1], dtype=np.float32) - x = np.array(list(range(17)), dtype=np.float32).reshape([17, 1]) - x = x + x.T + source_a = ArraySource(a_key, Array(data1, spec1)) + source_b = ArraySource(b_key, Array(data1, spec1)) + source_c = ArraySource(c_key, Array(data2, spec2)) - batch.arrays[ArrayKeys.A] = Array(x, spec).crop(request[ArrayKeys.A].roi) + return (source_a, source_b, source_c) + MergeProvider() - return batch - - -class ExampleTorchTrainSource(BatchProvider): - def setup(self): - spec = ArraySpec( - roi=Roi((0, 0), (2, 2)), - dtype=np.float32, - interpolatable=True, - voxel_size=(1, 1), - ) - self.provides(ArrayKeys.A, spec) - self.provides(ArrayKeys.B, spec) - - spec = ArraySpec(nonspatial=True) - self.provides(ArrayKeys.C, spec) - - def provide(self, request): - batch = Batch() - - spec = self.spec[ArrayKeys.A] - spec.roi = request[ArrayKeys.A].roi - - batch.arrays[ArrayKeys.A] = Array( - np.array([[0, 1], [2, 3]], dtype=np.float32), spec - ) - - spec = self.spec[ArrayKeys.B] - spec.roi = request[ArrayKeys.B].roi - - batch.arrays[ArrayKeys.B] = Array( - np.array([[0, 1], [2, 3]], dtype=np.float32), spec - ) - - spec = self.spec[ArrayKeys.C] - - batch.arrays[ArrayKeys.C] = Array(np.array([1], dtype=np.float32), spec) - - return batch +if not isinstance(torch, NoSuchModule): + class ExampleLinearModel(torch.nn.Module): + def __init__(self): + super(ExampleLinearModel, self).__init__() + self.linear = torch.nn.Linear(4, 1, False) + self.linear.weight.data = torch.Tensor([0, 1, 2, 3]) + + def forward(self, a, b): + a = a.reshape(-1) + b = b.reshape(-1) + c_pred = self.linear(a * b) + d_pred = c_pred * 2 + return d_pred + + +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda:0", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ), + ), + ], +) @skipIf(isinstance(torch, NoSuchModule), "torch is not installed") -class TestTorchTrain(ProviderTest): - def test_output(self): - logging.getLogger("gunpowder.torch.nodes.train").setLevel(logging.INFO) - - checkpoint_basename = self.path_to("model") - - ArrayKey("A") - ArrayKey("B") - ArrayKey("C") - ArrayKey("C_PREDICTED") - ArrayKey("C_GRADIENT") - - class ExampleModel(torch.nn.Module): - def __init__(self): - super(ExampleModel, self).__init__() - self.linear = torch.nn.Linear(4, 1, False) - - def forward(self, a, b): - a = a.reshape(-1) - b = b.reshape(-1) - return self.linear(a * b) - - model = ExampleModel() - loss = torch.nn.MSELoss() - optimizer = torch.optim.SGD(model.parameters(), lr=1e-7, momentum=0.999) - - source = ExampleTorchTrainSource() - train = Train( - model=model, - optimizer=optimizer, - loss=loss, - inputs={"a": ArrayKeys.A, "b": ArrayKeys.B}, - loss_inputs={0: ArrayKeys.C_PREDICTED, 1: ArrayKeys.C}, - outputs={0: ArrayKeys.C_PREDICTED}, - gradients={0: ArrayKeys.C_GRADIENT}, - array_specs={ - ArrayKeys.C_PREDICTED: ArraySpec(nonspatial=True), - ArrayKeys.C_GRADIENT: ArraySpec(nonspatial=True), - }, - checkpoint_basename=checkpoint_basename, - save_every=100, - spawn_subprocess=True, - ) - pipeline = source + train - - request = BatchRequest( - { - ArrayKeys.A: ArraySpec(roi=Roi((0, 0), (2, 2))), - ArrayKeys.B: ArraySpec(roi=Roi((0, 0), (2, 2))), - ArrayKeys.C: ArraySpec(nonspatial=True), - ArrayKeys.C_PREDICTED: ArraySpec(nonspatial=True), - ArrayKeys.C_GRADIENT: ArraySpec(nonspatial=True), - } - ) - - # train for a couple of iterations - with build(pipeline): +def test_loss_drops(tmpdir, device): + checkpoint_basename = str(tmpdir / "model") + + a_key = ArrayKey("A") + b_key = ArrayKey("B") + c_key = ArrayKey("C") + c_predicted_key = ArrayKey("C_PREDICTED") + c_gradient_key = ArrayKey("C_GRADIENT") + + model = ExampleLinearModel() + loss = torch.nn.MSELoss() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.999) + + source = example_train_source(a_key, b_key, c_key) + train = Train( + model=model, + optimizer=optimizer, + loss=loss, + inputs={"a": a_key, "b": b_key}, + loss_inputs={0: c_predicted_key, 1: c_key}, + outputs={0: c_predicted_key}, + gradients={0: c_gradient_key}, + array_specs={ + c_predicted_key: ArraySpec(nonspatial=True), + c_gradient_key: ArraySpec(nonspatial=True), + }, + checkpoint_basename=checkpoint_basename, + save_every=100, + spawn_subprocess=False, + device=device, + ) + pipeline = source + train + + request = BatchRequest( + { + a_key: ArraySpec(roi=Roi((0, 0), (2, 2))), + b_key: ArraySpec(roi=Roi((0, 0), (2, 2))), + c_key: ArraySpec(nonspatial=True), + c_predicted_key: ArraySpec(nonspatial=True), + c_gradient_key: ArraySpec(nonspatial=True), + } + ) + + # train for a couple of iterations + with build(pipeline): + batch = pipeline.request_batch(request) + + for i in range(200 - 1): + loss1 = batch.loss batch = pipeline.request_batch(request) + loss2 = batch.loss + assert loss2 < loss1 - for i in range(200 - 1): - loss1 = batch.loss - batch = pipeline.request_batch(request) - loss2 = batch.loss - self.assertLess(loss2, loss1) - - # resume training - with build(pipeline): - for i in range(100): - loss1 = batch.loss - batch = pipeline.request_batch(request) - loss2 = batch.loss - self.assertLess(loss2, loss1) - - + # resume training + with build(pipeline): + for i in range(100): + loss1 = batch.loss + batch = pipeline.request_batch(request) + loss2 = batch.loss + assert loss2 < loss1 + + +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda:0", + marks=[ + pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ), + pytest.mark.xfail( + reason="failing to move model to device when using a subprocess" + ), + ], + ), + ], +) @skipIf(isinstance(torch, NoSuchModule), "torch is not installed") -class TestTorchPredict(ProviderTest): - def test_output(self): - logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) - - a = ArrayKey("A") - b = ArrayKey("B") - c = ArrayKey("C") - c_pred = ArrayKey("C_PREDICTED") - d_pred = ArrayKey("D_PREDICTED") - - class ExampleModel(torch.nn.Module): - def __init__(self): - super(ExampleModel, self).__init__() - self.linear = torch.nn.Linear(4, 1, False) - self.linear.weight.data = torch.Tensor([1, 1, 1, 1]) - - def forward(self, a, b): - a = a.reshape(-1) - b = b.reshape(-1) - c_pred = self.linear(a * b) - d_pred = c_pred * 2 - return d_pred - - model = ExampleModel() - - source = ExampleTorchTrainSource() - predict = Predict( - model=model, - inputs={"a": a, "b": b}, - outputs={"linear": c_pred, 0: d_pred}, - array_specs={ - c: ArraySpec(nonspatial=True), - c_pred: ArraySpec(nonspatial=True), - d_pred: ArraySpec(nonspatial=True), - }, - spawn_subprocess=True, - ) - pipeline = source + predict - - request = BatchRequest( - { - a: ArraySpec(roi=Roi((0, 0), (2, 2))), - b: ArraySpec(roi=Roi((0, 0), (2, 2))), - c: ArraySpec(nonspatial=True), - c_pred: ArraySpec(nonspatial=True), - d_pred: ArraySpec(nonspatial=True), - } - ) - - # train for a couple of iterations - with build(pipeline): - batch1 = pipeline.request_batch(request) - batch2 = pipeline.request_batch(request) - - assert np.isclose(batch1[c_pred].data, batch2[c_pred].data) - assert np.isclose(batch1[c_pred].data, 1 + 4 + 9) - assert np.isclose(batch2[d_pred].data, 2 * (1 + 4 + 9)) +def test_output(device): + logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) + + a_key = ArrayKey("A") + b_key = ArrayKey("B") + c_key = ArrayKey("C") + c_pred = ArrayKey("C_PREDICTED") + d_pred = ArrayKey("D_PREDICTED") + + model = ExampleLinearModel() + + source = example_train_source(a_key, b_key, c_key) + predict = Predict( + model=model, + inputs={"a": a_key, "b": b_key}, + outputs={"linear": c_pred, 0: d_pred}, + array_specs={ + c_key: ArraySpec(nonspatial=True), + c_pred: ArraySpec(nonspatial=True), + d_pred: ArraySpec(nonspatial=True), + }, + spawn_subprocess=True, + device=device, + ) + pipeline = source + predict + + request = BatchRequest( + { + a_key: ArraySpec(roi=Roi((0, 0), (2, 2))), + b_key: ArraySpec(roi=Roi((0, 0), (2, 2))), + c_key: ArraySpec(nonspatial=True), + c_pred: ArraySpec(nonspatial=True), + d_pred: ArraySpec(nonspatial=True), + } + ) + + # train for a couple of iterations + with build(pipeline): + batch1 = pipeline.request_batch(request) + batch2 = pipeline.request_batch(request) + + assert np.isclose(batch1[c_pred].data, batch2[c_pred].data) + assert np.isclose(batch1[c_pred].data, 1 + 4 * 2 + 9 * 3) + assert np.isclose(batch2[d_pred].data, 2 * (1 + 4 * 2 + 9 * 3)) if not isinstance(torch, NoSuchModule): - class ExampleModel(torch.nn.Module): + class Example2DModel(torch.nn.Module): def __init__(self): - super(ExampleModel, self).__init__() + super(Example2DModel, self).__init__() self.linear = torch.nn.Conv2d(1, 1, 3) def forward(self, a): @@ -235,70 +222,107 @@ def forward(self, a): return pred +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda:0", + marks=[ + pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ), + pytest.mark.xfail( + reason="failing to move model to device in multiprocessing context" + ), + ], + ), + ], +) @skipIf(isinstance(torch, NoSuchModule), "torch is not installed") -class TestTorchPredictMultiprocessing(ProviderTest): - def test_scan(self): - logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) - - a = ArrayKey("A") - pred = ArrayKey("PRED") - - model = ExampleModel() - - reference_request = BatchRequest() - reference_request[a] = ArraySpec(roi=Roi((0, 0), (7, 7))) - reference_request[pred] = ArraySpec(roi=Roi((1, 1), (5, 5))) - - source = ExampleTorchTrain2DSource() - predict = Predict( - model=model, - inputs={"a": a}, - outputs={0: pred}, - array_specs={pred: ArraySpec()}, - ) - pipeline = source + predict + Scan(reference_request, num_workers=2) - - request = BatchRequest( - { - a: ArraySpec(roi=Roi((0, 0), (17, 17))), - pred: ArraySpec(roi=Roi((0, 0), (15, 15))), - } - ) - - # train for a couple of iterations - with build(pipeline): - batch = pipeline.request_batch(request) - assert pred in batch - - def test_precache(self): - logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) - - a = ArrayKey("A") - pred = ArrayKey("PRED") - - model = ExampleModel() - - reference_request = BatchRequest() - reference_request[a] = ArraySpec(roi=Roi((0, 0), (7, 7))) - reference_request[pred] = ArraySpec(roi=Roi((1, 1), (5, 5))) - - source = ExampleTorchTrain2DSource() - predict = Predict( - model=model, - inputs={"a": a}, - outputs={0: pred}, - array_specs={pred: ArraySpec()}, - ) - pipeline = source + predict + PreCache(cache_size=3, num_workers=2) - - request = BatchRequest( - { - a: ArraySpec(roi=Roi((0, 0), (17, 17))), - pred: ArraySpec(roi=Roi((0, 0), (15, 15))), - } - ) - - # train for a couple of iterations - with build(pipeline): - batch = pipeline.request_batch(request) - assert pred in batch +def test_scan(device): + logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) + + a_key = ArrayKey("A") + pred = ArrayKey("PRED") + + model = Example2DModel() + + reference_request = BatchRequest() + reference_request[a_key] = ArraySpec(roi=Roi((0, 0), (7, 7))) + reference_request[pred] = ArraySpec(roi=Roi((1, 1), (5, 5))) + + source = example_2d_source(a_key) + predict = Predict( + model=model, + inputs={"a": a_key}, + outputs={0: pred}, + array_specs={pred: ArraySpec()}, + device=device, + ) + pipeline = source + predict + Scan(reference_request, num_workers=2) + + request = BatchRequest( + { + a_key: ArraySpec(roi=Roi((0, 0), (17, 17))), + pred: ArraySpec(roi=Roi((0, 0), (15, 15))), + } + ) + + # train for a couple of iterations + with build(pipeline): + batch = pipeline.request_batch(request) + assert pred in batch + + +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda:0", + marks=[ + pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ), + pytest.mark.xfail( + reason="failing to move model to device in multiprocessing context" + ), + ], + ), + ], +) +@skipIf(isinstance(torch, NoSuchModule), "torch is not installed") +def test_precache(device): + logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) + + a_key = ArrayKey("A") + pred = ArrayKey("PRED") + + model = Example2DModel() + + reference_request = BatchRequest() + reference_request[a_key] = ArraySpec(roi=Roi((0, 0), (7, 7))) + reference_request[pred] = ArraySpec(roi=Roi((1, 1), (5, 5))) + + source = example_2d_source(a_key) + predict = Predict( + model=model, + inputs={"a": a_key}, + outputs={0: pred}, + array_specs={pred: ArraySpec()}, + device=device, + ) + pipeline = source + predict + PreCache(cache_size=3, num_workers=2) + + request = BatchRequest( + { + a_key: ArraySpec(roi=Roi((0, 0), (17, 17))), + pred: ArraySpec(roi=Roi((0, 0), (15, 15))), + } + ) + + # train for a couple of iterations + with build(pipeline): + batch = pipeline.request_batch(request) + assert pred in batch diff --git a/tests/conftest.py b/tests/conftest.py index a8f65ea1..1386c6b8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,10 +6,10 @@ # cannot parametrize unittest.TestCase. We should test both # fork and spawn but I'm not sure how to. # @pytest.fixture(params=["fork", "spawn"], autouse=True) -@pytest.fixture(autouse=True) -def context(monkeypatch): - ctx = mp.get_context("spawn") - monkeypatch.setattr(mp, "Queue", ctx.Queue) - monkeypatch.setattr(mp, "Process", ctx.Process) - monkeypatch.setattr(mp, "Event", ctx.Event) - monkeypatch.setattr(mp, "Value", ctx.Value) +# @pytest.fixture(autouse=True) +# def context(monkeypatch): +# ctx = mp.get_context("spawn") +# monkeypatch.setattr(mp, "Queue", ctx.Queue) +# monkeypatch.setattr(mp, "Process", ctx.Process) +# monkeypatch.setattr(mp, "Event", ctx.Event) +# monkeypatch.setattr(mp, "Value", ctx.Value)