Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

from_arrow with List columns #511

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions torcharrow/_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ def _arrowtype_to_dtype(t: pa.DataType, nullable: bool) -> dt.DType:
[dt.Field(f.name, _arrowtype_to_dtype(f.type, f.nullable)) for f in t],
nullable,
)
if pa.types.is_list(t):
return dt.List(_arrowtype_to_dtype(t.value_type, nullable), nullable=nullable)
raise NotImplementedError(f"Unsupported Arrow type: {str(t)}")


Expand Down
2 changes: 2 additions & 0 deletions torcharrow/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,8 @@ def cast_as(dtype):


def get_underlying_dtype(dtype: DType) -> DType:
if is_list(dtype):
return replace(dtype, nullable=False, item_dtype=replace(dtype.item_dtype, nullable=False))
return replace(dtype, nullable=False)


Expand Down
35 changes: 24 additions & 11 deletions torcharrow/test/test_arrow_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class TestArrowInterop(unittest.TestCase):
(pa.float64(), dt.Float64(True)),
(pa.string(), dt.String(True)),
(pa.large_string(), dt.String(True)),
(pa.list_(pa.int64()), dt.List(dt.Int64(True), True))
)

unsupported_types: Tuple[pa.DataType, ...] = (
Expand All @@ -44,7 +45,6 @@ class TestArrowInterop(unittest.TestCase):
pa.binary(),
pa.large_binary(),
pa.decimal128(38),
pa.list_(pa.int64()),
pa.large_list(pa.int64()),
pa.map_(pa.int64(), pa.int64()),
pa.dictionary(pa.int64(), pa.int64()),
Expand Down Expand Up @@ -420,19 +420,32 @@ def base_test_table_memory_reclaimed(self):
del df
self.assertEqual(pa.total_allocated_bytes(), initial_memory)

def base_test_table_unsupported_types(self):
def base_test_from_arrow_table_with_list(self):
pt = pa.table(
{
"f1": pa.array([1, 2, 3], type=pa.int64()),
"f2": pa.array(["foo", "bar", None], type=pa.string()),
"f3": pa.array([[1, 2], [3, 4, 5], [6]], type=pa.list_(pa.int8())),
}
)
with self.assertRaises(RuntimeError) as ex:
df = ta.from_arrow(pt, device=self.device)
self.assertTrue(
f"Unsupported Arrow type: {str(pt.field(2).type)}" in str(ex.exception)
"f1":[1, 2, 3],
"f2": ["foo", "bar", None],
"f3": [[1, 2], [3, 4, 5], [6]],
"f4": [[1, 2], [3, 4, 5], [6]],
},
schema=pa.schema(
[
pa.field("f1", pa.int64(), nullable=True),
pa.field("f2", pa.string(), nullable=True),
pa.field("f3", pa.list_(pa.int8()), nullable=True),
pa.field("f4", pa.list_(pa.int8()), nullable=False),
]
)
)
df = ta.from_arrow(pt, device=self.device)
for (i, ta_field) in enumerate(df.dtype.fields):
pa_field = pt.schema.field(i)
self.assertEqual(ta_field.name, pa_field.name)
self.assertEqual(
ta_field.dtype, _arrowtype_to_dtype(pa_field.type, pa_field.nullable)
)
self.assertEqual(list(df[ta_field.name]), pt[i].to_pylist())


def base_test_nullability(self):
pydata = [1, 2, 3]
Expand Down
4 changes: 2 additions & 2 deletions torcharrow/test/test_arrow_interop_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def test_table_ownership_transferred(self):
def test_table_memory_reclaimed(self):
return self.base_test_table_memory_reclaimed()

def test_table_unsupported_types(self):
return self.base_test_table_unsupported_types()
def test_from_arrow_table_with_list(self):
return self.base_test_from_arrow_table_with_list()

def test_nullability(self):
return self.base_test_nullability()
Expand Down
12 changes: 12 additions & 0 deletions torcharrow/velox_rt/list_column_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ def _from_pysequence(device: str, data: List[List], dtype: dt.List):
for i in data:
col._append(i)
return col._finalize()

@staticmethod
def _from_arrow(device: str, array, dtype: dt.List):
import pyarrow as pa

assert isinstance(array, pa.Array)

pydata = array.to_pylist()
return ListColumnCpu._from_pysequence(device, pydata, dtype)

def _append_null(self):
if self._finalized:
Expand Down Expand Up @@ -293,3 +302,6 @@ def vmap(self, fun: Callable[[Column], Column]):
Dispatcher.register(
(dt.List.typecode + "_from_pysequence", "cpu"), ListColumnCpu._from_pysequence
)
Dispatcher.register(
(dt.List.typecode + "_from_arrow", "cpu"), ListColumnCpu._from_arrow
)