Skip to content

Commit

Permalink
feat: accept pd.DataFrame in session.createDataFrame
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomzoy committed Jan 21, 2025
1 parent cf6d67f commit 4e3abe4
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 8 deletions.
28 changes: 20 additions & 8 deletions sqlframe/base/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import contextlib
import datetime
import logging
import sys
Expand Down Expand Up @@ -202,13 +203,16 @@ def range(self, *args):

def createDataFrame(
self,
data: t.Sequence[
t.Union[
t.Dict[str, ColumnLiterals],
t.List[ColumnLiterals],
t.Tuple[ColumnLiterals, ...],
ColumnLiterals,
]
data: t.Union[
t.Sequence[
t.Union[
t.Dict[str, ColumnLiterals],
t.List[ColumnLiterals],
t.Tuple[ColumnLiterals, ...],
ColumnLiterals,
],
],
pd.DataFrame,
],
schema: t.Optional[SchemaInput] = None,
samplingRatio: t.Optional[float] = None,
Expand All @@ -229,11 +233,18 @@ def createDataFrame(
):
raise NotImplementedError("Only schema of either list or string of list supported")

with contextlib.suppress(ImportError):
from pandas import DataFrame as pd_DataFrame

if isinstance(data, pd_DataFrame):
data = data.to_dict("records") # type: ignore

column_mapping: t.Mapping[str, t.Optional[exp.DataType]]
if schema is not None:
column_mapping = get_column_mapping_from_schema_input(
schema, dialect=self.input_dialect
)

elif data:
if isinstance(data[0], Row):
column_mapping = {col_name.strip(): None for col_name in data[0].__fields__}
Expand Down Expand Up @@ -375,7 +386,8 @@ def sql(
dialect = Dialect.get_or_raise(dialect or self.input_dialect)
expression = (
sqlglot.parse_one(
normalize_string(sqlQuery, from_dialect=dialect, is_query=True), read=dialect
normalize_string(sqlQuery, from_dialect=dialect, is_query=True),
read=dialect,
)
if isinstance(sqlQuery, str)
else sqlQuery
Expand Down
13 changes: 13 additions & 0 deletions tests/integration/test_int_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import typing as t

import pandas as pd
import pytest
from _pytest.fixtures import FixtureRequest
from pyspark.sql import DataFrame as PySparkDataFrame
Expand Down Expand Up @@ -29,6 +30,18 @@ def test_empty_df(
compare_frames(df_empty, dfs_empty, no_empty=False)


def test_dataframe_from_pandas(
pyspark_employee: PySparkDataFrame,
compare_frames: t.Callable,
):
compare_frames(
pyspark_employee,
pyspark_employee.sparkSession.createDataFrame(
pyspark_employee.toPandas(),
),
)


def test_simple_select(
pyspark_employee: PySparkDataFrame,
get_df: t.Callable[[str], BaseDataFrame],
Expand Down

0 comments on commit 4e3abe4

Please sign in to comment.