From a6de9a27a43a985a78a67eaa83f2500efd7add20 Mon Sep 17 00:00:00 2001 From: Jan Funke Date: Wed, 29 Nov 2023 14:38:12 -0500 Subject: [PATCH] Add progress callback to Scan node --- docs/source/api.rst | 1 + gunpowder/nodes/__init__.py | 2 +- gunpowder/nodes/scan.py | 78 +++++++++++++++++++++++++++++++++++-- 3 files changed, 76 insertions(+), 5 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 5f120a1c..21b7753b 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -334,6 +334,7 @@ Iterative Processing Nodes Scan ^^^^ .. autoclass:: Scan + .. autoclass:: ScanCallback DaisyRequestBlocks ^^^^^^^^^^^^^^^^^^ diff --git a/gunpowder/nodes/__init__.py b/gunpowder/nodes/__init__.py index 4131b824..1a152f8e 100644 --- a/gunpowder/nodes/__init__.py +++ b/gunpowder/nodes/__init__.py @@ -34,7 +34,7 @@ from .reject import Reject from .renumber_connected_components import RenumberConnectedComponents from .resample import Resample -from .scan import Scan +from .scan import Scan, ScanCallback from .shift_augment import ShiftAugment from .simple_augment import SimpleAugment from .snapshot import Snapshot diff --git a/gunpowder/nodes/scan.py b/gunpowder/nodes/scan.py index ef6b378e..3473764e 100644 --- a/gunpowder/nodes/scan.py +++ b/gunpowder/nodes/scan.py @@ -2,6 +2,7 @@ import multiprocessing import numpy as np import tqdm +from abc import ABC from gunpowder.array import Array from gunpowder.batch import Batch from gunpowder.coordinate import Coordinate @@ -13,6 +14,55 @@ logger = logging.getLogger(__name__) +class ScanCallback(ABC): + """Base class for :class:`Scan` callbacks. Implement any of ``start``, + ``update``, and ``stop`` in a subclass to create your own callback. + """ + + def start(self, num_total): + """Called once before :class:`Scan` starts scanning over chunks. + + Args: + + num_total (int): + + The total number of chunks to process. + """ + pass + + def update(self, num_processed): + """Called periodically by :class:`Scan` while processing chunks. + + Args: + + num_processed (int): + + The number of chunks already processed. + """ + pass + + def stop(self): + """Called once after :class:`Scan` scanned over all chunks.""" + pass + + +class TqdmCallback(ScanCallback): + """A default callback that uses ``tqdm`` to show a progress bar.""" + + def start(self, num_total): + logger.info("scanning over %d chunks", num_total) + + self.progress_bar = tqdm.tqdm(desc="Scan, chunks processed", total=num_total) + self.num_processed = 0 + + def update(self, num_processed): + self.progress_bar.update(num_processed - self.num_processed) + self.num_processed = num_processed + + def stop(self): + self.progress_bar.close() + + class Scan(BatchFilter): """Iteratively requests batches of size ``reference`` from upstream providers in a scanning fashion, until all requested ROIs are covered. If @@ -40,14 +90,24 @@ class Scan(BatchFilter): cache_size (``int``, optional): If multiple workers are used, how many batches to hold at most. + + progress_callback (class:`ScanCallback`, optional): + + A callback instance to get updated from this node while processing + chunks. See :class:`ScanCallback` for details. The default is a + callback that shows a ``tqdm`` progress bar. """ - def __init__(self, reference, num_workers=1, cache_size=50): + def __init__(self, reference, num_workers=1, cache_size=50, progress_callback=None): self.reference = reference.copy() self.num_workers = num_workers self.cache_size = cache_size self.workers = None self.batch = None + if progress_callback is None: + self.progress_callback = TqdmCallback() + else: + self.progress_callback = progress_callback def setup(self): if self.num_workers > 1: @@ -75,7 +135,8 @@ def provide(self, request): shifts = self._enumerate_shifts(shift_roi, stride) num_chunks = len(shifts) - logger.info("scanning over %d chunks", num_chunks) + if self.progress_callback is not None: + self.progress_callback.start(num_chunks) # the batch to return self.batch = Batch() @@ -85,24 +146,33 @@ def provide(self, request): shifted_reference = self._shift_request(self.reference, shift) self.request_queue.put(shifted_reference) - for i in tqdm.tqdm(range(num_chunks)): + for i in range(num_chunks): chunk = self.workers.get() if not empty_request: self._add_to_batch(request, chunk) + if self.progress_callback is not None: + self.progress_callback.update(i + 1) + logger.debug("processed chunk %d/%d", i + 1, num_chunks) else: - for i, shift in enumerate(tqdm.tqdm(shifts)): + for i, shift in enumerate(shifts): shifted_reference = self._shift_request(self.reference, shift) chunk = self._get_chunk(shifted_reference) if not empty_request: self._add_to_batch(request, chunk) + if self.progress_callback is not None: + self.progress_callback.update(i + 1) + logger.debug("processed chunk %d/%d", i + 1, num_chunks) + if self.progress_callback is not None: + self.progress_callback.stop() + batch = self.batch self.batch = None