diff --git a/pygmt/tests/test_clib_to_numpy.py b/pygmt/tests/test_clib_to_numpy.py index b8484a39186..acdff37e488 100644 --- a/pygmt/tests/test_clib_to_numpy.py +++ b/pygmt/tests/test_clib_to_numpy.py @@ -8,6 +8,13 @@ import pytest from pygmt.clib.conversion import _to_numpy +try: + import pyarrow as pa + + _HAS_PYARROW = True +except ImportError: + _HAS_PYARROW = False + def _check_result(result, expected_dtype): """ @@ -122,6 +129,11 @@ def test_to_numpy_ndarray_numpy_dtypes_numeric(dtype, expected_dtype): # - BooleanDtype # - ArrowDtype: a special dtype used to store data in the PyArrow format. # +# PyArrow dtypes can be specified using the following formats: +# +# - Prefixed with the name of the dtype and "[pyarrow]" (e.g., "int8[pyarrow]") +# - Specified using ``ArrowDType`` (e.g., "pd.ArrowDtype(pa.int8())") +# # References: # 1. https://pandas.pydata.org/docs/reference/arrays.html # 2. https://pandas.pydata.org/docs/user_guide/basics.html#basics-dtypes @@ -207,3 +219,137 @@ def test_to_numpy_pandas_series_pandas_dtypes_numeric_with_na(dtype, expected_dt result = _to_numpy(series) _check_result(result, expected_dtype) npt.assert_array_equal(result, np.array([1.0, np.nan, 3.0], dtype=expected_dtype)) + + +@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed") +@pytest.mark.parametrize( + ("dtype", "expected_dtype"), + [ + pytest.param("int8[pyarrow]", np.int8, id="int8[pyarrow]"), + pytest.param("int16[pyarrow]", np.int16, id="int16[pyarrow]"), + pytest.param("int32[pyarrow]", np.int32, id="int32[pyarrow]"), + pytest.param("int64[pyarrow]", np.int64, id="int64[pyarrow]"), + pytest.param("uint8[pyarrow]", np.uint8, id="uint8[pyarrow]"), + pytest.param("uint16[pyarrow]", np.uint16, id="uint16[pyarrow]"), + pytest.param("uint32[pyarrow]", np.uint32, id="uint32[pyarrow]"), + pytest.param("uint64[pyarrow]", np.uint64, id="uint64[pyarrow]"), + pytest.param("float16[pyarrow]", np.float16, id="float16[pyarrow]"), + pytest.param("float32[pyarrow]", np.float32, id="float32[pyarrow]"), + pytest.param("float64[pyarrow]", np.float64, id="float64[pyarrow]"), + ], +) +def test_to_numpy_pandas_series_pyarrow_dtypes_numeric(dtype, expected_dtype): + """ + Test the _to_numpy function with pandas.Series of pandas numeric dtypes. + """ + series = pd.Series([1, 2, 3], dtype=dtype) + result = _to_numpy(series) + _check_result(result, expected_dtype) + npt.assert_array_equal(result, series) + + +@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed") +@pytest.mark.parametrize( + ("dtype", "expected_dtype"), + [ + pytest.param("int8[pyarrow]", np.float64, id="int8[pyarrow]"), + pytest.param("int16[pyarrow]", np.float64, id="int16[pyarrow]"), + pytest.param("int32[pyarrow]", np.float64, id="int32[pyarrow]"), + pytest.param("int64[pyarrow]", np.float64, id="int64[pyarrow]"), + pytest.param("uint8[pyarrow]", np.float64, id="uint8[pyarrow]"), + pytest.param("uint16[pyarrow]", np.float64, id="uint16[pyarrow]"), + pytest.param("uint32[pyarrow]", np.float64, id="uint32[pyarrow]"), + pytest.param("uint64[pyarrow]", np.float64, id="uint64[pyarrow]"), + pytest.param("float16[pyarrow]", np.float16, id="float16[pyarrow]"), + pytest.param("float32[pyarrow]", np.float32, id="float32[pyarrow]"), + pytest.param("float64[pyarrow]", np.float64, id="float64[pyarrow]"), + ], +) +def test_to_numpy_pandas_series_pyarrow_dtypes_numeric_with_na(dtype, expected_dtype): + """ + Test the _to_numpy function with pandas.Series of pandas numeric dtypes and NA. + """ + series = pd.Series([1, pd.NA, 3], dtype=dtype) + result = _to_numpy(series) + _check_result(result, expected_dtype) + npt.assert_array_equal(result, np.array([1.0, np.nan, 3.0], dtype=expected_dtype)) + + +######################################################################################## +# Test the _to_numpy function with PyArrow arrays. +# +# PyArrow provides the following dtypes: +# +# - Numeric dtypes: +# - int8, int16, int32, int64 +# - uint8, uint16, uint32, uint64 +# - float16, float32, float64 +# +# Reference: https://arrow.apache.org/docs/python/api/datatypes.html +######################################################################################## +@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed") +@pytest.mark.parametrize( + ("dtype", "expected_dtype"), + [ + pytest.param("int8", np.int8, id="int8"), + pytest.param("int16", np.int16, id="int16"), + pytest.param("int32", np.int32, id="int32"), + pytest.param("int64", np.int64, id="int64"), + pytest.param("uint8", np.uint8, id="uint8"), + pytest.param("uint16", np.uint16, id="uint16"), + pytest.param("uint32", np.uint32, id="uint32"), + pytest.param("uint64", np.uint64, id="uint64"), + pytest.param("float16", np.float16, id="float16"), + pytest.param("float32", np.float32, id="float32"), + pytest.param("float64", np.float64, id="float64"), + ], +) +def test_to_numpy_pyarrow_array_pyarrow_dtypes_numeric(dtype, expected_dtype): + """ + Test the _to_numpy function with PyArrow arrays of PyArrow numeric dtypes. + """ + if dtype == "float16": + # float16 needs special handling + # Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html + array = pa.array(np.array([1.0, 2.0, 3.0], dtype=np.float16), type=pa.float16()) + else: + array = pa.array([1, 2, 3], type=dtype) + assert array.type == dtype + result = _to_numpy(array) + _check_result(result, expected_dtype) + npt.assert_array_equal(result, array) + + +@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed") +@pytest.mark.parametrize( + ("dtype", "expected_dtype"), + [ + pytest.param("int8", np.float64, id="int8"), + pytest.param("int16", np.float64, id="int16"), + pytest.param("int32", np.float64, id="int32"), + pytest.param("int64", np.float64, id="int64"), + pytest.param("uint8", np.float64, id="uint8"), + pytest.param("uint16", np.float64, id="uint16"), + pytest.param("uint32", np.float64, id="uint32"), + pytest.param("uint64", np.float64, id="uint64"), + pytest.param("float16", np.float16, id="float16"), + pytest.param("float32", np.float32, id="float32"), + pytest.param("float64", np.float64, id="float64"), + ], +) +def test_to_numpy_pyarrow_array_pyarrow_dtypes_numeric_with_na(dtype, expected_dtype): + """ + Test the _to_numpy function with PyArrow arrays of PyArrow numeric dtypes and NA. + """ + if dtype == "float16": + # float16 needs special handling + # Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html + array = pa.array( + np.array([1.0, None, 3.0], dtype=np.float16), type=pa.float16() + ) + else: + array = pa.array([1, None, 3], type=dtype) + assert array.type == dtype + result = _to_numpy(array) + _check_result(result, expected_dtype) + npt.assert_array_equal(result, array)