From 4e3abe4616aa5c7023f257fe2f947b193cc7f61d Mon Sep 17 00:00:00 2001 From: Thomas Petit-Jean Date: Tue, 21 Jan 2025 11:52:36 +0100 Subject: [PATCH] feat: accept pd.DataFrame in session.createDataFrame --- sqlframe/base/session.py | 28 ++++++++++++++++++------- tests/integration/test_int_dataframe.py | 13 ++++++++++++ 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/sqlframe/base/session.py b/sqlframe/base/session.py index f2ddee8..340b632 100644 --- a/sqlframe/base/session.py +++ b/sqlframe/base/session.py @@ -2,6 +2,7 @@ from __future__ import annotations +import contextlib import datetime import logging import sys @@ -202,13 +203,16 @@ def range(self, *args): def createDataFrame( self, - data: t.Sequence[ - t.Union[ - t.Dict[str, ColumnLiterals], - t.List[ColumnLiterals], - t.Tuple[ColumnLiterals, ...], - ColumnLiterals, - ] + data: t.Union[ + t.Sequence[ + t.Union[ + t.Dict[str, ColumnLiterals], + t.List[ColumnLiterals], + t.Tuple[ColumnLiterals, ...], + ColumnLiterals, + ], + ], + pd.DataFrame, ], schema: t.Optional[SchemaInput] = None, samplingRatio: t.Optional[float] = None, @@ -229,11 +233,18 @@ def createDataFrame( ): raise NotImplementedError("Only schema of either list or string of list supported") + with contextlib.suppress(ImportError): + from pandas import DataFrame as pd_DataFrame + + if isinstance(data, pd_DataFrame): + data = data.to_dict("records") # type: ignore + column_mapping: t.Mapping[str, t.Optional[exp.DataType]] if schema is not None: column_mapping = get_column_mapping_from_schema_input( schema, dialect=self.input_dialect ) + elif data: if isinstance(data[0], Row): column_mapping = {col_name.strip(): None for col_name in data[0].__fields__} @@ -375,7 +386,8 @@ def sql( dialect = Dialect.get_or_raise(dialect or self.input_dialect) expression = ( sqlglot.parse_one( - normalize_string(sqlQuery, from_dialect=dialect, is_query=True), read=dialect + normalize_string(sqlQuery, from_dialect=dialect, is_query=True), + read=dialect, ) if isinstance(sqlQuery, str) else sqlQuery diff --git a/tests/integration/test_int_dataframe.py b/tests/integration/test_int_dataframe.py index 89d0144..5e24e41 100644 --- a/tests/integration/test_int_dataframe.py +++ b/tests/integration/test_int_dataframe.py @@ -2,6 +2,7 @@ import typing as t +import pandas as pd import pytest from _pytest.fixtures import FixtureRequest from pyspark.sql import DataFrame as PySparkDataFrame @@ -29,6 +30,18 @@ def test_empty_df( compare_frames(df_empty, dfs_empty, no_empty=False) +def test_dataframe_from_pandas( + pyspark_employee: PySparkDataFrame, + compare_frames: t.Callable, +): + compare_frames( + pyspark_employee, + pyspark_employee.sparkSession.createDataFrame( + pyspark_employee.toPandas(), + ), + ) + + def test_simple_select( pyspark_employee: PySparkDataFrame, get_df: t.Callable[[str], BaseDataFrame],