From 43e3557b8d1b9bfa7b48e6bd1e7e1a5d6d8665f2 Mon Sep 17 00:00:00 2001 From: Ryan Eakman <6326532+eakmanrq@users.noreply.github.com> Date: Sun, 25 Aug 2024 12:27:34 -0700 Subject: [PATCH] feat: Add Activate to replace PySpark Imports (#155) feat: add pyspark replace --- README.md | 84 ++++++++---- docs/bigquery.md | 88 ++++++++++--- docs/configuration.md | 47 +++++++ docs/duckdb.md | 81 +++++++++--- docs/postgres.md | 85 +++++++++--- docs/snowflake.md | 99 +++++++++++--- docs/spark.md | 70 ++++++++-- docs/standalone.md | 54 ++++++-- setup.py | 1 + sqlframe/__init__.py | 83 ++++++++++++ sqlframe/base/session.py | 4 + sqlframe/bigquery/__init__.py | 13 +- sqlframe/bigquery/session.py | 3 +- sqlframe/duckdb/__init__.py | 15 ++- sqlframe/duckdb/column.py | 2 +- sqlframe/duckdb/session.py | 3 +- sqlframe/postgres/__init__.py | 13 +- sqlframe/postgres/session.py | 3 +- sqlframe/redshift/__init__.py | 13 +- sqlframe/redshift/session.py | 3 +- sqlframe/snowflake/__init__.py | 8 +- sqlframe/snowflake/session.py | 3 +- sqlframe/spark/__init__.py | 13 +- sqlframe/spark/session.py | 3 +- sqlframe/standalone/__init__.py | 8 +- sqlframe/standalone/session.py | 3 +- tests/conftest.py | 1 + .../engines/duck/test_duckdb_activate.py | 37 ++++++ .../postgres/test_postgres_activate.py | 37 ++++++ tests/unit/bigquery/__init__.py | 0 tests/unit/bigquery/test_activate.py | 51 +++++++ tests/unit/conftest.py | 124 ++++++++++++++++++ tests/unit/duck/__init__.py | 0 tests/unit/duck/test_activate.py | 41 ++++++ tests/unit/postgres/__init__.py | 0 tests/unit/postgres/test_activate.py | 41 ++++++ tests/unit/redshift/__init__.py | 0 tests/unit/redshift/test_activate.py | 41 ++++++ tests/unit/snowflake/__init__.py | 0 tests/unit/snowflake/test_activate.py | 41 ++++++ tests/unit/spark/__init__.py | 0 tests/unit/spark/test_activate.py | 41 ++++++ tests/unit/standalone/test_activate.py | 41 ++++++ tests/unit/test_activate.py | 37 ++++++ 44 files changed, 1185 insertions(+), 150 deletions(-) create mode 100644 tests/integration/engines/duck/test_duckdb_activate.py create mode 100644 tests/integration/engines/postgres/test_postgres_activate.py create mode 100644 tests/unit/bigquery/__init__.py create mode 100644 tests/unit/bigquery/test_activate.py create mode 100644 tests/unit/conftest.py create mode 100644 tests/unit/duck/__init__.py create mode 100644 tests/unit/duck/test_activate.py create mode 100644 tests/unit/postgres/__init__.py create mode 100644 tests/unit/postgres/test_activate.py create mode 100644 tests/unit/redshift/__init__.py create mode 100644 tests/unit/redshift/test_activate.py create mode 100644 tests/unit/snowflake/__init__.py create mode 100644 tests/unit/snowflake/test_activate.py create mode 100644 tests/unit/spark/__init__.py create mode 100644 tests/unit/spark/test_activate.py create mode 100644 tests/unit/standalone/test_activate.py create mode 100644 tests/unit/test_activate.py diff --git a/README.md b/README.md index 6da22a5..1dfd7a0 100644 --- a/README.md +++ b/README.md @@ -18,10 +18,10 @@ SQLFrame also has a "Standalone" session that be used to generate SQL without an SQLFrame is great for: -* Users who want to run PySpark DataFrame code without having to use a Spark cluster +* Users who want a DataFrame API that leverages the full power of their engine to do the processing +* Users who want to run PySpark code quickly locally without the overhead of starting a Spark session * Users who want a SQL representation of their DataFrame code for debugging or sharing with others - * See [Spark Engine](https://sqlframe.readthedocs.io/en/stable/spark/) for more details -* Users who want a DataFrame API that leverages the full power of their engine to do the processing +* Users who want to run PySpark DataFrame code without the complexity of using Spark for processing ## Installation @@ -45,44 +45,72 @@ See specific engine documentation for additional setup instructions. ## Configuration SQLFrame generates consistently accurate yet complex SQL for engine execution. -However, when using df.sql(), it produces more human-readable SQL. +However, when using df.sql(optimize=True), it produces more human-readable SQL. For details on how to configure this output and leverage OpenAI to enhance the SQL, see [Generated SQL Configuration](https://sqlframe.readthedocs.io/en/stable/configuration/#generated-sql). SQLFrame by default uses the Spark dialect for input and output. This can be changed to make SQLFrame feel more like a native DataFrame API for the engine you are using. See [Input and Output Dialect Configuration](https://sqlframe.readthedocs.io/en/stable/configuration/#input-and-output-dialect). +## Activating SQLFrame + +SQLFrame can either replace pyspark imports or be used alongside them. +To replace pyspark imports, use the [activate function](https://sqlframe.readthedocs.io/en/stable/configuration/#activating-sqlframe) to set the engine to use. + +```python +from sqlframe import activate + +# Activate SQLFrame to run directly on DuckDB +activate(engine="duckdb") + +from pyspark.sql import SparkSession +session = SparkSession.builder.getOrCreate() +``` + +SQLFrame can also be directly imported which both maintains pyspark imports but also allows for a more engine-native DataFrame API: + +```python +from sqlframe.duckdb import DuckDBSession + +session = DuckDBSession.builder.getOrCreate() +``` + ## Example Usage ```python -from sqlframe.bigquery import BigQuerySession -from sqlframe.bigquery import functions as F -from sqlframe.bigquery import Window +from sqlframe import activate + +# Activate SQLFrame to run directly on BigQuery +activate(engine="bigquery") + +from pyspark.sql import SparkSession +from pyspark.sql import functions as F +from pyspark.sql import Window -session = BigQuerySession() +session = SparkSession.builder.getOrCreate() table_path = '"bigquery-public-data".samples.natality' # Top 5 years with the greatest year-over-year % change in new families with single child df = ( - session.table(table_path) - .where(F.col("ever_born") == 1) - .groupBy("year") - .agg(F.count("*").alias("num_single_child_families")) - .withColumn( - "last_year_num_single_child_families", - F.lag(F.col("num_single_child_families"), 1).over(Window.orderBy("year")) - ) - .withColumn( - "percent_change", - (F.col("num_single_child_families") - F.col("last_year_num_single_child_families")) - / F.col("last_year_num_single_child_families") - ) - .orderBy(F.abs(F.col("percent_change")).desc()) - .select( - F.col("year").alias("year"), - F.format_number("num_single_child_families", 0).alias("new families single child"), - F.format_number(F.col("percent_change") * 100, 2).alias("percent change"), - ) - .limit(5) + session.table(table_path) + .where(F.col("ever_born") == 1) + .groupBy("year") + .agg(F.count("*").alias("num_single_child_families")) + .withColumn( + "last_year_num_single_child_families", + F.lag(F.col("num_single_child_families"), 1).over(Window.orderBy("year")) + ) + .withColumn( + "percent_change", + (F.col("num_single_child_families") - F.col("last_year_num_single_child_families")) + / F.col("last_year_num_single_child_families") + ) + .orderBy(F.abs(F.col("percent_change")).desc()) + .select( + F.col("year").alias("year"), + F.format_number("num_single_child_families", 0).alias("new families single child"), + F.format_number(F.col("percent_change") * 100, 2).alias("percent change"), + ) + .limit(5) ) ``` ```python diff --git a/docs/bigquery.md b/docs/bigquery.md index 51085c8..8c15e30 100644 --- a/docs/bigquery.md +++ b/docs/bigquery.md @@ -6,6 +6,46 @@ pip install "sqlframe[bigquery]" ``` +## Enabling SQLFrame + +SQLFrame can be used in two ways: + +* Directly importing the `sqlframe.bigquery` package +* Using the [activate](./configuration.md#activating-sqlframe) function to allow for continuing to use `pyspark.sql` but have it use SQLFrame behind the scenes. + +### Import + +If converting a PySpark pipeline, all `pyspark.sql` should be replaced with `sqlframe.bigquery`. +In addition, many classes will have a `BigQuery` prefix. +For example, `BigQueryDataFrame` instead of `DataFrame`. + + +```python +# PySpark import +# from pyspark.sql import SparkSession +# from pyspark.sql import functions as F +# from pyspark.sql.dataframe import DataFrame +# SQLFrame import +from sqlframe.bigquery import BigQuerySession +from sqlframe.bigquery import functions as F +from sqlframe.bigquery import BigQueryDataFrame +``` + +### Activate + +If you would like to continue using `pyspark.sql` but have it use SQLFrame behind the scenes, you can use the [activate](./configuration.md#activating-sqlframe) function. + +```python +from sqlframe import activate +activate("bigquery", config={"default_dataset": "sqlframe.db1"}) + +from pyspark.sql import SparkSession +``` + +`SparkSession` will now be a SQLFrame `BigQuerySession` object and everything will be run on BigQuery directly. + +See [activate configuration](./configuration.md#activating-sqlframe) for information on how to pass in a connection and config options. + ## Creating a Session SQLFrame uses the [BigQuery DBAPI Connection](https://cloud.google.com/python/docs/reference/bigquery/latest/dbapi#class-googlecloudbigquerydbapiconnectionclientnone-bqstorageclientnone) to connect to BigQuery. @@ -13,7 +53,7 @@ A BigQuerySession, which implements the PySpark Session API, can be created by p By default, SQLFrame will create a connection by inferring it from the environment (for example using gcloud auth). Regardless of approach, it is recommended to configure `default_dataset` in the `BigQuerySession` constructor in order to make it easier to use the catalog methods (see example below). -=== "Without Providing Connection" +=== "Import + Without Providing Connection" ```python from sqlframe.bigquery import BigQuerySession @@ -21,7 +61,7 @@ Regardless of approach, it is recommended to configure `default_dataset` in the session = BigQuerySession(default_dataset="sqlframe.db1") ``` -=== "With Providing Connection" +=== "Import + With Providing Connection" ```python import google.auth @@ -43,23 +83,39 @@ Regardless of approach, it is recommended to configure `default_dataset` in the session = BigQuerySession(conn=conn, default_dataset="sqlframe.db1") ``` -## Imports +=== "Activate + Without Providing Connection" -If converting a PySpark pipeline, all `pyspark.sql` should be replaced with `sqlframe.bigquery`. -In addition, many classes will have a `BigQuery` prefix. -For example, `BigQueryDataFrame` instead of `DataFrame`. + ```python + from sqlframe import activate + activate("bigquery", config={"default_dataset": "sqlframe.db1"}) + + from pyspark.sql import SparkSession + session = SparkSession.builder.getOrCreate() + ``` +=== "Activate + With Providing Connection" -```python -# PySpark import -# from pyspark.sql import SparkSession -# from pyspark.sql import functions as F -# from pyspark.sql.dataframe import DataFrame -# SQLFrame import -from sqlframe.bigquery import BigQuerySession -from sqlframe.bigquery import functions as F -from sqlframe.bigquery import BigQueryDataFrame -``` + ```python + import google.auth + from google.api_core import client_info + from google.oauth2 import service_account + from google.cloud.bigquery.dbapi import connect + from sqlframe import activate + creds = service_account.Credentials.from_service_account_file("path/to/credentials.json") + + client = google.cloud.bigquery.Client( + project="my-project", + credentials=creds, + location="us-central1", + client_info=client_info.ClientInfo(user_agent="sqlframe"), + ) + + conn = connect(client=client) + activate("bigquery", conn=conn, config={"default_dataset": "sqlframe.db1"}) + + from pyspark.sql import SparkSession + session = SparkSession.builder.getOrCreate() + ``` ## Using BigQuery Unique Functions diff --git a/docs/configuration.md b/docs/configuration.md index ddbac7c..761b365 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -24,6 +24,53 @@ In this configuration, you can use BigQuery syntax for elements such as date for SQLFrame supports multiple dialects, all of which can be specific as the `input_dialect` and `output_dialect`. +## Activating SQLFrame + +SQLFrame can be activated in order to replace `pyspark` imports with `sqlframe` imports for the given engine. +This allows you to use SQLFrame as a drop-in replacement for PySpark by just adding two lines of code. + +### Activate with Engine + +If you just provide an engine to `activate` then it will create a connection for that engine with default settings (if the engine supports it). + +```python + +from sqlframe import activate +activate("duckdb") + +from pyspark.sql import SparkSession +spark = SparkSession.builder.getOrCreate() +# "spark" is not a SQLFrame DuckDBSession and will run directly on DuckDB +``` + +### Activate with Connection + +If you provide a connection to `activate` then it will use that connection for the engine. + +```python +import duckdb +from sqlframe import activate +connection = duckdb.connect("file.duckdb") +activate("duckdb", conn=connection) + +from pyspark.sql import SparkSession +spark = SparkSession.builder.getOrCreate() +# "spark" is a SQLFrame DuckDBSession and will run directly on DuckDB using `file.duckdb` for persistence +``` + +### Activate with Configuration + +If you provide a configuration to `activate` then it will use that configuration to create a connection for the engine. + +```python +from sqlframe import activate +activate("duckdb", config={"sqlframe.input.dialect": "duckdb"}) + +from pyspark.sql import SparkSession +spark = SparkSession.builder.getOrCreate() +# "spark" is a SQLFrame DuckDBSession and will run directly on DuckDB with input dialect set to DuckDB +``` + ## Generated SQL ### Pretty diff --git a/docs/duckdb.md b/docs/duckdb.md index dd298f5..6614cd3 100644 --- a/docs/duckdb.md +++ b/docs/duckdb.md @@ -6,6 +6,46 @@ pip install "sqlframe[duckdb]" ``` +## Enabling SQLFrame + +SQLFrame can be used in two ways: + +* Directly importing the `sqlframe.duckdb` package +* Using the [activate](./configuration.md#activating-sqlframe) function to allow for continuing to use `pyspark.sql` but have it use SQLFrame behind the scenes. + +### Import + +If converting a PySpark pipeline, all `pyspark.sql` should be replaced with `sqlframe.duckdb`. +In addition, many classes will have a `DuckDB` prefix. +For example, `DuckDBDataFrame` instead of `DataFrame`. + + +```python +# PySpark import +# from pyspark.sql import SparkSession +# from pyspark.sql import functions as F +# from pyspark.sql.dataframe import DataFrame +# SQLFrame import +from sqlframe.duckdb import DuckDBSession +from sqlframe.duckdb import functions as F +from sqlframe.duckdb import DuckDBDataFrame +``` + +### Activate + +If you would like to continue using `pyspark.sql` but have it use SQLFrame behind the scenes, you can use the [activate](./configuration.md#activating-sqlframe) function. + +```python +from sqlframe import activate +activate("duckdb") + +from pyspark.sql import SparkSession +``` + +`SparkSession` will now be a SQLFrame `DuckDBSession` object and everything will be run on DuckDB directly. + +See [activate configuration](./configuration.md#activating-sqlframe) for information on how to pass in a connection and config options. + ## Creating a Session SQLFrame uses the `duckdb` package to connect to DuckDB. @@ -13,7 +53,7 @@ A DuckDBSession, which implements the PySpark Session API, can be created by pas By default, SQLFrame will create a connection to an in-memory database. -=== "Without Providing Connection" +=== "Import + Without Providing Connection" ```python from sqlframe.duckdb import DuckDBSession @@ -21,7 +61,7 @@ By default, SQLFrame will create a connection to an in-memory database. session = DuckDBSession() ``` -=== "With Providing Connection" +=== "Import + With Providing Connection" ```python import duckdb @@ -30,23 +70,30 @@ By default, SQLFrame will create a connection to an in-memory database. conn = duckdb.connect(database=":memory:") session = DuckDBSession(conn=conn) ``` -## Imports -If converting a PySpark pipeline, all `pyspark.sql` should be replaced with `sqlframe.duckdb`. -In addition, many classes will have a `DuckDB` prefix. -For example, `DuckDBDataFrame` instead of `DataFrame`. +=== "Activate + Without Providing Connection" + ```python + from sqlframe import activate + activate("duckdb") -```python -# PySpark import -# from pyspark.sql import SparkSession -# from pyspark.sql import functions as F -# from pyspark.sql.dataframe import DataFrame -# SQLFrame import -from sqlframe.duckdb import DuckDBSession -from sqlframe.duckdb import functions as F -from sqlframe.duckdb import DuckDBDataFrame -``` + from pyspark.sql import SparkSession + + session = SparkSession.builder.getOrCreate() + ``` + +=== "Activate + With Providing Connection" + + ```python + import duckdb + from sqlframe import activate + conn = duckdb.connect(database=":memory:") + activate("duckdb", conn=conn) + + from pyspark.sql import SparkSession + + session = SparkSession.builder.getOrCreate() + ``` ## Using DuckDB Unique Functions @@ -202,6 +249,8 @@ See something that you would like to see supported? [Open an issue](https://gith * sql * SQLFrame Specific: Get the SQL representation of the WindowSpec * [stat](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.stat.html) +* [toArrow](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.toArrow.html) + * SQLFrame Specific Argument: `batch_size` sets the number of rows to read per-batch and returns a `RecrodBatchReader` * [toDF](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.toDF.html) * [toPandas](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.toPandas.html) * [union](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.union.html) diff --git a/docs/postgres.md b/docs/postgres.md index 23138d4..be2f64c 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -6,26 +6,14 @@ pip install "sqlframe[postgres]" ``` -## Creating a Session +## Enabling SQLFrame -SQLFrame uses the `psycopg2` package to connect to Postgres. -A PostgresSession, which implements the PySpark Session API, is created by passing in a `psycopg2.Connection` object. - -```python -from psycopg2 import connect -from sqlframe.postgres import PostgresSession +SQLFrame can be used in two ways: -conn = connect( - dbname="postgres", - user="postgres", - password="password", - host="localhost", - port="5432", -) -session = PostgresSession(conn=conn) -``` +* Directly importing the `sqlframe.postgres` package +* Using the [activate](./configuration.md#activating-sqlframe) function to allow for continuing to use `pyspark.sql` but have it use SQLFrame behind the scenes. -## Imports +### Import If converting a PySpark pipeline, all `pyspark.sql` should be replaced with `sqlframe.postgres`. In addition, many classes will have a `Postgres` prefix. @@ -43,6 +31,69 @@ from sqlframe.postgres import functions as F from sqlframe.postgres import PostgresDataFrame ``` +### Activate + +If you would like to continue using `pyspark.sql` but have it use SQLFrame behind the scenes, you can use the [activate](./configuration.md#activating-sqlframe) function. + +```python +from psycopg2 import connect +from sqlframe import activate +conn = connect( + dbname="postgres", + user="postgres", + password="password", + host="localhost", + port="5432", +) +activate("postgres", conn=conn) + +from pyspark.sql import SparkSession +``` + +`SparkSession` will now be a SQLFrame `PostgresSession` object and everything will be run on Postgres directly. + +See [activate configuration](./configuration.md#activating-sqlframe) for information on how to pass in a connection and config options. + +## Creating a Session + +SQLFrame uses the `psycopg2` package to connect to Postgres. +A PostgresSession, which implements the PySpark Session API, is created by passing in a `psycopg2.Connection` object. + +=== "Import" + + ```python + from psycopg2 import connect + from sqlframe.postgres import PostgresSession + + conn = connect( + dbname="postgres", + user="postgres", + password="password", + host="localhost", + port="5432", + ) + session = PostgresSession(conn=conn) + ``` + +=== "Activate" + + ```python + from sqlframe import activate + + conn = connect( + dbname="postgres", + user="postgres", + password="password", + host="localhost", + port="5432", + ) + activate("postgres", conn=conn) + + from pyspark.sql import SparkSession + session = SparkSession.builder.getOrCreate() + ``` + + ## Using Postgres Unique Functions Postgres may have a function that isn't represented within the PySpark API. diff --git a/docs/snowflake.md b/docs/snowflake.md index 620e0eb..48d52b8 100644 --- a/docs/snowflake.md +++ b/docs/snowflake.md @@ -6,30 +6,14 @@ pip install "sqlframe[snowflake]" ``` -## Creating a Session - -SQLFrame uses the [Snowflake Python Connector](https://docs.snowflake.com/en/developer-guide/python-connector/python-connector) to connect to Snowflake. -A SnowflakeQuerySession, which implements the PySpark Session API, can be created by passing in a `snowflake.connector.connection.SnowflakeConnection` object. - -```python -import os +## Enabling SQLFrame -from snowflake.connector import connect -from sqlframe.snowflake import SnowflakeSession -from sqlframe.snowflake import functions as F +SQLFrame can be used in two ways: -connection = connect( - account=os.environ["SNOWFLAKE_ACCOUNT"], - user=os.environ["SNOWFLAKE_USER"], - password=os.environ["SNOWFLAKE_PASSWORD"], - warehouse=os.environ["SNOWFLAKE_WAREHOUSE"], - database=os.environ["SNOWFLAKE_DATABASE"], - schema=os.environ["SNOWFLAKE_SCHEMA"], -) -session = SnowflakeSession(conn=connection) -``` +* Directly importing the `sqlframe.snowflake` package +* Using the [activate](./configuration.md#activating-sqlframe) function to allow for continuing to use `pyspark.sql` but have it use SQLFrame behind the scenes. -## Imports +### Import If converting a PySpark pipeline, all `pyspark.sql` should be replaced with `sqlframe.snowflake`. In addition, many classes will have a `Snowflake` prefix. @@ -47,6 +31,79 @@ from sqlframe.snowflake import functions as F from sqlframe.snowflake import SnowflakeDataFrame ``` +### Activate + +If you would like to continue using `pyspark.sql` but have it use SQLFrame behind the scenes, you can use the [activate](./configuration.md#activating-sqlframe) function. + +```python +import os + +from snowflake.connector import connect +from sqlframe import activate +conn = connect( + account=os.environ["SNOWFLAKE_ACCOUNT"], + user=os.environ["SNOWFLAKE_USER"], + password=os.environ["SNOWFLAKE_PASSWORD"], + warehouse=os.environ["SNOWFLAKE_WAREHOUSE"], + database=os.environ["SNOWFLAKE_DATABASE"], + schema=os.environ["SNOWFLAKE_SCHEMA"], +) +activate("snowflake", conn=conn) + +from pyspark.sql import SparkSession +``` + +`SparkSession` will now be a SQLFrame `SnowflakeSession` object and everything will be run on Snowflake directly. + +See [activate configuration](./configuration.md#activating-sqlframe) for information on how to pass in a connection and config options. + + +## Creating a Session + +SQLFrame uses the [Snowflake Python Connector](https://docs.snowflake.com/en/developer-guide/python-connector/python-connector) to connect to Snowflake. +A SnowflakeQuerySession, which implements the PySpark Session API, can be created by passing in a `snowflake.connector.connection.SnowflakeConnection` object. + +=== "Import" + + ```python + import os + + from snowflake.connector import connect + from sqlframe.snowflake import SnowflakeSession + from sqlframe.snowflake import functions as F + + connection = connect( + account=os.environ["SNOWFLAKE_ACCOUNT"], + user=os.environ["SNOWFLAKE_USER"], + password=os.environ["SNOWFLAKE_PASSWORD"], + warehouse=os.environ["SNOWFLAKE_WAREHOUSE"], + database=os.environ["SNOWFLAKE_DATABASE"], + schema=os.environ["SNOWFLAKE_SCHEMA"], + ) + session = SnowflakeSession(conn=connection) + ``` + +=== Activate + + ```python + import os + + from snowflake.connector import connect + from sqlframe import activate + conn = connect( + account=os.environ["SNOWFLAKE_ACCOUNT"], + user=os.environ["SNOWFLAKE_USER"], + password=os.environ["SNOWFLAKE_PASSWORD"], + warehouse=os.environ["SNOWFLAKE_WAREHOUSE"], + database=os.environ["SNOWFLAKE_DATABASE"], + schema=os.environ["SNOWFLAKE_SCHEMA"], + ) + activate("snowflake", conn=conn) + + from pyspark.sql import SparkSession + spark = SparkSession.builder.getOrCreate() + ``` + ## Using Snowflake Unique Functions Snowflake may have a function that isn't represented within the PySpark API. diff --git a/docs/spark.md b/docs/spark.md index 4511442..3240f79 100644 --- a/docs/spark.md +++ b/docs/spark.md @@ -6,26 +6,19 @@ pip install "sqlframe[spark]" ``` -## Creating a Session +## Enabling SQLFrame -SQLFrame's SparkSession is created the same way you would normally create a SparkSession. -The configuration you apply to the builder will be applied the the SparkSession that SQLFrame will create. - -```python -from sqlframe.spark import SparkSession +SQLFrame can be used in two ways: -spark = SparkSession.builder.appName("MyApp").getOrCreate() +* Directly importing the `sqlframe.spark` package +* Using the [activate](./configuration.md#activating-sqlframe) function to allow for continuing to use `pyspark.sql` but have it use SQLFrame behind the scenes. -# Now you can use SQLFrame -``` - -## Imports +### Import If converting a PySpark pipeline, all `pyspark.sql` should be replaced with `sqlframe.spark`. In addition, many classes will have a `Spark` prefix. For example, `SparkDataFrame` instead of `DataFrame`. - ```python # PySpark import # from pyspark.sql import SparkSession @@ -37,6 +30,59 @@ from sqlframe.spark import functions as F from sqlframe.spark import SparkDataFrame ``` +### Activate + +If you would like to continue using `pyspark.sql` but have it use SQLFrame behind the scenes, you can use the [activate](./configuration.md#activating-sqlframe) function. + +```python +from sqlframe import activate +activate("spark") + +from pyspark.sql import SparkSession +``` + +## Creating a Session + +SQLFrame's SparkSession is created the same way you would normally create a SparkSession. +The configuration you apply to the builder will be applied the the SparkSession that SQLFrame will create. + +=== "Import" + + ```python + from sqlframe.spark import SparkSession + + spark = SparkSession.builder.appName("MyApp").getOrCreate() + + # Now you can use SQLFrame + ``` + +=== "Activate + Without Providing SparkSession" + + ```python + from sqlframe import activate + activate("spark") + + from pyspark.sql import SparkSession + + spark = SparkSession.builder.appName("MyApp").getOrCreate() + + # Now you can use SQLFrame + ``` + +=== "Activate + Providing SparkSession" + + ```python + from pyspark.sql import SparkSession + from sqlframe import activate + activate("spark", conn=SparkSession.builder.appName("MyApp").getOrCreate()) + + from pyspark.sql import SparkSession + + spark = SparkSession.getOrCreate() + + # Now you can use SQLFrame + ``` + ## Example Usage ```python diff --git a/docs/standalone.md b/docs/standalone.md index 7269842..d3a9f82 100644 --- a/docs/standalone.md +++ b/docs/standalone.md @@ -10,21 +10,14 @@ Also any operation that requires an actual connection to the database (like exec pip install sqlframe ``` -## Creating a Session - -Standalone supports defining both an `input_dialect` and an `output_dialect` which can be different from each other. -`input_dialect` is the dialect used when using the DataFrame API and `output_dialect` is the dialect used when generating the SQL query. -For example if you want Snowflake behavior of converting lowercase unquoted columns to uppercase, then you would set `input_dialect` to `snowflake`. -If you plan on running the query against BigQuery, then you would set `output_dialect` to `bigquery`. -Default is `spark` for both input and output dialects. +## Enabling SQLFrame -```python -from sqlframe.standalone import StandaloneSession +SQLFrame can be used in two ways: -session = StandaloneSession.builder.config(map={"sqlframe.input.dialect": 'duckdb', "sqlframe.output.dialect": 'bigquery'}).getOrCreate() -``` +* Directly importing the `sqlframe.standalone` package +* Using the [activate](./configuration.md#activating-sqlframe) function to allow for continuing to use `pyspark.sql` but have it use SQLFrame behind the scenes. -## Imports +### Import If converting a PySpark pipeline, all `pyspark.sql` should be replaced with `sqlframe.standalone`. In addition, many classes will have a `Standalone` prefix. @@ -42,6 +35,43 @@ from sqlframe.standalone import functions as F from sqlframe.standalone import StandaloneDataFrame ``` +### Activate + +If you would like to continue using `pyspark.sql` but have it use SQLFrame behind the scenes, you can use the [activate](./configuration.md#activating-sqlframe) function. + +```python +from sqlframe import activate +activate("standalone") + +from pyspark.sql import SparkSession +``` + +## Creating a Session + +Standalone supports defining both an `input_dialect` and an `output_dialect` which can be different from each other. +`input_dialect` is the dialect used when using the DataFrame API and `output_dialect` is the dialect used when generating the SQL query. +For example if you want Snowflake behavior of converting lowercase unquoted columns to uppercase, then you would set `input_dialect` to `snowflake`. +If you plan on running the query against BigQuery, then you would set `output_dialect` to `bigquery`. +Default is `spark` for both input and output dialects. + +=== "Import" + + ```python + from sqlframe.standalone import StandaloneSession + + session = StandaloneSession.builder.config(map={"sqlframe.input.dialect": 'duckdb', "sqlframe.output.dialect": 'bigquery'}).getOrCreate() + ``` + +=== "Activate" + + ```python + from sqlframe import activate + activate("standalone", config={"sqlframe.input.dialect": 'duckdb', "sqlframe.output.dialect": 'duckdb'}) + + from pyspark.sql import SparkSession + session = SparkSession.builder.getOrCreate() + ``` + ## Accessing Tables PySpark DataFrame API, and currently SQLFrame, requires that a table can be access to get it's schema information. diff --git a/setup.py b/setup.py index d758841..c042e82 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ ], "dev": [ "duckdb>=0.9,<1.1", + "findspark>=2,<3", "mypy>=1.10.0,<1.12", "openai>=1.30,<1.43", "pandas>=2,<3", diff --git a/sqlframe/__init__.py b/sqlframe/__init__.py index e69de29..0683233 100644 --- a/sqlframe/__init__.py +++ b/sqlframe/__init__.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import importlib +import sys +import typing as t +from unittest.mock import MagicMock + +if t.TYPE_CHECKING: + from sqlframe.base.session import CONN + +ENGINE_TO_PREFIX = { + "bigquery": "BigQuery", + "duckdb": "DuckDB", + "postgres": "Postgres", + "redshift": "Redshift", + "snowflake": "Snowflake", + "spark": "Spark", + "standalone": "Standalone", +} + +NAME_TO_FILE_OVERRIDE = { + "DataFrameNaFunctions": "dataframe", + "DataFrameStatFunctions": "dataframe", + "DataFrameReader": "readwriter", + "DataFrameWriter": "readwriter", + "GroupedData": "group", + "SparkSession": "session", + "WindowSpec": "window", + "UDFRegistration": "udf", +} + +ACTIVATE_CONFIG = {} + + +def activate( + engine: t.Optional[str] = None, + conn: t.Optional[CONN] = None, + config: t.Optional[t.Dict[str, t.Any]] = None, +) -> None: + import sqlframe + from sqlframe import testing + + pyspark_mock = MagicMock() + pyspark_mock.__file__ = "pyspark" + sys.modules["pyspark"] = pyspark_mock + pyspark_mock.testing = testing + sys.modules["pyspark.testing"] = testing + if conn: + ACTIVATE_CONFIG["sqlframe.conn"] = conn + for key, value in (config or {}).items(): + ACTIVATE_CONFIG[key] = value + if not engine: + return + engine = engine.lower() + if engine not in ENGINE_TO_PREFIX: + raise ValueError( + f"Unsupported engine {engine}. Supported engines are {', '.join(ENGINE_TO_PREFIX)}" + ) + prefix = ENGINE_TO_PREFIX[engine] + engine_module = importlib.import_module(f"sqlframe.{engine}") + + sys.modules["pyspark.sql"] = engine_module + pyspark_mock.sql = engine_module + types = engine_module.__dict__.copy() + resolved_files = set() + for name, obj in types.items(): + if name.startswith(prefix) or name in [ + "Column", + "Window", + "WindowSpec", + "functions", + "types", + ]: + name_without_prefix = name.replace(prefix, "") + if name_without_prefix == "Session": + name_without_prefix = "SparkSession" + setattr(engine_module, name_without_prefix, obj) + file = NAME_TO_FILE_OVERRIDE.get(name_without_prefix, name_without_prefix).lower() + engine_file = importlib.import_module(f"sqlframe.{engine}.{file}") + if engine_file not in resolved_files: + sys.modules[f"pyspark.sql.{file}"] = engine_file + resolved_files.add(engine_file) + setattr(engine_file, name_without_prefix, obj) diff --git a/sqlframe/base/session.py b/sqlframe/base/session.py index de6e4a8..77720a9 100644 --- a/sqlframe/base/session.py +++ b/sqlframe/base/session.py @@ -605,6 +605,10 @@ def session(self) -> _BaseSession: return _BaseSession(**self._session_kwargs) def getOrCreate(self) -> _BaseSession: + from sqlframe import ACTIVATE_CONFIG + + for k, v in ACTIVATE_CONFIG.items(): + self._set_config(k, v) self._set_session_properties() return self.session diff --git a/sqlframe/bigquery/__init__.py b/sqlframe/bigquery/__init__.py index e7f5c29..1d7ead1 100644 --- a/sqlframe/bigquery/__init__.py +++ b/sqlframe/bigquery/__init__.py @@ -1,23 +1,32 @@ from sqlframe.bigquery.catalog import BigQueryCatalog from sqlframe.bigquery.column import Column -from sqlframe.bigquery.dataframe import BigQueryDataFrame, BigQueryDataFrameNaFunctions +from sqlframe.bigquery.dataframe import ( + BigQueryDataFrame, + BigQueryDataFrameNaFunctions, + BigQueryDataFrameStatFunctions, +) from sqlframe.bigquery.group import BigQueryGroupedData from sqlframe.bigquery.readwriter import ( BigQueryDataFrameReader, BigQueryDataFrameWriter, ) from sqlframe.bigquery.session import BigQuerySession +from sqlframe.bigquery.types import Row +from sqlframe.bigquery.udf import BigQueryUDFRegistration from sqlframe.bigquery.window import Window, WindowSpec __all__ = [ "BigQueryCatalog", - "Column", "BigQueryDataFrame", "BigQueryDataFrameNaFunctions", "BigQueryGroupedData", "BigQueryDataFrameReader", "BigQueryDataFrameWriter", "BigQuerySession", + "BigQueryDataFrameStatFunctions", + "BigQueryUDFRegistration", + "Column", + "Row", "Window", "WindowSpec", ] diff --git a/sqlframe/bigquery/session.py b/sqlframe/bigquery/session.py index 0a8c5a4..f339ede 100644 --- a/sqlframe/bigquery/session.py +++ b/sqlframe/bigquery/session.py @@ -84,7 +84,6 @@ def session(self) -> BigQuerySession: return BigQuerySession(**self._session_kwargs) def getOrCreate(self) -> BigQuerySession: - self._set_session_properties() - return self.session + return super().getOrCreate() # type: ignore builder = Builder() diff --git a/sqlframe/duckdb/__init__.py b/sqlframe/duckdb/__init__.py index d2efde7..b24112b 100644 --- a/sqlframe/duckdb/__init__.py +++ b/sqlframe/duckdb/__init__.py @@ -1,20 +1,29 @@ from sqlframe.duckdb.catalog import DuckDBCatalog -from sqlframe.duckdb.column import DuckDBColumn -from sqlframe.duckdb.dataframe import DuckDBDataFrame, DuckDBDataFrameNaFunctions +from sqlframe.duckdb.column import Column +from sqlframe.duckdb.dataframe import ( + DuckDBDataFrame, + DuckDBDataFrameNaFunctions, + DuckDBDataFrameStatFunctions, +) from sqlframe.duckdb.group import DuckDBGroupedData from sqlframe.duckdb.readwriter import DuckDBDataFrameReader, DuckDBDataFrameWriter from sqlframe.duckdb.session import DuckDBSession +from sqlframe.duckdb.types import Row +from sqlframe.duckdb.udf import DuckDBUDFRegistration from sqlframe.duckdb.window import Window, WindowSpec __all__ = [ + "Column", "DuckDBCatalog", - "DuckDBColumn", "DuckDBDataFrame", "DuckDBDataFrameNaFunctions", "DuckDBGroupedData", "DuckDBDataFrameReader", "DuckDBDataFrameWriter", "DuckDBSession", + "DuckDBDataFrameStatFunctions", + "DuckDBUDFRegistration", + "Row", "Window", "WindowSpec", ] diff --git a/sqlframe/duckdb/column.py b/sqlframe/duckdb/column.py index 6e50166..ec58008 100644 --- a/sqlframe/duckdb/column.py +++ b/sqlframe/duckdb/column.py @@ -1 +1 @@ -from sqlframe.base.column import Column as DuckDBColumn +from sqlframe.base.column import Column diff --git a/sqlframe/duckdb/session.py b/sqlframe/duckdb/session.py index 3576cfd..d9f12a0 100644 --- a/sqlframe/duckdb/session.py +++ b/sqlframe/duckdb/session.py @@ -69,7 +69,6 @@ def session(self) -> DuckDBSession: return DuckDBSession(**self._session_kwargs) def getOrCreate(self) -> DuckDBSession: - self._set_session_properties() - return self.session + return super().getOrCreate() # type: ignore builder = Builder() diff --git a/sqlframe/postgres/__init__.py b/sqlframe/postgres/__init__.py index 901b616..397c5ba 100644 --- a/sqlframe/postgres/__init__.py +++ b/sqlframe/postgres/__init__.py @@ -1,23 +1,32 @@ from sqlframe.postgres.catalog import PostgresCatalog from sqlframe.postgres.column import Column -from sqlframe.postgres.dataframe import PostgresDataFrame, PostgresDataFrameNaFunctions +from sqlframe.postgres.dataframe import ( + PostgresDataFrame, + PostgresDataFrameNaFunctions, + PostgresDataFrameStatFunctions, +) from sqlframe.postgres.group import PostgresGroupedData from sqlframe.postgres.readwriter import ( PostgresDataFrameReader, PostgresDataFrameWriter, ) from sqlframe.postgres.session import PostgresSession +from sqlframe.postgres.types import Row +from sqlframe.postgres.udf import PostgresUDFRegistration from sqlframe.postgres.window import Window, WindowSpec __all__ = [ - "PostgresCatalog", "Column", + "PostgresCatalog", "PostgresDataFrame", "PostgresDataFrameNaFunctions", "PostgresGroupedData", "PostgresDataFrameReader", "PostgresDataFrameWriter", "PostgresSession", + "PostgresDataFrameStatFunctions", + "PostgresUDFRegistration", + "Row", "Window", "WindowSpec", ] diff --git a/sqlframe/postgres/session.py b/sqlframe/postgres/session.py index 445cbeb..f3b26df 100644 --- a/sqlframe/postgres/session.py +++ b/sqlframe/postgres/session.py @@ -79,7 +79,6 @@ def session(self) -> PostgresSession: return PostgresSession(**self._session_kwargs) def getOrCreate(self) -> PostgresSession: - self._set_session_properties() - return self.session + return super().getOrCreate() # type: ignore builder = Builder() diff --git a/sqlframe/redshift/__init__.py b/sqlframe/redshift/__init__.py index 0a56288..d624baa 100644 --- a/sqlframe/redshift/__init__.py +++ b/sqlframe/redshift/__init__.py @@ -1,23 +1,32 @@ from sqlframe.redshift.catalog import RedshiftCatalog from sqlframe.redshift.column import Column -from sqlframe.redshift.dataframe import RedshiftDataFrame, RedshiftDataFrameNaFunctions +from sqlframe.redshift.dataframe import ( + RedshiftDataFrame, + RedshiftDataFrameNaFunctions, + RedshiftDataFrameStatFunctions, +) from sqlframe.redshift.group import RedshiftGroupedData from sqlframe.redshift.readwriter import ( RedshiftDataFrameReader, RedshiftDataFrameWriter, ) from sqlframe.redshift.session import RedshiftSession +from sqlframe.redshift.types import Row +from sqlframe.redshift.udf import RedshiftUDFRegistration from sqlframe.redshift.window import Window, WindowSpec __all__ = [ - "RedshiftCatalog", "Column", + "RedshiftCatalog", "RedshiftDataFrame", "RedshiftDataFrameNaFunctions", "RedshiftGroupedData", "RedshiftDataFrameReader", "RedshiftDataFrameWriter", "RedshiftSession", + "RedshiftDataFrameStatFunctions", + "RedshiftUDFRegistration", + "Row", "Window", "WindowSpec", ] diff --git a/sqlframe/redshift/session.py b/sqlframe/redshift/session.py index 1d2567c..ed8dd19 100644 --- a/sqlframe/redshift/session.py +++ b/sqlframe/redshift/session.py @@ -49,7 +49,6 @@ def session(self) -> RedshiftSession: return RedshiftSession(**self._session_kwargs) def getOrCreate(self) -> RedshiftSession: - self._set_session_properties() - return self.session + return super().getOrCreate() # type: ignore builder = Builder() diff --git a/sqlframe/snowflake/__init__.py b/sqlframe/snowflake/__init__.py index 5dc3bdb..25cc77d 100644 --- a/sqlframe/snowflake/__init__.py +++ b/sqlframe/snowflake/__init__.py @@ -3,6 +3,7 @@ from sqlframe.snowflake.dataframe import ( SnowflakeDataFrame, SnowflakeDataFrameNaFunctions, + SnowflakeDataFrameStatFunctions, ) from sqlframe.snowflake.group import SnowflakeGroupedData from sqlframe.snowflake.readwriter import ( @@ -10,17 +11,22 @@ SnowflakeDataFrameWriter, ) from sqlframe.snowflake.session import SnowflakeSession +from sqlframe.snowflake.types import Row +from sqlframe.snowflake.udf import SnowflakeUDFRegistration from sqlframe.snowflake.window import Window, WindowSpec __all__ = [ - "SnowflakeCatalog", "Column", + "Row", + "SnowflakeCatalog", "SnowflakeDataFrame", "SnowflakeDataFrameNaFunctions", "SnowflakeGroupedData", "SnowflakeDataFrameReader", "SnowflakeDataFrameWriter", "SnowflakeSession", + "SnowflakeDataFrameStatFunctions", + "SnowflakeUDFRegistration", "Window", "WindowSpec", ] diff --git a/sqlframe/snowflake/session.py b/sqlframe/snowflake/session.py index a7b7fd2..0ef56a0 100644 --- a/sqlframe/snowflake/session.py +++ b/sqlframe/snowflake/session.py @@ -86,7 +86,6 @@ def session(self) -> SnowflakeSession: return SnowflakeSession(**self._session_kwargs) def getOrCreate(self) -> SnowflakeSession: - self._set_session_properties() - return self.session + return super().getOrCreate() # type: ignore builder = Builder() diff --git a/sqlframe/spark/__init__.py b/sqlframe/spark/__init__.py index 15504a3..2c1b247 100644 --- a/sqlframe/spark/__init__.py +++ b/sqlframe/spark/__init__.py @@ -1,23 +1,32 @@ from sqlframe.spark.catalog import SparkCatalog from sqlframe.spark.column import Column -from sqlframe.spark.dataframe import SparkDataFrame, SparkDataFrameNaFunctions +from sqlframe.spark.dataframe import ( + SparkDataFrame, + SparkDataFrameNaFunctions, + SparkDataFrameStatFunctions, +) from sqlframe.spark.group import SparkGroupedData from sqlframe.spark.readwriter import ( SparkDataFrameReader, SparkDataFrameWriter, ) from sqlframe.spark.session import SparkSession +from sqlframe.spark.types import Row +from sqlframe.spark.udf import SparkUDFRegistration from sqlframe.spark.window import Window, WindowSpec __all__ = [ - "SparkCatalog", "Column", + "Row", + "SparkCatalog", "SparkDataFrame", "SparkDataFrameNaFunctions", "SparkGroupedData", "SparkDataFrameReader", "SparkDataFrameWriter", "SparkSession", + "SparkDataFrameStatFunctions", + "SparkUDFRegistration", "Window", "WindowSpec", ] diff --git a/sqlframe/spark/session.py b/sqlframe/spark/session.py index af2efd9..2052543 100644 --- a/sqlframe/spark/session.py +++ b/sqlframe/spark/session.py @@ -162,5 +162,4 @@ def session(self) -> SparkSession: return SparkSession(**self._session_kwargs) def getOrCreate(self) -> SparkSession: - self._set_session_properties() - return self.session + return super().getOrCreate() # type: ignore diff --git a/sqlframe/standalone/__init__.py b/sqlframe/standalone/__init__.py index 74c3d06..f0b5b61 100644 --- a/sqlframe/standalone/__init__.py +++ b/sqlframe/standalone/__init__.py @@ -3,6 +3,7 @@ from sqlframe.standalone.dataframe import ( StandaloneDataFrame, StandaloneDataFrameNaFunctions, + StandaloneDataFrameStatFunctions, ) from sqlframe.standalone.group import StandaloneGroupedData from sqlframe.standalone.readwriter import ( @@ -10,17 +11,22 @@ StandaloneDataFrameWriter, ) from sqlframe.standalone.session import StandaloneSession +from sqlframe.standalone.types import Row +from sqlframe.standalone.udf import StandaloneUDFRegistration from sqlframe.standalone.window import Window, WindowSpec __all__ = [ - "StandaloneCatalog", "Column", + "Row", + "StandaloneCatalog", "StandaloneDataFrame", "StandaloneDataFrameNaFunctions", "StandaloneGroupedData", "StandaloneDataFrameReader", "StandaloneDataFrameWriter", "StandaloneSession", + "StandaloneDataFrameStatFunctions", + "StandaloneUDFRegistration", "Window", "WindowSpec", ] diff --git a/sqlframe/standalone/session.py b/sqlframe/standalone/session.py index 37b6006..e1bed5c 100644 --- a/sqlframe/standalone/session.py +++ b/sqlframe/standalone/session.py @@ -37,7 +37,6 @@ def session(self) -> StandaloneSession: return StandaloneSession() def getOrCreate(self) -> StandaloneSession: - self._set_session_properties() - return self.session + return super().getOrCreate() # type: ignore builder = Builder() diff --git a/tests/conftest.py b/tests/conftest.py index 59bd445..ecb108a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys import time import pytest diff --git a/tests/integration/engines/duck/test_duckdb_activate.py b/tests/integration/engines/duck/test_duckdb_activate.py new file mode 100644 index 0000000..2abb3b2 --- /dev/null +++ b/tests/integration/engines/duck/test_duckdb_activate.py @@ -0,0 +1,37 @@ +import pytest + +from sqlframe import activate + + +@pytest.mark.forked +def test_activate_with_connection(): + import duckdb + + connector = duckdb.connect() + connector.execute('CREATE SCHEMA "memory"."activate_test"') + connector.execute('CREATE TABLE "memory"."activate_test"."test" (a INT)') + connector.execute('INSERT INTO "memory"."activate_test"."test" VALUES (1)') + activate("duckdb", conn=connector) + from pyspark.sql import SparkSession + + spark = SparkSession.builder.appName("test").getOrCreate() + + df = spark.table("memory.activate_test.test").select("`a`") + assert df.collect() == [(1,)] + + +@pytest.mark.forked +def test_activate_with_connection_and_input_dialect(): + import duckdb + + connector = duckdb.connect() + connector.execute('CREATE SCHEMA "memory"."activate_test"') + connector.execute('CREATE TABLE "memory"."activate_test"."test" (a INT)') + connector.execute('INSERT INTO "memory"."activate_test"."test" VALUES (1)') + activate("duckdb", conn=connector, config={"sqlframe.input.dialect": "duckdb"}) + from pyspark.sql import SparkSession + + spark = SparkSession.builder.appName("test").getOrCreate() + + df = spark.table("memory.activate_test.test").select('"a"') + assert df.collect() == [(1,)] diff --git a/tests/integration/engines/postgres/test_postgres_activate.py b/tests/integration/engines/postgres/test_postgres_activate.py new file mode 100644 index 0000000..a8bcd48 --- /dev/null +++ b/tests/integration/engines/postgres/test_postgres_activate.py @@ -0,0 +1,37 @@ +import pytest + +from sqlframe import activate + +pytest_plugins = ["tests.common_fixtures"] + + +@pytest.mark.forked +def test_activate_with_connection(function_scoped_postgres): + cursor = function_scoped_postgres.cursor() + cursor.execute('CREATE SCHEMA "activate_test"') + cursor.execute('CREATE TABLE "activate_test"."test" (a INT)') + cursor.execute('INSERT INTO "activate_test"."test" VALUES (1)') + activate("postgres", conn=function_scoped_postgres) + from pyspark.sql import SparkSession + + spark = SparkSession.builder.appName("test").getOrCreate() + + df = spark.table("activate_test.test").select("`a`") + assert df.collect() == [(1,)] + + +@pytest.mark.forked +def test_activate_with_connection_and_input_dialect(function_scoped_postgres): + cursor = function_scoped_postgres.cursor() + cursor.execute('CREATE SCHEMA "activate_test"') + cursor.execute('CREATE TABLE "activate_test"."test" (a INT)') + cursor.execute('INSERT INTO "activate_test"."test" VALUES (1)') + activate( + "postgres", conn=function_scoped_postgres, config={"sqlframe.input.dialect": "postgres"} + ) + from pyspark.sql import SparkSession + + spark = SparkSession.builder.appName("test").getOrCreate() + + df = spark.table("activate_test.test").select('"a"') + assert df.collect() == [(1,)] diff --git a/tests/unit/bigquery/__init__.py b/tests/unit/bigquery/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/bigquery/test_activate.py b/tests/unit/bigquery/test_activate.py new file mode 100644 index 0000000..9a1efec --- /dev/null +++ b/tests/unit/bigquery/test_activate.py @@ -0,0 +1,51 @@ +import pytest + +from sqlframe import activate +from sqlframe.bigquery import ( + BigQueryCatalog, + BigQueryDataFrame, + BigQueryDataFrameNaFunctions, + BigQueryDataFrameReader, + BigQueryDataFrameStatFunctions, + BigQueryDataFrameWriter, + BigQueryGroupedData, + BigQuerySession, + BigQueryUDFRegistration, +) +from sqlframe.bigquery import Column as BigQueryColumn +from sqlframe.bigquery import Row as BigQueryRow +from sqlframe.bigquery import Window as BigQueryWindow +from sqlframe.bigquery import WindowSpec as BigQueryWindowSpec +from sqlframe.bigquery import functions as BigQueryF +from sqlframe.bigquery import types as BigQueryTypes + + +@pytest.mark.forked +def test_activate_bigquery(): + check_pyspark_imports( + "bigquery", + sqlf_session=BigQuerySession, + sqlf_catalog=BigQueryCatalog, + sqlf_column=BigQueryColumn, + sqlf_dataframe=BigQueryDataFrame, + sqlf_grouped_data=BigQueryGroupedData, + sqlf_window=BigQueryWindow, + sqlf_window_spec=BigQueryWindowSpec, + sqlf_functions=BigQueryF, + sqlf_types=BigQueryTypes, + sqlf_udf_registration=BigQueryUDFRegistration, + sqlf_dataframe_reader=BigQueryDataFrameReader, + sqlf_dataframe_writer=BigQueryDataFrameWriter, + sqlf_dataframe_na_functions=BigQueryDataFrameNaFunctions, + sqlf_dataframe_stat_functions=BigQueryDataFrameStatFunctions, + sqlf_row=BigQueryRow, + ) + + +@pytest.mark.forked +def test_activate_bigquery_default_dataset(): + activate("bigquery", config={"default_dataset": "sqlframe.sqlframe_test"}) + from pyspark.sql import SparkSession + + spark = SparkSession.builder.appName("test").getOrCreate() + assert spark.default_dataset == "sqlframe.sqlframe_test" diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 0000000..d616803 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,124 @@ +import sys + +import findspark +import pytest + +from sqlframe import activate + + +@pytest.fixture() +def check_pyspark_imports(temporary_sys_modules): + def _check_pyspark_imports( + engine_name, + sqlf_session, + sqlf_catalog, + sqlf_column, + sqlf_dataframe, + sqlf_grouped_data, + sqlf_window, + sqlf_window_spec, + sqlf_functions, + sqlf_types, + sqlf_udf_registration, + sqlf_dataframe_reader, + sqlf_dataframe_writer, + sqlf_dataframe_na_functions, + sqlf_dataframe_stat_functions, + sqlf_row, + ): + activate(engine=engine_name) + findspark.init() + # TODO: functions, types, udf + from pyspark.sql import ( + Catalog, + Column, + DataFrame, + DataFrameNaFunctions, + DataFrameReader, + DataFrameStatFunctions, + DataFrameWriter, + GroupedData, + Row, + SparkSession, + UDFRegistration, + Window, + WindowSpec, + types, + ) + from pyspark.sql import functions as F + + assert SparkSession == sqlf_session + assert Catalog == sqlf_catalog + assert Column == sqlf_column + assert GroupedData == sqlf_grouped_data + assert DataFrame == sqlf_dataframe + assert Window == sqlf_window + assert WindowSpec == sqlf_window_spec + assert F == sqlf_functions + assert types == sqlf_types + assert UDFRegistration == sqlf_udf_registration + assert DataFrameNaFunctions == sqlf_dataframe_na_functions + assert DataFrameStatFunctions == sqlf_dataframe_stat_functions + assert DataFrameReader == sqlf_dataframe_reader + assert DataFrameWriter == sqlf_dataframe_writer + assert Row == sqlf_row + + from pyspark.sql.session import SparkSession as SparkSession2 + + assert SparkSession2 == sqlf_session + + from pyspark.sql.catalog import Catalog as Catalog2 + + assert Catalog2 == sqlf_catalog + + from pyspark.sql.column import Column as Column2 + + assert Column2 == sqlf_column + + from pyspark.sql.dataframe import DataFrame as DataFrame2 + from pyspark.sql.dataframe import DataFrameNaFunctions as DataFrameNaFunctions2 + from pyspark.sql.dataframe import ( + DataFrameStatFunctions as DataFrameStatFunctions2, + ) + + assert DataFrame2 == sqlf_dataframe + assert DataFrameNaFunctions2 == sqlf_dataframe_na_functions + assert DataFrameStatFunctions2 == sqlf_dataframe_stat_functions + + from pyspark.sql.group import GroupedData as GroupedData2 + + assert GroupedData2 == sqlf_grouped_data + + from pyspark.sql.window import WindowSpec as WindowSpec2 + + assert WindowSpec2 == sqlf_window_spec + + from pyspark.sql.readwriter import DataFrameReader as DataFrameReader2 + from pyspark.sql.readwriter import DataFrameWriter as DataFrameWriter2 + + assert DataFrameReader2 == sqlf_dataframe_reader + assert DataFrameWriter2 == sqlf_dataframe_writer + + from pyspark.sql.window import Window, WindowSpec + + assert Window == sqlf_window + assert WindowSpec == sqlf_window_spec + + from pyspark.sql import functions as F + from pyspark.sql import types + + assert F == sqlf_functions + assert types == sqlf_types + assert types.Row == sqlf_row + + from pyspark.sql import UDFRegistration + + assert UDFRegistration == sqlf_udf_registration + + import pyspark.sql.functions as F + import pyspark.sql.types as types + + assert F == sqlf_functions + assert types == sqlf_types + + return _check_pyspark_imports diff --git a/tests/unit/duck/__init__.py b/tests/unit/duck/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/duck/test_activate.py b/tests/unit/duck/test_activate.py new file mode 100644 index 0000000..b7e192b --- /dev/null +++ b/tests/unit/duck/test_activate.py @@ -0,0 +1,41 @@ +import pytest + +from sqlframe.duckdb import Column as DuckDBColumn +from sqlframe.duckdb import ( + DuckDBCatalog, + DuckDBDataFrame, + DuckDBDataFrameNaFunctions, + DuckDBDataFrameReader, + DuckDBDataFrameStatFunctions, + DuckDBDataFrameWriter, + DuckDBGroupedData, + DuckDBSession, + DuckDBUDFRegistration, +) +from sqlframe.duckdb import Row as DuckDBRow +from sqlframe.duckdb import Window as DuckDBWindow +from sqlframe.duckdb import WindowSpec as DuckDBWindowSpec +from sqlframe.duckdb import functions as DuckDBF +from sqlframe.duckdb import types as DuckDBTypes + + +@pytest.mark.forked +def test_activate_duckdb(): + check_pyspark_imports( + "duckdb", + sqlf_session=DuckDBSession, + sqlf_catalog=DuckDBCatalog, + sqlf_column=DuckDBColumn, + sqlf_dataframe=DuckDBDataFrame, + sqlf_grouped_data=DuckDBGroupedData, + sqlf_window=DuckDBWindow, + sqlf_window_spec=DuckDBWindowSpec, + sqlf_functions=DuckDBF, + sqlf_types=DuckDBTypes, + sqlf_udf_registration=DuckDBUDFRegistration, + sqlf_dataframe_reader=DuckDBDataFrameReader, + sqlf_dataframe_writer=DuckDBDataFrameWriter, + sqlf_dataframe_na_functions=DuckDBDataFrameNaFunctions, + sqlf_dataframe_stat_functions=DuckDBDataFrameStatFunctions, + sqlf_row=DuckDBRow, + ) diff --git a/tests/unit/postgres/__init__.py b/tests/unit/postgres/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/postgres/test_activate.py b/tests/unit/postgres/test_activate.py new file mode 100644 index 0000000..874b77e --- /dev/null +++ b/tests/unit/postgres/test_activate.py @@ -0,0 +1,41 @@ +import pytest + +from sqlframe.postgres import Column as PostgresColumn +from sqlframe.postgres import ( + PostgresCatalog, + PostgresDataFrame, + PostgresDataFrameNaFunctions, + PostgresDataFrameReader, + PostgresDataFrameStatFunctions, + PostgresDataFrameWriter, + PostgresGroupedData, + PostgresSession, + PostgresUDFRegistration, +) +from sqlframe.postgres import Row as PostgresRow +from sqlframe.postgres import Window as PostgresWindow +from sqlframe.postgres import WindowSpec as PostgresWindowSpec +from sqlframe.postgres import functions as PostgresF +from sqlframe.postgres import types as PostgresTypes + + +@pytest.mark.forked +def test_activate_postgres(check_pyspark_imports): + check_pyspark_imports( + "postgres", + sqlf_session=PostgresSession, + sqlf_catalog=PostgresCatalog, + sqlf_column=PostgresColumn, + sqlf_dataframe=PostgresDataFrame, + sqlf_grouped_data=PostgresGroupedData, + sqlf_window=PostgresWindow, + sqlf_window_spec=PostgresWindowSpec, + sqlf_functions=PostgresF, + sqlf_types=PostgresTypes, + sqlf_udf_registration=PostgresUDFRegistration, + sqlf_dataframe_reader=PostgresDataFrameReader, + sqlf_dataframe_writer=PostgresDataFrameWriter, + sqlf_dataframe_na_functions=PostgresDataFrameNaFunctions, + sqlf_dataframe_stat_functions=PostgresDataFrameStatFunctions, + sqlf_row=PostgresRow, + ) diff --git a/tests/unit/redshift/__init__.py b/tests/unit/redshift/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/redshift/test_activate.py b/tests/unit/redshift/test_activate.py new file mode 100644 index 0000000..7dfd804 --- /dev/null +++ b/tests/unit/redshift/test_activate.py @@ -0,0 +1,41 @@ +import pytest + +from sqlframe.redshift import Column as RedshiftColumn +from sqlframe.redshift import ( + RedshiftCatalog, + RedshiftDataFrame, + RedshiftDataFrameNaFunctions, + RedshiftDataFrameReader, + RedshiftDataFrameStatFunctions, + RedshiftDataFrameWriter, + RedshiftGroupedData, + RedshiftSession, + RedshiftUDFRegistration, +) +from sqlframe.redshift import Row as RedshiftRow +from sqlframe.redshift import Window as RedshiftWindow +from sqlframe.redshift import WindowSpec as RedshiftWindowSpec +from sqlframe.redshift import functions as RedshiftF +from sqlframe.redshift import types as RedshiftTypes + + +@pytest.mark.forked +def test_activate_redshift(check_pyspark_imports): + check_pyspark_imports( + "redshift", + sqlf_session=RedshiftSession, + sqlf_catalog=RedshiftCatalog, + sqlf_column=RedshiftColumn, + sqlf_dataframe=RedshiftDataFrame, + sqlf_grouped_data=RedshiftGroupedData, + sqlf_window=RedshiftWindow, + sqlf_window_spec=RedshiftWindowSpec, + sqlf_functions=RedshiftF, + sqlf_types=RedshiftTypes, + sqlf_udf_registration=RedshiftUDFRegistration, + sqlf_dataframe_reader=RedshiftDataFrameReader, + sqlf_dataframe_writer=RedshiftDataFrameWriter, + sqlf_dataframe_na_functions=RedshiftDataFrameNaFunctions, + sqlf_dataframe_stat_functions=RedshiftDataFrameStatFunctions, + sqlf_row=RedshiftRow, + ) diff --git a/tests/unit/snowflake/__init__.py b/tests/unit/snowflake/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/snowflake/test_activate.py b/tests/unit/snowflake/test_activate.py new file mode 100644 index 0000000..b0adb6d --- /dev/null +++ b/tests/unit/snowflake/test_activate.py @@ -0,0 +1,41 @@ +import pytest + +from sqlframe.snowflake import Column as SnowflakeColumn +from sqlframe.snowflake import Row as SnowflakeRow +from sqlframe.snowflake import ( + SnowflakeCatalog, + SnowflakeDataFrame, + SnowflakeDataFrameNaFunctions, + SnowflakeDataFrameReader, + SnowflakeDataFrameStatFunctions, + SnowflakeDataFrameWriter, + SnowflakeGroupedData, + SnowflakeSession, + SnowflakeUDFRegistration, +) +from sqlframe.snowflake import Window as SnowflakeWindow +from sqlframe.snowflake import WindowSpec as SnowflakeWindowSpec +from sqlframe.snowflake import functions as SnowflakeF +from sqlframe.snowflake import types as SnowflakeTypes + + +@pytest.mark.forked +def test_activate_snowflake(): + check_pyspark_imports( + "snowflake", + sqlf_session=SnowflakeSession, + sqlf_catalog=SnowflakeCatalog, + sqlf_column=SnowflakeColumn, + sqlf_dataframe=SnowflakeDataFrame, + sqlf_grouped_data=SnowflakeGroupedData, + sqlf_window=SnowflakeWindow, + sqlf_window_spec=SnowflakeWindowSpec, + sqlf_functions=SnowflakeF, + sqlf_types=SnowflakeTypes, + sqlf_udf_registration=SnowflakeUDFRegistration, + sqlf_dataframe_reader=SnowflakeDataFrameReader, + sqlf_dataframe_writer=SnowflakeDataFrameWriter, + sqlf_dataframe_na_functions=SnowflakeDataFrameNaFunctions, + sqlf_dataframe_stat_functions=SnowflakeDataFrameStatFunctions, + sqlf_row=SnowflakeRow, + ) diff --git a/tests/unit/spark/__init__.py b/tests/unit/spark/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/spark/test_activate.py b/tests/unit/spark/test_activate.py new file mode 100644 index 0000000..8826541 --- /dev/null +++ b/tests/unit/spark/test_activate.py @@ -0,0 +1,41 @@ +import pytest + +from sqlframe.spark import Column as SparkColumn +from sqlframe.spark import Row as SparkRow +from sqlframe.spark import ( + SparkCatalog, + SparkDataFrame, + SparkDataFrameNaFunctions, + SparkDataFrameReader, + SparkDataFrameStatFunctions, + SparkDataFrameWriter, + SparkGroupedData, + SparkSession, + SparkUDFRegistration, +) +from sqlframe.spark import Window as SparkWindow +from sqlframe.spark import WindowSpec as SparkWindowSpec +from sqlframe.spark import functions as SparkF +from sqlframe.spark import types as SparkTypes + + +@pytest.mark.forked +def test_replace_pyspark_spark(check_pyspark_imports): + check_pyspark_imports( + "spark", + sqlf_session=SparkSession, + sqlf_catalog=SparkCatalog, + sqlf_column=SparkColumn, + sqlf_dataframe=SparkDataFrame, + sqlf_grouped_data=SparkGroupedData, + sqlf_window=SparkWindow, + sqlf_window_spec=SparkWindowSpec, + sqlf_functions=SparkF, + sqlf_types=SparkTypes, + sqlf_udf_registration=SparkUDFRegistration, + sqlf_dataframe_reader=SparkDataFrameReader, + sqlf_dataframe_writer=SparkDataFrameWriter, + sqlf_dataframe_na_functions=SparkDataFrameNaFunctions, + sqlf_dataframe_stat_functions=SparkDataFrameStatFunctions, + sqlf_row=SparkRow, + ) diff --git a/tests/unit/standalone/test_activate.py b/tests/unit/standalone/test_activate.py new file mode 100644 index 0000000..4bf7c21 --- /dev/null +++ b/tests/unit/standalone/test_activate.py @@ -0,0 +1,41 @@ +import pytest + +from sqlframe.standalone import Column as StandaloneColumn +from sqlframe.standalone import Row as StandaloneRow +from sqlframe.standalone import ( + StandaloneCatalog, + StandaloneDataFrame, + StandaloneDataFrameNaFunctions, + StandaloneDataFrameReader, + StandaloneDataFrameStatFunctions, + StandaloneDataFrameWriter, + StandaloneGroupedData, + StandaloneSession, + StandaloneUDFRegistration, +) +from sqlframe.standalone import Window as StandaloneWindow +from sqlframe.standalone import WindowSpec as StandaloneWindowSpec +from sqlframe.standalone import functions as StandaloneF +from sqlframe.standalone import types as StandaloneTypes + + +@pytest.mark.forked +def test_activate_standalone(): + check_pyspark_imports( + "standalone", + sqlf_session=StandaloneSession, + sqlf_catalog=StandaloneCatalog, + sqlf_column=StandaloneColumn, + sqlf_dataframe=StandaloneDataFrame, + sqlf_grouped_data=StandaloneGroupedData, + sqlf_window=StandaloneWindow, + sqlf_window_spec=StandaloneWindowSpec, + sqlf_functions=StandaloneF, + sqlf_types=StandaloneTypes, + sqlf_udf_registration=StandaloneUDFRegistration, + sqlf_dataframe_reader=StandaloneDataFrameReader, + sqlf_dataframe_writer=StandaloneDataFrameWriter, + sqlf_dataframe_na_functions=StandaloneDataFrameNaFunctions, + sqlf_dataframe_stat_functions=StandaloneDataFrameStatFunctions, + sqlf_row=StandaloneRow, + ) diff --git a/tests/unit/test_activate.py b/tests/unit/test_activate.py new file mode 100644 index 0000000..d456b79 --- /dev/null +++ b/tests/unit/test_activate.py @@ -0,0 +1,37 @@ +import sys +from unittest.mock import MagicMock + +import findspark +import pytest + +from sqlframe import activate +from sqlframe import testing as SQLFrameTesting + + +@pytest.mark.forked +def test_activate_testing(): + activate() + findspark.init() + from pyspark import testing + + assert testing == SQLFrameTesting + assert testing.assertDataFrameEqual == SQLFrameTesting.assertDataFrameEqual + assert testing.assertSchemaEqual == SQLFrameTesting.assertSchemaEqual + from pyspark.testing import assertDataFrameEqual, assertSchemaEqual + + assert assertDataFrameEqual == SQLFrameTesting.assertDataFrameEqual + assert assertSchemaEqual == SQLFrameTesting.assertSchemaEqual + import pyspark.testing as testing + + assert testing == SQLFrameTesting + assert testing.assertDataFrameEqual == SQLFrameTesting.assertDataFrameEqual + + +@pytest.mark.forked +def test_activate_no_engine(): + activate() + findspark.init() + # A way that people check if pyspark is available + from pyspark import context + + assert isinstance(context, MagicMock)