Skip to content

Commit

Permalink
Add test cases, support annotated assignment, fix list assignments (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelzw authored Mar 15, 2024
1 parent c1dd2cb commit d39cd4a
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 30 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
[pypi]: https://pypi.org/project/polarify
[pypi-badge]: https://img.shields.io/pypi/v/polarify.svg?style=flat-square&logo=pypi&logoColor=white
[python-version-badge]: https://img.shields.io/pypi/pyversions/polarify?style=flat-square&logoColor=white&logo=python
[codecov-badge]: https://codecov.io/gh/quantco/polarify/branch/main/graph/badge.svg
[codecov-badge]: https://img.shields.io/codecov/c/github/quantco/polarify?style=flat-square&logo=codecov
[codecov]: https://codecov.io/gh/quantco/polarify

Welcome to **polarIFy**, a Python function decorator that simplifies the way you write logical statements for Polars. With polarIFy, you can use Python's language structures like `if / elif / else` statements and transform them into `pl.when(..).then(..).otherwise(..)` statements. This makes your code more readable and less cumbersome to write. 🎉
Expand Down
32 changes: 7 additions & 25 deletions polarify/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def generic_visit(self, node):
@dataclass
class UnresolvedState:
"""
When a execution flow is not finished (i.e., not returned) in a function, we need to keep track
When an execution flow is not finished (i.e., not returned) in a function, we need to keep track
of the assignments.
"""

Expand All @@ -101,8 +101,7 @@ def _handle_assign(stmt: ast.Assign, assignments: dict[str, ast.expr]):
)
assert len(t.elts) == len(stmt.value.elts)
for sub_t, sub_v in zip(t.elts, stmt.value.elts):
diff = _handle_assign(ast.Assign(targets=[sub_t], value=sub_v), assignments)
assignments.update(diff)
_handle_assign(ast.Assign(targets=[sub_t], value=sub_v), assignments)
else:
raise ValueError(
f"Unsupported expression type inside assignment target: {type(t)}"
Expand Down Expand Up @@ -140,7 +139,10 @@ class State:

node: UnresolvedState | ReturnState | ConditionalState

def handle_assign(self, expr: ast.Assign):
def handle_assign(self, expr: ast.Assign | ast.AnnAssign):
if isinstance(expr, ast.AnnAssign):
expr = ast.Assign(targets=[expr.target], value=expr.value)

if isinstance(self.node, UnresolvedState):
self.node.handle_assign(expr)
elif isinstance(self.node, ConditionalState):
Expand All @@ -167,33 +169,13 @@ def handle_return(self, value: ast.expr):
self.node.then.handle_return(value)
self.node.orelse.handle_return(value)

def check_all_branches_return(self):
if isinstance(self.node, UnresolvedState):
return False
elif isinstance(self.node, ReturnState):
return True
else:
return (
self.node.then.check_all_branches_return()
and self.node.orelse.check_all_branches_return()
)


def is_returning_body(stmts: list[ast.stmt]) -> bool:
for s in stmts:
if isinstance(s, ast.Return):
return True
elif isinstance(s, ast.If):
return bool(is_returning_body(s.body) and is_returning_body(s.orelse))
return False


def parse_body(full_body: list[ast.stmt], assignments: dict[str, ast.expr] | None = None) -> State:
if assignments is None:
assignments = {}
state = State(UnresolvedState(assignments))
for stmt in full_body:
if isinstance(stmt, ast.Assign):
if isinstance(stmt, (ast.Assign, ast.AnnAssign)):
state.handle_assign(stmt)
elif isinstance(stmt, ast.If):
state.handle_if(stmt)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "hatchling.build"
[project]
name = "polarify"
description = "Simplifying conditional Polars Expressions with Python 🐍 🐻‍❄️"
version = "0.1.4"
version = "0.1.5"
readme = "README.md"
license = "MIT"
requires-python = ">=3.9"
Expand Down
89 changes: 89 additions & 0 deletions tests/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,71 @@ def walrus_expr(x):
return s * y


def return_nothing(x):
if x > 0:
return
else:
return 1


def no_return(x):
s = x


def return_end(x):
s = x
return


def annotated_assign(x):
s: int = 15
return s + x


def conditional_assign(x):
s = 1
if x > 0:
s = 2
b = 3
return b


def return_constant(x):
return 1


def return_constant_2(x):
return 1 + 2


def return_unconditional_constant(x):
if x > 0:
s = 1
else:
s = 2
return 1


def return_constant_additional_assignments(x):
s = 2
return 1


def return_conditional_constant(x):
if x > 0:
return 1
return 0


def multiple_if(x):
s = 1
if x > 0:
s = 2
if x > 1:
s = 3
return s


def multiple_if_else(x):
if x > 0:
s = 1
Expand Down Expand Up @@ -179,6 +244,16 @@ def multiple_equals(x):
return x + a + b


def tuple_assignments(x):
a, b = 1, x
return x + a + b


def list_assignments(x):
[a, b] = 1, x
return x + a + b


functions = [
signum,
early_return,
Expand All @@ -199,14 +274,28 @@ def multiple_equals(x):
signum_no_default,
nested_partial_return_with_assignments,
multiple_equals,
tuple_assignments,
list_assignments,
annotated_assign,
conditional_assign,
multiple_if,
return_unconditional_constant,
return_conditional_constant,
]

xfail_functions = [
walrus_expr,
# our test setup does not work with literal expressions
return_constant,
return_constant_2,
return_constant_additional_assignments,
]

unsupported_functions = [
# function, match string in error message
(chained_compare_expr, "Polars can't handle chained comparisons"),
(bool_op, "ast.BoolOp"), # TODO: make error message more specific
(return_end, "return needs a value"),
(no_return, "Not all branches return"),
(return_nothing, "return needs a value"),
]
6 changes: 3 additions & 3 deletions tests/test_parse_body.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
params=functions
+ [pytest.param(f, marks=pytest.mark.xfail(reason="not implemented")) for f in xfail_functions],
)
def test_funcs(request):
def funcs(request):
original_func = request.param
transformed_func = polarify(original_func)
original_func_unparsed = inspect.getsource(original_func)
Expand All @@ -41,9 +41,9 @@ def test_funcs(request):
chunked=False if pl_version < Version("0.18.1") else None,
)
)
def test_transform_function(df: polars.DataFrame, test_funcs):
def test_transform_function(df: polars.DataFrame, funcs):
x = polars.col("x")
transformed_func, original_func = test_funcs
transformed_func, original_func = funcs

if pl_version < Version("0.19.0"):
df_with_transformed_func = df.select(transformed_func(x).alias("apply"))
Expand Down

0 comments on commit d39cd4a

Please sign in to comment.