Skip to content

Commit

Permalink
Added docstrings and getter methods for ElementwiseFunc classes
Browse files Browse the repository at this point in the history
Added stable API to retrieve implementation functions in each elementwise
function class instance to allow `dpnp` to access that information using
stable API.
  • Loading branch information
oleksandr-pavlyk committed Nov 3, 2023
1 parent 421b270 commit 41ec378
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 0 deletions.
150 changes: 150 additions & 0 deletions dpctl/tensor/_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,31 @@
class UnaryElementwiseFunc:
"""
Class that implements unary element-wise functions.
Args:
name (str):
Name of the unary function
result_type_resovler_fn (callable):
Function that takes dtype of the input and
returns the dtype of the result if the
implementation functions supports it, or
returns `None` otherwise.
unary_dp_impl_fn (callable):
Data-parallel implementation function with signature
`impl_fn(src: usm_ndarray, dst: usm_ndarray,
sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
where the `src` is the argument array, `dst` is the
array to be populated with function values, effectively
evaluating `dst = func(src)`.
The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
The first event corresponds to data-management host tasks,
including lifetime management of argument Python objects to ensure
that their associated USM allocation is not freed before offloaded
computational tasks complete execution, while the second event
corresponds to computational tasks associated with function
evaluation.
docs (str):
Documentation string for the unary function.
"""

def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs):
Expand All @@ -55,8 +80,31 @@ def __str__(self):
def __repr__(self):
return f"<{self.__name__} '{self.name_}'>"

def get_implementation_function(self):
"""Returns the implementation function for
this elementwise unary function.
"""
return self.unary_fn_

def get_type_result_resolver_function(self):
"""Returns the type resolver function for this
elementwise unary function.
"""
return self.result_type_resolver_fn_

@property
def types(self):
"""Returns information about types supported by
implementation function, using NumPy's character
encoding for data types, e.g.
:Example:
.. code-block:: python
dpctl.tensor.sin.types
# Outputs: ['e->e', 'f->f', 'd->d', 'F->F', 'D->D']
"""
types = self.types_
if not types:
types = []
Expand Down Expand Up @@ -363,6 +411,56 @@ def _get_shape(o):
class BinaryElementwiseFunc:
"""
Class that implements binary element-wise functions.
Args:
name (str):
Name of the unary function
result_type_resovle_fn (callable):
Function that takes dtypes of the input and
returns the dtype of the result if the
implementation functions supports it, or
returns `None` otherwise.
binary_dp_impl_fn (callable):
Data-parallel umplementation function with signature
`impl_fn(src1: usm_ndarray, src2: usm_ndarray, dst: usm_ndarray,
sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
where the `src1` and `src2` are the argument arrays, `dst` is the
array to be populated with function values,
i.e. `dst=func(src1, src2)`.
The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
The first event corresponds to data-management host tasks,
including lifetime management of argument Python objects to ensure
that their associated USM allocation is not freed before offloaded
computational tasks complete execution, while the second event
corresponds to computational tasks associated with function
evaluation.
docs (str):
Documentation string for the unary function.
binary_inplace_fn (callable, optional):
Data-parallel omplementation function with signature
`impl_fn(src: usm_ndarray, dst: usm_ndarray,
sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
where the `src` is the argument array, `dst` is the
array to be populated with function values,
i.e. `dst=func(dst, src)`.
The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
The first event corresponds to data-management host tasks,
including async lifetime management of Python arguments,
while the second event corresponds to computational tasks
associated with function evaluation.
acceptance_fn (callable, optional):
Function to influence type promotion behavior of this binary
function. The function takes 6 arguments:
arg1_dtype - Data type of the first argument
arg2_dtype - Data type of the second argument
ret_buf1_dtype - Data type the first argument would be cast to
ret_buf2_dtype - Data type the second argument would be cast to
res_dtype - Data type of the output array with function values
sycl_dev - The :class:`dpctl.SyclDevice` where the function
evaluation is carried out.
The function is only called when both arguments of the binary
function require casting, e.g. both arguments of
`dpctl.tensor.logaddexp` are arrays with integral data type.
"""

def __init__(
Expand Down Expand Up @@ -392,8 +490,60 @@ def __str__(self):
def __repr__(self):
return f"<{self.__name__} '{self.name_}'>"

def get_implementation_function(self):
"""Returns the out-of-place implementation
function for this elementwise binary function.
"""
return self.binary_fn_

def get_implementation_inplace_function(self):
"""Returns the in-place implementation
function for this elementwise binary function.
"""
return self.binary_inplace_fn_

def get_type_result_resolver_function(self):
"""Returns the type resolver function for this
elementwise binary function.
"""
return self.result_type_resolver_fn_

def get_type_promotion_path_acceptance_function(self):
"""Returns the acceptance function for this
elementwise binary function.
Acceptance function influences the type promotion
behavior of this binary function.
The function takes 6 arguments:
arg1_dtype - Data type of the first argument
arg2_dtype - Data type of the second argument
ret_buf1_dtype - Data type the first argument would be cast to
ret_buf2_dtype - Data type the second argument would be cast to
res_dtype - Data type of the output array with function values
sycl_dev - :class:`dpctl.SyclDevice` on which function evaluation
is carried out.
The acceptance function is only invoked if both input arrays must be
cast to intermediary data types, as would happen during call of
`dpctl.tensor.hypot` with both arrays being of integral data type.
"""
return self.acceptance_fn_

@property
def types(self):
"""Returns information about types supported by
implementation function, using NumPy's character
encoding for data types, e.g.
:Example:
.. code-block:: python
dpctl.tensor.divide.types
# Outputs: ['ee->e', 'ff->f', 'fF->F', 'dd->d', 'dD->D',
# 'Ff->F', 'FF->F', 'Dd->D', 'DD->D']
"""
types = self.types_
if not types:
types = []
Expand Down
80 changes: 80 additions & 0 deletions dpctl/tests/elementwise/test_elementwise_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Data Parallel Control (dpctl)
#
# Copyright 2020-2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dpctl.tensor as dpt

unary_fn = dpt.negative
binary_fn = dpt.divide


def test_unary_class_getters():
fn = unary_fn.get_implementation_function()
assert callable(fn)

fn = unary_fn.get_type_result_resolver_function()
assert callable(fn)


def test_unary_class_types_property():
loop_types = unary_fn.types
assert isinstance(loop_types, list)
assert len(loop_types) > 0
assert all(isinstance(sig, str) for sig in loop_types)
assert all("->" in sig for sig in loop_types)


def test_unary_class_str_repr():
s = str(unary_fn)
r = repr(unary_fn)

assert isinstance(s, str)
assert isinstance(r, str)
kl_n = unary_fn.__name__
assert kl_n in s
assert kl_n in r


def test_binary_class_getters():
fn = binary_fn.get_implementation_function()
assert callable(fn)

fn = binary_fn.get_implementation_inplace_function()
assert callable(fn)

fn = binary_fn.get_type_result_resolver_function()
assert callable(fn)

fn = binary_fn.get_type_promotion_path_acceptance_function()
assert callable(fn)


def test_binary_class_types_property():
loop_types = binary_fn.types
assert isinstance(loop_types, list)
assert len(loop_types) > 0
assert all(isinstance(sig, str) for sig in loop_types)
assert all("->" in sig for sig in loop_types)


def test_binary_class_str_repr():
s = str(binary_fn)
r = repr(binary_fn)

assert isinstance(s, str)
assert isinstance(r, str)
kl_n = binary_fn.__name__
assert kl_n in s
assert kl_n in r

0 comments on commit 41ec378

Please sign in to comment.