From d39cd4a1ff852fb1e60a368fed27c25901d1b8ef Mon Sep 17 00:00:00 2001 From: Pavel Zwerschke Date: Fri, 15 Mar 2024 19:46:05 +0100 Subject: [PATCH] Add test cases, support annotated assignment, fix list assignments (#59) --- README.md | 2 +- polarify/main.py | 32 ++++----------- pyproject.toml | 2 +- tests/functions.py | 89 ++++++++++++++++++++++++++++++++++++++++ tests/test_parse_body.py | 6 +-- 5 files changed, 101 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index 66b1a94..31463ba 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ [pypi]: https://pypi.org/project/polarify [pypi-badge]: https://img.shields.io/pypi/v/polarify.svg?style=flat-square&logo=pypi&logoColor=white [python-version-badge]: https://img.shields.io/pypi/pyversions/polarify?style=flat-square&logoColor=white&logo=python -[codecov-badge]: https://codecov.io/gh/quantco/polarify/branch/main/graph/badge.svg +[codecov-badge]: https://img.shields.io/codecov/c/github/quantco/polarify?style=flat-square&logo=codecov [codecov]: https://codecov.io/gh/quantco/polarify Welcome to **polarIFy**, a Python function decorator that simplifies the way you write logical statements for Polars. With polarIFy, you can use Python's language structures like `if / elif / else` statements and transform them into `pl.when(..).then(..).otherwise(..)` statements. This makes your code more readable and less cumbersome to write. 🎉 diff --git a/polarify/main.py b/polarify/main.py index dad210e..4f97ecd 100644 --- a/polarify/main.py +++ b/polarify/main.py @@ -82,7 +82,7 @@ def generic_visit(self, node): @dataclass class UnresolvedState: """ - When a execution flow is not finished (i.e., not returned) in a function, we need to keep track + When an execution flow is not finished (i.e., not returned) in a function, we need to keep track of the assignments. """ @@ -101,8 +101,7 @@ def _handle_assign(stmt: ast.Assign, assignments: dict[str, ast.expr]): ) assert len(t.elts) == len(stmt.value.elts) for sub_t, sub_v in zip(t.elts, stmt.value.elts): - diff = _handle_assign(ast.Assign(targets=[sub_t], value=sub_v), assignments) - assignments.update(diff) + _handle_assign(ast.Assign(targets=[sub_t], value=sub_v), assignments) else: raise ValueError( f"Unsupported expression type inside assignment target: {type(t)}" @@ -140,7 +139,10 @@ class State: node: UnresolvedState | ReturnState | ConditionalState - def handle_assign(self, expr: ast.Assign): + def handle_assign(self, expr: ast.Assign | ast.AnnAssign): + if isinstance(expr, ast.AnnAssign): + expr = ast.Assign(targets=[expr.target], value=expr.value) + if isinstance(self.node, UnresolvedState): self.node.handle_assign(expr) elif isinstance(self.node, ConditionalState): @@ -167,33 +169,13 @@ def handle_return(self, value: ast.expr): self.node.then.handle_return(value) self.node.orelse.handle_return(value) - def check_all_branches_return(self): - if isinstance(self.node, UnresolvedState): - return False - elif isinstance(self.node, ReturnState): - return True - else: - return ( - self.node.then.check_all_branches_return() - and self.node.orelse.check_all_branches_return() - ) - - -def is_returning_body(stmts: list[ast.stmt]) -> bool: - for s in stmts: - if isinstance(s, ast.Return): - return True - elif isinstance(s, ast.If): - return bool(is_returning_body(s.body) and is_returning_body(s.orelse)) - return False - def parse_body(full_body: list[ast.stmt], assignments: dict[str, ast.expr] | None = None) -> State: if assignments is None: assignments = {} state = State(UnresolvedState(assignments)) for stmt in full_body: - if isinstance(stmt, ast.Assign): + if isinstance(stmt, (ast.Assign, ast.AnnAssign)): state.handle_assign(stmt) elif isinstance(stmt, ast.If): state.handle_if(stmt) diff --git a/pyproject.toml b/pyproject.toml index fbd8f4d..5e25ee9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "hatchling.build" [project] name = "polarify" description = "Simplifying conditional Polars Expressions with Python 🐍 🐻‍❄️" -version = "0.1.4" +version = "0.1.5" readme = "README.md" license = "MIT" requires-python = ">=3.9" diff --git a/tests/functions.py b/tests/functions.py index 35e5fb3..6b17dc7 100644 --- a/tests/functions.py +++ b/tests/functions.py @@ -107,6 +107,71 @@ def walrus_expr(x): return s * y +def return_nothing(x): + if x > 0: + return + else: + return 1 + + +def no_return(x): + s = x + + +def return_end(x): + s = x + return + + +def annotated_assign(x): + s: int = 15 + return s + x + + +def conditional_assign(x): + s = 1 + if x > 0: + s = 2 + b = 3 + return b + + +def return_constant(x): + return 1 + + +def return_constant_2(x): + return 1 + 2 + + +def return_unconditional_constant(x): + if x > 0: + s = 1 + else: + s = 2 + return 1 + + +def return_constant_additional_assignments(x): + s = 2 + return 1 + + +def return_conditional_constant(x): + if x > 0: + return 1 + return 0 + + +def multiple_if(x): + s = 1 + if x > 0: + s = 2 + if x > 1: + s = 3 + return s + + def multiple_if_else(x): if x > 0: s = 1 @@ -179,6 +244,16 @@ def multiple_equals(x): return x + a + b +def tuple_assignments(x): + a, b = 1, x + return x + a + b + + +def list_assignments(x): + [a, b] = 1, x + return x + a + b + + functions = [ signum, early_return, @@ -199,14 +274,28 @@ def multiple_equals(x): signum_no_default, nested_partial_return_with_assignments, multiple_equals, + tuple_assignments, + list_assignments, + annotated_assign, + conditional_assign, + multiple_if, + return_unconditional_constant, + return_conditional_constant, ] xfail_functions = [ walrus_expr, + # our test setup does not work with literal expressions + return_constant, + return_constant_2, + return_constant_additional_assignments, ] unsupported_functions = [ # function, match string in error message (chained_compare_expr, "Polars can't handle chained comparisons"), (bool_op, "ast.BoolOp"), # TODO: make error message more specific + (return_end, "return needs a value"), + (no_return, "Not all branches return"), + (return_nothing, "return needs a value"), ] diff --git a/tests/test_parse_body.py b/tests/test_parse_body.py index d5073a2..78a7356 100644 --- a/tests/test_parse_body.py +++ b/tests/test_parse_body.py @@ -21,7 +21,7 @@ params=functions + [pytest.param(f, marks=pytest.mark.xfail(reason="not implemented")) for f in xfail_functions], ) -def test_funcs(request): +def funcs(request): original_func = request.param transformed_func = polarify(original_func) original_func_unparsed = inspect.getsource(original_func) @@ -41,9 +41,9 @@ def test_funcs(request): chunked=False if pl_version < Version("0.18.1") else None, ) ) -def test_transform_function(df: polars.DataFrame, test_funcs): +def test_transform_function(df: polars.DataFrame, funcs): x = polars.col("x") - transformed_func, original_func = test_funcs + transformed_func, original_func = funcs if pl_version < Version("0.19.0"): df_with_transformed_func = df.select(transformed_func(x).alias("apply"))