Skip to content

Commit

Permalink
clib.conversion._to_numpy: Add tests for pandas.Series and pyarrow.ar…
Browse files Browse the repository at this point in the history
…ray with pyarrow numeric dtypes
  • Loading branch information
seisman committed Nov 6, 2024
1 parent eceff7f commit 2a2ab7a
Showing 1 changed file with 146 additions and 0 deletions.
146 changes: 146 additions & 0 deletions pygmt/tests/test_clib_to_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
import pytest
from pygmt.clib.conversion import _to_numpy

try:
import pyarrow as pa

_HAS_PYARROW = True
except ImportError:
_HAS_PYARROW = False


def _check_result(result, expected_dtype):
"""
Expand Down Expand Up @@ -122,6 +129,11 @@ def test_to_numpy_ndarray_numpy_dtypes_numeric(dtype, expected_dtype):
# - BooleanDtype
# - ArrowDtype: a special dtype used to store data in the PyArrow format.
#
# PyArrow dtypes can be specified using the following formats:
#
# - Prefixed with the name of the dtype and "[pyarrow]" (e.g., "int8[pyarrow]")
# - Specified using ``ArrowDType`` (e.g., "pd.ArrowDtype(pa.int8())")
#
# References:
# 1. https://pandas.pydata.org/docs/reference/arrays.html
# 2. https://pandas.pydata.org/docs/user_guide/basics.html#basics-dtypes
Expand Down Expand Up @@ -207,3 +219,137 @@ def test_to_numpy_pandas_series_pandas_dtypes_numeric_with_na(dtype, expected_dt
result = _to_numpy(series)
_check_result(result, expected_dtype)
npt.assert_array_equal(result, np.array([1.0, np.nan, 3.0], dtype=expected_dtype))


@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed")
@pytest.mark.parametrize(
("dtype", "expected_dtype"),
[
pytest.param("int8[pyarrow]", np.int8, id="int8[pyarrow]"),
pytest.param("int16[pyarrow]", np.int16, id="int16[pyarrow]"),
pytest.param("int32[pyarrow]", np.int32, id="int32[pyarrow]"),
pytest.param("int64[pyarrow]", np.int64, id="int64[pyarrow]"),
pytest.param("uint8[pyarrow]", np.uint8, id="uint8[pyarrow]"),
pytest.param("uint16[pyarrow]", np.uint16, id="uint16[pyarrow]"),
pytest.param("uint32[pyarrow]", np.uint32, id="uint32[pyarrow]"),
pytest.param("uint64[pyarrow]", np.uint64, id="uint64[pyarrow]"),
pytest.param("float16[pyarrow]", np.float16, id="float16[pyarrow]"),
pytest.param("float32[pyarrow]", np.float32, id="float32[pyarrow]"),
pytest.param("float64[pyarrow]", np.float64, id="float64[pyarrow]"),
],
)
def test_to_numpy_pandas_series_pyarrow_dtypes_numeric(dtype, expected_dtype):
"""
Test the _to_numpy function with pandas.Series of pandas numeric dtypes.
"""
series = pd.Series([1, 2, 3], dtype=dtype)
result = _to_numpy(series)
_check_result(result, expected_dtype)
npt.assert_array_equal(result, series)


@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed")
@pytest.mark.parametrize(
("dtype", "expected_dtype"),
[
pytest.param("int8[pyarrow]", np.float64, id="int8[pyarrow]"),
pytest.param("int16[pyarrow]", np.float64, id="int16[pyarrow]"),
pytest.param("int32[pyarrow]", np.float64, id="int32[pyarrow]"),
pytest.param("int64[pyarrow]", np.float64, id="int64[pyarrow]"),
pytest.param("uint8[pyarrow]", np.float64, id="uint8[pyarrow]"),
pytest.param("uint16[pyarrow]", np.float64, id="uint16[pyarrow]"),
pytest.param("uint32[pyarrow]", np.float64, id="uint32[pyarrow]"),
pytest.param("uint64[pyarrow]", np.float64, id="uint64[pyarrow]"),
pytest.param("float16[pyarrow]", np.float16, id="float16[pyarrow]"),
pytest.param("float32[pyarrow]", np.float32, id="float32[pyarrow]"),
pytest.param("float64[pyarrow]", np.float64, id="float64[pyarrow]"),
],
)
def test_to_numpy_pandas_series_pyarrow_dtypes_numeric_with_na(dtype, expected_dtype):
"""
Test the _to_numpy function with pandas.Series of pandas numeric dtypes and NA.
"""
series = pd.Series([1, pd.NA, 3], dtype=dtype)
result = _to_numpy(series)
_check_result(result, expected_dtype)
npt.assert_array_equal(result, np.array([1.0, np.nan, 3.0], dtype=expected_dtype))


########################################################################################
# Test the _to_numpy function with PyArrow arrays.
#
# PyArrow provides the following dtypes:
#
# - Numeric dtypes:
# - int8, int16, int32, int64
# - uint8, uint16, uint32, uint64
# - float16, float32, float64
#
# Reference: https://arrow.apache.org/docs/python/api/datatypes.html
########################################################################################
@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed")
@pytest.mark.parametrize(
("dtype", "expected_dtype"),
[
pytest.param("int8", np.int8, id="int8"),
pytest.param("int16", np.int16, id="int16"),
pytest.param("int32", np.int32, id="int32"),
pytest.param("int64", np.int64, id="int64"),
pytest.param("uint8", np.uint8, id="uint8"),
pytest.param("uint16", np.uint16, id="uint16"),
pytest.param("uint32", np.uint32, id="uint32"),
pytest.param("uint64", np.uint64, id="uint64"),
pytest.param("float16", np.float16, id="float16"),
pytest.param("float32", np.float32, id="float32"),
pytest.param("float64", np.float64, id="float64"),
],
)
def test_to_numpy_pyarrow_array_pyarrow_dtypes_numeric(dtype, expected_dtype):
"""
Test the _to_numpy function with PyArrow arrays of PyArrow numeric dtypes.
"""
if dtype == "float16":
# float16 needs special handling
# Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html
array = pa.array(np.array([1.0, 2.0, 3.0], dtype=np.float16), type=pa.float16())
else:
array = pa.array([1, 2, 3], type=dtype)
assert array.type == dtype
result = _to_numpy(array)
_check_result(result, expected_dtype)
npt.assert_array_equal(result, array)


@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed")
@pytest.mark.parametrize(
("dtype", "expected_dtype"),
[
pytest.param("int8", np.float64, id="int8"),
pytest.param("int16", np.float64, id="int16"),
pytest.param("int32", np.float64, id="int32"),
pytest.param("int64", np.float64, id="int64"),
pytest.param("uint8", np.float64, id="uint8"),
pytest.param("uint16", np.float64, id="uint16"),
pytest.param("uint32", np.float64, id="uint32"),
pytest.param("uint64", np.float64, id="uint64"),
pytest.param("float16", np.float16, id="float16"),
pytest.param("float32", np.float32, id="float32"),
pytest.param("float64", np.float64, id="float64"),
],
)
def test_to_numpy_pyarrow_array_pyarrow_dtypes_numeric_with_na(dtype, expected_dtype):
"""
Test the _to_numpy function with PyArrow arrays of PyArrow numeric dtypes and NA.
"""
if dtype == "float16":
# float16 needs special handling
# Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html
array = pa.array(
np.array([1.0, None, 3.0], dtype=np.float16), type=pa.float16()
)
else:
array = pa.array([1, None, 3], type=dtype)
assert array.type == dtype
result = _to_numpy(array)
_check_result(result, expected_dtype)
npt.assert_array_equal(result, array)

0 comments on commit 2a2ab7a

Please sign in to comment.