From 8483236c7daa8909014e4a7354bdc5dedf20cd12 Mon Sep 17 00:00:00 2001 From: Ryan Eakman <6326532+eakmanrq@users.noreply.github.com> Date: Fri, 17 May 2024 20:41:07 -0700 Subject: [PATCH] add action workflow (#1) --- .github/workflows/main.workflow.yaml | 32 +++++++++++++++++++ Makefile | 2 +- setup.py | 17 +++++++--- sqlframe/base/dataframe.py | 4 +-- sqlframe/base/decorators.py | 12 +++---- sqlframe/base/mixins/readwriter_mixins.py | 3 ++ sqlframe/base/operations.py | 14 ++++---- sqlframe/base/session.py | 6 ++-- sqlframe/base/transforms.py | 4 ++- sqlframe/base/util.py | 7 ++-- sqlframe/bigquery/catalog.py | 5 +-- sqlframe/bigquery/session.py | 5 +-- sqlframe/duckdb/readwriter.py | 2 +- tests/common_fixtures.py | 23 +++++++------ tests/conftest.py | 12 +++++++ .../engines/{duckdb => duck}/__init__.py | 0 .../{duckdb => duck}/test_duckdb_catalog.py | 0 .../{duckdb => duck}/test_duckdb_session.py | 0 .../engines/test_engine_session.py | 4 +-- 19 files changed, 107 insertions(+), 45 deletions(-) create mode 100644 .github/workflows/main.workflow.yaml rename tests/integration/engines/{duckdb => duck}/__init__.py (100%) rename tests/integration/engines/{duckdb => duck}/test_duckdb_catalog.py (100%) rename tests/integration/engines/{duckdb => duck}/test_duckdb_session.py (100%) diff --git a/.github/workflows/main.workflow.yaml b/.github/workflows/main.workflow.yaml new file mode 100644 index 0000000..2d92d96 --- /dev/null +++ b/.github/workflows/main.workflow.yaml @@ -0,0 +1,32 @@ +name: SQLFrame +on: + push: + branches: + - main + pull_request: + types: + - synchronize + - opened +jobs: + run-tests: + runs-on: ubuntu-latest + env: + PYTEST_XDIST_AUTO_NUM_WORKERS: 4 + strategy: + matrix: + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] + steps: + - name: Checkout + uses: actions/checkout@v2 + - name: Install Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: make install-dev + - name: Run Style + run: make style + - name: Setup Postgres + uses: ikalnytskyi/action-setup-postgres@v6 + - name: Run tests + run: make local-test diff --git a/Makefile b/Makefile index f564fa2..76059ad 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ install-dev: - pip install -e ".[dev]" + pip install -e ".[dev,duckdb,postgres,redshift,bigquery,snowflake,spark]" install-pre-commit: pre-commit install diff --git a/setup.py b/setup.py index 8ff52ca..f46bf70 100644 --- a/setup.py +++ b/setup.py @@ -26,16 +26,13 @@ "bigquery": [ "google-cloud-bigquery[pandas]", "google-cloud-bigquery-storage", + "pandas", ], "dev": [ "duckdb", - "mkdocs==1.4.2", - "mkdocs-include-markdown-plugin==4.0.3", - "mkdocs-material==9.0.5", - "mkdocs-material-extensions==1.1.1", "mypy", "pandas", - "pymdown-extensions", + "pandas-stubs", "psycopg", "pyarrow", "pyspark", @@ -47,17 +44,27 @@ "typing_extensions", "types-psycopg2", ], + "docs": [ + "mkdocs==1.4.2", + "mkdocs-include-markdown-plugin==4.0.3", + "mkdocs-material==9.0.5", + "mkdocs-material-extensions==1.1.1", + "pymdown-extensions", + ], "duckdb": [ "duckdb", "pandas", ], "postgres": [ + "pandas", "psycopg2", ], "redshift": [ + "pandas", "redshift_connector", ], "snowflake": [ + "pandas", "snowflake-connector-python[pandas,secure-local-storage]", ], "spark": [ diff --git a/sqlframe/base/dataframe.py b/sqlframe/base/dataframe.py index 5a822ae..e2384e7 100644 --- a/sqlframe/base/dataframe.py +++ b/sqlframe/base/dataframe.py @@ -662,7 +662,7 @@ def crossJoin(self, other: DF) -> Self: | 16| Bob| 85| +---+-----+------+ """ - return self.join.__wrapped__(self, other, how="cross") + return self.join.__wrapped__(self, other, how="cross") # type: ignore @operation(Operation.FROM) def join( @@ -769,7 +769,7 @@ def join( new_df = self.copy(expression=join_expression) new_df.pending_join_hints.extend(self.pending_join_hints) new_df.pending_hints.extend(other_df.pending_hints) - new_df = new_df.select.__wrapped__(new_df, *select_column_names) + new_df = new_df.select.__wrapped__(new_df, *select_column_names) # type: ignore return new_df @operation(Operation.ORDER_BY) diff --git a/sqlframe/base/decorators.py b/sqlframe/base/decorators.py index 2f69bd7..25e144b 100644 --- a/sqlframe/base/decorators.py +++ b/sqlframe/base/decorators.py @@ -11,7 +11,7 @@ from sqlframe.base.catalog import _BaseCatalog -def normalize(normalize_kwargs: t.List[str]): +def normalize(normalize_kwargs: t.List[str]) -> t.Callable[[t.Callable], t.Callable]: """ Decorator used around DataFrame methods to indicate what type of operation is being performed from the ordered Operation enums. This is used to determine which operations should be performed on a CTE vs. @@ -23,9 +23,9 @@ def normalize(normalize_kwargs: t.List[str]): in cases where there is overlap in names. """ - def decorator(func: t.Callable): + def decorator(func: t.Callable) -> t.Callable: @functools.wraps(func) - def wrapper(self: _BaseCatalog, *args, **kwargs): + def wrapper(self: _BaseCatalog, *args, **kwargs) -> _BaseCatalog: kwargs.update(dict(zip(func.__code__.co_varnames[1:], args))) for kwarg in normalize_kwargs: if kwarg in kwargs: @@ -43,9 +43,9 @@ def wrapper(self: _BaseCatalog, *args, **kwargs): return decorator -def func_metadata(unsupported_engines: t.Optional[t.Union[str, t.List[str]]] = None): - def _metadata(func): - func.unsupported_engines = ensure_list(unsupported_engines) if unsupported_engines else [] +def func_metadata(unsupported_engines: t.Optional[t.Union[str, t.List[str]]] = None) -> t.Callable: + def _metadata(func: t.Callable) -> t.Callable: + func.unsupported_engines = ensure_list(unsupported_engines) if unsupported_engines else [] # type: ignore return func return _metadata diff --git a/sqlframe/base/mixins/readwriter_mixins.py b/sqlframe/base/mixins/readwriter_mixins.py index ff8803b..41711e3 100644 --- a/sqlframe/base/mixins/readwriter_mixins.py +++ b/sqlframe/base/mixins/readwriter_mixins.py @@ -108,6 +108,9 @@ def _write(self, path: str, mode: t.Optional[str], format: str, **options): # t raise NotImplementedError("Append mode is not supported for parquet.") pandas_df.to_parquet(path, **kwargs) elif format == "json": + # Pandas versions are inconsistent on how to handle True/False index so we just remove it + # since in all versions it will not result in an index column in the output. + del kwargs["index"] kwargs["mode"] = mode kwargs["orient"] = "records" pandas_df.to_json(path, lines=True, **kwargs) diff --git a/sqlframe/base/operations.py b/sqlframe/base/operations.py index 89f9bb0..ffc035f 100644 --- a/sqlframe/base/operations.py +++ b/sqlframe/base/operations.py @@ -23,7 +23,7 @@ class Operation(IntEnum): LIMIT = 7 -def operation(op: Operation): +def operation(op: Operation) -> t.Callable[[t.Callable], t.Callable]: """ Decorator used around DataFrame methods to indicate what type of operation is being performed from the ordered Operation enums. This is used to determine which operations should be performed on a CTE vs. @@ -35,9 +35,9 @@ def operation(op: Operation): in cases where there is overlap in names. """ - def decorator(func: t.Callable): + def decorator(func: t.Callable) -> t.Callable: @functools.wraps(func) - def wrapper(self: _BaseDataFrame, *args, **kwargs): + def wrapper(self: _BaseDataFrame, *args, **kwargs) -> _BaseDataFrame: if self.last_op == Operation.INIT: self = self._convert_leaf_to_cte() self.last_op = Operation.NO_OP @@ -47,7 +47,7 @@ def wrapper(self: _BaseDataFrame, *args, **kwargs): self = self._convert_leaf_to_cte() df: t.Union[_BaseDataFrame, _BaseGroupedData] = func(self, *args, **kwargs) df.last_op = new_op # type: ignore - return df + return df # type: ignore wrapper.__wrapped__ = func # type: ignore return wrapper @@ -55,7 +55,7 @@ def wrapper(self: _BaseDataFrame, *args, **kwargs): return decorator -def group_operation(op: Operation): +def group_operation(op: Operation) -> t.Callable[[t.Callable], t.Callable]: """ Decorator used around DataFrame methods to indicate what type of operation is being performed from the ordered Operation enums. This is used to determine which operations should be performed on a CTE vs. @@ -67,9 +67,9 @@ def group_operation(op: Operation): in cases where there is overlap in names. """ - def decorator(func: t.Callable): + def decorator(func: t.Callable) -> t.Callable: @functools.wraps(func) - def wrapper(self: _BaseGroupedData, *args, **kwargs): + def wrapper(self: _BaseGroupedData, *args, **kwargs) -> _BaseDataFrame: if self._df.last_op == Operation.INIT: self._df = self._df._convert_leaf_to_cte() self._df.last_op = Operation.NO_OP diff --git a/sqlframe/base/session.py b/sqlframe/base/session.py index c93218f..392cb8f 100644 --- a/sqlframe/base/session.py +++ b/sqlframe/base/session.py @@ -11,9 +11,9 @@ from functools import cached_property import sqlglot -from more_itertools import take from sqlglot import Dialect, exp from sqlglot.expressions import parse_identifier +from sqlglot.helper import seq_get from sqlglot.optimizer import optimize from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlglot.optimizer.qualify_columns import ( @@ -211,10 +211,10 @@ def get_default_data_type(value: t.Any) -> t.Optional[str]: row_types.append((row_name, default_type)) return "struct<" + ", ".join(f"{k}: {v}" for (k, v) in row_types) + ">" elif isinstance(value, dict): - sample_row = take(1, value.items()) + sample_row = seq_get(list(value.items()), 0) if not sample_row: return None - key, value = sample_row[0] + key, value = sample_row default_key = get_default_data_type(key) default_value = get_default_data_type(value) if not default_key or not default_value: diff --git a/sqlframe/base/transforms.py b/sqlframe/base/transforms.py index 1123178..99c4caf 100644 --- a/sqlframe/base/transforms.py +++ b/sqlframe/base/transforms.py @@ -5,7 +5,9 @@ from sqlglot import expressions as exp -def replace_id_value(node, replacement_mapping: t.Dict[exp.Identifier, exp.Identifier]): +def replace_id_value( + node: exp.Expression, replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] +) -> exp.Expression: if isinstance(node, exp.Identifier) and node in replacement_mapping: node = node.replace(replacement_mapping[node].copy()) return node diff --git a/sqlframe/base/util.py b/sqlframe/base/util.py index f5e6918..6ee9268 100644 --- a/sqlframe/base/util.py +++ b/sqlframe/base/util.py @@ -9,6 +9,7 @@ from sqlglot.schema import ensure_column_mapping as sqlglot_ensure_column_mapping if t.TYPE_CHECKING: + from pandas.core.frame import DataFrame as PandasDataFrame from pyspark.sql.dataframe import SparkSession as PySparkSession from sqlframe.base import types @@ -97,7 +98,7 @@ def get_tables_from_expression_with_join(expression: exp.Select) -> t.List[exp.T return [left_table] + other_tables -def to_csv(options: t.Dict[str, OptionalPrimitiveType], equality_char: str = "="): +def to_csv(options: t.Dict[str, OptionalPrimitiveType], equality_char: str = "=") -> str: return ", ".join( [f"{k}{equality_char}{v}" for k, v in (options or {}).items() if v is not None] ) @@ -116,7 +117,7 @@ def ensure_column_mapping(schema: t.Union[str, StructType]) -> t.Dict: # SO: https://stackoverflow.com/questions/37513355/converting-pandas-dataframe-into-spark-dataframe-error -def get_equivalent_spark_type(pandas_type): +def get_equivalent_spark_type(pandas_type) -> types.DataType: """ This method will retrieve the corresponding spark type given a pandas type. @@ -139,7 +140,7 @@ def get_equivalent_spark_type(pandas_type): return type_map.get(str(pandas_type).lower(), types.StringType()) -def pandas_to_spark_schema(pandas_df): +def pandas_to_spark_schema(pandas_df: PandasDataFrame) -> types.StructType: """ This method will return a spark dataframe schema given a pandas dataframe. diff --git a/sqlframe/bigquery/catalog.py b/sqlframe/bigquery/catalog.py index 0e77441..8a9e30f 100644 --- a/sqlframe/bigquery/catalog.py +++ b/sqlframe/bigquery/catalog.py @@ -3,7 +3,6 @@ import fnmatch import typing as t -from google.cloud.bigquery import StandardSqlDataType from sqlglot import exp from sqlframe.base.catalog import CatalogMetadata, Column, Function @@ -16,8 +15,10 @@ from sqlframe.base.util import schema_, to_schema if t.TYPE_CHECKING: - from sqlframe.bigquery.session import BigQuerySession # noqa + from google.cloud.bigquery import StandardSqlDataType + from sqlframe.bigquery.dataframe import BigQueryDataFrame # noqa + from sqlframe.bigquery.session import BigQuerySession # noqa class BigQueryCatalog( diff --git a/sqlframe/bigquery/session.py b/sqlframe/bigquery/session.py index 5cdfa5a..aba638e 100644 --- a/sqlframe/bigquery/session.py +++ b/sqlframe/bigquery/session.py @@ -11,9 +11,10 @@ ) if t.TYPE_CHECKING: - from google.cloud import bigquery + from google.cloud.bigquery.client import Client as BigQueryClient from google.cloud.bigquery.dbapi.connection import Connection as BigQueryConnection else: + BigQueryClient = t.Any BigQueryConnection = t.Any @@ -48,7 +49,7 @@ def __init__( self.default_dataset = default_dataset @property - def _client(self) -> bigquery.client.Client: + def _client(self) -> BigQueryClient: assert self._connection return self._connection._client diff --git a/sqlframe/duckdb/readwriter.py b/sqlframe/duckdb/readwriter.py index 302ca48..f3f6f8b 100644 --- a/sqlframe/duckdb/readwriter.py +++ b/sqlframe/duckdb/readwriter.py @@ -87,7 +87,7 @@ def _write(self, path: str, mode: t.Optional[str], **options): # type: ignore return if mode == "append": raise NotImplementedError("Append mode not supported") - options = to_csv(options, equality_char=" ") + options = to_csv(options, equality_char=" ") # type: ignore sqls = self._df.sql(pretty=False, optimize=False, as_list=True) for i, sql in enumerate(sqls): if i < len(sqls) - 1: diff --git a/tests/common_fixtures.py b/tests/common_fixtures.py index 9f9e796..8166de7 100644 --- a/tests/common_fixtures.py +++ b/tests/common_fixtures.py @@ -19,7 +19,9 @@ from sqlframe.standalone.session import StandaloneSession if t.TYPE_CHECKING: - from google.cloud.bigquery.dbapi.connection import Connection as BigQueryConnection + from google.cloud.bigquery.dbapi.connection import ( + Connection as BigQueryConnection, + ) from redshift_connector.core import Connection as RedshiftConnection from snowflake.connector import SnowflakeConnection @@ -36,6 +38,7 @@ def pyspark_session(tmp_path_factory) -> PySparkSession: .config("spark.sql.warehouse.dir", data_dir) .config("spark.driver.extraJavaOptions", f"-Dderby.system.home={derby_dir}") .config("spark.sql.shuffle.partitions", 1) + .config("spark.sql.session.timeZone", "America/Los_Angeles") .master("local[1]") .appName("Unit-tests") .getOrCreate() @@ -60,11 +63,11 @@ def spark_session(pyspark_session: PySparkSession) -> SparkSession: @pytest.fixture(scope="function") -def duckdb_session(monkeypatch: pytest.MonkeyPatch) -> DuckDBSession: - import duckdb +def duckdb_session() -> DuckDBSession: + from duckdb import connect # https://github.com/duckdb/duckdb/issues/11404 - connection = duckdb.connect() + connection = connect() connection.sql("set TimeZone = 'UTC'") return DuckDBSession(conn=connection) @@ -74,12 +77,12 @@ def function_scoped_postgres(postgresql_proc): import psycopg2 janitor = DatabaseJanitor( - postgresql_proc.user, - postgresql_proc.host, - postgresql_proc.port, - postgresql_proc.dbname, - postgresql_proc.version, - postgresql_proc.password, + user=postgresql_proc.user, + host=postgresql_proc.host, + port=postgresql_proc.port, + dbname=postgresql_proc.dbname, + version=postgresql_proc.version, + password=postgresql_proc.password, ) with janitor: conn = psycopg2.connect( diff --git a/tests/conftest.py b/tests/conftest.py index 4150350..b3083b7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,20 @@ from __future__ import annotations +import time + import pytest +@pytest.fixture(scope="session", autouse=True) +def set_tz(): + import os + + os.environ["TZ"] = "US/Pacific" + time.tzset() + yield + del os.environ["TZ"] + + @pytest.fixture(scope="function", autouse=True) def rescope_sparksession_singleton(): from sqlframe.base.session import _BaseSession diff --git a/tests/integration/engines/duckdb/__init__.py b/tests/integration/engines/duck/__init__.py similarity index 100% rename from tests/integration/engines/duckdb/__init__.py rename to tests/integration/engines/duck/__init__.py diff --git a/tests/integration/engines/duckdb/test_duckdb_catalog.py b/tests/integration/engines/duck/test_duckdb_catalog.py similarity index 100% rename from tests/integration/engines/duckdb/test_duckdb_catalog.py rename to tests/integration/engines/duck/test_duckdb_catalog.py diff --git a/tests/integration/engines/duckdb/test_duckdb_session.py b/tests/integration/engines/duck/test_duckdb_session.py similarity index 100% rename from tests/integration/engines/duckdb/test_duckdb_session.py rename to tests/integration/engines/duck/test_duckdb_session.py diff --git a/tests/integration/engines/test_engine_session.py b/tests/integration/engines/test_engine_session.py index 2c728df..d187c1e 100644 --- a/tests/integration/engines/test_engine_session.py +++ b/tests/integration/engines/test_engine_session.py @@ -29,8 +29,8 @@ def test_session(cleanup_session: _BaseSession): cola_type = exp.DataType.build("DECIMAL", dialect=session.output_dialect) else: cola_type = exp.DataType.build("INT", dialect=session.output_dialect) - cola_name = "COLA" if session.output_dialect == "snowflake" else "cola" - colb_name = "COLB" if session.output_dialect == "snowflake" else "colb" + cola_name = '"COLA"' if session.output_dialect == "snowflake" else '"cola"' + colb_name = '"COLB"' if session.output_dialect == "snowflake" else '"colb"' assert columns == { cola_name: cola_type, colb_name: exp.DataType.build("VARCHAR", dialect=session.output_dialect)