Skip to content

Commit

Permalink
Add support for match ... case (#60)
Browse files Browse the repository at this point in the history
Co-authored-by: Bela Stoyan <[email protected]>
Co-authored-by: Pavel Zwerschke <[email protected]>
  • Loading branch information
3 people authored May 30, 2024
1 parent 97fb1da commit 9aa94fa
Show file tree
Hide file tree
Showing 8 changed files with 2,241 additions and 1,348 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:

release:
name: Publish package
if: github.event_name == 'push' && github.ref_name == 'main' && needs.build.outputs.version-changed == 'true'
if: github.event_name == 'push' && github.repository == 'Quantco/polarify' && github.ref_name == 'main' && needs.build.outputs.version-changed == 'true'
needs: [build]
runs-on: ubuntu-latest
permissions:
Expand Down
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,17 @@ polarIFy is still in an early stage of development and doesn't support the full
- assignments (like `x = 1`)
- polars expressions (like `pl.col("x")`, TODO)
- side-effect free functions that return a polars expression (can be generated by `@polarify`) (TODO)
- `match` statements

### Unsupported operations

- `for` loops
- `while` loops
- `break` statements
- `:=` walrus operator
- `match ... case` statements (TODO)
- dictionary mappings in `match` statements
- list matching in `match` statements
- star patterns in `match statements
- functions with side-effects (`print`, `pl.write_csv`, ...)

## 🚀 Benchmarks
Expand Down
2,981 changes: 1,668 additions & 1,313 deletions pixi.lock

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ lint = "pre-commit run --all"

[environments]
default = ["test"]
pl014 = ["pl014", "py39", "test"]
pl015 = ["pl015", "py39", "test"]
pl016 = ["pl016", "py39", "test"]
pl017 = ["pl017", "py39", "test"]
pl018 = ["pl018", "py39", "test"]
pl019 = ["pl019", "py39", "test"]
pl020 = ["pl020", "py39", "test"]
pl014 = ["pl014", "py310", "test"]
pl015 = ["pl015", "py310", "test"]
pl016 = ["pl016", "py310", "test"]
pl017 = ["pl017", "py310", "test"]
pl018 = ["pl018", "py310", "test"]
pl019 = ["pl019", "py310", "test"]
pl020 = ["pl020", "py310", "test"]
py39 = ["py39", "test"]
py310 = ["py310", "test"]
py311 = ["py311", "test"]
Expand Down
235 changes: 210 additions & 25 deletions polarify/main.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,74 @@
from __future__ import annotations

import ast
import sys
from collections.abc import Sequence
from copy import copy, deepcopy
from dataclasses import dataclass

PY_39 = sys.version_info <= (3, 9)

# TODO: make walrus throw ValueError
# TODO: match ... case


def build_polars_when_then_otherwise(test: ast.expr, then: ast.expr, orelse: ast.expr) -> ast.Call:
when_node = ast.Call(
func=ast.Attribute(value=ast.Name(id="pl", ctx=ast.Load()), attr="when", ctx=ast.Load()),
args=[test],
keywords=[],
)
@dataclass
class UnresolvedCase:
"""
An unresolved case in a conditional statement. (if, match, etc.)
Each case consists of a test expression and a state.
The value of the state is not yet resolved.
"""

then_node = ast.Call(
func=ast.Attribute(value=when_node, attr="then", ctx=ast.Load()),
args=[then],
keywords=[],
)
test: ast.expr
state: State

def __init__(self, test: ast.expr, then: State):
self.test = test
self.state = then


@dataclass
class ResolvedCase:
"""
A resolved case in a conditional statement. (if, match, etc.)
Each case consists of a test expression and a state.
The value of the state is resolved.
"""

test: ast.expr
state: ast.expr

def __init__(self, test: ast.expr, then: ast.expr):
self.test = test
self.state = then

def __iter__(self):
return iter([self.test, self.state])


def build_polars_when_then_otherwise(body: Sequence[ResolvedCase], orelse: ast.expr) -> ast.Call:
nodes: list[ast.Call] = []

assert body or orelse, "No when-then cases provided."

for test, then in body:
when_node = ast.Call(
func=ast.Attribute(
value=nodes[-1] if nodes else ast.Name(id="pl", ctx=ast.Load()),
attr="when",
ctx=ast.Load(),
),
args=[test],
keywords=[],
)
then_node = ast.Call(
func=ast.Attribute(value=when_node, attr="then", ctx=ast.Load()),
args=[then],
keywords=[],
)
nodes.append(then_node)
final_node = ast.Call(
func=ast.Attribute(value=then_node, attr="otherwise", ctx=ast.Load()),
func=ast.Attribute(value=nodes[-1], attr="otherwise", ctx=ast.Load()),
args=[orelse],
keywords=[],
)
Expand Down Expand Up @@ -63,7 +110,7 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.Call:
test = self.visit(node.test)
body = self.visit(node.body)
orelse = self.visit(node.orelse)
return build_polars_when_then_otherwise(test, body, orelse)
return build_polars_when_then_otherwise([ResolvedCase(test, body)], orelse)

def visit_Constant(self, node: ast.Constant) -> ast.Constant:
return node
Expand Down Expand Up @@ -122,11 +169,11 @@ class ReturnState:
@dataclass
class ConditionalState:
"""
A conditional state, with a test expression and two branches.
A list of conditional states.
Each case consists of a test expression and a state.
"""

test: ast.expr
then: State
body: Sequence[UnresolvedCase]
orelse: State


Expand All @@ -139,25 +186,106 @@ class State:

node: UnresolvedState | ReturnState | ConditionalState

def translate_match(
self,
subj: ast.expr | Sequence[ast.expr] | ast.Tuple,
pattern: ast.pattern,
guard: ast.expr | None = None,
):
"""
Translate a match_case statement into a regular AST expression.
translate_match takes a subject, a pattern and a guard.
patterns can be a MatchValue, MatchAs, MatchOr, or MatchSequence.
subjects can be a single expression (e.g x or (2 * x + 1)) or a list of expressions.
translate_match is called per each case in a match statement.
"""

if isinstance(pattern, ast.MatchValue):
equality_ast = ast.Compare(
left=subj,
ops=[ast.Eq()],
comparators=[pattern.value],
)

if guard is not None:
return ast.BinOp(
left=guard,
op=ast.BitAnd(),
right=equality_ast,
)

return equality_ast
elif isinstance(pattern, ast.MatchAs):
if pattern.name is not None:
self.handle_assign(
ast.Assign(
targets=[ast.Name(id=pattern.name, ctx=ast.Store())],
value=subj,
)
)
return guard
elif isinstance(pattern, ast.MatchOr):
return ast.BinOp(
left=self.translate_match(subj, pattern.patterns[0], guard),
op=ast.BitOr(),
right=(
self.translate_match(subj, ast.MatchOr(patterns=pattern.patterns[1:]))
if pattern.patterns[2:]
else self.translate_match(subj, pattern.patterns[1])
),
)
elif isinstance(pattern, ast.MatchSequence):
if isinstance(pattern.patterns[-1], ast.MatchStar):
raise ValueError("starred patterns are not supported.")

if isinstance(subj, ast.Tuple):
# TODO: Use polars list operations in the future
left = self.translate_match(subj.elts[0], pattern.patterns[0], guard)
right = (
self.translate_match(
ast.Tuple(elts=subj.elts[1:]),
ast.MatchSequence(patterns=pattern.patterns[1:]),
)
if pattern.patterns[2:]
else self.translate_match(subj.elts[1], pattern.patterns[1])
)

return (
left or right
if left is None or right is None
else ast.BinOp(left=left, op=ast.BitAnd(), right=right)
)
raise ValueError("Matching lists is not supported.")
else:
raise ValueError(
f"Incompatible match and subject types: {type(pattern)} and {type(subj)}."
)

def handle_assign(self, expr: ast.Assign | ast.AnnAssign):
if isinstance(expr, ast.AnnAssign):
expr = ast.Assign(targets=[expr.target], value=expr.value)

if isinstance(self.node, UnresolvedState):
self.node.handle_assign(expr)
elif isinstance(self.node, ConditionalState):
self.node.then.handle_assign(expr)
for case in self.node.body:
case.state.handle_assign(expr)
self.node.orelse.handle_assign(expr)

def handle_if(self, stmt: ast.If):
if isinstance(self.node, UnresolvedState):
self.node = ConditionalState(
test=InlineTransformer.inline_expr(stmt.test, self.node.assignments),
then=parse_body(stmt.body, copy(self.node.assignments)),
body=[
UnresolvedCase(
InlineTransformer.inline_expr(stmt.test, self.node.assignments),
parse_body(stmt.body, copy(self.node.assignments)),
)
],
orelse=parse_body(stmt.orelse, copy(self.node.assignments)),
)
elif isinstance(self.node, ConditionalState):
self.node.then.handle_if(stmt)
for case in self.node.body:
case.state.handle_if(stmt)
self.node.orelse.handle_if(stmt)

def handle_return(self, value: ast.expr):
Expand All @@ -166,9 +294,58 @@ def handle_return(self, value: ast.expr):
expr=InlineTransformer.inline_expr(value, self.node.assignments)
)
elif isinstance(self.node, ConditionalState):
self.node.then.handle_return(value)
for case in self.node.body:
case.state.handle_return(value)
self.node.orelse.handle_return(value)

def handle_match(self, stmt: ast.Match):
def is_catch_all(case: ast.match_case) -> bool:
# We check if the case is a catch-all pattern without a guard
# If it has a guard, we treat it as a regular case
return (
isinstance(case.pattern, ast.MatchAs)
and case.pattern.name is None
and case.guard is None
)

def ignore_case(case: ast.match_case) -> bool:
# if the length of the pattern is not equal to the length of the subject, python ignores the case
return (
isinstance(case.pattern, ast.MatchSequence)
and isinstance(stmt.subject, ast.Tuple)
and len(stmt.subject.elts) != len(case.pattern.patterns)
) or (isinstance(case.pattern, ast.MatchValue) and isinstance(stmt.subject, ast.Tuple))

if isinstance(self.node, UnresolvedState):
# We can always rewrite catch-all patterns to orelse since python throws a SyntaxError if the catch-all pattern is not the last case.
orelse = next(
iter([case.body for case in stmt.cases if is_catch_all(case)]),
[],
)
self.node = ConditionalState(
body=[
UnresolvedCase(
# translate_match transforms the match statement case into regular AST expressions so that the InlineTransformer can handle assignments correctly
# Note that by the time parse_body is called this has mutated the assignments
InlineTransformer.inline_expr(
self.translate_match(stmt.subject, case.pattern, case.guard),
self.node.assignments,
),
parse_body(case.body, copy(self.node.assignments)),
)
for case in stmt.cases
if not is_catch_all(case) and not ignore_case(case)
],
orelse=parse_body(
orelse,
copy(self.node.assignments),
),
)
elif isinstance(self.node, ConditionalState):
for case in self.node.body:
case.state.handle_match(stmt)
self.node.orelse.handle_match(stmt)


def parse_body(full_body: list[ast.stmt], assignments: dict[str, ast.expr] | None = None) -> State:
if assignments is None:
Expand All @@ -182,9 +359,11 @@ def parse_body(full_body: list[ast.stmt], assignments: dict[str, ast.expr] | Non
elif isinstance(stmt, ast.Return):
if stmt.value is None:
raise ValueError("return needs a value")

state.handle_return(stmt.value)
break
elif isinstance(stmt, ast.Match):
assert not PY_39
state.handle_match(stmt)
else:
raise ValueError(f"Unsupported statement type: {type(stmt)}")
return state
Expand All @@ -194,9 +373,15 @@ def transform_tree_into_expr(node: State) -> ast.expr:
if isinstance(node.node, ReturnState):
return node.node.expr
elif isinstance(node.node, ConditionalState):
if not node.node.body:
# this happens if none of the cases will ever match or exist
# in these cases we just need to return the orelse body
return transform_tree_into_expr(node.node.orelse)
return build_polars_when_then_otherwise(
node.node.test,
transform_tree_into_expr(node.node.then),
[
ResolvedCase(case.test, transform_tree_into_expr(case.state))
for case in node.node.body
],
transform_tree_into_expr(node.node.orelse),
)
else:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "hatchling.build"
[project]
name = "polarify"
description = "Simplifying conditional Polars Expressions with Python 🐍 🐻‍❄️"
version = "0.1.5"
version = "0.2.0"
readme = "README.md"
license = {file = "LICENSE"}
requires-python = ">=3.9"
Expand Down
Loading

0 comments on commit 9aa94fa

Please sign in to comment.