Skip to content

Commit

Permalink
clib.conversion._to_numpy: Add tests for pandas.Series with pandas st…
Browse files Browse the repository at this point in the history
…ring dtype (#3607)
  • Loading branch information
seisman authored Nov 15, 2024
1 parent d982275 commit 3d08919
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 4 deletions.
8 changes: 7 additions & 1 deletion pygmt/clib/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Functions to convert data types into ctypes friendly formats.
"""

import contextlib
import ctypes as ctp
import warnings
from collections.abc import Sequence
Expand Down Expand Up @@ -160,7 +161,7 @@ def _to_numpy(data: Any) -> np.ndarray:
dtypes: dict[str, type | str] = {
# For string dtypes.
"large_string": np.str_, # pa.large_string and pa.large_utf8
"string": np.str_, # pa.string and pa.utf8
"string": np.str_, # pa.string, pa.utf8, pd.StringDtype
"string_view": np.str_, # pa.string_view
# For datetime dtypes.
"date32[day][pyarrow]": "datetime64[D]",
Expand All @@ -180,6 +181,11 @@ def _to_numpy(data: Any) -> np.ndarray:
else:
vec_dtype = str(getattr(data, "dtype", getattr(data, "type", "")))
array = np.ascontiguousarray(data, dtype=dtypes.get(vec_dtype))

# Check if a np.object_ array can be converted to np.str_.
if array.dtype == np.object_:
with contextlib.suppress(TypeError, ValueError):
return np.ascontiguousarray(array, dtype=np.str_)
return array


Expand Down
6 changes: 3 additions & 3 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,7 +1475,7 @@ def virtualfile_from_vectors(
# 2 columns contains coordinates like longitude, latitude, or datetime string
# types.
for col, array in enumerate(arrays[2:]):
if pd.api.types.is_string_dtype(array.dtype):
if np.issubdtype(array.dtype, np.str_):
columns = col + 2
break

Expand Down Expand Up @@ -1506,9 +1506,9 @@ def virtualfile_from_vectors(
strings = string_arrays[0]
elif len(string_arrays) > 1:
strings = np.array(
[" ".join(vals) for vals in zip(*string_arrays, strict=True)]
[" ".join(vals) for vals in zip(*string_arrays, strict=True)],
dtype=np.str_,
)
strings = np.asanyarray(a=strings, dtype=np.str_)
self.put_strings(
dataset, family="GMT_IS_VECTOR|GMT_IS_DUPLICATE", strings=strings
)
Expand Down
26 changes: 26 additions & 0 deletions pygmt/tests/test_clib_to_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pytest
from packaging.version import Version
from pygmt.clib.conversion import _to_numpy
from pygmt.helpers.testing import skip_if_no

try:
import pyarrow as pa
Expand Down Expand Up @@ -174,6 +175,31 @@ def test_to_numpy_pandas_series_numpy_dtypes_numeric(dtype, expected_dtype):
npt.assert_array_equal(result, series)


@pytest.mark.parametrize(
"dtype",
[
None,
np.str_,
"U10",
"string[python]",
pytest.param("string[pyarrow]", marks=skip_if_no(package="pyarrow")),
pytest.param("string[pyarrow_numpy]", marks=skip_if_no(package="pyarrow")),
],
)
def test_to_numpy_pandas_series_pandas_dtypes_string(dtype):
"""
Test the _to_numpy function with pandas.Series of pandas string types.
In pandas, string arrays can be specified in multiple ways.
Reference: https://pandas.pydata.org/docs/reference/api/pandas.StringDtype.html
"""
array = pd.Series(["abc", "defg", "12345"], dtype=dtype)
result = _to_numpy(array)
_check_result(result, np.str_)
npt.assert_array_equal(result, array)


@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed")
@pytest.mark.parametrize(
("dtype", "expected_dtype"),
Expand Down

0 comments on commit 3d08919

Please sign in to comment.