Skip to content

Commit

Permalink
Add progress callback to Scan node
Browse files Browse the repository at this point in the history
  • Loading branch information
funkey committed Nov 29, 2023
1 parent a33e5b4 commit a6de9a2
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ Iterative Processing Nodes
Scan
^^^^
.. autoclass:: Scan
.. autoclass:: ScanCallback

DaisyRequestBlocks
^^^^^^^^^^^^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion gunpowder/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .reject import Reject
from .renumber_connected_components import RenumberConnectedComponents
from .resample import Resample
from .scan import Scan
from .scan import Scan, ScanCallback
from .shift_augment import ShiftAugment
from .simple_augment import SimpleAugment
from .snapshot import Snapshot
Expand Down
78 changes: 74 additions & 4 deletions gunpowder/nodes/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import multiprocessing
import numpy as np
import tqdm
from abc import ABC
from gunpowder.array import Array
from gunpowder.batch import Batch
from gunpowder.coordinate import Coordinate
Expand All @@ -13,6 +14,55 @@
logger = logging.getLogger(__name__)


class ScanCallback(ABC):
"""Base class for :class:`Scan` callbacks. Implement any of ``start``,
``update``, and ``stop`` in a subclass to create your own callback.
"""

def start(self, num_total):
"""Called once before :class:`Scan` starts scanning over chunks.
Args:
num_total (int):
The total number of chunks to process.
"""
pass

def update(self, num_processed):
"""Called periodically by :class:`Scan` while processing chunks.
Args:
num_processed (int):
The number of chunks already processed.
"""
pass

def stop(self):
"""Called once after :class:`Scan` scanned over all chunks."""
pass


class TqdmCallback(ScanCallback):
"""A default callback that uses ``tqdm`` to show a progress bar."""

def start(self, num_total):
logger.info("scanning over %d chunks", num_total)

self.progress_bar = tqdm.tqdm(desc="Scan, chunks processed", total=num_total)
self.num_processed = 0

def update(self, num_processed):
self.progress_bar.update(num_processed - self.num_processed)
self.num_processed = num_processed

def stop(self):
self.progress_bar.close()


class Scan(BatchFilter):
"""Iteratively requests batches of size ``reference`` from upstream
providers in a scanning fashion, until all requested ROIs are covered. If
Expand Down Expand Up @@ -40,14 +90,24 @@ class Scan(BatchFilter):
cache_size (``int``, optional):
If multiple workers are used, how many batches to hold at most.
progress_callback (class:`ScanCallback`, optional):
A callback instance to get updated from this node while processing
chunks. See :class:`ScanCallback` for details. The default is a
callback that shows a ``tqdm`` progress bar.
"""

def __init__(self, reference, num_workers=1, cache_size=50):
def __init__(self, reference, num_workers=1, cache_size=50, progress_callback=None):
self.reference = reference.copy()
self.num_workers = num_workers
self.cache_size = cache_size
self.workers = None
self.batch = None
if progress_callback is None:
self.progress_callback = TqdmCallback()
else:
self.progress_callback = progress_callback

def setup(self):
if self.num_workers > 1:
Expand Down Expand Up @@ -75,7 +135,8 @@ def provide(self, request):
shifts = self._enumerate_shifts(shift_roi, stride)
num_chunks = len(shifts)

logger.info("scanning over %d chunks", num_chunks)
if self.progress_callback is not None:
self.progress_callback.start(num_chunks)

# the batch to return
self.batch = Batch()
Expand All @@ -85,24 +146,33 @@ def provide(self, request):
shifted_reference = self._shift_request(self.reference, shift)
self.request_queue.put(shifted_reference)

for i in tqdm.tqdm(range(num_chunks)):
for i in range(num_chunks):
chunk = self.workers.get()

if not empty_request:
self._add_to_batch(request, chunk)

if self.progress_callback is not None:
self.progress_callback.update(i + 1)

logger.debug("processed chunk %d/%d", i + 1, num_chunks)

else:
for i, shift in enumerate(tqdm.tqdm(shifts)):
for i, shift in enumerate(shifts):
shifted_reference = self._shift_request(self.reference, shift)
chunk = self._get_chunk(shifted_reference)

if not empty_request:
self._add_to_batch(request, chunk)

if self.progress_callback is not None:
self.progress_callback.update(i + 1)

logger.debug("processed chunk %d/%d", i + 1, num_chunks)

if self.progress_callback is not None:
self.progress_callback.stop()

batch = self.batch
self.batch = None

Expand Down

0 comments on commit a6de9a2

Please sign in to comment.