diff --git a/atomlib/atoms.py b/atomlib/atoms.py index 159dd31..4f54d7f 100644 --- a/atomlib/atoms.py +++ b/atomlib/atoms.py @@ -137,9 +137,7 @@ def _with_columns_stacked(df: polars.DataFrame, cols: t.Sequence[str], out_col: i = df.get_column_index(cols[0]) dtype = df[cols[0]].dtype - # https://github.com/pola-rs/polars/issues/18369 - arr = [] if len(df) == 0 else numpy.array(tuple(df[c].to_numpy() for c in cols)).T - + arr = numpy.array(tuple(df[c].to_numpy() for c in cols)).T return df.drop(cols).insert_column(i, polars.Series(out_col, arr, polars.Array(dtype, len(cols)))) @@ -402,7 +400,7 @@ def concat(cls: t.Type[HasAtomsT], # this method is tricky. It needs to accept raw Atoms, as well as HasAtoms of the # same type as ``cls``. if _is_abstract(cls): - raise TypeError(f"concat() must be called on a concrete class.") + raise TypeError("concat() must be called on a concrete class.") if isinstance(atoms, HasAtoms): atoms = (atoms,) @@ -420,7 +418,7 @@ def concat(cls: t.Type[HasAtomsT], cols = reduce(operator.and_, (df.schema.keys() for df in dfs)) schema = OrderedDict((col, dfs[0].schema[col]) for col in cols) if len(schema) == 0: - raise ValueError(f"Atoms have no columns in common") + raise ValueError("Atoms have no columns in common") dfs = [_select_schema(df, schema) for df in dfs] how = 'vertical' @@ -752,7 +750,7 @@ def add_atom(self, elem: t.Union[int, str], /, if hasattr(x, '__len__') and len(x) > 1: # type: ignore (x, y, z) = to_vec3(x) elif y is None or z is None: - raise ValueError(f"Must specify vector of positions or x, y, & z.") + raise ValueError("Must specify vector of positions or x, y, & z.") sym = get_sym(elem) if isinstance(elem, int) else elem d: t.Dict[str, t.Any] = {'x': x, 'y': y, 'z': z, 'symbol': sym, **kwargs} @@ -963,8 +961,7 @@ def with_velocity(self, pts: t.Optional[ArrayLike] = None, assert pts.shape[-1] == 3 all_pts[selection] = pts - # https://github.com/pola-rs/polars/issues/18369 - all_pts = numpy.broadcast_to(all_pts, (len(self), 3)) if len(self) else [] + all_pts = numpy.broadcast_to(all_pts, (len(self), 3)) return self.with_columns(polars.Series('velocity', all_pts, polars.Array(polars.Float64, 3))) diff --git a/atomlib/elem.py b/atomlib/elem.py index bfd145c..92ca416 100644 --- a/atomlib/elem.py +++ b/atomlib/elem.py @@ -11,6 +11,8 @@ except ImportError: from polars.exceptions import PolarsPanicError as PanicException # type: ignore +from polars.exceptions import InvalidOperationError + from .types import ElemLike ELEMENTS = { @@ -42,6 +44,7 @@ 'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og', ] assert len(ELEMENTS) == len(ELEMENT_SYMBOLS) +ELEMENT_SYMBOLS_POLARS = polars.Series([ELEMENT_SYMBOLS], dtype=polars.List(polars.Utf8)) DATA_PATH = files('atomlib.data') _ELEMENT_MASSES: t.Optional[numpy.ndarray] = None @@ -82,11 +85,14 @@ def get_elem(sym: t.Union[int, str, polars.Series]): return sym if isinstance(sym, polars.Series): - elem = sym.str.extract(_SYM_RE, 0).str.to_lowercase() \ - .replace_strict(ELEMENTS, default=255, return_dtype=polars.UInt8) \ - .alias('elem') - - if (invalid := sym.filter(sym.is_not_null() & (elem > 118)).to_list()): + # TODO: this is a mess + elem = sym.cast(polars.Utf8).str.extract(_SYM_RE, 0).str.to_lowercase() \ + .replace_strict( + old=list(ELEMENTS.keys()), new=list(ELEMENTS.values()), + default=None, return_dtype=polars.Int8 + ).alias('elem') + + if (invalid := sym.filter(sym.is_not_null() & elem.is_null()).to_list()): raise ValueError(f"Invalid element symbol(s) '{', '.join(map(str, invalid))}'") return elem @@ -124,13 +130,17 @@ def get_sym(elem: polars.Series) -> polars.Series: def get_sym(elem: t.Union[int, polars.Series]): if isinstance(elem, polars.Series): - try: - return elem.map_elements(_get_sym, return_dtype=polars.Utf8, skip_nulls=True) \ - .alias('symbol') - except PanicException: - # attempt to recreate the error in Python - _ = [_get_sym(t.cast(int, e)) for e in elem.to_list() if e is not None] - raise + sym = elem.cast(polars.Int64).replace_strict( + list(range(1, len(ELEMENT_SYMBOLS)+1)), + ELEMENT_SYMBOLS, + default=None, + return_dtype=polars.Utf8, + ).alias('symbol') + + if (invalid := elem.filter(elem.is_not_null() & sym.is_null()).unique().to_list()): + raise ValueError(f"Invalid atomic number(s) {', '.join(map(str, invalid))}") + + return sym return _get_sym(elem) diff --git a/atomlib/io/test_xyz.py b/atomlib/io/test_xyz.py index 04ea416..8131e18 100644 --- a/atomlib/io/test_xyz.py +++ b/atomlib/io/test_xyz.py @@ -28,7 +28,7 @@ def test_xyz_invalid(): O 1.36 4.08 4.08 120 1.36 4.08 4.08 """ - with pytest.raises(ValueError, match="Invalid atomic number 120"): + with pytest.raises(ValueError, match=re.escape("Invalid atomic number(s) 120")): XYZ.from_file(StringIO(xyz_in)) xyz_in = \ diff --git a/atomlib/test_elem.py b/atomlib/test_elem.py index 29cb71a..9703317 100644 --- a/atomlib/test_elem.py +++ b/atomlib/test_elem.py @@ -5,6 +5,7 @@ import pytest import numpy import polars +from polars.testing import assert_series_equal from .elem import get_elem, get_elems, get_sym, get_mass from .elem import get_radius, get_ionic_radius @@ -49,7 +50,10 @@ def test_get_elem_series(): def test_get_elem_series_nulls(): sym = polars.Series(['Al', None, 'Ag', 'Na']) - assert (get_elem(sym) == polars.Series([13, None, 47, 11])).all() + assert_series_equal( + get_elem(sym), + polars.Series('elem', [13, None, 47, 11], polars.Int8) + ) def test_get_sym_series(): @@ -62,7 +66,10 @@ def test_get_sym_series(): def test_get_sym_series_nulls(): elem = polars.Series((74, 102, 62, None, 19)) - assert (get_sym(elem) == polars.Series(["W", "No", "Sm", None, "K"])).all() + assert_series_equal( + get_sym(elem), + polars.Series('symbol', ["W", "No", "Sm", None, "K"], polars.Utf8) + ) def test_get_elem_fail(): @@ -83,7 +90,7 @@ def test_get_sym_fail(): with pytest.raises(ValueError, match="Invalid atomic number 255"): get_sym(255) - with pytest.raises(ValueError, match="Invalid atomic number 255"): + with pytest.raises(ValueError, match=re.escape("Invalid atomic number(s) 255")): get_sym(polars.Series([12, 14, 255, 1])) diff --git a/pyproject.toml b/pyproject.toml index e88c21f..99be3b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ dependencies = [ "click~=8.1", # for cli "numpy>=1.22,<2.3.0", # tested on 2.0.0 "scipy~=1.8", - "polars~=1.7.1", + "polars~=1.9.0", "matplotlib~=3.5", "requests~=2.28", "lxml~=5.0",