Skip to content

Commit

Permalink
Dispatch for nested backends #24
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-czech committed May 18, 2020
1 parent 65ca04a commit 014ea31
Showing 1 changed file with 22 additions and 12 deletions.
34 changes: 22 additions & 12 deletions notebooks/platform/xarray/lib/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ class NumbaBackend(PackageBackend):
id = 'numba'
requirements = [Requirement('numba')]

class CudaBackend(PackageBackend):
id = 'cuda'
requirements = [Requirement('numba')]

class NetworkxBackend(PackageBackend):
id = 'networkx'
requirements = [Requirement('networkx')]
Expand Down Expand Up @@ -139,29 +143,35 @@ def register_backend(self, backend: Backend) -> Backend:
self._update_config()
return backend

def resolve(self, fn: Callable, *args, **kwargs) -> Backend:
# Passed parameters get highest priority
backend_id = kwargs.get('backend')
def _process(self, kwargs, default):
kwargs = dict(kwargs)
if 'backend' not in kwargs:
kwargs['backend'] = default
parts = kwargs.pop('backend').split('/')
if len(parts) > 1:
kwargs['backend'] = '/'.join(parts[1:])
return parts[0], kwargs

def resolve_backend(self, fn: Callable, *args, **kwargs) -> Backend:
# Passed parameters get highest priority for backend next
# to settings in configuration
backend_id, kwargs = self._process(kwargs, self.config.get(str(self.domain.append('backend'))))

# Followed by configuration
backend_id = backend_id or self.config.get(str(self.domain.append('backend')))
if backend_id and backend_id != 'auto':
if backend_id not in self.backends:
raise ValueError(f'Backend "{backend_id}" not implemented for function {fn.__name__}')
return self.backends[backend_id]

# And then automatic selection/validation:
return self.backends[backend_id], kwargs

# ** Analyze fn/args/kwargs here **
# For now, simply return the first compatible backend

backend = next((b for b in self.backends.values() if is_compatible(b)), None)
if backend is None:
raise ValueError(f'No backend found for function "{fn.__name__}" (domain = "{self.domain}")')
return backend
return backend, kwargs

def dispatch(self, fn: Callable, *args, **kwargs):
backend = self.resolve(fn, *args, **kwargs)
kwargs.pop('backend', None) # Pop this off since implementations cannot expect it
backend, kwargs = self.resolve_backend(fn, *args, **kwargs)
return backend.dispatch(fn, *args, **kwargs)


Expand Down Expand Up @@ -218,4 +228,4 @@ def decorator(backend_fn: Callable):
# Combine doc strings for backends that may add new parameters or details
backend_fn.__doc__ = (frontend_fn.__doc__ or '') + '\n' + (backend_fn.__doc__ or '')
return backend_fn
return decorator
return decorator

0 comments on commit 014ea31

Please sign in to comment.