diff --git a/README.md b/README.md index b3a72cd..6072e63 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ This gets transformed into: def signum(x: pl.Expr) -> pl.Expr: return pl.when(x > 0).then(1).otherwise(pl.when(x < 0).then(-1).otherwise(0)) ``` + ### Handling Multiple Statements polarIFy can also handle multiple statements like: @@ -125,6 +126,38 @@ print(result) # └─────┴─────────┘ ``` +### Displaying the transpiled polars expression + +You can also display the transpiled polars expression by calling the `transform_func_to_new_source` method: + +```python +from polarify import transform_func_to_new_source + +def signum(x): + s = 0 + if x > 0: + s = 1 + elif x < 0: + s = -1 + return s + + +print(f"Original function:\n{inspect.getsource(signum)}") +# Original function: +# def signum(x): +# s = 0 +# if x > 0: +# s = 1 +# elif x < 0: +# s = -1 +# return s +print(f"Transformed function:\n{transform_func_to_new_source(signum)}") +# Transformed function: +# def signum_polarified(x): +# import polars as pl +# return pl.when(x > 0).then(1).otherwise(pl.when(x < 0).then(-1).otherwise(0)) +``` + TODO: complicated example with nested functions ## ⚙️ How It Works diff --git a/pixi.toml b/pixi.toml index 7a0438b..31a28a1 100644 --- a/pixi.toml +++ b/pixi.toml @@ -2,7 +2,7 @@ # https://github.com/prefix-dev/pixi/issues/79 [project] name = "polarify" -version = "0.1.1" +version = "0.1.2" description = "Simplifying conditional Polars Expressions with Python 🐍 🐻‍❄️" authors = ["Bela Stoyan ", "Pavel Zwerschke "] channels = ["conda-forge"] diff --git a/polarify/__init__.py b/polarify/__init__.py index ab72542..9d16ea3 100644 --- a/polarify/__init__.py +++ b/polarify/__init__.py @@ -2,14 +2,16 @@ import inspect from functools import wraps -from .main import parse_body +from .main import parse_body, transform_tree_into_expr def transform_func_to_new_source(func) -> str: source = inspect.getsource(func) tree = ast.parse(source) func_def: ast.FunctionDef = tree.body[0] # type: ignore - expr = parse_body(func_def.body) + root_node = parse_body(func_def.body) + + expr = transform_tree_into_expr(root_node) # Replace the body of the function with the parsed expr # Also import polars as pl since this is used in the generated code diff --git a/polarify/main.py b/polarify/main.py index c4e44be..0a533f4 100644 --- a/polarify/main.py +++ b/polarify/main.py @@ -1,14 +1,16 @@ +from __future__ import annotations + import ast -from copy import copy -from typing import Union +from copy import copy, deepcopy +from dataclasses import dataclass # TODO: make walrus throw ValueError -# TODO: Switch - -Assignments = dict[str, ast.expr] +# TODO: match ... case -def build_polars_when_then_otherwise(test: ast.expr, then: ast.expr, orelse: ast.expr): +def build_polars_when_then_otherwise( + test: ast.expr, then: ast.expr, orelse: ast.expr +) -> ast.Call: when_node = ast.Call( func=ast.Attribute( value=ast.Name(id="pl", ctx=ast.Load()), attr="when", ctx=ast.Load() @@ -32,45 +34,47 @@ def build_polars_when_then_otherwise(test: ast.expr, then: ast.expr, orelse: ast # ruff: noqa: N802 class InlineTransformer(ast.NodeTransformer): - def __init__(self, assignments: Assignments): + def __init__(self, assignments: dict[str, ast.expr]): self.assignments = assignments @classmethod - def inline_expr(cls, expr: ast.expr, assignments: Assignments) -> ast.expr: - return cls(assignments).visit(expr) + def inline_expr(cls, expr: ast.expr, assignments: dict[str, ast.expr]) -> ast.expr: + expr = cls(assignments).visit(deepcopy(expr)) + assert isinstance(expr, ast.expr) + return expr - def visit_Name(self, node): + def visit_Name(self, node: ast.Name) -> ast.expr: if node.id in self.assignments: return self.visit(self.assignments[node.id]) else: return node - def visit_BinOp(self, node): + def visit_BinOp(self, node: ast.BinOp) -> ast.BinOp: node.left = self.visit(node.left) node.right = self.visit(node.right) return node - def visit_UnaryOp(self, node): + def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.UnaryOp: node.operand = self.visit(node.operand) return node - def visit_Call(self, node): + def visit_Call(self, node: ast.Call) -> ast.Call: node.args = [self.visit(arg) for arg in node.args] node.keywords = [ ast.keyword(arg=k.arg, value=self.visit(k.value)) for k in node.keywords ] return node - def visit_IfExp(self, node): + def visit_IfExp(self, node: ast.IfExp) -> ast.Call: test = self.visit(node.test) body = self.visit(node.body) orelse = self.visit(node.orelse) return build_polars_when_then_otherwise(test, body, orelse) - def visit_Constant(self, node): + def visit_Constant(self, node: ast.Constant) -> ast.Constant: return node - def visit_Compare(self, node): + def visit_Compare(self, node: ast.Compare) -> ast.Compare: if len(node.comparators) > 1: raise ValueError("Polars can't handle chained comparisons") node.left = self.visit(node.left) @@ -81,151 +85,147 @@ def generic_visit(self, node): raise ValueError(f"Unsupported expression type: {type(node)}") -def is_returning_body(stmts: list[ast.stmt]) -> bool: - for s in stmts: - if isinstance(s, ast.Return): - return True - elif isinstance(s, ast.If): - if is_returning_body(s.body) and is_returning_body(s.orelse): - return True - elif is_returning_body(s.body) ^ is_returning_body(s.orelse): - # TODO: investigate - raise ValueError( - "All branches of a If statement must either return or not for now" - ) - return False - - -def handle_assign(stmt: ast.Assign, assignments: Assignments) -> Assignments: - assignments = copy(assignments) - diff_assignments = {} - - for t in stmt.targets: - if isinstance(t, ast.Name): - new_value = InlineTransformer.inline_expr(stmt.value, assignments) - assignments[t.id] = new_value - diff_assignments[t.id] = new_value - elif isinstance(t, (ast.List, ast.Tuple)): - assert ( - isinstance(stmt.value, ast.Tuple) - or isinstance(stmt.value, ast.List) - and len(t.elts) == len(stmt.value.elts) +@dataclass +class UnresolvedState: + """ + When a execution flow is not finished (i.e., not returned) in a function, we need to keep track + of the assignments. + """ + + assignments: dict[str, ast.expr] + + def handle_assign(self, stmt: ast.Assign): + def _handle_assign(stmt: ast.Assign, assignments: dict[str, ast.expr]): + for t in stmt.targets: + if isinstance(t, ast.Name): + new_value = InlineTransformer.inline_expr(stmt.value, assignments) + assignments[t.id] = new_value + elif isinstance(t, (ast.List, ast.Tuple)): + if not isinstance(stmt.value, (ast.List, ast.Tuple)): + raise ValueError( + f"Assignment target is {type(t)}, but value is {type(stmt.value)}" + ) + assert len(t.elts) == len(stmt.value.elts) + for sub_t, sub_v in zip(t.elts, stmt.value.elts): + diff = _handle_assign( + ast.Assign(targets=[sub_t], value=sub_v), assignments + ) + assignments.update(diff) + else: + raise ValueError( + f"Unsupported expression type inside assignment target: {type(t)}" + ) + + _handle_assign(stmt, self.assignments) + + +@dataclass +class ReturnState: + """ + The expression of a return statement. + """ + + expr: ast.expr + + +@dataclass +class ConditionalState: + """ + A conditional state, with a test expression and two branches. + """ + + test: ast.expr + then: State + orelse: State + + +@dataclass +class State: + """ + A state in the execution flow. + Either unresolved assignments, a return statement, or a conditional state. + """ + + node: UnresolvedState | ReturnState | ConditionalState + + def handle_assign(self, expr: ast.Assign): + if isinstance(self.node, UnresolvedState): + self.node.handle_assign(expr) + elif isinstance(self.node, ConditionalState): + self.node.then.handle_assign(expr) + self.node.orelse.handle_assign(expr) + + def handle_if(self, stmt: ast.If): + if isinstance(self.node, UnresolvedState): + self.node = ConditionalState( + test=InlineTransformer.inline_expr(stmt.test, self.node.assignments), + then=parse_body(stmt.body, copy(self.node.assignments)), + orelse=parse_body(stmt.orelse, copy(self.node.assignments)), ) - for sub_t, sub_v in zip(t.elts, stmt.value.elts): - diff = handle_assign( - ast.Assign(targets=[sub_t], value=sub_v), assignments - ) - assignments.update(diff) - diff_assignments.update(diff) - else: - raise ValueError( - f"Unsupported expression type inside assignment target: {type(t)}" + elif isinstance(self.node, ConditionalState): + self.node.then.handle_if(stmt) + self.node.orelse.handle_if(stmt) + + def handle_return(self, value: ast.expr): + if isinstance(self.node, UnresolvedState): + self.node = ReturnState( + expr=InlineTransformer.inline_expr(value, self.node.assignments) ) - return diff_assignments - - -def handle_non_returning_if(stmt: ast.If, assignments: Assignments) -> Assignments: - assignments = copy(assignments) - assert not is_returning_body(stmt.orelse) and not is_returning_body(stmt.body) - test = InlineTransformer.inline_expr(stmt.test, assignments) - - diff_assignments = {} - all_vars_changed_in_body = get_all_vars_changed_in_body(stmt.body, assignments) - all_vars_changed_in_orelse = get_all_vars_changed_in_body(stmt.orelse, assignments) - - def updated_or_default_assignments(var: str, diff: Assignments) -> ast.expr: - if var in diff: - return diff[var] - elif var in assignments: - return assignments[var] + elif isinstance(self.node, ConditionalState): + self.node.then.handle_return(value) + self.node.orelse.handle_return(value) + + def check_all_branches_return(self): + if isinstance(self.node, UnresolvedState): + return False + elif isinstance(self.node, ReturnState): + return True else: - raise ValueError( - f"Variable {var} has to be either defined in" - " all branches or have a previous defintion" + return ( + self.node.then.check_all_branches_return() + and self.node.orelse.check_all_branches_return() ) - for var in all_vars_changed_in_body | all_vars_changed_in_orelse: - expr = build_polars_when_then_otherwise( - test, - updated_or_default_assignments(var, all_vars_changed_in_body), - updated_or_default_assignments(var, all_vars_changed_in_orelse), - ) - assignments[var] = expr - diff_assignments[var] = expr - return diff_assignments - - -def get_all_vars_changed_in_body( - body: list[ast.stmt], assignments: Assignments -) -> Assignments: - assignments = copy(assignments) - diff_assignments = {} - - for s in body: - if isinstance(s, ast.Assign): - diff = handle_assign(s, assignments) - assignments.update(diff) - diff_assignments.update(diff) - elif isinstance(s, ast.If): - if_diff = handle_non_returning_if(s, assignments) - assignments.update(if_diff) - diff_assignments.update(if_diff) - elif isinstance(s, ast.Return): - raise ValueError("This should not happen.") - else: - raise ValueError(f"Unsupported statement type: {type(s)}") - return diff_assignments +def is_returning_body(stmts: list[ast.stmt]) -> bool: + for s in stmts: + if isinstance(s, ast.Return): + return True + elif isinstance(s, ast.If): + return bool(is_returning_body(s.body) and is_returning_body(s.orelse)) + return False def parse_body( - full_body: list[ast.stmt], assignments: Union[Assignments, None] = None -) -> ast.expr: + full_body: list[ast.stmt], assignments: dict[str, ast.expr] | None = None +) -> State: if assignments is None: assignments = {} - assignments = copy(assignments) - assert len(full_body) > 0 - for i, stmt in enumerate(full_body): + state = State(UnresolvedState(assignments)) + for stmt in full_body: if isinstance(stmt, ast.Assign): - # update assignments - assignments.update(handle_assign(stmt, assignments)) + state.handle_assign(stmt) elif isinstance(stmt, ast.If): - if is_returning_body(stmt.body) and is_returning_body(stmt.orelse): - test = InlineTransformer.inline_expr(stmt.test, assignments) - body = parse_body(stmt.body, assignments) - orelse = parse_body(stmt.orelse, assignments) - return build_polars_when_then_otherwise(test, body, orelse) - elif is_returning_body(stmt.body): - test = InlineTransformer.inline_expr(stmt.test, assignments) - body = parse_body(stmt.body, assignments) - orelse_everything = parse_body( - stmt.orelse + full_body[i + 1 :], assignments - ) - return build_polars_when_then_otherwise(test, body, orelse_everything) - elif is_returning_body(stmt.orelse): - test = ast.Call( - func=ast.Attribute( - value=InlineTransformer.inline_expr(stmt.test, assignments), - attr="not", - ctx=ast.Load(), - ), - args=[], - keywords=[], - ) - orelse = parse_body(stmt.orelse, assignments) - body_everything = parse_body( - stmt.body + full_body[i + 1 :], assignments - ) - return build_polars_when_then_otherwise(test, orelse, body_everything) - else: - diff = handle_non_returning_if(stmt, assignments) - assignments.update(diff) - + state.handle_if(stmt) elif isinstance(stmt, ast.Return): if stmt.value is None: raise ValueError("return needs a value") - # Handle return statements - return InlineTransformer.inline_expr(stmt.value, assignments) + + state.handle_return(stmt.value) + break else: raise ValueError(f"Unsupported statement type: {type(stmt)}") - raise ValueError("Missing return statement") + return state + + +def transform_tree_into_expr(node: State) -> ast.expr: + if isinstance(node.node, ReturnState): + return node.node.expr + elif isinstance(node.node, ConditionalState): + return build_polars_when_then_otherwise( + node.node.test, + transform_tree_into_expr(node.node.then), + transform_tree_into_expr(node.node.orelse), + ) + else: + raise ValueError("Not all branches return") diff --git a/pyproject.toml b/pyproject.toml index 2144013..d840dca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "hatchling.build" [project] name = "polarify" description = "Simplifying conditional Polars Expressions with Python 🐍 🐻‍❄️" -version = "0.1.1" +version = "0.1.2" readme = "README.md" license = "MIT" requires-python = ">=3.9" diff --git a/tests/functions.py b/tests/functions.py index 7f47eca..0778b2c 100644 --- a/tests/functions.py +++ b/tests/functions.py @@ -19,6 +19,19 @@ def signum_no_default(x): return 0 +def nested_partial_return_with_assignments(x): + if x > 0: + s = 1 + if x > 1: + s = 2 + return s + x + else: + s = -1 + else: + return -5 - x + return s * x + + def early_return(x): if x > 0: return 1 @@ -71,6 +84,13 @@ def compare_expr(x): return s +def bool_op(x): + if (0 < x) and (x < 10): + return 0 + else: + return 1 + + def chained_compare_expr(x): if 0 < x < 10: s = 1 @@ -171,9 +191,16 @@ def two_if_expr(x): override_default, no_if_else, two_if_expr, + signum_no_default, + nested_partial_return_with_assignments, ] xfail_functions = [ walrus_expr, - signum_no_default, +] + +unsupported_functions = [ + # function, match string in error message + (chained_compare_expr, "Polars can't handle chained comparisons"), + (bool_op, "ast.BoolOp"), # TODO: make error message more specific ] diff --git a/tests/test_error_handling.py b/tests/test_error_handling.py index ec028c0..1fa59c0 100644 --- a/tests/test_error_handling.py +++ b/tests/test_error_handling.py @@ -2,9 +2,11 @@ from polarify import polarify -from .functions import chained_compare_expr +from .functions import unsupported_functions -def test_chained_compare_fail(): - with pytest.raises(ValueError): - polarify(chained_compare_expr) +@pytest.mark.parametrize("func_match", unsupported_functions) +def test_unsupported_functions(func_match): + func, match = func_match + with pytest.raises(ValueError, match=match): + polarify(func)