Skip to content

Commit

Permalink
feat!: make base dataframe class public (#233)
Browse files Browse the repository at this point in the history
  • Loading branch information
eakmanrq authored Dec 27, 2024
1 parent 4a9d132 commit 6137f55
Show file tree
Hide file tree
Showing 20 changed files with 159 additions and 161 deletions.
4 changes: 2 additions & 2 deletions sqlframe/base/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
}


DF = t.TypeVar("DF", bound="_BaseDataFrame")
DF = t.TypeVar("DF", bound="BaseDataFrame")


class OpenAIMode(enum.Enum):
Expand Down Expand Up @@ -198,7 +198,7 @@ def cov(self, col1: str, col2: str) -> float:
STAT = t.TypeVar("STAT", bound=_BaseDataFrameStatFunctions)


class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
_na: t.Type[NA]
_stat: t.Type[STAT]
_group_data: t.Type[GROUP_DATA]
Expand Down
2 changes: 1 addition & 1 deletion sqlframe/base/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def wrapper(*args, **kwargs):
col_name = col_name.this
alias_name = f"{func.__name__}__{col_name or ''}__"
# BigQuery has restrictions on alias names so we constrain it to alphanumeric characters and underscores
return result.alias(re.sub("\W", "_", alias_name))
return result.alias(re.sub("\W", "_", alias_name)) # type: ignore
return result

wrapper.unsupported_engines = ( # type: ignore
Expand Down
6 changes: 3 additions & 3 deletions sqlframe/base/mixins/dataframe_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
SESSION,
STAT,
WRITER,
_BaseDataFrame,
BaseDataFrame,
)

if sys.version_info >= (3, 11):
Expand All @@ -23,7 +23,7 @@
logger = logging.getLogger(__name__)


class NoCachePersistSupportMixin(_BaseDataFrame, t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
class NoCachePersistSupportMixin(BaseDataFrame, t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
def cache(self) -> Self:
logger.warning("This engine does not support caching. Ignoring cache() call.")
return self
Expand All @@ -34,7 +34,7 @@ def persist(self) -> Self:


class TypedColumnsFromTempViewMixin(
_BaseDataFrame, t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]
BaseDataFrame, t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]
):
@property
def _typed_columns(self) -> t.List[Column]:
Expand Down
10 changes: 5 additions & 5 deletions sqlframe/base/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from enum import IntEnum

if t.TYPE_CHECKING:
from sqlframe.base.dataframe import _BaseDataFrame
from sqlframe.base.dataframe import BaseDataFrame
from sqlframe.base.group import _BaseGroupedData


Expand Down Expand Up @@ -37,15 +37,15 @@ def operation(op: Operation) -> t.Callable[[t.Callable], t.Callable]:

def decorator(func: t.Callable) -> t.Callable:
@functools.wraps(func)
def wrapper(self: _BaseDataFrame, *args, **kwargs) -> _BaseDataFrame:
def wrapper(self: BaseDataFrame, *args, **kwargs) -> BaseDataFrame:
if self.last_op == Operation.INIT:
self = self._convert_leaf_to_cte()
self.last_op = Operation.NO_OP
last_op = self.last_op
new_op = op if op != Operation.NO_OP else last_op
if new_op < last_op or (last_op == new_op == Operation.SELECT):
self = self._convert_leaf_to_cte()
df: t.Union[_BaseDataFrame, _BaseGroupedData] = func(self, *args, **kwargs)
df: t.Union[BaseDataFrame, _BaseGroupedData] = func(self, *args, **kwargs)
df.last_op = new_op # type: ignore
return df # type: ignore

Expand All @@ -69,15 +69,15 @@ def group_operation(op: Operation) -> t.Callable[[t.Callable], t.Callable]:

def decorator(func: t.Callable) -> t.Callable:
@functools.wraps(func)
def wrapper(self: _BaseGroupedData, *args, **kwargs) -> _BaseDataFrame:
def wrapper(self: _BaseGroupedData, *args, **kwargs) -> BaseDataFrame:
if self._df.last_op == Operation.INIT:
self._df = self._df._convert_leaf_to_cte()
self._df.last_op = Operation.NO_OP
last_op = self._df.last_op
new_op = op if op != Operation.NO_OP else last_op
if new_op < last_op or (last_op == new_op == Operation.SELECT):
self._df = self._df._convert_leaf_to_cte()
df: _BaseDataFrame = func(self, *args, **kwargs)
df: BaseDataFrame = func(self, *args, **kwargs)
df.last_op = new_op # type: ignore
return df

