diff --git a/docs/bigquery.md b/docs/bigquery.md index 72c2fc2..70c3a13 100644 --- a/docs/bigquery.md +++ b/docs/bigquery.md @@ -599,3 +599,215 @@ See something that you would like to see supported? [Open an issue](https://gith * [rowsBetween](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.WindowSpec.rowsBetween.html) * sql * SQLFrame Specific: Get the SQL representation of the WindowSpec + + +## Extra Functionality not Present in PySpark + +SQLFrame supports the following extra functionality not in PySpark + +### Table Class + +SQLFrame provides a `Table` class that supports extra DML operations like `update`, `delete` and `merge`. This class is returned when using the `table` function from the `DataFrameReader` class. + +```python +import google.auth +from google.api_core import client_info +from google.oauth2 import service_account +from google.cloud.bigquery.dbapi import connect +from sqlframe.bigquery import BigQuerySession +from sqlframe.base.table import WhenMatched, WhenNotMatched, WhenNotMatchedBySource + +creds = service_account.Credentials.from_service_account_file("path/to/credentials.json") + +client = google.cloud.bigquery.Client( + project="my-project", + credentials=creds, + location="us-central1", + client_info=client_info.ClientInfo(user_agent="sqlframe"), +) + +conn = connect(client=client) +session = BigQuerySession(conn=conn, default_dataset="sqlframe.db1") + +df_employee = session.createDataFrame( + [ + {"id": 1, "fname": "Jack", "lname": "Shephard", "age": 37, "store_id": 1}, + {"id": 2, "fname": "John", "lname": "Locke", "age": 65, "store_id": 2}, + {"id": 3, "fname": "Kate", "lname": "Austen", "age": 37, "store_id": 3}, + {"id": 4, "fname": "Claire", "lname": "Littleton", "age": 27, "store_id": 1}, + {"id": 5, "fname": "Hugo", "lname": "Reyes", "age": 29, "store_id": 3}, + ] +) + +df_employee.write.mode("overwrite").saveAsTable("employee") + +table_employee = session.table("employee") # This object is of Type DatabricksTable +``` + +#### Update Statement +The `update` method of the `Table` class is equivalent to the `UPDATE table_name` statement used in standard `sql`. + +```python +# Generates a `LazyExpression` object which can be executed using the `execute` method +update_expr = table_employee.update( + set_={"age": table_employee["age"] + 1}, + where=table_employee["id"] == 1, +) + +# Excecutes the update statement +update_expr.execute() + +# Show the result +table_employee.show() +``` + +Output: +``` ++----+--------+-----------+-----+----------+ +| id | fname | lname | age | store_id | ++----+--------+-----------+-----+----------+ +| 1 | Jack | Shephard | 38 | 1 | +| 2 | John | Locke | 65 | 2 | +| 3 | Kate | Austen | 37 | 3 | +| 4 | Claire | Littleton | 27 | 1 | +| 5 | Hugo | Reyes | 29 | 3 | ++----+--------+-----------+-----+----------+ +``` +#### Delete Statement +The `delete` method of the `Table` class is equivalent to the `DELETE FROM table_name` statement used in standard `sql`. + +```python +# Generates a `LazyExpression` object which can be executed using the `execute` method +delete_expr = table_employee.delete( + where=table_employee["id"] == 1, +) + +# Excecutes the delete statement +delete_expr.execute() + +# Show the result +table_employee.show() +``` + +Output: +``` ++----+--------+-----------+-----+----------+ +| id | fname | lname | age | store_id | ++----+--------+-----------+-----+----------+ +| 2 | John | Locke | 65 | 2 | +| 3 | Kate | Austen | 37 | 3 | +| 4 | Claire | Littleton | 27 | 1 | +| 5 | Hugo | Reyes | 29 | 3 | ++----+--------+-----------+-----+----------+ +``` +#### Merge Statement + +The `merge` method of the `Table` class is equivalent to the `MERGE INTO table_name` statement used in some `sql` engines. + +```python +df_new_employee = session.createDataFrame( + [ + {"id": 1, "fname": "Jack", "lname": "Shephard", "age": 38, "store_id": 1, "delete": False}, + {"id": 2, "fname": "Cate", "lname": "Austen", "age": 39, "store_id": 5, "delete": False}, + {"id": 5, "fname": "Ugo", "lname": "Reyes", "age": 29, "store_id": 3, "delete": True}, + {"id": 6, "fname": "Sun-Hwa", "lname": "Kwon", "age": 27, "store_id": 5, "delete": False}, + ] +) + +# Generates a `LazyExpression` object which can be executed using the `execute` method +merge_expr = table_employee.merge( + df_new_employee, + condition=table_employee["id"] == df_new_employee["id"], + clauses=[ + WhenMatched(condition=table_employee["fname"] == df_new_employee["fname"]).update( + set_={ + "age": df_new_employee["age"], + } + ), + WhenMatched(condition=df_new_employee["delete"]).delete(), + WhenNotMatched().insert( + values={ + "id": df_new_employee["id"], + "fname": df_new_employee["fname"], + "lname": df_new_employee["lname"], + "age": df_new_employee["age"], + "store_id": df_new_employee["store_id"], + } + ), + ], +) + +# Excecutes the merge statement +merge_expr.execute() + +# Show the result +table_employee.show() +``` + +Output: +``` ++----+---------+-----------+-----+----------+ +| id | fname | lname | age | store_id | ++----+---------+-----------+-----+----------+ +| 1 | Jack | Shephard | 38 | 1 | +| 2 | John | Locke | 65 | 2 | +| 3 | Kate | Austen | 37 | 3 | +| 4 | Claire | Littleton | 27 | 1 | +| 6 | Sun-Hwa | Kwon | 27 | 5 | ++----+---------+-----------+-----+----------+ +``` + + +Some engines like `BigQuery` support an extra clause inside the `merge` statement which is `WHEN NOT MATCHED BY SOURCE THEN DELETE`. + +```python +df_new_employee = session.createDataFrame( + [ + {"id": 1, "fname": "Jack", "lname": "Shephard", "age": 38, "store_id": 1}, + {"id": 2, "fname": "Cate", "lname": "Austen", "age": 39, "store_id": 5}, + {"id": 5, "fname": "Hugo", "lname": "Reyes", "age": 29, "store_id": 3}, + {"id": 6, "fname": "Sun-Hwa", "lname": "Kwon", "age": 27, "store_id": 5}, + ] +) + +# Generates a `LazyExpression` object which can be executed using the `execute` method +merge_expr = table_employee.merge( + df_new_employee, + condition=table_employee["id"] == df_new_employee["id"], + clauses=[ + WhenMatched(condition=table_employee["fname"] == df_new_employee["fname"]).update( + set_={ + "age": df_new_employee["age"], + } + ), + WhenNotMatched().insert( + values={ + "id": df_new_employee["id"], + "fname": df_new_employee["fname"], + "lname": df_new_employee["lname"], + "age": df_new_employee["age"], + "store_id": df_new_employee["store_id"], + } + ), + WhenNotMatchedBySource().delete(), + ], +) + +# Excecutes the merge statement +merge_expr.execute() + +# Show the result +table_employee.show() +``` + +Output: +``` ++----+---------+-----------+-----+----------+ +| id | fname | lname | age | store_id | ++----+---------+-----------+-----+----------+ +| 1 | Jack | Shephard | 38 | 1 | +| 2 | John | Locke | 65 | 2 | +| 5 | Hugo | Reyes | 29 | 3 | +| 6 | Sun-Hwa | Kwon | 27 | 5 | ++----+---------+-----------+-----+----------+ +``` diff --git a/docs/databricks.md b/docs/databricks.md index 2ab042a..1a49b10 100644 --- a/docs/databricks.md +++ b/docs/databricks.md @@ -1,3 +1,5 @@ +from sqlframe.base.table import WhenNotMatchedBySourcefrom sqlframe.base.table import WhenMatched + # Databricks (In Development) ## Installation @@ -153,3 +155,211 @@ print(session.catalog.listColumns(table_path)) +----------------+-------------------+ """ ``` + +## Extra Functionality not Present in PySpark + +SQLFrame supports the following extra functionality not in PySpark + +### Table Class + +SQLFrame provides a `Table` class that supports extra DML operations like `update`, `delete` and `merge`. This class is returned when using the `table` function from the `DataFrameReader` class. + +```python +import os + +from databricks.sql import connect +from sqlframe.databricks import DatabricksSession +from sqlframe.base.table import WhenMatched, WhenNotMatched, WhenNotMatchedBySource + +conn = connect( + server_hostname="dbc-xxxxxxxx-xxxx.cloud.databricks.com", + http_path="/sql/1.0/warehouses/xxxxxxxxxxxxxxxx", + access_token=os.environ["ACCESS_TOKEN"], # Replace this with how you get your databricks access token + auth_type="access_token", + catalog="catalog", + schema="schema", +) +session = DatabricksSession(conn=conn) + +df_employee = session.createDataFrame( + [ + {"id": 1, "fname": "Jack", "lname": "Shephard", "age": 37, "store_id": 1}, + {"id": 2, "fname": "John", "lname": "Locke", "age": 65, "store_id": 2}, + {"id": 3, "fname": "Kate", "lname": "Austen", "age": 37, "store_id": 3}, + {"id": 4, "fname": "Claire", "lname": "Littleton", "age": 27, "store_id": 1}, + {"id": 5, "fname": "Hugo", "lname": "Reyes", "age": 29, "store_id": 3}, + ] +) + +df_employee.write.mode("overwrite").saveAsTable("employee") + +table_employee = session.table("employee") # This object is of Type DatabricksTable +``` + +#### Update Statement +The `update` method of the `Table` class is equivalent to the `UPDATE table_name` statement used in standard `sql`. + +```python +# Generates a `LazyExpression` object which can be executed using the `execute` method +update_expr = table_employee.update( + set_={"age": table_employee["age"] + 1}, + where=table_employee["id"] == 1, +) + +# Excecutes the update statement +update_expr.execute() + +# Show the result +table_employee.show() +``` + +Output: +``` ++----+--------+-----------+-----+----------+ +| id | fname | lname | age | store_id | ++----+--------+-----------+-----+----------+ +| 1 | Jack | Shephard | 38 | 1 | +| 2 | John | Locke | 65 | 2 | +| 3 | Kate | Austen | 37 | 3 | +| 4 | Claire | Littleton | 27 | 1 | +| 5 | Hugo | Reyes | 29 | 3 | ++----+--------+-----------+-----+----------+ +``` +#### Delete Statement +The `delete` method of the `Table` class is equivalent to the `DELETE FROM table_name` statement used in standard `sql`. + +```python +# Generates a `LazyExpression` object which can be executed using the `execute` method +delete_expr = table_employee.delete( + where=table_employee["id"] == 1, +) + +# Excecutes the delete statement +delete_expr.execute() + +# Show the result +table_employee.show() +``` + +Output: +``` ++----+--------+-----------+-----+----------+ +| id | fname | lname | age | store_id | ++----+--------+-----------+-----+----------+ +| 2 | John | Locke | 65 | 2 | +| 3 | Kate | Austen | 37 | 3 | +| 4 | Claire | Littleton | 27 | 1 | +| 5 | Hugo | Reyes | 29 | 3 | ++----+--------+-----------+-----+----------+ +``` +#### Merge Statement + +The `merge` method of the `Table` class is equivalent to the `MERGE INTO table_name` statement used in some `sql` engines. + +```python +df_new_employee = session.createDataFrame( + [ + {"id": 1, "fname": "Jack", "lname": "Shephard", "age": 38, "store_id": 1, "delete": False}, + {"id": 2, "fname": "Cate", "lname": "Austen", "age": 39, "store_id": 5, "delete": False}, + {"id": 5, "fname": "Ugo", "lname": "Reyes", "age": 29, "store_id": 3, "delete": True}, + {"id": 6, "fname": "Sun-Hwa", "lname": "Kwon", "age": 27, "store_id": 5, "delete": False}, + ] +) + +# Generates a `LazyExpression` object which can be executed using the `execute` method +merge_expr = table_employee.merge( + df_new_employee, + condition=table_employee["id"] == df_new_employee["id"], + clauses=[ + WhenMatched(condition=table_employee["fname"] == df_new_employee["fname"]).update( + set_={ + "age": df_new_employee["age"], + } + ), + WhenMatched(condition=df_new_employee["delete"]).delete(), + WhenNotMatched().insert( + values={ + "id": df_new_employee["id"], + "fname": df_new_employee["fname"], + "lname": df_new_employee["lname"], + "age": df_new_employee["age"], + "store_id": df_new_employee["store_id"], + } + ), + ], +) + +# Excecutes the merge statement +merge_expr.execute() + +# Show the result +table_employee.show() +``` + +Output: +``` ++----+---------+-----------+-----+----------+ +| id | fname | lname | age | store_id | ++----+---------+-----------+-----+----------+ +| 1 | Jack | Shephard | 38 | 1 | +| 2 | John | Locke | 65 | 2 | +| 3 | Kate | Austen | 37 | 3 | +| 4 | Claire | Littleton | 27 | 1 | +| 6 | Sun-Hwa | Kwon | 27 | 5 | ++----+---------+-----------+-----+----------+ +``` + + +Some engines like `Databricks` support an extra clause inside the `merge` statement which is `WHEN NOT MATCHED BY SOURCE THEN DELETE`. + +```python +df_new_employee = session.createDataFrame( + [ + {"id": 1, "fname": "Jack", "lname": "Shephard", "age": 38, "store_id": 1}, + {"id": 2, "fname": "Cate", "lname": "Austen", "age": 39, "store_id": 5}, + {"id": 5, "fname": "Hugo", "lname": "Reyes", "age": 29, "store_id": 3}, + {"id": 6, "fname": "Sun-Hwa", "lname": "Kwon", "age": 27, "store_id": 5}, + ] +) + +# Generates a `LazyExpression` object which can be executed using the `execute` method +merge_expr = table_employee.merge( + df_new_employee, + condition=table_employee["id"] == df_new_employee["id"], + clauses=[ + WhenMatched(condition=table_employee["fname"] == df_new_employee["fname"]).update( + set_={ + "age": df_new_employee["age"], + } + ), + WhenNotMatched().insert( + values={ + "id": df_new_employee["id"], + "fname": df_new_employee["fname"], + "lname": df_new_employee["lname"], + "age": df_new_employee["age"], + "store_id": df_new_employee["store_id"], + } + ), + WhenNotMatchedBySource().delete(), + ], +) + +# Excecutes the merge statement +merge_expr.execute() + +# Show the result +table_employee.show() +``` + +Output: +``` ++----+---------+-----------+-----+----------+ +| id | fname | lname | age | store_id | ++----+---------+-----------+-----+----------+ +| 1 | Jack | Shephard | 38 | 1 | +| 2 | John | Locke | 65 | 2 | +| 5 | Hugo | Reyes | 29 | 3 | +| 6 | Sun-Hwa | Kwon | 27 | 5 | ++----+---------+-----------+-----+----------+ +``` diff --git a/docs/duckdb.md b/docs/duckdb.md index 4b24354..d0f664d 100644 --- a/docs/duckdb.md +++ b/docs/duckdb.md @@ -564,3 +564,90 @@ See something that you would like to see supported? [Open an issue](https://gith * [rowsBetween](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.WindowSpec.rowsBetween.html) * sql * SQLFrame Specific: Get the SQL representation of the WindowSpec + +## Extra Functionality not Present in PySpark + +SQLFrame supports the following extra functionality not in PySpark + +### Table Class + +SQLFrame provides a `Table` class that supports extra DML operations like `update` and `delete`. This class is returned when using the `table` function from the `DataFrameReader` class. + +```python +import duckdb +from sqlframe.duckdb import DuckDBSession + +conn = duckdb.connect(database=":memory:") +session = DuckDBSession(conn=conn) + +df_employee = session.createDataFrame( + [ + {"id": 1, "fname": "Jack", "lname": "Shephard", "age": 37, "store_id": 1}, + {"id": 2, "fname": "John", "lname": "Locke", "age": 65, "store_id": 2}, + {"id": 3, "fname": "Kate", "lname": "Austen", "age": 37, "store_id": 3}, + {"id": 4, "fname": "Claire", "lname": "Littleton", "age": 27, "store_id": 1}, + {"id": 5, "fname": "Hugo", "lname": "Reyes", "age": 29, "store_id": 3}, + ] +) + +df_employee.write.mode("overwrite").saveAsTable("employee") + +table_employee = session.table("employee") # This object is of Type DatabricksTable +``` + +#### Update Statement +The `update` method of the `Table` class is equivalent to the `UPDATE table_name` statement used in standard `sql`. + +```python +# Generates a `LazyExpression` object which can be executed using the `execute` method +update_expr = table_employee.update( + set_={"age": table_employee["age"] + 1}, + where=table_employee["id"] == 1, +) + +# Excecutes the update statement +update_expr.execute() + +# Show the result +table_employee.show() +``` + +Output: +``` ++----+--------+-----------+-----+----------+ +| id | fname | lname | age | store_id | ++----+--------+-----------+-----+----------+ +| 1 | Jack | Shephard | 38 | 1 | +| 2 | John | Locke | 65 | 2 | +| 3 | Kate | Austen | 37 | 3 | +| 4 | Claire | Littleton | 27 | 1 | +| 5 | Hugo | Reyes | 29 | 3 | ++----+--------+-----------+-----+----------+ +``` +#### Delete Statement +The `delete` method of the `Table` class is equivalent to the `DELETE FROM table_name` statement used in standard `sql`. + +```python +# Generates a `LazyExpression` object which can be executed using the `execute` method +delete_expr = table_employee.delete( + where=table_employee["id"] == 1, +) + +# Excecutes the delete statement +delete_expr.execute() + +# Show the result +table_employee.show() +``` + +Output: +``` ++----+--------+-----------+-----+----------+ +| id | fname | lname | age | store_id | ++----+--------+-----------+-----+----------+ +| 2 | John | Locke | 65 | 2 | +| 3 | Kate | Austen | 37 | 3 | +| 4 | Claire | Littleton | 27 | 1 | +| 5 | Hugo | Reyes | 29 | 3 | ++----+--------+-----------+-----+----------+ +``` diff --git a/docs/postgres.md b/docs/postgres.md index be2f64c..44844f2 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -552,3 +552,208 @@ See something that you would like to see supported? [Open an issue](https://gith * [rowsBetween](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.WindowSpec.rowsBetween.html) * sql * SQLFrame Specific: Get the SQL representation of the WindowSpec + +## Extra Functionality not Present in PySpark + +SQLFrame supports the following extra functionality not in PySpark + +### Table Class + +SQLFrame provides a `Table` class that supports extra DML operations like `update`, `delete` and `merge`. This class is returned when using the `table` function from the `DataFrameReader` class. + +```python +from psycopg2 import connect +from sqlframe.postgres import PostgresSession +from sqlframe.base.table import WhenMatched, WhenNotMatched, WhenNotMatchedBySource + +conn = connect( + dbname="postgres", + user="postgres", + password="password", + host="localhost", + port="5432", +) +session = PostgresSession(conn=conn) + +df_employee = session.createDataFrame( + [ + {"id": 1, "fname": "Jack", "lname": "Shephard", "age": 37, "store_id": 1}, + {"id": 2, "fname": "John", "lname": "Locke", "age": 65, "store_id": 2}, + {"id": 3, "fname": "Kate", "lname": "Austen", "age": 37, "store_id": 3}, + {"id": 4, "fname": "Claire", "lname": "Littleton", "age": 27, "store_id": 1}, + {"id": 5, "fname": "Hugo", "lname": "Reyes", "age": 29, "store_id": 3}, + ] +) + +df_employee.write.mode("overwrite").saveAsTable("employee") + +table_employee = session.table("employee") # This object is of Type DatabricksTable +``` + +#### Update Statement +The `update` method of the `Table` class is equivalent to the `UPDATE table_name` statement used in standard `sql`. + +```python +# Generates a `LazyExpression` object which can be executed using the `execute` method +update_expr = table_employee.update( + set_={"age": table_employee["age"] + 1}, + where=table_employee["id"] == 1, +) + +# Excecutes the update statement +update_expr.execute() + +# Show the result +table_employee.show() +``` + +Output: +``` ++----+--------+-----------+-----+----------+ +| id | fname | lname | age | store_id | ++----+--------+-----------+-----+----------+ +| 1 | Jack | Shephard | 38 | 1 | +| 2 | John | Locke | 65 | 2 | +| 3 | Kate | Austen | 37 | 3 | +| 4 | Claire | Littleton | 27 | 1 | +| 5 | Hugo | Reyes | 29 | 3 | ++----+--------+-----------+-----+----------+ +``` +#### Delete Statement +The `delete` method of the `Table` class is equivalent to the `DELETE FROM table_name` statement used in standard `sql`. + +```python +# Generates a `LazyExpression` object which can be executed using the `execute` method +delete_expr = table_employee.delete( + where=table_employee["id"] == 1, +) + +# Excecutes the delete statement +delete_expr.execute() + +# Show the result +table_employee.show() +``` + +Output: +``` ++----+--------+-----------+-----+----------+ +| id | fname | lname | age | store_id | ++----+--------+-----------+-----+----------+ +| 2 | John | Locke | 65 | 2 | +| 3 | Kate | Austen | 37 | 3 | +| 4 | Claire | Littleton | 27 | 1 | +| 5 | Hugo | Reyes | 29 | 3 | ++----+--------+-----------+-----+----------+ +``` +#### Merge Statement + +The `merge` method of the `Table` class is equivalent to the `MERGE INTO table_name` statement used in some `sql` engines. + +```python +df_new_employee = session.createDataFrame( + [ + {"id": 1, "fname": "Jack", "lname": "Shephard", "age": 38, "store_id": 1, "delete": False}, + {"id": 2, "fname": "Cate", "lname": "Austen", "age": 39, "store_id": 5, "delete": False}, + {"id": 5, "fname": "Ugo", "lname": "Reyes", "age": 29, "store_id": 3, "delete": True}, + {"id": 6, "fname": "Sun-Hwa", "lname": "Kwon", "age": 27, "store_id": 5, "delete": False}, + ] +) + +# Generates a `LazyExpression` object which can be executed using the `execute` method +merge_expr = table_employee.merge( + df_new_employee, + condition=table_employee["id"] == df_new_employee["id"], + clauses=[ + WhenMatched(condition=table_employee["fname"] == df_new_employee["fname"]).update( + set_={ + "age": df_new_employee["age"], + } + ), + WhenMatched(condition=df_new_employee["delete"]).delete(), + WhenNotMatched().insert( + values={ + "id": df_new_employee["id"], + "fname": df_new_employee["fname"], + "lname": df_new_employee["lname"], + "age": df_new_employee["age"], + "store_id": df_new_employee["store_id"], + } + ), + ], +) + +# Excecutes the merge statement +merge_expr.execute() + +# Show the result +table_employee.show() +``` + +Output: +``` ++----+---------+-----------+-----+----------+ +| id | fname | lname | age | store_id | ++----+---------+-----------+-----+----------+ +| 1 | Jack | Shephard | 38 | 1 | +| 2 | John | Locke | 65 | 2 | +| 3 | Kate | Austen | 37 | 3 | +| 4 | Claire | Littleton | 27 | 1 | +| 6 | Sun-Hwa | Kwon | 27 | 5 | ++----+---------+-----------+-----+----------+ +``` + + +Some engines like `Postgres` support an extra clause inside the `merge` statement which is `WHEN NOT MATCHED BY SOURCE THEN DELETE`. + +```python +df_new_employee = session.createDataFrame( + [ + {"id": 1, "fname": "Jack", "lname": "Shephard", "age": 38, "store_id": 1}, + {"id": 2, "fname": "Cate", "lname": "Austen", "age": 39, "store_id": 5}, + {"id": 5, "fname": "Hugo", "lname": "Reyes", "age": 29, "store_id": 3}, + {"id": 6, "fname": "Sun-Hwa", "lname": "Kwon", "age": 27, "store_id": 5}, + ] +) + +# Generates a `LazyExpression` object which can be executed using the `execute` method +merge_expr = table_employee.merge( + df_new_employee, + condition=table_employee["id"] == df_new_employee["id"], + clauses=[ + WhenMatched(condition=table_employee["fname"] == df_new_employee["fname"]).update( + set_={ + "age": df_new_employee["age"], + } + ), + WhenNotMatched().insert( + values={ + "id": df_new_employee["id"], + "fname": df_new_employee["fname"], + "lname": df_new_employee["lname"], + "age": df_new_employee["age"], + "store_id": df_new_employee["store_id"], + } + ), + WhenNotMatchedBySource().delete(), + ], +) + +# Excecutes the merge statement +merge_expr.execute() + +# Show the result +table_employee.show() +``` + +Output: +``` ++----+---------+-----------+-----+----------+ +| id | fname | lname | age | store_id | ++----+---------+-----------+-----+----------+ +| 1 | Jack | Shephard | 38 | 1 | +| 2 | John | Locke | 65 | 2 | +| 5 | Hugo | Reyes | 29 | 3 | +| 6 | Sun-Hwa | Kwon | 27 | 5 | ++----+---------+-----------+-----+----------+ +``` diff --git a/docs/redshift.md b/docs/redshift.md index 0b31c0a..dfa6a09 100644 --- a/docs/redshift.md +++ b/docs/redshift.md @@ -160,3 +160,98 @@ print(session.catalog.listColumns(table_path)) +------+---------------------------+----------------+ """ ``` + +## Extra Functionality not Present in PySpark + +SQLFrame supports the following extra functionality not in PySpark + +### Table Class + +SQLFrame provides a `Table` class that supports extra DML operations like `update` and `delete`. This class is returned when using the `table` function from the `DataFrameReader` class. + +```python +import os + +from redshift_connector import connect +from sqlframe.redshift import RedshiftSession + +conn = connect( + user="user", + password=os.environ["PASSWORD"], # Replace this with how you get your password + database="database", + host="xxxxx.xxxxxx.region.redshift-serverless.amazonaws.com", + port=5439, +) +session = RedshiftSession(conn=conn) + +df_employee = session.createDataFrame( + [ + {"id": 1, "fname": "Jack", "lname": "Shephard", "age": 37, "store_id": 1}, + {"id": 2, "fname": "John", "lname": "Locke", "age": 65, "store_id": 2}, + {"id": 3, "fname": "Kate", "lname": "Austen", "age": 37, "store_id": 3}, + {"id": 4, "fname": "Claire", "lname": "Littleton", "age": 27, "store_id": 1}, + {"id": 5, "fname": "Hugo", "lname": "Reyes", "age": 29, "store_id": 3}, + ] +) + +df_employee.write.mode("overwrite").saveAsTable("employee") + +table_employee = session.table("employee") # This object is of Type DatabricksTable +``` + +#### Update Statement +The `update` method of the `Table` class is equivalent to the `UPDATE table_name` statement used in standard `sql`. + +```python +# Generates a `LazyExpression` object which can be executed using the `execute` method +update_expr = table_employee.update( + set_={"age": table_employee["age"] + 1}, + where=table_employee["id"] == 1, +) + +# Excecutes the update statement +update_expr.execute() + +# Show the result +table_employee.show() +``` + +Output: +``` ++----+--------+-----------+-----+----------+ +| id | fname | lname | age | store_id | ++----+--------+-----------+-----+----------+ +| 1 | Jack | Shephard | 38 | 1 | +| 2 | John | Locke | 65 | 2 | +| 3 | Kate | Austen | 37 | 3 | +| 4 | Claire | Littleton | 27 | 1 | +| 5 | Hugo | Reyes | 29 | 3 | ++----+--------+-----------+-----+----------+ +``` +#### Delete Statement +The `delete` method of the `Table` class is equivalent to the `DELETE FROM table_name` statement used in standard `sql`. + +```python +# Generates a `LazyExpression` object which can be executed using the `execute` method +delete_expr = table_employee.delete( + where=table_employee["id"] == 1, +) + +# Excecutes the delete statement +delete_expr.execute() + +# Show the result +table_employee.show() +``` + +Output: +``` ++----+--------+-----------+-----+----------+ +| id | fname | lname | age | store_id | ++----+--------+-----------+-----+----------+ +| 2 | John | Locke | 65 | 2 | +| 3 | Kate | Austen | 37 | 3 | +| 4 | Claire | Littleton | 27 | 1 | +| 5 | Hugo | Reyes | 29 | 3 | ++----+--------+-----------+-----+----------+ +``` diff --git a/docs/snowflake.md b/docs/snowflake.md index ed856ef..8e99e54 100644 --- a/docs/snowflake.md +++ b/docs/snowflake.md @@ -604,3 +604,156 @@ See something that you would like to see supported? [Open an issue](https://gith * [rowsBetween](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.WindowSpec.rowsBetween.html) * sql * SQLFrame Specific: Get the SQL representation of the WindowSpec + +## Extra Functionality not Present in PySpark + +SQLFrame supports the following extra functionality not in PySpark + +### Table Class + +SQLFrame provides a `Table` class that supports extra DML operations like `update`, `delete` and `merge`. This class is returned when using the `table` function from the `DataFrameReader` class. + +```python +import os + +from snowflake.connector import connect +from sqlframe.snowflake import SnowflakeSession +from sqlframe.base.table import WhenMatched, WhenNotMatched + +connection = connect( + account=os.environ["SQLFRAME_SNOWFLAKE_ACCOUNT"], + user=os.environ["SQLFRAME_SNOWFLAKE_USER"], + password=os.environ["SQLFRAME_SNOWFLAKE_PASSWORD"], + warehouse=os.environ["SQLFRAME_SNOWFLAKE_WAREHOUSE"], + database=os.environ["SQLFRAME_SNOWFLAKE_DATABASE"], + schema=os.environ["SQLFRAME_SNOWFLAKE_SCHEMA"], +) +session = SnowflakeSession(conn=connection) + +df_employee = session.createDataFrame( + [ + {"id": 1, "fname": "Jack", "lname": "Shephard", "age": 37, "store_id": 1}, + {"id": 2, "fname": "John", "lname": "Locke", "age": 65, "store_id": 2}, + {"id": 3, "fname": "Kate", "lname": "Austen", "age": 37, "store_id": 3}, + {"id": 4, "fname": "Claire", "lname": "Littleton", "age": 27, "store_id": 1}, + {"id": 5, "fname": "Hugo", "lname": "Reyes", "age": 29, "store_id": 3}, + ] +) + +df_employee.write.mode("overwrite").saveAsTable("employee") + +table_employee = session.table("employee") # This object is of Type DatabricksTable +``` + +#### Update Statement +The `update` method of the `Table` class is equivalent to the `UPDATE table_name` statement used in standard `sql`. + +```python +# Generates a `LazyExpression` object which can be executed using the `execute` method +update_expr = table_employee.update( + set_={"age": table_employee["age"] + 1}, + where=table_employee["id"] == 1, +) + +# Excecutes the update statement +update_expr.execute() + +# Show the result +table_employee.show() +``` + +Output: +``` ++----+--------+-----------+-----+----------+ +| id | fname | lname | age | store_id | ++----+--------+-----------+-----+----------+ +| 1 | Jack | Shephard | 38 | 1 | +| 2 | John | Locke | 65 | 2 | +| 3 | Kate | Austen | 37 | 3 | +| 4 | Claire | Littleton | 27 | 1 | +| 5 | Hugo | Reyes | 29 | 3 | ++----+--------+-----------+-----+----------+ +``` +#### Delete Statement +The `delete` method of the `Table` class is equivalent to the `DELETE FROM table_name` statement used in standard `sql`. + +```python +# Generates a `LazyExpression` object which can be executed using the `execute` method +delete_expr = table_employee.delete( + where=table_employee["id"] == 1, +) + +# Excecutes the delete statement +delete_expr.execute() + +# Show the result +table_employee.show() +``` + +Output: +``` ++----+--------+-----------+-----+----------+ +| id | fname | lname | age | store_id | ++----+--------+-----------+-----+----------+ +| 2 | John | Locke | 65 | 2 | +| 3 | Kate | Austen | 37 | 3 | +| 4 | Claire | Littleton | 27 | 1 | +| 5 | Hugo | Reyes | 29 | 3 | ++----+--------+-----------+-----+----------+ +``` +#### Merge Statement + +The `merge` method of the `Table` class is equivalent to the `MERGE INTO table_name` statement used in some `sql` engines. + +```python +df_new_employee = session.createDataFrame( + [ + {"id": 1, "fname": "Jack", "lname": "Shephard", "age": 38, "store_id": 1, "delete": False}, + {"id": 2, "fname": "Cate", "lname": "Austen", "age": 39, "store_id": 5, "delete": False}, + {"id": 5, "fname": "Ugo", "lname": "Reyes", "age": 29, "store_id": 3, "delete": True}, + {"id": 6, "fname": "Sun-Hwa", "lname": "Kwon", "age": 27, "store_id": 5, "delete": False}, + ] +) + +# Generates a `LazyExpression` object which can be executed using the `execute` method +merge_expr = table_employee.merge( + df_new_employee, + condition=table_employee["id"] == df_new_employee["id"], + clauses=[ + WhenMatched(condition=table_employee["fname"] == df_new_employee["fname"]).update( + set_={ + "age": df_new_employee["age"], + } + ), + WhenMatched(condition=df_new_employee["delete"]).delete(), + WhenNotMatched().insert( + values={ + "id": df_new_employee["id"], + "fname": df_new_employee["fname"], + "lname": df_new_employee["lname"], + "age": df_new_employee["age"], + "store_id": df_new_employee["store_id"], + } + ), + ], +) + +# Excecutes the merge statement +merge_expr.execute() + +# Show the result +table_employee.show() +``` + +Output: +``` ++----+---------+-----------+-----+----------+ +| id | fname | lname | age | store_id | ++----+---------+-----------+-----+----------+ +| 1 | Jack | Shephard | 38 | 1 | +| 2 | John | Locke | 65 | 2 | +| 3 | Kate | Austen | 37 | 3 | +| 4 | Claire | Littleton | 27 | 1 | +| 6 | Sun-Hwa | Kwon | 27 | 5 | ++----+---------+-----------+-----+----------+ +``` diff --git a/sqlframe/base/dataframe.py b/sqlframe/base/dataframe.py index ed6b24c..faa57b6 100644 --- a/sqlframe/base/dataframe.py +++ b/sqlframe/base/dataframe.py @@ -872,6 +872,68 @@ def crossJoin(self, other: DF) -> Self: """ return self.join.__wrapped__(self, other, how="cross") # type: ignore + def _handle_self_join(self, other_df: DF, join_columns: t.List[Column]): + # If the two dataframes being joined come from the same branch, we then check if they have any columns that + # were created using the "branch_id" (df["column_name"]). If so, we know that we need to differentiate + # the two columns since they would end up with the same table name. We do this by checking for the unique + # uuids in the other df and finding columns that have metadata on them that match the uuids. If so, we know + # it comes from the other df and we change the table name to the other df's table name. + # See `test_self_join` for an example of this. + if self.branch_id == other_df.branch_id: + other_df_unique_uuids = other_df.known_uuids - self.known_uuids + for col in join_columns: + for col_expr in col.expression.find_all(exp.Column): + if ( + "join_on_uuid" in col_expr.meta + and col_expr.meta["join_on_uuid"] in other_df_unique_uuids + ): + col_expr.set("table", exp.to_identifier(other_df.latest_cte_name)) + + @staticmethod + def _handle_join_column_names_only( + join_columns: t.List[Column], + join_expression: exp.Select, + other_df: DF, + table_names: t.List[str], + ): + potential_ctes = [ + cte + for cte in join_expression.ctes + if cte.alias_or_name in table_names and cte.alias_or_name != other_df.latest_cte_name + ] + # Determine the table to reference for the left side of the join by checking each of the left side + # tables and see if they have the column being referenced. + join_column_pairs = [] + for join_column in join_columns: + num_matching_ctes = 0 + for cte in potential_ctes: + if join_column.alias_or_name in cte.this.named_selects: + left_column = join_column.copy().set_table_name(cte.alias_or_name) + right_column = join_column.copy().set_table_name(other_df.latest_cte_name) + join_column_pairs.append((left_column, right_column)) + num_matching_ctes += 1 + # We only want to match one table to the column and that should be matched left -> right + # so we break after the first match + break + if num_matching_ctes == 0: + raise ValueError( + f"Column `{join_column.alias_or_name}` does not exist in any of the tables." + ) + join_clause = functools.reduce( + lambda x, y: x & y, + [left_column == right_column for left_column, right_column in join_column_pairs], + ) + return join_column_pairs, join_clause + + def _normalize_join_clause( + self, join_columns: t.List[Column], join_expression: t.Optional[exp.Select] + ) -> Column: + join_columns = self._ensure_and_normalize_cols(join_columns, join_expression) + if len(join_columns) > 1: + join_columns = [functools.reduce(lambda x, y: x & y, join_columns)] + join_clause = join_columns[0] + return join_clause + @operation(Operation.FROM) def join( self, @@ -895,21 +957,8 @@ def join( self_columns = self._get_outer_select_columns(join_expression) other_columns = self._get_outer_select_columns(other_df.expression) join_columns = self._ensure_and_normalize_cols(on) - # If the two dataframes being joined come from the same branch, we then check if they have any columns that - # were created using the "branch_id" (df["column_name"]). If so, we know that we need to differentiate - # the two columns since they would end up with the same table name. We do this by checking for the unique - # uuids in the other df and finding columns that have metadata on them that match the uuids. If so, we know - # it comes from the other df and we change the table name to the other df's table name. - # See `test_self_join` for an example of this. - if self.branch_id == other_df.branch_id: - other_df_unique_uuids = other_df.known_uuids - self.known_uuids - for col in join_columns: - for col_expr in col.expression.find_all(exp.Column): - if ( - "join_on_uuid" in col_expr.meta - and col_expr.meta["join_on_uuid"] in other_df_unique_uuids - ): - col_expr.set("table", exp.to_identifier(other_df.latest_cte_name)) + self._handle_self_join(other_df, join_columns) + # Determines the join clause and select columns to be used passed on what type of columns were provided for # the join. The columns returned changes based on how the on expression is provided. if how != "cross": @@ -923,38 +972,9 @@ def join( table.alias_or_name for table in get_tables_from_expression_with_join(join_expression) ] - potential_ctes = [ - cte - for cte in join_expression.ctes - if cte.alias_or_name in table_names - and cte.alias_or_name != other_df.latest_cte_name - ] - # Determine the table to reference for the left side of the join by checking each of the left side - # tables and see if they have the column being referenced. - join_column_pairs = [] - for join_column in join_columns: - num_matching_ctes = 0 - for cte in potential_ctes: - if join_column.alias_or_name in cte.this.named_selects: - left_column = join_column.copy().set_table_name(cte.alias_or_name) - right_column = join_column.copy().set_table_name( - other_df.latest_cte_name - ) - join_column_pairs.append((left_column, right_column)) - num_matching_ctes += 1 - # We only want to match one table to the column and that should be matched left -> right - # so we break after the first match - break - if num_matching_ctes == 0: - raise ValueError( - f"Column `{join_column.alias_or_name}` does not exist in any of the tables." - ) - join_clause = functools.reduce( - lambda x, y: x & y, - [ - left_column == right_column - for left_column, right_column in join_column_pairs - ], + + join_column_pairs, join_clause = self._handle_join_column_names_only( + join_columns, join_expression, other_df, table_names ) join_column_names = [ coalesce( @@ -989,10 +1009,7 @@ def join( * There is no deduplication of the results. * The left join dataframe columns go first and right come after. No sort preference is given to join columns """ - join_columns = self._ensure_and_normalize_cols(join_columns, join_expression) - if len(join_columns) > 1: - join_columns = [functools.reduce(lambda x, y: x & y, join_columns)] - join_clause = join_columns[0] + join_clause = self._normalize_join_clause(join_columns, join_expression) select_column_names = [ column.alias_or_name for column in self_columns + other_columns ] diff --git a/sqlframe/base/mixins/table_mixins.py b/sqlframe/base/mixins/table_mixins.py new file mode 100644 index 0000000..001b6e9 --- /dev/null +++ b/sqlframe/base/mixins/table_mixins.py @@ -0,0 +1,335 @@ +import functools +import logging +import typing as t + +from sqlglot import exp + +try: + from sqlglot.expressions import Whens +except ImportError: + Whens = None # type: ignore +from sqlglot.helper import object_to_dict + +from sqlframe.base.column import Column +from sqlframe.base.table import ( + DF, + Clause, + LazyExpression, + WhenMatched, + WhenNotMatched, + WhenNotMatchedBySource, + _BaseTable, +) + +if t.TYPE_CHECKING: + from sqlframe.base._typing import ColumnOrLiteral + + +logger = logging.getLogger(__name__) + + +def ensure_cte() -> t.Callable[[t.Callable], t.Callable]: + def decorator(func: t.Callable) -> t.Callable: + @functools.wraps(func) + def wrapper(self: _BaseTable, *args, **kwargs) -> t.Any: + if len(self.expression.ctes) > 0: + return func(self, *args, **kwargs) # type: ignore + self_class = self.__class__ + self = self._convert_leaf_to_cte() + self = self_class(**object_to_dict(self)) + return func(self, *args, **kwargs) # type: ignore + + wrapper.__wrapped__ = func # type: ignore + return wrapper + + return decorator + + +class _BaseTableMixins(_BaseTable, t.Generic[DF]): + def _ensure_where_condition( + self, where: t.Optional[t.Union[Column, str, bool]] = None + ) -> exp.Expression: + self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name + + if where is None: + logger.warning("Empty value for `where`clause. Defaults to `True`.") + condition: exp.Expression = exp.Boolean(this=True) + else: + condition_list = self._ensure_and_normalize_cols(where, self.expression) + if len(condition_list) > 1: + condition_list = [functools.reduce(lambda x, y: x & y, condition_list)] + for col_expr in condition_list[0].expression.find_all(exp.Column): + if col_expr.table == self.expression.args["from"].this.alias_or_name: + col_expr.set("table", exp.to_identifier(self_name)) + condition = condition_list[0].expression + if isinstance(condition, exp.Alias): + condition = condition.this + return condition + + +class UpdateSupportMixin(_BaseTableMixins, t.Generic[DF]): + @ensure_cte() + def update( + self, + set_: t.Dict[t.Union[Column, str], t.Union[Column, "ColumnOrLiteral", exp.Expression]], + where: t.Optional[t.Union[Column, str, bool]] = None, + ) -> LazyExpression: + self_expr = self.expression.ctes[0].this.args["from"].this + + condition = self._ensure_where_condition(where) + update_set = self._ensure_and_normalize_update_set(set_) + update_expr = exp.Update( + this=self_expr, + expressions=[ + exp.EQ( + this=key, + expression=val, + ) + for key, val in update_set.items() + ], + where=exp.Where(this=condition), + ) + + return LazyExpression(update_expr, self.session) + + def _ensure_and_normalize_update_set( + self, + set_: t.Dict[t.Union[Column, str], t.Union[Column, "ColumnOrLiteral", exp.Expression]], + ) -> t.Dict[str, exp.Expression]: + self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name + update_set = {} + for key, val in set_.items(): + key_column: Column = self._ensure_and_normalize_col(key) + key_expr = list(key_column.expression.find_all(exp.Column)) + if len(key_expr) > 1: + raise ValueError(f"Can only update one a single column at a time.") + key = key_expr[0].alias_or_name + + val_column: Column = self._ensure_and_normalize_col(val) + for col_expr in val_column.expression.find_all(exp.Column): + if col_expr.table == self.expression.args["from"].this.alias_or_name: + col_expr.set("table", exp.to_identifier(self_name)) + else: + raise ValueError( + f"Column `{col_expr.alias_or_name}` does not exist in the table." + ) + + update_set[key] = val_column.expression + return update_set + + +class DeleteSupportMixin(_BaseTableMixins, t.Generic[DF]): + @ensure_cte() + def delete( + self, + where: t.Optional[t.Union[Column, str, bool]] = None, + ) -> LazyExpression: + self_expr = self.expression.ctes[0].this.args["from"].this + + condition = self._ensure_where_condition(where) + delete_expr = exp.Delete( + this=self_expr, + where=exp.Where(this=condition), + ) + + return LazyExpression(delete_expr, self.session) + + +class MergeSupportMixin(_BaseTable, t.Generic[DF]): + _merge_supported_clauses: t.Iterable[ + t.Union[t.Type[WhenMatched], t.Type[WhenNotMatched], t.Type[WhenNotMatchedBySource]] + ] + _merge_support_star: bool + + @ensure_cte() + def merge( + self, + other_df: DF, + condition: t.Union[str, t.List[str], Column, t.List[Column], bool], + clauses: t.Iterable[t.Union[WhenMatched, WhenNotMatched, WhenNotMatchedBySource]], + ) -> LazyExpression: + self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name + self_expr = self.expression.ctes[0].this.args["from"].this + + other_df = other_df._convert_leaf_to_cte() + + if condition is None: + raise ValueError("condition cannot be None") + + condition_columns: Column = self._ensure_and_normalize_condition(condition, other_df) + other_name = self._create_hash_from_expression(other_df.expression) + other_expr = exp.Subquery( + this=other_df.expression, alias=exp.TableAlias(this=exp.to_identifier(other_name)) + ) + + for col_expr in condition_columns.expression.find_all(exp.Column): + if col_expr.table == self.expression.args["from"].this.alias_or_name: + col_expr.set("table", exp.to_identifier(self_name)) + if col_expr.table == other_df.latest_cte_name: + col_expr.set("table", exp.to_identifier(other_name)) + + merge_expressions = [] + for clause in clauses: + if not isinstance(clause, tuple(self._merge_supported_clauses)): + raise ValueError( + f"Unsupported clause type {type(clause.clause)} for merge operation" + ) + expression = None + + if clause.clause.condition is not None: + cond_clause = self._ensure_and_normalize_condition( + clause.clause.condition, other_df, True + ) + for col_expr in cond_clause.expression.find_all(exp.Column): + if col_expr.table == self.expression.args["from"].this.alias_or_name: + col_expr.set("table", exp.to_identifier(self_name)) + if col_expr.table == other_df.latest_cte_name: + col_expr.set("table", exp.to_identifier(other_name)) + else: + cond_clause = None + if clause.clause.clause_type == Clause.UPDATE: + update_set = self._ensure_and_normalize_assignments( + clause.clause.assignments, other_df + ) + expression = exp.When( + matched=clause.clause.matched, + source=clause.clause.by_source, + condition=cond_clause.expression if cond_clause else None, + then=exp.Update( + expressions=[ + exp.EQ( + this=key, + expression=val, + ) + for key, val in update_set.items() + ] + ), + ) + if clause.clause.clause_type == Clause.UPDATE_ALL: + if not self._support_star: + raise ValueError("Merge operation does not support UPDATE_ALL") + expression = exp.When( + matched=clause.clause.matched, + source=clause.clause.by_source, + condition=cond_clause.expression if cond_clause else None, + then=exp.Update(expressions=[exp.Star()]), + ) + elif clause.clause.clause_type == Clause.INSERT: + insert_values = self._ensure_and_normalize_assignments( + clause.clause.assignments, other_df + ) + expression = exp.When( + matched=clause.clause.matched, + source=clause.clause.by_source, + condition=cond_clause.expression if cond_clause else None, + then=exp.Insert( + this=exp.Tuple(expressions=[key for key in insert_values.keys()]), + expression=exp.Tuple(expressions=[val for val in insert_values.values()]), + ), + ) + elif clause.clause.clause_type == Clause.INSERT_ALL: + if not self._support_star: + raise ValueError("Merge operation does not support INSERT_ALL") + expression = exp.When( + matched=clause.clause.matched, + source=clause.clause.by_source, + condition=cond_clause.expression if cond_clause else None, + then=exp.Insert(expression=exp.Star()), + ) + elif clause.clause.clause_type == Clause.DELETE: + expression = exp.When( + matched=clause.clause.matched, + source=clause.clause.by_source, + condition=cond_clause.expression if cond_clause else None, + then=exp.var("DELETE"), + ) + + if expression: + merge_expressions.append(expression) + + if Whens is None: + merge_expr = exp.merge( + *merge_expressions, + into=self_expr, + using=other_expr, + on=condition_columns.expression, + ) + else: + merge_expr = exp.merge( + Whens(expressions=merge_expressions), + into=self_expr, + using=other_expr, + on=condition_columns.expression, + ) + + return LazyExpression(merge_expr, self.session) + + def _ensure_and_normalize_condition( + self, + condition: t.Union[str, t.List[str], Column, t.List[Column], bool], + other_df: DF, + clause: t.Optional[bool] = False, + ): + join_expression = self._add_ctes_to_expression( + self.expression, other_df.expression.copy().ctes + ) + condition = self._ensure_and_normalize_cols(condition, self.expression) + self._handle_self_join(other_df, condition) + + if isinstance(condition[0].expression, exp.Column) and not clause: + table_names = [ + table.alias_or_name + for table in [ + self.expression.args["from"].this, + other_df.expression.args["from"].this, + ] + ] + + join_column_pairs, join_clause = self._handle_join_column_names_only( + condition, join_expression, other_df, table_names + ) + else: + join_clause = self._normalize_join_clause(condition, join_expression) + return join_clause + + def _ensure_and_normalize_assignments( + self, + assignments: t.Dict[ + t.Union[Column, str], t.Union[Column, "ColumnOrLiteral", exp.Expression] + ], + other_df, + ) -> t.Dict[exp.Column, exp.Expression]: + self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name + other_name = self._create_hash_from_expression(other_df.expression) + update_set = {} + for key, val in assignments.items(): + key_column: Column = self._ensure_and_normalize_col(key) + key_expr = list(key_column.expression.find_all(exp.Column)) + if len(key_expr) > 1: + raise ValueError(f"Target expression `{key_expr}` should be a single column.") + column_key = exp.column(key_expr[0].alias_or_name) + + val = self._ensure_and_normalize_col(val) + val = self._ensure_and_normalize_cols(val, other_df.expression)[0] + if self.branch_id == other_df.branch_id: + other_df_unique_uuids = other_df.known_uuids - self.known_uuids + for col_expr in val.expression.find_all(exp.Column): + if ( + "join_on_uuid" in col_expr.meta + and col_expr.meta["join_on_uuid"] in other_df_unique_uuids + ): + col_expr.set("table", exp.to_identifier(other_df.latest_cte_name)) + + for col_expr in val.expression.find_all(exp.Column): + if not col_expr.table or col_expr.table == other_df.latest_cte_name: + col_expr.set("table", exp.to_identifier(other_name)) + elif col_expr.table == self.expression.args["from"].this.alias_or_name: + col_expr.set("table", exp.to_identifier(self_name)) + else: + raise ValueError( + f"Column `{col_expr.alias_or_name}` does not exist in any of the tables." + ) + if isinstance(val.expression, exp.Alias): + val.expression = val.expression.this + update_set[column_key] = val.expression + return update_set diff --git a/sqlframe/base/readerwriter.py b/sqlframe/base/readerwriter.py index 57b8057..c05b58c 100644 --- a/sqlframe/base/readerwriter.py +++ b/sqlframe/base/readerwriter.py @@ -21,19 +21,20 @@ if t.TYPE_CHECKING: from sqlframe.base._typing import OptionalPrimitiveType, PathOrPaths from sqlframe.base.column import Column - from sqlframe.base.session import DF, _BaseSession + from sqlframe.base.session import DF, TABLE, _BaseSession from sqlframe.base.types import StructType SESSION = t.TypeVar("SESSION", bound=_BaseSession) else: SESSION = t.TypeVar("SESSION") DF = t.TypeVar("DF") + TABLE = t.TypeVar("TABLE") logger = logging.getLogger(__name__) -class _BaseDataFrameReader(t.Generic[SESSION, DF]): +class _BaseDataFrameReader(t.Generic[SESSION, DF, TABLE]): def __init__(self, spark: SESSION): self._session = spark self.state_format_to_read: t.Optional[str] = None @@ -42,7 +43,7 @@ def __init__(self, spark: SESSION): def session(self) -> SESSION: return self._session - def table(self, tableName: str) -> DF: + def table(self, tableName: str) -> TABLE: tableName = normalize_string(tableName, from_dialect="input", is_table=True) if df := self.session.temp_views.get(tableName): return df @@ -50,7 +51,7 @@ def table(self, tableName: str) -> DF: self.session.catalog.add_table(table) columns = self.session.catalog.get_columns_from_schema(table) - return self.session._create_df( + return self.session._create_table( exp.Select() .from_(tableName, dialect=self.session.input_dialect) .select(*columns, dialect=self.session.input_dialect) diff --git a/sqlframe/base/session.py b/sqlframe/base/session.py index 3dd47fe..2710092 100644 --- a/sqlframe/base/session.py +++ b/sqlframe/base/session.py @@ -27,6 +27,7 @@ from sqlframe.base.dataframe import BaseDataFrame from sqlframe.base.normalize import normalize_dict from sqlframe.base.readerwriter import _BaseDataFrameReader, _BaseDataFrameWriter +from sqlframe.base.table import _BaseTable from sqlframe.base.udf import _BaseUDFRegistration from sqlframe.base.util import ( get_column_mapping_from_schema_input, @@ -65,17 +66,19 @@ def fetchdf(self) -> pd.DataFrame: ... READER = t.TypeVar("READER", bound=_BaseDataFrameReader) WRITER = t.TypeVar("WRITER", bound=_BaseDataFrameWriter) DF = t.TypeVar("DF", bound=BaseDataFrame) +TABLE = t.TypeVar("TABLE", bound=_BaseTable) UDF_REGISTRATION = t.TypeVar("UDF_REGISTRATION", bound=_BaseUDFRegistration) _MISSING = "MISSING" -class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, CONN, UDF_REGISTRATION]): +class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, TABLE, CONN, UDF_REGISTRATION]): _instance = None _reader: t.Type[READER] _writer: t.Type[WRITER] _catalog: t.Type[CATALOG] _df: t.Type[DF] + _table: t.Type[TABLE] _udf_registration: t.Type[UDF_REGISTRATION] SANITIZE_COLUMN_NAMES = False @@ -158,12 +161,15 @@ def _sanitize_column_name(self, name: str) -> str: return name.replace("(", "_").replace(")", "_") return name - def table(self, tableName: str) -> DF: + def table(self, tableName: str) -> TABLE: return self.read.table(tableName) def _create_df(self, *args, **kwargs) -> DF: return self._df(self, *args, **kwargs) + def _create_table(self, *args, **kwargs) -> TABLE: + return self._table(self, *args, **kwargs) + def __new__(cls, *args, **kwargs): if _BaseSession._instance is None: _BaseSession._instance = super().__new__(cls) diff --git a/sqlframe/base/table.py b/sqlframe/base/table.py new file mode 100644 index 0000000..bce1b15 --- /dev/null +++ b/sqlframe/base/table.py @@ -0,0 +1,238 @@ +import sys +import typing as t +from enum import IntEnum +from uuid import uuid4 + +from sqlglot import exp +from sqlglot.expressions import _to_s +from sqlglot.helper import object_to_dict + +from sqlframe.base.dataframe import DF, SESSION, BaseDataFrame + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +if t.TYPE_CHECKING: + from sqlframe.base._typing import ColumnOrLiteral + from sqlframe.base.column import Column + from sqlframe.base.types import Row + + +class Clause(IntEnum): + UPDATE = 1 + UPDATE_ALL = 2 + DELETE = 3 + INSERT = 4 + INSERT_ALL = 5 + + +class MergeClause: + def __init__( + self, + clause_type: Clause, + condition: t.Optional[t.Union[str, t.List[str], "Column", t.List["Column"], bool]] = None, + assignments: t.Optional[ + t.Dict[t.Union["Column", str], t.Union["Column", "ColumnOrLiteral", exp.Expression]] + ] = None, + matched: bool = True, + by_source: bool = False, + ): + self.clause_type = clause_type + self.condition = condition + self.assignments = assignments + self.matched = matched + self.by_source = by_source + + +class WhenMatched: + def __init__( + self, + condition: t.Optional[t.Union[str, t.List[str], "Column", t.List["Column"], bool]] = None, + ): + self._condition = condition + self._clause: t.Union[MergeClause, None] = None + + def update( + self, + set_: t.Dict[t.Union["Column", str], t.Union["Column", "ColumnOrLiteral", exp.Expression]], + ) -> Self: + if self._clause: + clause = self._clause.clause_type.name.lower() + raise ValueError(f"WhenMatched already has an '{clause}' clause") + self._clause = MergeClause( + Clause.UPDATE, + self._condition, + {k: v for k, v in set_.items()}, + matched=True, + by_source=False, + ) + return self + + def update_all(self) -> Self: + if self._clause: + clause = self._clause.clause_type.name.lower() + raise ValueError(f"WhenMatched already has an '{clause}' clause") + self._clause = MergeClause( + Clause.UPDATE_ALL, + self._condition, + {}, + matched=True, + by_source=False, + ) + return self + + def delete(self) -> Self: + if self._clause: + clause = self._clause.clause_type.name.lower() + raise ValueError(f"WhenMatched already has an '{clause}' clause") + self._clause = MergeClause(Clause.DELETE, self._condition, matched=True, by_source=False) + return self + + @property + def clause(self): + return self._clause + + +class WhenNotMatched: + def __init__( + self, + condition: t.Optional[t.Union[str, t.List[str], "Column", t.List["Column"], bool]] = None, + ): + self._condition = condition + self._clause: t.Union[MergeClause, None] = None + + def insert( + self, + values: t.Dict[ + t.Union["Column", str], t.Union["Column", "ColumnOrLiteral", exp.Expression] + ], + ) -> Self: + if self._clause: + clause = self._clause.clause_type.name.lower() + raise ValueError(f"WhenNotMatched already has an '{clause}' clause") + self._clause = MergeClause( + Clause.INSERT, + self._condition, + {k: v for k, v in values.items()}, + matched=False, + by_source=False, + ) + return self + + def insert_all(self) -> Self: + if self._clause: + clause = self._clause.clause_type.name.lower() + raise ValueError(f"WhenNotMatched already has an '{clause}' clause") + self._clause = MergeClause( + Clause.INSERT_ALL, + self._condition, + {}, + matched=False, + by_source=False, + ) + return self + + @property + def clause(self): + return self._clause + + +class WhenNotMatchedBySource(object): + def __init__( + self, + condition: t.Optional[t.Union[str, t.List[str], "Column", t.List["Column"], bool]] = None, + ): + self._condition = condition + self._clause: t.Union[MergeClause, None] = None + + def update( + self, + set_: t.Dict[t.Union["Column", str], t.Union["Column", "ColumnOrLiteral", exp.Expression]], + ) -> Self: + if self._clause: + clause = self._clause.clause_type.name.lower() + raise ValueError(f"WhenNotMatchedBySource already has an '{clause}' clause") + self._clause = MergeClause( + Clause.UPDATE, + self._condition, + {k: v for k, v in set_.items()}, + matched=False, + by_source=True, + ) + return self + + def delete(self) -> Self: + if self._clause: + clause = self._clause.clause_type.name.lower() + raise ValueError(f"WhenNotMatchedBySource already has an '{clause}' clause") + self._clause = MergeClause(Clause.DELETE, self._condition, matched=False, by_source=True) + return self + + @property + def clause(self): + return self._clause + + +class LazyExpression: + def __init__( + self, + expression: exp.Expression, + session: SESSION, + ): + self._expression = expression + self._session = session + + def execute(self) -> t.List["Row"]: + return self._session._collect(self._expression) + + @property + def expression(self) -> exp.Expression: + return self._expression + + def __str__(self) -> str: + return self._expression.sql() + + def __repr__(self) -> str: + return _to_s(self._expression) + + +class _BaseTable(BaseDataFrame, t.Generic[DF]): + _df: t.Type[DF] + + def copy(self, **kwargs): + kwargs["join_on_uuid"] = str(uuid4()) + return self._df(**object_to_dict(self, **kwargs)) + + def __copy__(self): + return self.copy() + + def table_copy(self): + return self.__class__(**object_to_dict(self)) + + def alias(self, name: str, **kwargs) -> Self: + df = BaseDataFrame.alias(self, name, **kwargs) + new_df = self.__class__(**object_to_dict(df)) + return new_df + + def update( + self, + set_: t.Dict[t.Union["Column", str], t.Union["Column", "ColumnOrLiteral", exp.Expression]], + where: t.Optional[t.Union["Column", str, bool]] = None, + ) -> LazyExpression: + raise NotImplementedError() + + def merge( + self, + source: DF, + condition: t.Union[str, t.List[str], "Column", t.List["Column"], bool], + clauses: t.Iterable[t.Union[WhenMatched, WhenNotMatched, WhenNotMatchedBySource]], + ) -> LazyExpression: + raise NotImplementedError() + + def delete( + self, + where: t.Optional[t.Union["Column", str, bool]] = None, + ) -> LazyExpression: + raise NotImplementedError() diff --git a/sqlframe/bigquery/catalog.py b/sqlframe/bigquery/catalog.py index dd0b7e5..d4358ca 100644 --- a/sqlframe/bigquery/catalog.py +++ b/sqlframe/bigquery/catalog.py @@ -51,6 +51,7 @@ def currentDatabase(self) -> str: from_dialect=self.session.execution_dialect, to_dialect=self.session.output_dialect, is_schema=True, + quote_identifiers=True, ) return to_schema(current_database, dialect=self.session.output_dialect).db diff --git a/sqlframe/bigquery/readwriter.py b/sqlframe/bigquery/readwriter.py index 76cdc28..9e9e7b3 100644 --- a/sqlframe/bigquery/readwriter.py +++ b/sqlframe/bigquery/readwriter.py @@ -13,11 +13,12 @@ if t.TYPE_CHECKING: from sqlframe.bigquery.session import BigQuerySession # noqa from sqlframe.bigquery.dataframe import BigQueryDataFrame # noqa + from sqlframe.bigquery.table import BigQueryTable # noqa class BigQueryDataFrameReader( PandasLoaderMixin["BigQuerySession", "BigQueryDataFrame"], - _BaseDataFrameReader["BigQuerySession", "BigQueryDataFrame"], + _BaseDataFrameReader["BigQuerySession", "BigQueryDataFrame", "BigQueryTable"], ): pass diff --git a/sqlframe/bigquery/session.py b/sqlframe/bigquery/session.py index 363d793..84444f3 100644 --- a/sqlframe/bigquery/session.py +++ b/sqlframe/bigquery/session.py @@ -9,6 +9,7 @@ BigQueryDataFrameReader, BigQueryDataFrameWriter, ) +from sqlframe.bigquery.table import BigQueryTable from sqlframe.bigquery.udf import BigQueryUDFRegistration if t.TYPE_CHECKING: @@ -25,6 +26,7 @@ class BigQuerySession( BigQueryDataFrameReader, BigQueryDataFrameWriter, BigQueryDataFrame, + BigQueryTable, BigQueryConnection, BigQueryUDFRegistration, ], @@ -33,6 +35,7 @@ class BigQuerySession( _reader = BigQueryDataFrameReader _writer = BigQueryDataFrameWriter _df = BigQueryDataFrame + _table = BigQueryTable _udf_registration = BigQueryUDFRegistration QUALIFY_INFO_SCHEMA_WITH_DATABASE = True diff --git a/sqlframe/bigquery/table.py b/sqlframe/bigquery/table.py new file mode 100644 index 0000000..5f47ecb --- /dev/null +++ b/sqlframe/bigquery/table.py @@ -0,0 +1,24 @@ +from sqlframe.base.mixins.table_mixins import ( + DeleteSupportMixin, + MergeSupportMixin, + UpdateSupportMixin, +) +from sqlframe.base.table import ( + WhenMatched, + WhenNotMatched, + WhenNotMatchedBySource, + _BaseTable, +) +from sqlframe.bigquery.dataframe import BigQueryDataFrame + + +class BigQueryTable( + BigQueryDataFrame, + UpdateSupportMixin["BigQueryDataFrame"], + DeleteSupportMixin["BigQueryDataFrame"], + MergeSupportMixin["BigQueryDataFrame"], + _BaseTable["BigQueryDataFrame"], +): + _df = BigQueryDataFrame + _merge_supported_clauses = [WhenMatched, WhenNotMatched, WhenNotMatchedBySource] + _merge_support_star = False diff --git a/sqlframe/databricks/readwriter.py b/sqlframe/databricks/readwriter.py index 777f157..18f064b 100644 --- a/sqlframe/databricks/readwriter.py +++ b/sqlframe/databricks/readwriter.py @@ -23,11 +23,12 @@ if t.TYPE_CHECKING: from sqlframe.databricks.session import DatabricksSession # noqa from sqlframe.databricks.dataframe import DatabricksDataFrame # noqa + from sqlframe.databricks.table import DatabricksTable # noqa class DatabricksDataFrameReader( PandasLoaderMixin["DatabricksSession", "DatabricksDataFrame"], - _BaseDataFrameReader["DatabricksSession", "DatabricksDataFrame"], + _BaseDataFrameReader["DatabricksSession", "DatabricksDataFrame", "DatabricksTable"], ): pass diff --git a/sqlframe/databricks/session.py b/sqlframe/databricks/session.py index a442d2f..7bf35fc 100644 --- a/sqlframe/databricks/session.py +++ b/sqlframe/databricks/session.py @@ -10,6 +10,7 @@ DatabricksDataFrameReader, DatabricksDataFrameWriter, ) +from sqlframe.databricks.table import DatabricksTable from sqlframe.databricks.udf import DatabricksUDFRegistration if t.TYPE_CHECKING: @@ -24,6 +25,7 @@ class DatabricksSession( DatabricksDataFrameReader, DatabricksDataFrameWriter, DatabricksDataFrame, + DatabricksTable, DatabricksConnection, DatabricksUDFRegistration, ], @@ -32,6 +34,7 @@ class DatabricksSession( _reader = DatabricksDataFrameReader _writer = DatabricksDataFrameWriter _df = DatabricksDataFrame + _table = DatabricksTable _udf_registration = DatabricksUDFRegistration def __init__( diff --git a/sqlframe/databricks/table.py b/sqlframe/databricks/table.py new file mode 100644 index 0000000..b5e5388 --- /dev/null +++ b/sqlframe/databricks/table.py @@ -0,0 +1,24 @@ +from sqlframe.base.mixins.table_mixins import ( + DeleteSupportMixin, + MergeSupportMixin, + UpdateSupportMixin, +) +from sqlframe.base.table import ( + WhenMatched, + WhenNotMatched, + WhenNotMatchedBySource, + _BaseTable, +) +from sqlframe.databricks.dataframe import DatabricksDataFrame + + +class DatabricksTable( + DatabricksDataFrame, + UpdateSupportMixin["DatabricksDataFrame"], + DeleteSupportMixin["DatabricksDataFrame"], + MergeSupportMixin["DatabricksDataFrame"], + _BaseTable["DatabricksDataFrame"], +): + _df = DatabricksDataFrame + _merge_supported_clauses = [WhenMatched, WhenNotMatched, WhenNotMatchedBySource] + _merge_support_star = True diff --git a/sqlframe/duckdb/readwriter.py b/sqlframe/duckdb/readwriter.py index 73aae08..2160764 100644 --- a/sqlframe/duckdb/readwriter.py +++ b/sqlframe/duckdb/readwriter.py @@ -16,11 +16,14 @@ from sqlframe.base.types import StructType from sqlframe.duckdb.dataframe import DuckDBDataFrame from sqlframe.duckdb.session import DuckDBSession # noqa + from sqlframe.duckdb.table import DuckDBTable # noqa logger = logging.getLogger(__name__) -class DuckDBDataFrameReader(_BaseDataFrameReader["DuckDBSession", "DuckDBDataFrame"]): +class DuckDBDataFrameReader( + _BaseDataFrameReader["DuckDBSession", "DuckDBDataFrame", "DuckDBTable"] +): def load( self, path: t.Optional[PathOrPaths] = None, diff --git a/sqlframe/duckdb/session.py b/sqlframe/duckdb/session.py index abdce5a..160a9f2 100644 --- a/sqlframe/duckdb/session.py +++ b/sqlframe/duckdb/session.py @@ -11,6 +11,7 @@ DuckDBDataFrameReader, DuckDBDataFrameWriter, ) +from sqlframe.duckdb.table import DuckDBTable from sqlframe.duckdb.udf import DuckDBUDFRegistration if t.TYPE_CHECKING: @@ -26,6 +27,7 @@ class DuckDBSession( DuckDBDataFrameReader, DuckDBDataFrameWriter, DuckDBDataFrame, + DuckDBTable, DuckDBPyConnection, DuckDBUDFRegistration, ] @@ -34,6 +36,7 @@ class DuckDBSession( _reader = DuckDBDataFrameReader _writer = DuckDBDataFrameWriter _df = DuckDBDataFrame + _table = DuckDBTable _udf_registration = DuckDBUDFRegistration def __init__(self, conn: t.Optional[DuckDBPyConnection] = None, *args, **kwargs): diff --git a/sqlframe/duckdb/table.py b/sqlframe/duckdb/table.py new file mode 100644 index 0000000..9e9cf45 --- /dev/null +++ b/sqlframe/duckdb/table.py @@ -0,0 +1,16 @@ +from sqlframe.base.mixins.table_mixins import ( + DeleteSupportMixin, + MergeSupportMixin, + UpdateSupportMixin, +) +from sqlframe.base.table import _BaseTable +from sqlframe.duckdb.dataframe import DuckDBDataFrame + + +class DuckDBTable( + DuckDBDataFrame, + UpdateSupportMixin["DuckDBDataFrame"], + DeleteSupportMixin["DuckDBDataFrame"], + _BaseTable["DuckDBDataFrame"], +): + _df = DuckDBDataFrame diff --git a/sqlframe/postgres/readwriter.py b/sqlframe/postgres/readwriter.py index 47d03c7..02143f9 100644 --- a/sqlframe/postgres/readwriter.py +++ b/sqlframe/postgres/readwriter.py @@ -13,11 +13,12 @@ if t.TYPE_CHECKING: from sqlframe.postgres.session import PostgresSession # noqa from sqlframe.postgres.dataframe import PostgresDataFrame # noqa + from sqlframe.postgres.table import PostgresTable # noqa class PostgresDataFrameReader( PandasLoaderMixin["PostgresSession", "PostgresDataFrame"], - _BaseDataFrameReader["PostgresSession", "PostgresDataFrame"], + _BaseDataFrameReader["PostgresSession", "PostgresDataFrame", "PostgresTable"], ): pass diff --git a/sqlframe/postgres/session.py b/sqlframe/postgres/session.py index cf3dcc7..b57ce39 100644 --- a/sqlframe/postgres/session.py +++ b/sqlframe/postgres/session.py @@ -11,6 +11,7 @@ PostgresDataFrameReader, PostgresDataFrameWriter, ) +from sqlframe.postgres.table import PostgresTable from sqlframe.postgres.udf import PostgresUDFRegistration if t.TYPE_CHECKING: @@ -27,6 +28,7 @@ class PostgresSession( PostgresDataFrameReader, PostgresDataFrameWriter, PostgresDataFrame, + PostgresTable, psycopg2_connection, PostgresUDFRegistration, ], @@ -35,6 +37,7 @@ class PostgresSession( _reader = PostgresDataFrameReader _writer = PostgresDataFrameWriter _df = PostgresDataFrame + _table = PostgresTable _udf_registration = PostgresUDFRegistration def __init__(self, conn: t.Optional[psycopg2_connection] = None): diff --git a/sqlframe/postgres/table.py b/sqlframe/postgres/table.py new file mode 100644 index 0000000..f0e3cd0 --- /dev/null +++ b/sqlframe/postgres/table.py @@ -0,0 +1,24 @@ +from sqlframe.base.mixins.table_mixins import ( + DeleteSupportMixin, + MergeSupportMixin, + UpdateSupportMixin, +) +from sqlframe.base.table import ( + WhenMatched, + WhenNotMatched, + WhenNotMatchedBySource, + _BaseTable, +) +from sqlframe.postgres.dataframe import PostgresDataFrame + + +class PostgresTable( + PostgresDataFrame, + UpdateSupportMixin["PostgresDataFrame"], + DeleteSupportMixin["PostgresDataFrame"], + MergeSupportMixin["PostgresDataFrame"], + _BaseTable["PostgresDataFrame"], +): + _df = PostgresDataFrame + _merge_supported_clauses = [WhenMatched, WhenNotMatched, WhenNotMatchedBySource] + _merge_support_star = False diff --git a/sqlframe/redshift/readwriter.py b/sqlframe/redshift/readwriter.py index 3f7a573..5b2c3ab 100644 --- a/sqlframe/redshift/readwriter.py +++ b/sqlframe/redshift/readwriter.py @@ -13,11 +13,12 @@ if t.TYPE_CHECKING: from sqlframe.redshift.session import RedshiftSession # noqa from sqlframe.redshift.dataframe import RedshiftDataFrame # noqa + from sqlframe.redshift.table import RedshiftTable # noqa class RedshiftDataFrameReader( PandasLoaderMixin["RedshiftSession", "RedshiftDataFrame"], - _BaseDataFrameReader["RedshiftSession", "RedshiftDataFrame"], + _BaseDataFrameReader["RedshiftSession", "RedshiftDataFrame", "RedshiftTable"], ): pass diff --git a/sqlframe/redshift/session.py b/sqlframe/redshift/session.py index 3103988..a1e6614 100644 --- a/sqlframe/redshift/session.py +++ b/sqlframe/redshift/session.py @@ -10,6 +10,7 @@ RedshiftDataFrameReader, RedshiftDataFrameWriter, ) +from sqlframe.redshift.table import RedshiftTable from sqlframe.redshift.udf import RedshiftUDFRegistration if t.TYPE_CHECKING: @@ -24,6 +25,7 @@ class RedshiftSession( RedshiftDataFrameReader, RedshiftDataFrameWriter, RedshiftDataFrame, + RedshiftTable, RedshiftConnection, RedshiftUDFRegistration, ], @@ -32,6 +34,7 @@ class RedshiftSession( _reader = RedshiftDataFrameReader _writer = RedshiftDataFrameWriter _df = RedshiftDataFrame + _table = RedshiftTable _udf_registration = RedshiftUDFRegistration def __init__(self, conn: t.Optional[RedshiftConnection] = None): diff --git a/sqlframe/redshift/table.py b/sqlframe/redshift/table.py new file mode 100644 index 0000000..7bb5de5 --- /dev/null +++ b/sqlframe/redshift/table.py @@ -0,0 +1,15 @@ +from sqlframe.base.mixins.table_mixins import ( + DeleteSupportMixin, + UpdateSupportMixin, +) +from sqlframe.base.table import _BaseTable +from sqlframe.redshift.dataframe import RedshiftDataFrame + + +class RedshiftTable( + RedshiftDataFrame, + UpdateSupportMixin["RedshiftDataFrame"], + DeleteSupportMixin["RedshiftDataFrame"], + _BaseTable["RedshiftDataFrame"], +): + _df = RedshiftDataFrame diff --git a/sqlframe/snowflake/readwriter.py b/sqlframe/snowflake/readwriter.py index 4d62b75..38a7c4b 100644 --- a/sqlframe/snowflake/readwriter.py +++ b/sqlframe/snowflake/readwriter.py @@ -13,11 +13,12 @@ if t.TYPE_CHECKING: from sqlframe.snowflake.session import SnowflakeSession # noqa from sqlframe.snowflake.dataframe import SnowflakeDataFrame # noqa + from sqlframe.snowflake.table import SnowflakeTable # noqa class SnowflakeDataFrameReader( PandasLoaderMixin["SnowflakeSession", "SnowflakeDataFrame"], - _BaseDataFrameReader["SnowflakeSession", "SnowflakeDataFrame"], + _BaseDataFrameReader["SnowflakeSession", "SnowflakeDataFrame", "SnowflakeTable"], ): pass diff --git a/sqlframe/snowflake/session.py b/sqlframe/snowflake/session.py index ddfd742..7d9517f 100644 --- a/sqlframe/snowflake/session.py +++ b/sqlframe/snowflake/session.py @@ -18,6 +18,7 @@ SnowflakeDataFrameReader, SnowflakeDataFrameWriter, ) +from sqlframe.snowflake.table import SnowflakeTable if t.TYPE_CHECKING: from snowflake.connector import SnowflakeConnection @@ -51,6 +52,7 @@ class SnowflakeSession( SnowflakeDataFrameReader, SnowflakeDataFrameWriter, SnowflakeDataFrame, + SnowflakeTable, SnowflakeConnection, SnowflakeUDFRegistration, ], @@ -59,6 +61,7 @@ class SnowflakeSession( _reader = SnowflakeDataFrameReader _writer = SnowflakeDataFrameWriter _df = SnowflakeDataFrame + _table = SnowflakeTable _udf_registration = SnowflakeUDFRegistration def __init__(self, conn: t.Optional[SnowflakeConnection] = None): diff --git a/sqlframe/snowflake/table.py b/sqlframe/snowflake/table.py new file mode 100644 index 0000000..6d61a27 --- /dev/null +++ b/sqlframe/snowflake/table.py @@ -0,0 +1,23 @@ +from sqlframe.base.mixins.table_mixins import ( + DeleteSupportMixin, + MergeSupportMixin, + UpdateSupportMixin, +) +from sqlframe.base.table import ( + WhenMatched, + WhenNotMatched, + _BaseTable, +) +from sqlframe.snowflake.dataframe import SnowflakeDataFrame + + +class SnowflakeTable( + SnowflakeDataFrame, + UpdateSupportMixin["SnowflakeDataFrame"], + DeleteSupportMixin["SnowflakeDataFrame"], + MergeSupportMixin["SnowflakeDataFrame"], + _BaseTable["SnowflakeDataFrame"], +): + _df = SnowflakeDataFrame + _merge_supported_clauses = [WhenMatched, WhenNotMatched] + _merge_support_star = False diff --git a/sqlframe/spark/readwriter.py b/sqlframe/spark/readwriter.py index 126601a..1a9cbfa 100644 --- a/sqlframe/spark/readwriter.py +++ b/sqlframe/spark/readwriter.py @@ -13,11 +13,12 @@ if t.TYPE_CHECKING: from sqlframe.spark.dataframe import SparkDataFrame from sqlframe.spark.session import SparkSession + from sqlframe.spark.table import SparkTable class SparkDataFrameReader( PandasLoaderMixin["SparkSession", "SparkDataFrame"], - _BaseDataFrameReader["SparkSession", "SparkDataFrame"], + _BaseDataFrameReader["SparkSession", "SparkDataFrame", "SparkTable"], ): pass diff --git a/sqlframe/spark/session.py b/sqlframe/spark/session.py index bcf5dc0..3e24d6a 100644 --- a/sqlframe/spark/session.py +++ b/sqlframe/spark/session.py @@ -13,6 +13,7 @@ SparkDataFrameReader, SparkDataFrameWriter, ) +from sqlframe.spark.table import SparkTable from sqlframe.spark.types import Row from sqlframe.spark.udf import SparkUDFRegistration @@ -32,6 +33,7 @@ class SparkSession( SparkDataFrameReader, SparkDataFrameWriter, SparkDataFrame, + SparkTable, PySparkSession, SparkUDFRegistration, ], @@ -40,6 +42,7 @@ class SparkSession( _reader = SparkDataFrameReader _writer = SparkDataFrameWriter _df = SparkDataFrame + _table = SparkTable _udf_registration = SparkUDFRegistration def __init__(self, conn: t.Optional[PySparkSession] = None, *args, **kwargs): diff --git a/sqlframe/spark/table.py b/sqlframe/spark/table.py new file mode 100644 index 0000000..b7fbeff --- /dev/null +++ b/sqlframe/spark/table.py @@ -0,0 +1,6 @@ +from sqlframe.base.table import _BaseTable +from sqlframe.spark.dataframe import SparkDataFrame + + +class SparkTable(SparkDataFrame, _BaseTable["SparkDataFrame"]): + _df = SparkDataFrame diff --git a/sqlframe/standalone/readwriter.py b/sqlframe/standalone/readwriter.py index 6343271..f5e05d0 100644 --- a/sqlframe/standalone/readwriter.py +++ b/sqlframe/standalone/readwriter.py @@ -9,9 +9,12 @@ if t.TYPE_CHECKING: from sqlframe.standalone.dataframe import StandaloneDataFrame from sqlframe.standalone.session import StandaloneSession + from sqlframe.standalone.table import StandaloneTable -class StandaloneDataFrameReader(_BaseDataFrameReader["StandaloneSession", "StandaloneDataFrame"]): +class StandaloneDataFrameReader( + _BaseDataFrameReader["StandaloneSession", "StandaloneDataFrame", "StandaloneTable"] +): pass diff --git a/sqlframe/standalone/session.py b/sqlframe/standalone/session.py index 9ef49b9..327efd3 100644 --- a/sqlframe/standalone/session.py +++ b/sqlframe/standalone/session.py @@ -9,6 +9,7 @@ StandaloneDataFrameReader, StandaloneDataFrameWriter, ) +from sqlframe.standalone.table import StandaloneTable from sqlframe.standalone.udf import StandaloneUDFRegistration @@ -18,6 +19,7 @@ class StandaloneSession( StandaloneDataFrameReader, StandaloneDataFrameWriter, StandaloneDataFrame, + StandaloneTable, object, StandaloneUDFRegistration, ] @@ -26,6 +28,7 @@ class StandaloneSession( _reader = StandaloneDataFrameReader _writer = StandaloneDataFrameWriter _df = StandaloneDataFrame + _table = StandaloneTable _udf_registration = StandaloneUDFRegistration @property diff --git a/sqlframe/standalone/table.py b/sqlframe/standalone/table.py new file mode 100644 index 0000000..66888b7 --- /dev/null +++ b/sqlframe/standalone/table.py @@ -0,0 +1,6 @@ +from sqlframe.base.table import _BaseTable +from sqlframe.standalone.dataframe import StandaloneDataFrame + + +class StandaloneTable(StandaloneDataFrame, _BaseTable["StandaloneDataFrame"]): + _df = StandaloneDataFrame diff --git a/tests/integration/engines/test_engine_table.py b/tests/integration/engines/test_engine_table.py new file mode 100644 index 0000000..dd9a283 --- /dev/null +++ b/tests/integration/engines/test_engine_table.py @@ -0,0 +1,413 @@ +from __future__ import annotations + +import datetime +import typing as t + +import pytest +from pyspark.sql import SparkSession as PySparkSession + +from sqlframe.base.table import WhenMatched, WhenNotMatched, _BaseTable +from sqlframe.base.types import Row +from sqlframe.bigquery.session import BigQuerySession +from sqlframe.databricks import DatabricksSession +from sqlframe.duckdb.session import DuckDBSession +from sqlframe.postgres.session import PostgresSession +from sqlframe.redshift.session import RedshiftSession +from sqlframe.snowflake.session import SnowflakeSession +from sqlframe.spark.session import SparkSession +from sqlframe.standalone.session import StandaloneSession + +if t.TYPE_CHECKING: + from sqlframe.base.dataframe import BaseDataFrame + +pytest_plugins = ["tests.integration.fixtures"] + + +@pytest.fixture +def merge_data() -> t.Tuple[t.List[t.Any], str]: + return ( + [ + ( + "3368c22d-edd8-4ae9-a0b8-f0956b9ffa88", + 1, + "Jack", + "Shephard", + 37, + 1, + "1999-01-01", + "2900-01-01", + ), + ( + "6db4c8ce-464d-4772-a686-e6bd9e10286f", + 2, + "John", + "Locke", + 65, + 1, + "1999-01-01", + "2900-01-01", + ), + ( + "7984b6b1-cb4c-49a0-bc83-61064f4d77e6", + 3, + "Kate", + "Austen", + 37, + 2, + "1999-01-01", + "2900-01-01", + ), + ( + "8ee4d1bf-c3c7-42a0-8727-233881d6212e", + 4, + "Claire", + "Littleton", + 27, + 2, + "1999-01-01", + "2900-01-01", + ), + ( + "0c67fff1-9471-4360-8624-982b009cf315", + 5, + "Hugo", + "Reyes", + 29, + 100, + "1999-01-01", + "2900-01-01", + ), + ], + "s_id STRING, employee_id INTEGER, fname STRING, lname STRING, age INTEGER, " + "store_id INTEGER, start_date DATE, end_date DATE", + ) + + +@pytest.fixture +def cleanup_employee_df( + get_engine_df: t.Callable[[str], BaseDataFrame], +) -> t.Iterator[BaseDataFrame]: + df = get_engine_df("employee") + df.session._execute("DROP TABLE IF EXISTS update_employee") + df.session._execute("DROP TABLE IF EXISTS merge_employee") + df.session._execute("DROP TABLE IF EXISTS delete_employee") + yield df + df.session._execute("DROP TABLE IF EXISTS update_employee") + df.session._execute("DROP TABLE IF EXISTS merge_employee") + df.session._execute("DROP TABLE IF EXISTS delete_employee") + + +def test_update_table(cleanup_employee_df: BaseDataFrame, caplog): + session = cleanup_employee_df.session + if isinstance( + session, + ( + StandaloneSession, + PySparkSession, + SparkSession, + ), + ): + pytest.skip("Engine doesn't support update") + df_employee = cleanup_employee_df + df_employee.write.saveAsTable("update_employee") + df = session.read.table("update_employee") + assert isinstance(df, _BaseTable) + update_expr = df.update( + set_={"age": df["age"] + 1}, + where=df["employee_id"] == 1, + ) + result = update_expr.execute() + # Postgres, RedshiftSession and BigQuery don't support returning the number of affected rows + if not isinstance(session, (PostgresSession, BigQuerySession, RedshiftSession)): + assert result[0][0] == 1 + + df2 = session.read.table("update_employee") + assert ( + df2.where(df2["employee_id"] == 1).select("age").collect()[0]["age"] + == df_employee.where(df_employee["employee_id"] == 1).select("age").collect()[0]["age"] + 1 + ) + + +def test_delete_table(cleanup_employee_df: BaseDataFrame, caplog): + session = cleanup_employee_df.session + if isinstance( + session, + ( + StandaloneSession, + PySparkSession, + SparkSession, + ), + ): + pytest.skip("Engine doesn't support delete") + df_employee = cleanup_employee_df + df_employee.write.saveAsTable("delete_employee") + df = session.read.table("delete_employee") + assert isinstance(df, _BaseTable) + delete_expr = df.delete(where=df["age"] > 28) + result = delete_expr.execute() + # Postgres, RedshiftSession and BigQuery don't support returning the number of affected rows + if not isinstance(session, (PostgresSession, BigQuerySession, RedshiftSession)): + assert result[0][0] == 4 + + df2 = session.read.table("delete_employee") + assert df2.collect() == [ + Row(employee_id=4, fname="Claire", lname="Littleton", age=27, store_id=2) + ] + + +def test_merge_table_simple(cleanup_employee_df: BaseDataFrame, caplog): + session = cleanup_employee_df.session + if isinstance( + session, + ( + StandaloneSession, + PySparkSession, + RedshiftSession, + SparkSession, + DuckDBSession, + ), + ): + pytest.skip("Engine doesn't support merge") + df_employee = cleanup_employee_df + df_employee.write.saveAsTable("merge_employee") + df = session.read.table("merge_employee") + assert isinstance(df, _BaseTable) + + df2 = session.createDataFrame( + [ + (1, "Jack", "Shephard", 38, 2), + (6, "Mary", "Sue", 21, 45), + ], + ["employee_id", "fname", "lname", "age", "store_id"], + ) + + merge_expr = df.merge( + df2, + condition=df.employee_id == df2.employee_id, + clauses=[ + WhenMatched(condition=df.fname == df2.fname).update( + set_={ + "lname": df2.lname, + "age": df2.age, + "store_id": df2.store_id, + } + ), + WhenNotMatched().insert( + values={ + "employee_id": df2.employee_id, + "fname": df2.fname, + "lname": df2.lname, + "age": df2.age, + "store_id": df2.store_id, + } + ), + ], + ) + result = merge_expr.execute() + # Postgres and BigQuery don't support returning the number of affected rows + if not isinstance(session, (PostgresSession, BigQuerySession)): + if isinstance(session, SnowflakeSession): + assert (result[0][0] + result[0][1]) == 2 + else: + assert result[0][0] == 2 + + df_merged = session.read.table("merge_employee") + assert sorted(df_merged.collect()) == [ + Row(employee_id=1, fname="Jack", lname="Shephard", age=38, store_id=2), + Row(employee_id=2, fname="John", lname="Locke", age=65, store_id=1), + Row(employee_id=3, fname="Kate", lname="Austen", age=37, store_id=2), + Row(employee_id=4, fname="Claire", lname="Littleton", age=27, store_id=2), + Row(employee_id=5, fname="Hugo", lname="Reyes", age=29, store_id=100), + Row(employee_id=6, fname="Mary", lname="Sue", age=21, store_id=45), + ] + + +def test_merge_table(cleanup_employee_df: BaseDataFrame, merge_data, get_func, caplog): + session = cleanup_employee_df.session + col = get_func("col", session) + expr = get_func("expr", session) + lit = get_func("lit", session) + if isinstance( + session, + ( + StandaloneSession, + PySparkSession, + RedshiftSession, + SparkSession, + DuckDBSession, + ), + ): + pytest.skip("Engine doesn't support merge") + + if isinstance(session, DatabricksSession): + uuid_func = "uuid" + elif isinstance(session, PostgresSession): + uuid_func = "gen_random_uuid" + elif isinstance(session, SnowflakeSession): + uuid_func = "uuid_string" + elif isinstance(session, BigQuerySession): + uuid_func = "generate_uuid" + else: + pytest.skip("Cannot generate uuids in this engine") + + data, schema = merge_data + df_employee = session.createDataFrame(data, schema) + df_employee.write.saveAsTable("merge_employee") + df = session.read.table("merge_employee") + assert isinstance(df, _BaseTable) + + start_date = "2024-01-01" + end_date = "2900-01-01" + df2 = session.createDataFrame( + [ + (1, "Jack", "Shephard", 38, 2), + (6, "Mary", "Sue", 21, 45), + ], + ["employee_id", "fname", "lname", "age", "store_id"], + ) + + source = df2.select( + expr(f"{uuid_func}()").alias("s_id"), + col("employee_id").alias("employee_id"), + col("fname").alias("fname"), + col("lname").alias("lname"), + col("age").alias("age"), + col("store_id").alias("store_id"), + lit(end_date).alias("end_date"), + ) + + updates = ( + source.alias("src") + .join( + df.alias("tg"), + on=(col("tg.employee_id") == col("src.employee_id")), + ) + .where( + f"tg.end_date = '{end_date}' " + f"AND (" + f"src.fname <> tg.fname OR src.lname <> tg.lname OR src.age <> tg.age OR src.store_id <> tg.store_id" + f")" + ) + ) + + staging = updates.select( + lit(0).alias("__key__"), + col("src.*"), + lit(start_date).cast("date").alias("start_date"), + ).unionByName( + source.alias("src").select( + lit(1).alias("__key__"), + col("src.*"), + lit(start_date).cast("date").alias("start_date"), + ) + ) + + merge_expr = df.alias("tg").merge( + staging.alias("st"), + condition=(col("st.employee_id") == col("tg.employee_id")) & (col("st.__key__") == lit(1)), + clauses=[ + WhenMatched( + condition=(col("tg.end_date") == lit(end_date)) + & ( + (col("st.fname") != col("tg.fname")) + | (col("st.lname") != col("tg.lname")) + | (col("st.age") != col("tg.age")) + | (col("st.store_id") != col("tg.store_id")) + ) + ).update( + set_={ + "end_date": col("st.start_date"), + } + ), + WhenNotMatched().insert( + values={ + "s_id": col("st.s_id"), + "employee_id": col("st.employee_id"), + "fname": col("st.fname"), + "lname": col("st.lname"), + "age": col("st.age"), + "store_id": col("st.store_id"), + "start_date": lit(start_date), + "end_date": lit(end_date), + } + ), + ], + ) + + result = merge_expr.execute() + if not isinstance(session, (PostgresSession, BigQuerySession)): + if isinstance(session, SnowflakeSession): + assert (result[0][0] + result[0][1]) == 3 + else: + assert result[0][0] == 3 + + df_merged = session.read.table("merge_employee").select( + "employee_id", "fname", "lname", "age", "store_id", "start_date", "end_date" + ) + + assert df_merged.count() == 7 + assert sorted(df_merged.collect()) == [ + Row( + employee_id=1, + fname="Jack", + lname="Shephard", + age=37, + store_id=1, + start_date=datetime.date(1999, 1, 1), + end_date=datetime.date(2024, 1, 1), + ), + Row( + employee_id=1, + fname="Jack", + lname="Shephard", + age=38, + store_id=2, + start_date=datetime.date(2024, 1, 1), + end_date=datetime.date(2900, 1, 1), + ), + Row( + employee_id=2, + fname="John", + lname="Locke", + age=65, + store_id=1, + start_date=datetime.date(1999, 1, 1), + end_date=datetime.date(2900, 1, 1), + ), + Row( + employee_id=3, + fname="Kate", + lname="Austen", + age=37, + store_id=2, + start_date=datetime.date(1999, 1, 1), + end_date=datetime.date(2900, 1, 1), + ), + Row( + employee_id=4, + fname="Claire", + lname="Littleton", + age=27, + store_id=2, + start_date=datetime.date(1999, 1, 1), + end_date=datetime.date(2900, 1, 1), + ), + Row( + employee_id=5, + fname="Hugo", + lname="Reyes", + age=29, + store_id=100, + start_date=datetime.date(1999, 1, 1), + end_date=datetime.date(2900, 1, 1), + ), + Row( + employee_id=6, + fname="Mary", + lname="Sue", + age=21, + store_id=45, + start_date=datetime.date(2024, 1, 1), + end_date=datetime.date(2900, 1, 1), + ), + ]