Skip to content

Commit

Permalink
Polars version update
Browse files Browse the repository at this point in the history
  • Loading branch information
hexane360 committed Oct 4, 2024
1 parent aba34c1 commit 5351c33
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 25 deletions.
13 changes: 5 additions & 8 deletions atomlib/atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,7 @@ def _with_columns_stacked(df: polars.DataFrame, cols: t.Sequence[str], out_col:
i = df.get_column_index(cols[0])
dtype = df[cols[0]].dtype

# https://github.com/pola-rs/polars/issues/18369
arr = [] if len(df) == 0 else numpy.array(tuple(df[c].to_numpy() for c in cols)).T

arr = numpy.array(tuple(df[c].to_numpy() for c in cols)).T
return df.drop(cols).insert_column(i, polars.Series(out_col, arr, polars.Array(dtype, len(cols))))


Expand Down Expand Up @@ -402,7 +400,7 @@ def concat(cls: t.Type[HasAtomsT],
# this method is tricky. It needs to accept raw Atoms, as well as HasAtoms of the
# same type as ``cls``.
if _is_abstract(cls):
raise TypeError(f"concat() must be called on a concrete class.")
raise TypeError("concat() must be called on a concrete class.")

if isinstance(atoms, HasAtoms):
atoms = (atoms,)
Expand All @@ -420,7 +418,7 @@ def concat(cls: t.Type[HasAtomsT],
cols = reduce(operator.and_, (df.schema.keys() for df in dfs))
schema = OrderedDict((col, dfs[0].schema[col]) for col in cols)
if len(schema) == 0:
raise ValueError(f"Atoms have no columns in common")
raise ValueError("Atoms have no columns in common")

dfs = [_select_schema(df, schema) for df in dfs]
how = 'vertical'
Expand Down Expand Up @@ -752,7 +750,7 @@ def add_atom(self, elem: t.Union[int, str], /,
if hasattr(x, '__len__') and len(x) > 1: # type: ignore
(x, y, z) = to_vec3(x)
elif y is None or z is None:
raise ValueError(f"Must specify vector of positions or x, y, & z.")
raise ValueError("Must specify vector of positions or x, y, & z.")

sym = get_sym(elem) if isinstance(elem, int) else elem
d: t.Dict[str, t.Any] = {'x': x, 'y': y, 'z': z, 'symbol': sym, **kwargs}
Expand Down Expand Up @@ -963,8 +961,7 @@ def with_velocity(self, pts: t.Optional[ArrayLike] = None,
assert pts.shape[-1] == 3
all_pts[selection] = pts

# https://github.com/pola-rs/polars/issues/18369
all_pts = numpy.broadcast_to(all_pts, (len(self), 3)) if len(self) else []
all_pts = numpy.broadcast_to(all_pts, (len(self), 3))
return self.with_columns(polars.Series('velocity', all_pts, polars.Array(polars.Float64, 3)))


Expand Down
34 changes: 22 additions & 12 deletions atomlib/elem.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
except ImportError:
from polars.exceptions import PolarsPanicError as PanicException # type: ignore

from polars.exceptions import InvalidOperationError

Check warning on line 14 in atomlib/elem.py

View workflow job for this annotation

GitHub Actions / Type Check

Import "InvalidOperationError" is not accessed (reportUnusedImport)

Check warning on line 14 in atomlib/elem.py

View workflow job for this annotation

GitHub Actions / Type Check

Import "InvalidOperationError" is not accessed (reportUnusedImport)

Check warning on line 14 in atomlib/elem.py

View workflow job for this annotation

GitHub Actions / Type Check

Import "InvalidOperationError" is not accessed (reportUnusedImport)

from .types import ElemLike

ELEMENTS = {
Expand Down Expand Up @@ -42,6 +44,7 @@
'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og',
]
assert len(ELEMENTS) == len(ELEMENT_SYMBOLS)
ELEMENT_SYMBOLS_POLARS = polars.Series([ELEMENT_SYMBOLS], dtype=polars.List(polars.Utf8))

DATA_PATH = files('atomlib.data')
_ELEMENT_MASSES: t.Optional[numpy.ndarray] = None
Expand Down Expand Up @@ -82,11 +85,14 @@ def get_elem(sym: t.Union[int, str, polars.Series]):
return sym

if isinstance(sym, polars.Series):
elem = sym.str.extract(_SYM_RE, 0).str.to_lowercase() \
.replace_strict(ELEMENTS, default=255, return_dtype=polars.UInt8) \
.alias('elem')

if (invalid := sym.filter(sym.is_not_null() & (elem > 118)).to_list()):
# TODO: this is a mess
elem = sym.cast(polars.Utf8).str.extract(_SYM_RE, 0).str.to_lowercase() \
.replace_strict(
old=list(ELEMENTS.keys()), new=list(ELEMENTS.values()),
default=None, return_dtype=polars.Int8
).alias('elem')

if (invalid := sym.filter(sym.is_not_null() & elem.is_null()).to_list()):
raise ValueError(f"Invalid element symbol(s) '{', '.join(map(str, invalid))}'")

return elem
Expand Down Expand Up @@ -124,13 +130,17 @@ def get_sym(elem: polars.Series) -> polars.Series:

def get_sym(elem: t.Union[int, polars.Series]):
if isinstance(elem, polars.Series):
try:
return elem.map_elements(_get_sym, return_dtype=polars.Utf8, skip_nulls=True) \
.alias('symbol')
except PanicException:
# attempt to recreate the error in Python
_ = [_get_sym(t.cast(int, e)) for e in elem.to_list() if e is not None]
raise
sym = elem.cast(polars.Int64).replace_strict(
list(range(1, len(ELEMENT_SYMBOLS)+1)),
ELEMENT_SYMBOLS,
default=None,
return_dtype=polars.Utf8,
).alias('symbol')

if (invalid := elem.filter(elem.is_not_null() & sym.is_null()).unique().to_list()):
raise ValueError(f"Invalid atomic number(s) {', '.join(map(str, invalid))}")

return sym

return _get_sym(elem)

Expand Down
2 changes: 1 addition & 1 deletion atomlib/io/test_xyz.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_xyz_invalid():
O 1.36 4.08 4.08
120 1.36 4.08 4.08
"""
with pytest.raises(ValueError, match="Invalid atomic number 120"):
with pytest.raises(ValueError, match=re.escape("Invalid atomic number(s) 120")):
XYZ.from_file(StringIO(xyz_in))

xyz_in = \
Expand Down
13 changes: 10 additions & 3 deletions atomlib/test_elem.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
import numpy
import polars
from polars.testing import assert_series_equal

from .elem import get_elem, get_elems, get_sym, get_mass
from .elem import get_radius, get_ionic_radius
Expand Down Expand Up @@ -49,7 +50,10 @@ def test_get_elem_series():
def test_get_elem_series_nulls():
sym = polars.Series(['Al', None, 'Ag', 'Na'])

assert (get_elem(sym) == polars.Series([13, None, 47, 11])).all()
assert_series_equal(
get_elem(sym),
polars.Series('elem', [13, None, 47, 11], polars.Int8)
)


def test_get_sym_series():
Expand All @@ -62,7 +66,10 @@ def test_get_sym_series():
def test_get_sym_series_nulls():
elem = polars.Series((74, 102, 62, None, 19))

assert (get_sym(elem) == polars.Series(["W", "No", "Sm", None, "K"])).all()
assert_series_equal(
get_sym(elem),
polars.Series('symbol', ["W", "No", "Sm", None, "K"], polars.Utf8)
)


def test_get_elem_fail():
Expand All @@ -83,7 +90,7 @@ def test_get_sym_fail():
with pytest.raises(ValueError, match="Invalid atomic number 255"):
get_sym(255)

with pytest.raises(ValueError, match="Invalid atomic number 255"):
with pytest.raises(ValueError, match=re.escape("Invalid atomic number(s) 255")):
get_sym(polars.Series([12, 14, 255, 1]))


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies = [
"click~=8.1", # for cli
"numpy>=1.22,<2.3.0", # tested on 2.0.0
"scipy~=1.8",
"polars~=1.7.1",
"polars~=1.9.0",
"matplotlib~=3.5",
"requests~=2.28",
"lxml~=5.0",
Expand Down

0 comments on commit 5351c33

Please sign in to comment.