From dc3db3b1046197451f7adc0000a545370145a9a8 Mon Sep 17 00:00:00 2001 From: Kyle Goodrick Date: Mon, 6 Jan 2025 14:21:35 -0700 Subject: [PATCH 1/2] Add polars LazyFrame extension for high-quality query plan graph output. --- marimo/__init__.py | 3 ++ marimo/_polars/lazyframe.py | 63 +++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) create mode 100644 marimo/_polars/lazyframe.py diff --git a/marimo/__init__.py b/marimo/__init__.py index ded4cb59fd6..c62eeffc8c5 100644 --- a/marimo/__init__.py +++ b/marimo/__init__.py @@ -116,6 +116,9 @@ from marimo._plugins.stateless.tabs import tabs from marimo._plugins.stateless.tree import tree from marimo._plugins.stateless.video import video +from marimo._polars import ( + lazyframe, # noqa: F401 # Import is required for lazyframe registration +) from marimo._runtime import output from marimo._runtime.capture import ( capture_stderr, diff --git a/marimo/_polars/lazyframe.py b/marimo/_polars/lazyframe.py new file mode 100644 index 00000000000..cb1ad000545 --- /dev/null +++ b/marimo/_polars/lazyframe.py @@ -0,0 +1,63 @@ +import re +from typing import TYPE_CHECKING, Any + +from marimo import mermaid +from marimo._dependencies.dependencies import DependencyManager +from marimo._output.hypertext import Html + +if TYPE_CHECKING: + import polars as pl + +if DependencyManager.polars.has(): + import polars as pl + + @pl.api.register_lazyframe_namespace("mo") + class Marimo: + def __init__(self, ldf: pl.LazyFrame): + self._ldf: pl.LazyFrame = ldf + + def show_graph(self, **kwargs: Any) -> Html | str: + # We are specifying raw_output already, so we need to remove if it is passed + raw_output = kwargs.pop("raw_output", False) + + dot = self._ldf.show_graph(raw_output=True, **kwargs) + + if raw_output: + return dot + + return self._dot_to_mermaid_html(dot) + + # _dot_to_mermaid_html is a separate method from show_graph so that in the future + # polars can be updated to output mermaid directly when calling native show_graph + # inside a marimo environment. + # In order to do that we need a function that will not recursively call show_graph + @classmethod + def _dot_to_mermaid_html(self, dot: str) -> Html: + return mermaid(self._polars_dot_to_mermaid(dot)) + + @staticmethod + def _parse_node_label(label: str) -> str: + # replace escaped newlines + label = label.replace(r"\n", "\n") + # replace escaped quotes + label = label.replace('\\"', "#quot;") + return label + + @classmethod + def _polars_dot_to_mermaid(cls, dot: str) -> str: + edge_regex = r"(?P\w+) -- (?P\w+)" + node_regex = r"(?P\w+)(\s+)?\[label=\"(?P