Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Polars extension #3356

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions marimo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@
from marimo._plugins.stateless.tabs import tabs
from marimo._plugins.stateless.tree import tree
from marimo._plugins.stateless.video import video
from marimo._polars import (
lazyframe, # noqa: F401 # Import is required for lazyframe registration
)
Comment on lines +119 to +121
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If registration happens on import of the polars module, then this can be removed.

https://github.com/marimo-team/marimo/blob/main/marimo/_output/formatters/df_formatters.py

As far as I can tell there's no need to add lazyframe to the marimo public API?

from marimo._runtime import output
from marimo._runtime.capture import (
capture_stderr,
Expand Down
63 changes: 63 additions & 0 deletions marimo/_polars/lazyframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import re
from typing import TYPE_CHECKING, Any

from marimo import mermaid
from marimo._dependencies.dependencies import DependencyManager
from marimo._output.hypertext import Html

if TYPE_CHECKING:
import polars as pl

if DependencyManager.polars.has():
import polars as pl

@pl.api.register_lazyframe_namespace("mo")
class Marimo:
def __init__(self, ldf: pl.LazyFrame):
self._ldf: pl.LazyFrame = ldf

def show_graph(self, **kwargs: Any) -> Html | str:
# We are specifying raw_output already, so we need to remove if it is passed
raw_output = kwargs.pop("raw_output", False)

dot = self._ldf.show_graph(raw_output=True, **kwargs)

if raw_output:
return dot

return self._dot_to_mermaid_html(dot)

# _dot_to_mermaid_html is a separate method from show_graph so that in the future
# polars can be updated to output mermaid directly when calling native show_graph
# inside a marimo environment.
# In order to do that we need a function that will not recursively call show_graph
@classmethod
def _dot_to_mermaid_html(self, dot: str) -> Html:
return mermaid(self._polars_dot_to_mermaid(dot))

@staticmethod
def _parse_node_label(label: str) -> str:
# replace escaped newlines
label = label.replace(r"\n", "\n")
# replace escaped quotes
label = label.replace('\\"', "#quot;")
return label

@classmethod
def _polars_dot_to_mermaid(cls, dot: str) -> str:
edge_regex = r"(?P<node1>\w+) -- (?P<node2>\w+)"
node_regex = r"(?P<node>\w+)(\s+)?\[label=\"(?P<label>.*)\"]"

nodes = [n for n in re.finditer(node_regex, dot)]
edges = [e for e in re.finditer(edge_regex, dot)]

return "\n".join(
[
"graph TD",
*[
f'\t{n["node"]}["{cls._parse_node_label(n["label"])}"]'
for n in nodes
],
*[f'\t{e["node1"]} --- {e["node2"]}' for e in edges],
]
)
69 changes: 69 additions & 0 deletions tests/_polars/test_polars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest

from marimo._dependencies.dependencies import DependencyManager
from marimo._output.hypertext import Html

HAS_DEPS = DependencyManager.polars.has()


@pytest.fixture
def simple_lf():
import polars as pl

return (
pl.LazyFrame(
{
"a": [1, 2, 3],
"b": [4, 5, 6],
}
)
.filter(pl.col("a") > 1)
.group_by("a")
.agg(pl.col("b").sum())
)


@pytest.mark.skipif(not HAS_DEPS, reason="polars is required")
def test_show_graph(simple_lf):
lf = simple_lf

assert type(lf.mo.show_graph()) is Html
assert type(lf.mo.show_graph(raw_output=True)) is str


@pytest.mark.skipif(not HAS_DEPS, reason="polars is required")
def test_dot_to_html(simple_lf):
lf = simple_lf

dot = lf.show_graph(raw_output=True)

assert type(lf.mo._dot_to_mermaid_html(dot)) is Html


@pytest.mark.skipif(not HAS_DEPS, reason="polars is required")
def test_parse_node_label(simple_lf):
lf = simple_lf

assert lf.mo._parse_node_label(r"\"") == "#quot;"
assert lf.mo._parse_node_label(r"\n") == "\n"
assert lf.mo._parse_node_label(r"\"\n\"") == "#quot;\n#quot;"


@pytest.mark.skipif(not HAS_DEPS, reason="polars is required")
def test_dot_to_mermaid(simple_lf):
lf = simple_lf
dot = lf.show_graph(raw_output=True)

mermaid_str = lf.mo._polars_dot_to_mermaid(dot)

assert type(mermaid_str) is str
assert mermaid_str == (
"""graph TD
p2["TABLE
π 2/2;
σ [(col(#quot;a#quot;)) > (1)]"]
p1["AGG [col(#quot;b#quot;).sum()]
BY
[col(#quot;a#quot;)]"]
p1 --- p2"""
)
Loading