Skip to content

Commit

Permalink
fix: properly implement and test explain (#246)
Browse files Browse the repository at this point in the history
  • Loading branch information
eakmanrq authored Jan 17, 2025
1 parent b7e15bc commit cf6d67f
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 5 deletions.
20 changes: 15 additions & 5 deletions sqlframe/base/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
_na: t.Type[NA]
_stat: t.Type[STAT]
_group_data: t.Type[GROUP_DATA]
_EXPLAIN_PREFIX = "EXPLAIN"

def __init__(
self,
Expand Down Expand Up @@ -1144,6 +1145,18 @@ def dropna(
final_df = filtered_df.select(*all_columns)
return final_df

def _get_explain_plan_rows(self) -> t.List[Row]:
sql_queries = self.sql(
pretty=False, optimize=False, as_list=True, dialect=self.session.execution_dialect
)
if len(sql_queries) > 1:
raise ValueError("Cannot explain a DataFrame with multiple queries")
sql_query = " ".join([self._EXPLAIN_PREFIX, sql_queries[0]])
results = self.session._collect(sql_query)
if len(results) != 1:
raise ValueError("Got more than one result from explain query")
return results

def explain(
self, extended: t.Optional[t.Union[bool, str]] = None, mode: t.Optional[str] = None
) -> None:
Expand Down Expand Up @@ -1212,11 +1225,8 @@ def explain(
...Statistics...
...
"""
sql_queries = self.sql(pretty=False, optimize=False, as_list=True)
if len(sql_queries) > 1:
raise ValueError("Cannot explain a DataFrame with multiple queries")
sql_query = "EXPLAIN " + sql_queries[0]
self.session._execute(sql_query)
results = self._get_explain_plan_rows()
print(results[0][0])

@operation(Operation.FROM)
def fillna(
Expand Down
5 changes: 5 additions & 0 deletions sqlframe/bigquery/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,8 @@ def field_to_column(field: bigquery.SchemaField) -> CatalogColumn:
sql = self.session._to_sql(self.expression)
query_job = self.session._client.query(sql, job_config=job_config)
return [field_to_column(field) for field in query_job.schema]

def explain(
self, extended: t.Optional[t.Union[bool, str]] = None, mode: t.Optional[str] = None
) -> None:
raise NotImplementedError("BigQuery does not support EXPLAIN")
6 changes: 6 additions & 0 deletions sqlframe/duckdb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ class DuckDBDataFrame(
_stat = DuckDBDataFrameStatFunctions
_group_data = DuckDBGroupedData

def explain(
self, extended: t.Optional[t.Union[bool, str]] = None, mode: t.Optional[str] = None
) -> None:
results = self._get_explain_plan_rows()
print(results[0][1])

@t.overload
def toArrow(self) -> ArrowTable: ...

Expand Down
1 change: 1 addition & 0 deletions sqlframe/snowflake/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class SnowflakeDataFrame(
_na = SnowflakeDataFrameNaFunctions
_stat = SnowflakeDataFrameStatFunctions
_group_data = SnowflakeGroupedData
_EXPLAIN_PREFIX = "EXPLAIN USING TEXT"

@property
def _typed_columns(self) -> t.List[CatalogColumn]:
Expand Down
5 changes: 5 additions & 0 deletions tests/integration/engines/bigquery/test_bigquery_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,8 @@ def test_schema_nested(bigquery_datatypes: BigQueryDataFrame):
assert struct_fields[8].dataType == types.TimestampType()
assert struct_fields[9].name == "boolean_col"
assert struct_fields[9].dataType == types.BooleanType()


def test_explain(bigquery_employee: BigQueryDataFrame):
with pytest.raises(NotImplementedError):
bigquery_employee.explain()
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,11 @@ def test_schema_nested(databricks_datatypes: DatabricksDataFrame):
assert struct_fields[9].dataType == types.TimestampType()
assert struct_fields[10].name == "boolean_col"
assert struct_fields[10].dataType == types.BooleanType()


def test_explain(databricks_employee: DatabricksDataFrame, capsys):
databricks_employee.explain()
output = capsys.readouterr().out.strip()
assert "== Physical Plan ==" in output
assert "LocalTableScan" in output
assert "== Photon Explanation ==" in output
14 changes: 14 additions & 0 deletions tests/integration/engines/duck/test_duckdb_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,17 @@ def test_to_arrow_batch(duckdb_employee: DuckDBDataFrame):
assert fifth_batch.column(4).to_pylist() == [100]
with pytest.raises(StopIteration):
record_batch_reader.read_next_batch()


def test_explain(duckdb_employee: DuckDBDataFrame, capsys):
duckdb_employee.explain()
assert (
capsys.readouterr().out.strip()
== """
┌───────────────────────────┐
│ COLUMN_DATA_SCAN │
│ ──────────────────── │
│ ~5 Rows │
└───────────────────────────┘
""".strip()
)
8 changes: 8 additions & 0 deletions tests/integration/engines/postgres/test_postgres_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,11 @@ def test_schema_nested(postgres_datatypes: PostgresDataFrame):
assert struct_fields[6].dataType == types.TimestampType()
assert struct_fields[7].name == "boolean_col"
assert struct_fields[7].dataType == types.BooleanType()


def test_explain(postgres_employee: PostgresDataFrame, capsys):
postgres_employee.explain()
assert (
capsys.readouterr().out.strip()
== """Values Scan on "*VALUES*" (cost=0.00..0.06 rows=5 width=76)""".strip()
)
16 changes: 16 additions & 0 deletions tests/integration/engines/snowflake/test_snowflake_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,19 @@ def test_schema_nested(snowflake_datatypes: SnowflakeDataFrame):
assert struct_fields[9].dataType == types.TimestampType()
assert struct_fields[10].name == "boolean_col"
assert struct_fields[10].dataType == types.BooleanType()


def test_explain(snowflake_employee: SnowflakeDataFrame, capsys):
snowflake_employee.explain()
assert (
capsys.readouterr().out.strip()
== """
GlobalStats:
partitionsTotal=0
partitionsAssigned=0
bytesAssigned=0
Operations:
1:0 ->Result A1.EMPLOYEE_ID, A1.FNAME, A1.LNAME, A1.AGE, A1.STORE_ID
1:1 ->ValuesClause (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)
""".strip()
)
7 changes: 7 additions & 0 deletions tests/integration/engines/spark/test_spark_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,10 @@ def test_schema_nested(spark_datatypes: SparkDataFrame):
assert struct_fields[9].dataType == types.TimestampType()
assert struct_fields[10].name == "boolean_col"
assert struct_fields[10].dataType == types.BooleanType()


def test_explain(spark_employee: SparkDataFrame, capsys):
spark_employee.explain()
output = capsys.readouterr().out.strip()
assert "== Physical Plan ==" in output
assert "LocalTableScan" in output

0 comments on commit cf6d67f

Please sign in to comment.