Expand Down
4 changes: 2 additions & 2 deletions sqlframe/base/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from sqlglot.schema import MappingSchema

from sqlframe.base.catalog import _BaseCatalog
from sqlframe.base.dataframe import _BaseDataFrame
from sqlframe.base.dataframe import BaseDataFrame
from sqlframe.base.normalize import normalize_dict
from sqlframe.base.readerwriter import _BaseDataFrameReader, _BaseDataFrameWriter
from sqlframe.base.udf import _BaseUDFRegistration
Expand Down Expand Up @@ -64,7 +64,7 @@ def fetchdf(self) -> pd.DataFrame: ...
CATALOG = t.TypeVar("CATALOG", bound=_BaseCatalog)
READER = t.TypeVar("READER", bound=_BaseDataFrameReader)
WRITER = t.TypeVar("WRITER", bound=_BaseDataFrameWriter)
DF = t.TypeVar("DF", bound=_BaseDataFrame)
DF = t.TypeVar("DF", bound=BaseDataFrame)
UDF_REGISTRATION = t.TypeVar("UDF_REGISTRATION", bound=_BaseUDFRegistration)

_MISSING = "MISSING"
Expand Down
4 changes: 2 additions & 2 deletions sqlframe/bigquery/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from sqlframe.base.catalog import Column as CatalogColumn
from sqlframe.base.dataframe import (
_BaseDataFrame,
BaseDataFrame,
_BaseDataFrameNaFunctions,
_BaseDataFrameStatFunctions,
)
Expand All @@ -30,7 +30,7 @@ class BigQueryDataFrameStatFunctions(_BaseDataFrameStatFunctions["BigQueryDataFr

class BigQueryDataFrame(
NoCachePersistSupportMixin,
_BaseDataFrame[
BaseDataFrame[
"BigQuerySession",
"BigQueryDataFrameWriter",
"BigQueryDataFrameNaFunctions",
Expand Down
4 changes: 2 additions & 2 deletions sqlframe/databricks/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from sqlframe.base.catalog import Column as CatalogColumn
from sqlframe.base.dataframe import (
_BaseDataFrame,
BaseDataFrame,
_BaseDataFrameNaFunctions,
_BaseDataFrameStatFunctions,
)
Expand All @@ -31,7 +31,7 @@ class DatabricksDataFrameStatFunctions(_BaseDataFrameStatFunctions["DatabricksDa

class DatabricksDataFrame(
NoCachePersistSupportMixin,
_BaseDataFrame[
BaseDataFrame[
"DatabricksSession",
"DatabricksDataFrameWriter",
"DatabricksDataFrameNaFunctions",
Expand Down
4 changes: 2 additions & 2 deletions sqlframe/duckdb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import typing as t

from sqlframe.base.dataframe import (
_BaseDataFrame,
BaseDataFrame,
_BaseDataFrameNaFunctions,
_BaseDataFrameStatFunctions,
)
Expand Down Expand Up @@ -34,7 +34,7 @@ class DuckDBDataFrameStatFunctions(_BaseDataFrameStatFunctions["DuckDBDataFrame"
class DuckDBDataFrame(
NoCachePersistSupportMixin,
TypedColumnsFromTempViewMixin,
_BaseDataFrame[
BaseDataFrame[
"DuckDBSession",
"DuckDBDataFrameWriter",
"DuckDBDataFrameNaFunctions",
Expand Down
4 changes: 2 additions & 2 deletions sqlframe/postgres/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import typing as t

from sqlframe.base.dataframe import (
_BaseDataFrame,
BaseDataFrame,
_BaseDataFrameNaFunctions,
_BaseDataFrameStatFunctions,
)
Expand Down Expand Up @@ -39,7 +39,7 @@ class PostgresDataFrameStatFunctions(_BaseDataFrameStatFunctions["PostgresDataFr
class PostgresDataFrame(
NoCachePersistSupportMixin,
TypedColumnsFromTempViewMixin,
_BaseDataFrame[
BaseDataFrame[
"PostgresSession",
"PostgresDataFrameWriter",
"PostgresDataFrameNaFunctions",
Expand Down
4 changes: 2 additions & 2 deletions sqlframe/redshift/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import typing as t

from sqlframe.base.dataframe import (
_BaseDataFrame,
BaseDataFrame,
_BaseDataFrameNaFunctions,
_BaseDataFrameStatFunctions,
)
Expand All @@ -30,7 +30,7 @@ class RedshiftDataFrameStatFunctions(_BaseDataFrameStatFunctions["RedshiftDataFr

class RedshiftDataFrame(
NoCachePersistSupportMixin,
_BaseDataFrame[
BaseDataFrame[
"RedshiftSession",
"RedshiftDataFrameWriter",
"RedshiftDataFrameNaFunctions",
Expand Down
4 changes: 2 additions & 2 deletions sqlframe/snowflake/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from sqlframe.base.catalog import Column as CatalogColumn
from sqlframe.base.dataframe import (
_BaseDataFrame,
BaseDataFrame,
_BaseDataFrameNaFunctions,
_BaseDataFrameStatFunctions,
)
Expand All @@ -32,7 +32,7 @@ class SnowflakeDataFrameStatFunctions(_BaseDataFrameStatFunctions["SnowflakeData

class SnowflakeDataFrame(
NoCachePersistSupportMixin,
_BaseDataFrame[
BaseDataFrame[
"SnowflakeSession",
"SnowflakeDataFrameWriter",
"SnowflakeDataFrameNaFunctions",
Expand Down
4 changes: 2 additions & 2 deletions sqlframe/spark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from sqlframe.base.catalog import Column
from sqlframe.base.dataframe import (
_BaseDataFrame,
BaseDataFrame,
_BaseDataFrameNaFunctions,
_BaseDataFrameStatFunctions,
)
Expand All @@ -31,7 +31,7 @@ class SparkDataFrameStatFunctions(_BaseDataFrameStatFunctions["SparkDataFrame"])

class SparkDataFrame(
NoCachePersistSupportMixin,
_BaseDataFrame[
BaseDataFrame[
"SparkSession",
"SparkDataFrameWriter",
"SparkDataFrameNaFunctions",
Expand Down
4 changes: 2 additions & 2 deletions sqlframe/standalone/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import typing as t

from sqlframe.base.dataframe import (
_BaseDataFrame,
BaseDataFrame,
_BaseDataFrameNaFunctions,
_BaseDataFrameStatFunctions,
)
Expand All @@ -23,7 +23,7 @@ class StandaloneDataFrameStatFunctions(_BaseDataFrameStatFunctions["StandaloneDa


class StandaloneDataFrame(
_BaseDataFrame[
BaseDataFrame[
"StandaloneSession",
"StandaloneDataFrameWriter",
"StandaloneDataFrameNaFunctions",
Expand Down
6 changes: 3 additions & 3 deletions sqlframe/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from itertools import zip_longest

from sqlframe.base import types
from sqlframe.base.dataframe import _BaseDataFrame
from sqlframe.base.dataframe import BaseDataFrame
from sqlframe.base.exceptions import (
DataFrameDiffError,
SchemaDiffError,
Expand Down Expand Up @@ -64,8 +64,8 @@ def red(s: str) -> str:

# Source: https://github.com/apache/spark/blob/master/python/pyspark/testing/utils.py#L519
def assertDataFrameEqual(
actual: t.Union[_BaseDataFrame, pd.DataFrame, t.List[types.Row]],
expected: t.Union[_BaseDataFrame, pd.DataFrame, t.List[types.Row]],
actual: t.Union[BaseDataFrame, pd.DataFrame, t.List[types.Row]],
expected: t.Union[BaseDataFrame, pd.DataFrame, t.List[types.Row]],
checkRowOrder: bool = False,
rtol: float = 1e-5,
atol: float = 1e-8,
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/engines/test_engine_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from sqlframe.base.types import Row

if t.TYPE_CHECKING:
from sqlframe.base.dataframe import _BaseDataFrame
from sqlframe.base.dataframe import BaseDataFrame

pytest_plugins = ["tests.integration.fixtures"]


def test_collect(get_engine_df: t.Callable[[str], _BaseDataFrame], get_func):
def test_collect(get_engine_df: t.Callable[[str], BaseDataFrame], get_func):
employee = get_engine_df("employee")
col = get_func("col", employee.session)
results = employee.select(col("fname"), col("lname")).collect()
Expand All @@ -24,7 +24,7 @@ def test_collect(get_engine_df: t.Callable[[str], _BaseDataFrame], get_func):


def test_show(
get_engine_df: t.Callable[[str], _BaseDataFrame],
get_engine_df: t.Callable[[str], BaseDataFrame],
get_func,
capsys,
caplog,
Expand Down Expand Up @@ -53,7 +53,7 @@ def test_show(


def test_show_limit(
get_engine_df: t.Callable[[str], _BaseDataFrame], capsys, is_snowflake: t.Callable
get_engine_df: t.Callable[[str], BaseDataFrame], capsys, is_snowflake: t.Callable
):
employee = get_engine_df("employee")
employee.show(1)
Expand Down
Loading

0 comments on commit 6137f55

Please sign in to comment.