-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: to TensorFlow RaggedTensor (#3210)
* feat: to TensorFlow RaggedTensor * style: pre-commit fixes * fix the tensorflow library import * style: pre-commit fixes * update exception * added some tests for different data types conversions * change end of line formats * add tensorflow library to the test-full-requirements * add tensorflow library to the test-full-requirements * change tensorflow library version * change tensorflow library version * change tensorflow library version * add a new github actions test for ml libraries * delete tensorflow from full test requirements * update requirements-test-ml.txt * update requirements-test-full.txt * update requirements-test-ml.txt * update the docstring for the main function * add a new function from_raggedtensor * delete import of tensorflow library * minor changes * fix the tests names Co-authored-by: Ianna Osborne <[email protected]> * Apply suggestions from code review Co-authored-by: Jim Pivarski <[email protected]> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ianna Osborne <[email protected]> Co-authored-by: Jim Pivarski <[email protected]>
- Loading branch information
1 parent
769ee40
commit 81c48fc
Showing
6 changed files
with
285 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
fsspec>=2022.11.0;sys_platform != "win32" | ||
pytest>=6 | ||
pytest-cov | ||
pytest-xdist | ||
tensorflow >= 2.12 | ||
torch >= 2.4.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE | ||
|
||
from __future__ import annotations | ||
|
||
import awkward as ak | ||
from awkward._dispatch import high_level_function | ||
|
||
__all__ = ("from_raggedtensor",) | ||
|
||
|
||
@high_level_function() | ||
def from_raggedtensor(array): | ||
""" | ||
Args: | ||
array: (`tensorflow.RaggedTensor`): | ||
RaggedTensor to convert into an Awkward Array. | ||
Converts a TensorFlow RaggedTensor into an Awkward Array. | ||
If `array` contains any other data types the function raises an error. | ||
""" | ||
|
||
# Dispatch | ||
yield (array,) | ||
|
||
# Implementation | ||
return _impl(array) | ||
|
||
|
||
def _impl(array): | ||
try: | ||
# get the flat values | ||
content = array.flat_values.numpy() | ||
except AttributeError as err: | ||
raise TypeError( | ||
"""only RaggedTensor can be converted to awkward array""" | ||
) from err | ||
# convert them to ak.contents right away | ||
content = ak.contents.NumpyArray(content) | ||
|
||
# get the offsets | ||
offsets_arr = [] | ||
for splits in array.nested_row_splits: | ||
split = splits.numpy() | ||
# convert to ak.index | ||
offset = ak.index.Index64(split) | ||
offsets_arr.append(offset) | ||
|
||
# if a tensor has one *ragged dimension* | ||
if len(offsets_arr) == 1: | ||
result = ak.contents.ListOffsetArray(offsets_arr[0], content) | ||
return ak.Array(result) | ||
|
||
# if a tensor has multiple *ragged dimensions* | ||
return ak.Array(_recursive_call(content, offsets_arr, 0)) | ||
|
||
|
||
def _recursive_call(content, offsets_arr, count): | ||
if count == len(offsets_arr) - 2: | ||
return ak.contents.ListOffsetArray( | ||
offsets_arr[count], | ||
ak.contents.ListOffsetArray(offsets_arr[count + 1], content), | ||
) | ||
else: | ||
return ak.contents.ListOffsetArray( | ||
offsets_arr[count], _recursive_call(content, offsets_arr, count) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE | ||
|
||
from __future__ import annotations | ||
|
||
import awkward as ak | ||
from awkward._dispatch import high_level_function | ||
|
||
__all__ = ("to_raggedtensor",) | ||
|
||
|
||
@high_level_function() | ||
def to_raggedtensor(array): | ||
""" | ||
Args: | ||
array: Array-like data. May be a high level #ak.Array, | ||
or low-level #ak.contents.ListOffsetArray, #ak.contents.ListArray, | ||
#ak.contents.RegularArray, #ak.contents.NumpyArray | ||
Converts `array` (only ListOffsetArray, ListArray, RegularArray and NumpyArray data types supported) | ||
into a ragged tensor, if possible. | ||
If `array` contains any other data types (RecordArray for example) the function raises an error. | ||
""" | ||
|
||
# Dispatch | ||
yield (array,) | ||
|
||
# Implementation | ||
return _impl(array) | ||
|
||
|
||
def _impl(array): | ||
try: | ||
import tensorflow as tf | ||
except ImportError as err: | ||
raise ImportError( | ||
"""to use ak.to_raggedtensor, you must install the 'tensorflow' package with: | ||
pip install tensorflow | ||
or | ||
conda install tensorflow""" | ||
) from err | ||
|
||
# unwrap the awkward array if it was made with ak.Array function | ||
# also transforms a python list to awkward array | ||
array = ak.to_layout(array, allow_record=False) | ||
|
||
if isinstance(array, ak.contents.numpyarray.NumpyArray): | ||
return tf.RaggedTensor.from_row_splits( | ||
values=array.data, row_splits=[0, array.__len__()] | ||
) | ||
else: | ||
flat_values, nested_row_splits = _recursive_call(array, ()) | ||
|
||
return tf.RaggedTensor.from_nested_row_splits(flat_values, nested_row_splits) | ||
|
||
|
||
def _recursive_call(layout, offsets_arr): | ||
try: | ||
# change all the possible layout types to ListOffsetArray | ||
if isinstance(layout, ak.contents.listarray.ListArray): | ||
layout = layout.to_ListOffsetArray64() | ||
elif isinstance(layout, ak.contents.regulararray.RegularArray): | ||
layout = layout.to_ListOffsetArray64() | ||
elif not isinstance( | ||
layout, | ||
( | ||
ak.contents.listoffsetarray.ListOffsetArray, | ||
ak.contents.numpyarray.NumpyArray, | ||
), | ||
): | ||
raise TypeError( | ||
"Only arrays containing variable-length lists (var *) or" | ||
" regular-length lists (# *) of numbers can be converted into a TensorFlow RaggedTensor" | ||
) | ||
|
||
# recursively gather all of the offsets of an array | ||
offsets_arr += (layout.offsets.data,) | ||
|
||
except AttributeError: | ||
# at the last iteration form a ragged tensor from the | ||
# accumulated offsets and flattened values of the array | ||
return layout.data, offsets_arr | ||
return _recursive_call(layout.content, offsets_arr) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE | ||
|
||
from __future__ import annotations | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
import awkward as ak | ||
|
||
to_raggedtensor = ak.operations.to_raggedtensor | ||
from_raggedtensor = ak.operations.from_raggedtensor | ||
|
||
tf = pytest.importorskip("tensorflow") | ||
|
||
content = ak.contents.NumpyArray( | ||
np.array([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9]) | ||
) | ||
starts1 = ak.index.Index64(np.array([0, 3, 3, 5, 6])) | ||
stops1 = ak.index.Index64(np.array([3, 3, 5, 6, 9])) | ||
starts2 = ak.index.Index64(np.array([0, 3])) | ||
stops2 = ak.index.Index64(np.array([3, 5])) | ||
|
||
array = np.arange(2 * 3 * 5).reshape(2, 3, 5) | ||
content2 = ak.contents.NumpyArray(array.reshape(-1)) | ||
inneroffsets = ak.index.Index64(np.array([0, 5, 10, 15, 20, 25, 30])) | ||
outeroffsets = ak.index.Index64(np.array([0, 3, 6])) | ||
|
||
|
||
def test_convert_to_raggedtensor(): | ||
# a test for ListArray -> RaggedTensor | ||
array1 = ak.contents.ListArray(starts1, stops1, content) | ||
assert to_raggedtensor(array1).to_list() == [ | ||
[1.1, 2.2, 3.3], | ||
[], | ||
[4.4, 5.5], | ||
[6.6], | ||
[7.7, 8.8, 9.9], | ||
] | ||
|
||
# a test for awkward.highlevel.Array -> RaggedTensor | ||
array2 = ak.Array(array1) | ||
assert to_raggedtensor(array2).to_list() == [ | ||
[1.1, 2.2, 3.3], | ||
[], | ||
[4.4, 5.5], | ||
[6.6], | ||
[7.7, 8.8, 9.9], | ||
] | ||
|
||
# a test for NumpyArray -> RaggedTensor | ||
array3 = content | ||
assert to_raggedtensor(array3).to_list() == [ | ||
[1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9] | ||
] | ||
|
||
# a test for RegularArray -> RaggedTensor | ||
array4 = ak.contents.RegularArray(content, size=2) | ||
assert to_raggedtensor(array4).to_list() == [ | ||
[1.1, 2.2], | ||
[3.3, 4.4], | ||
[5.5, 6.6], | ||
[7.7, 8.8], | ||
] | ||
|
||
# try a single line awkward array | ||
array5 = ak.Array([3, 1, 4, 1, 9, 2, 6]) | ||
assert to_raggedtensor(array5).to_list() == [[3, 1, 4, 1, 9, 2, 6]] | ||
|
||
# try a multiple ragged array | ||
array6 = ak.Array([[[1.1, 2.2], [3.3]], [], [[4.4, 5.5]]]) | ||
assert to_raggedtensor(array6).to_list() == [[[1.1, 2.2], [3.3]], [], [[4.4, 5.5]]] | ||
|
||
# try a listoffset array inside a listoffset array | ||
array7 = ak.contents.ListOffsetArray( | ||
outeroffsets, ak.contents.ListOffsetArray(inneroffsets, content2) | ||
) | ||
assert to_raggedtensor(array7).to_list() == [ | ||
[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11, 12, 13, 14]], | ||
[[15, 16, 17, 18, 19], [20, 21, 22, 23, 24], [25, 26, 27, 28, 29]], | ||
] | ||
|
||
# try a list array inside a list array | ||
|
||
array8 = ak.contents.ListArray( | ||
starts2, stops2, ak.contents.ListArray(starts1, stops1, content) | ||
) | ||
assert to_raggedtensor(array8).to_list() == [ | ||
[[1.1, 2.2, 3.3], [], [4.4, 5.5]], | ||
[[6.6], [7.7, 8.8, 9.9]], | ||
] | ||
|
||
# try just a python list | ||
array9 = [3, 1, 4, 1, 9, 2, 6] | ||
assert to_raggedtensor(array9).to_list() == [[3, 1, 4, 1, 9, 2, 6]] | ||
|
||
|
||
np_array1 = np.array([1.1, 2.2, 3.3, 4.4, 5.5], dtype=np.float32) | ||
|
||
offsets1 = ak.index.Index64(np.array([0, 2, 3, 3, 5])) | ||
content1 = ak.contents.NumpyArray(np_array1) | ||
|
||
|
||
def test_convert_from_raggedtensor(): | ||
tf_array1 = tf.RaggedTensor.from_row_splits( | ||
values=[1.1, 2.2, 3.3, 4.4, 5.5], row_splits=[0, 2, 3, 3, 5] | ||
) | ||
|
||
ak_array1 = ak.contents.ListOffsetArray(offsets1, content1) | ||
result1 = ak.to_layout(from_raggedtensor(tf_array1), allow_record=False) | ||
assert (result1.content.data == np_array1).all() | ||
assert (result1.offsets.data == [0, 2, 3, 3, 5]).all() | ||
assert from_raggedtensor(tf_array1).to_list() == ak_array1.to_list() | ||
|
||
tf_array2 = tf.RaggedTensor.from_nested_row_splits( | ||
flat_values=[3, 1, 4, 1, 5, 9, 2, 6], | ||
nested_row_splits=([0, 3, 3, 5], [0, 4, 4, 7, 8, 8]), | ||
) | ||
assert from_raggedtensor(tf_array2).to_list() == [ | ||
[[3, 1, 4, 1], [], [5, 9, 2]], | ||
[], | ||
[[6], []], | ||
] |