diff --git a/notebooks/platform/xarray/lib/dispatch.py b/notebooks/platform/xarray/lib/dispatch.py index 0801c74..c5a6190 100755 --- a/notebooks/platform/xarray/lib/dispatch.py +++ b/notebooks/platform/xarray/lib/dispatch.py @@ -96,6 +96,10 @@ class NumbaBackend(PackageBackend): id = 'numba' requirements = [Requirement('numba')] +class CudaBackend(PackageBackend): + id = 'cuda' + requirements = [Requirement('numba')] + class NetworkxBackend(PackageBackend): id = 'networkx' requirements = [Requirement('networkx')] @@ -139,29 +143,35 @@ def register_backend(self, backend: Backend) -> Backend: self._update_config() return backend - def resolve(self, fn: Callable, *args, **kwargs) -> Backend: - # Passed parameters get highest priority - backend_id = kwargs.get('backend') + def _process(self, kwargs, default): + kwargs = dict(kwargs) + if 'backend' not in kwargs: + kwargs['backend'] = default + parts = kwargs.pop('backend').split('/') + if len(parts) > 1: + kwargs['backend'] = '/'.join(parts[1:]) + return parts[0], kwargs + + def resolve_backend(self, fn: Callable, *args, **kwargs) -> Backend: + # Passed parameters get highest priority for backend next + # to settings in configuration + backend_id, kwargs = self._process(kwargs, self.config.get(str(self.domain.append('backend')))) - # Followed by configuration - backend_id = backend_id or self.config.get(str(self.domain.append('backend'))) if backend_id and backend_id != 'auto': if backend_id not in self.backends: raise ValueError(f'Backend "{backend_id}" not implemented for function {fn.__name__}') - return self.backends[backend_id] - - # And then automatic selection/validation: + return self.backends[backend_id], kwargs # ** Analyze fn/args/kwargs here ** # For now, simply return the first compatible backend + backend = next((b for b in self.backends.values() if is_compatible(b)), None) if backend is None: raise ValueError(f'No backend found for function "{fn.__name__}" (domain = "{self.domain}")') - return backend + return backend, kwargs def dispatch(self, fn: Callable, *args, **kwargs): - backend = self.resolve(fn, *args, **kwargs) - kwargs.pop('backend', None) # Pop this off since implementations cannot expect it + backend, kwargs = self.resolve_backend(fn, *args, **kwargs) return backend.dispatch(fn, *args, **kwargs) @@ -218,4 +228,4 @@ def decorator(backend_fn: Callable): # Combine doc strings for backends that may add new parameters or details backend_fn.__doc__ = (frontend_fn.__doc__ or '') + '\n' + (backend_fn.__doc__ or '') return backend_fn - return decorator \ No newline at end of file + return decorator