Skip to content

Commit

Permalink
Databricks tests (#218)
Browse files Browse the repository at this point in the history
* Added databricks tests

* Fixed test

* Added Unit tests

* Added Unit tests

* Fixed typos

* Fixed style

* Revert part of the syntax

* Revert "removing quotes from schema"

* Revert "removing quotes from schema"

* "Fixed" struct and map return type. Changed tests to work with the fixed semantics.

* Fixed test

* Fixed style

* Added disable_pandas option to default databricks connection.

* Changed 'test_transform_keys' for databricks.

* Added comments to window related functions.

* Added comments to window related functions.
  • Loading branch information
zerodarkzone authored Dec 15, 2024
1 parent 61fda5d commit 4de9375
Show file tree
Hide file tree
Showing 18 changed files with 957 additions and 82 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ duckdb-test:
snowflake-test:
pytest -n auto -m "snowflake"

databricks-test:
pytest -n auto -m "databricks"

style:
pre-commit run --all-files

Expand Down
2 changes: 0 additions & 2 deletions docs/databricks.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from test import auth_type

# Databricks (In Development)

## Installation
Expand Down
5 changes: 5 additions & 0 deletions sqlframe/base/function_alternatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,6 +1220,11 @@ def get_json_object_cast_object(col: ColumnOrName, path: str) -> Column:
return get_json_object(col_func(col).cast("variant"), path)


def get_json_object_using_function(col: ColumnOrName, path: str) -> Column:
lit = get_func_from_session("lit")
return Column.invoke_anonymous_function(col, "GET_JSON_OBJECT", lit(path))


def create_map_with_cast(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
from sqlframe.base.functions import create_map

Expand Down
20 changes: 10 additions & 10 deletions sqlframe/base/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2173,7 +2173,7 @@ def current_database() -> Column:
current_schema = current_database


@meta(unsupported_engines="*")
@meta(unsupported_engines=["*", "databricks"])
def current_timezone() -> Column:
return Column.invoke_anonymous_function(None, "current_timezone")

Expand Down Expand Up @@ -2261,7 +2261,7 @@ def get(col: ColumnOrName, index: t.Union[ColumnOrName, int]) -> Column:
return Column.invoke_anonymous_function(col, "get", index)


@meta(unsupported_engines="*")
@meta(unsupported_engines=["*", "databricks"])
def get_active_spark_context() -> SparkContext:
"""Raise RuntimeError if SparkContext is not initialized,
otherwise, returns the active SparkContext."""
Expand Down Expand Up @@ -2778,7 +2778,7 @@ def isnotnull(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "isnotnull")


@meta(unsupported_engines="*")
@meta(unsupported_engines=["*", "databricks"])
def java_method(*cols: ColumnOrName) -> Column:
"""
Calls a method with reflection.
Expand Down Expand Up @@ -3050,7 +3050,7 @@ def ln(col: ColumnOrName) -> Column:
return Column.invoke_expression_over_column(col, expression.Ln)


@meta(unsupported_engines="*")
@meta(unsupported_engines=["*", "databricks"])
def localtimestamp() -> Column:
"""
Returns the current timestamp without time zone at the start of query evaluation
Expand Down Expand Up @@ -3080,7 +3080,7 @@ def localtimestamp() -> Column:
return Column.invoke_anonymous_function(None, "localtimestamp")


@meta(unsupported_engines="*")
@meta(unsupported_engines=["*", "databricks"])
def make_dt_interval(
days: t.Optional[ColumnOrName] = None,
hours: t.Optional[ColumnOrName] = None,
Expand Down Expand Up @@ -3227,7 +3227,7 @@ def make_timestamp(
)


@meta(unsupported_engines="*")
@meta(unsupported_engines=["*", "databricks"])
def make_timestamp_ltz(
years: ColumnOrName,
months: ColumnOrName,
Expand Down Expand Up @@ -3354,7 +3354,7 @@ def make_timestamp_ntz(
)


@meta(unsupported_engines="*")
@meta(unsupported_engines=["*", "databricks"])
def make_ym_interval(
years: t.Optional[ColumnOrName] = None,
months: t.Optional[ColumnOrName] = None,
Expand Down Expand Up @@ -3922,7 +3922,7 @@ def printf(format: ColumnOrName, *cols: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(format, "printf", *cols)


@meta(unsupported_engines=["*", "spark"])
@meta(unsupported_engines=["*", "spark", "databricks"])
def product(col: ColumnOrName) -> Column:
"""
Aggregate function: returns the product of the values in a group.
Expand Down Expand Up @@ -3961,7 +3961,7 @@ def product(col: ColumnOrName) -> Column:
reduce = aggregate


@meta(unsupported_engines="*")
@meta(unsupported_engines=["*", "databricks"])
def reflect(*cols: ColumnOrName) -> Column:
"""
Calls a method with reflection.
Expand Down Expand Up @@ -5046,7 +5046,7 @@ def to_str(value: t.Any) -> t.Optional[str]:
return str(value)


@meta(unsupported_engines="*")
@meta(unsupported_engines=["*", "databricks"])
def to_timestamp_ltz(
timestamp: ColumnOrName,
format: t.Optional[ColumnOrName] = None,
Expand Down
14 changes: 12 additions & 2 deletions sqlframe/databricks/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@


class DatabricksCatalog(
SetCurrentCatalogFromUseMixin["DatabricksSession", "DatabricksDataFrame"],
GetCurrentCatalogFromFunctionMixin["DatabricksSession", "DatabricksDataFrame"],
GetCurrentDatabaseFromFunctionMixin["DatabricksSession", "DatabricksDataFrame"],
ListDatabasesFromInfoSchemaMixin["DatabricksSession", "DatabricksDataFrame"],
Expand All @@ -38,6 +37,15 @@ class DatabricksCatalog(
CURRENT_CATALOG_EXPRESSION: exp.Expression = exp.func("current_catalog")
UPPERCASE_INFO_SCHEMA = True

def setCurrentCatalog(self, catalogName: str) -> None:
self.session._collect(
exp.Use(
kind=exp.Var(this=exp.to_identifier("CATALOG")),
this=exp.parse_identifier(catalogName, dialect=self.session.input_dialect),
),
quote_identifiers=False,
)

def listFunctions(
self, dbName: t.Optional[str] = None, pattern: t.Optional[str] = None
) -> t.List[Function]:
Expand Down Expand Up @@ -106,7 +114,9 @@ def listFunctions(
)
functions = [
Function(
name=normalize_string(x["function"], from_dialect="execution", to_dialect="output"),
name=normalize_string(
x["function"].split(".")[-1], from_dialect="execution", to_dialect="output"
),
catalog=normalize_string(
schema.catalog, from_dialect="execution", to_dialect="output"
),
Expand Down
5 changes: 3 additions & 2 deletions sqlframe/databricks/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import logging
import sys
import typing as t

from sqlframe.base.catalog import Column as CatalogColumn
Expand Down Expand Up @@ -52,7 +51,9 @@ def _typed_columns(self) -> t.List[CatalogColumn]:
columns.append(
CatalogColumn(
name=normalize_string(
row.col_name, from_dialect="execution", to_dialect="output"
row.col_name,
from_dialect="execution",
to_dialect="output",
),
dataType=normalize_string(
row.data_type,
Expand Down
1 change: 1 addition & 0 deletions sqlframe/databricks/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@
arrays_overlap_renamed as arrays_overlap,
_is_string_using_typeof_string_lcase as _is_string,
try_element_at_zero_based as try_element_at,
get_json_object_using_function as get_json_object,
)
15 changes: 14 additions & 1 deletion sqlframe/databricks/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,20 @@ def __init__(
from databricks import sql

if not hasattr(self, "_conn"):
super().__init__(conn or sql.connect(server_hostname, http_path, access_token))
super().__init__(
conn or sql.connect(server_hostname, http_path, access_token, disable_pandas=True)
)

@classmethod
def _try_get_map(cls, value: t.Any) -> t.Optional[t.Dict[str, t.Any]]:
if (
value
and isinstance(value, list)
and all(isinstance(item, tuple) for item in value)
and all(len(item) == 2 for item in value)
):
return dict(value)
return None

class Builder(_BaseSession.Builder):
DEFAULT_EXECUTION_DIALECT = "databricks"
Expand Down
27 changes: 27 additions & 0 deletions tests/common_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from sqlframe.base.session import _BaseSession
from sqlframe.bigquery.session import BigQuerySession
from sqlframe.databricks.session import DatabricksSession
from sqlframe.duckdb.session import DuckDBSession
from sqlframe.postgres.session import PostgresSession
from sqlframe.redshift.session import RedshiftSession
Expand All @@ -22,6 +23,7 @@
from sqlframe.standalone.session import StandaloneSession

if t.TYPE_CHECKING:
from databricks.sql import Connection as DatabricksConnection
from google.cloud.bigquery.dbapi.connection import (
Connection as BigQueryConnection,
)
Expand Down Expand Up @@ -231,6 +233,31 @@ def snowflake_session(snowflake_connection: SnowflakeConnection) -> SnowflakeSes
return session


@pytest.fixture(scope="session")
def databricks_connection() -> DatabricksConnection:
from databricks.sql import connect

conn = connect(
server_hostname=os.environ["SQLFRAME_DATABRICKS_SERVER_HOSTNAME"],
http_path=os.environ["SQLFRAME_DATABRICKS_HTTP_PATH"],
access_token=os.environ["SQLFRAME_DATABRICKS_ACCESS_TOKEN"],
auth_type="access_token",
catalog=os.environ["SQLFRAME_DATABRICKS_CATALOG"],
schema=os.environ["SQLFRAME_DATABRICKS_SCHEMA"],
_disable_pandas=True,
)
return conn


@pytest.fixture
def databricks_session(databricks_connection: DatabricksConnection) -> DatabricksSession:
session = DatabricksSession(databricks_connection)
session._execute("CREATE SCHEMA IF NOT EXISTS db1")
session._execute("CREATE TABLE IF NOT EXISTS db1.table1 (id INTEGER, name VARCHAR(100))")
session._execute("CREATE OR REPLACE FUNCTION db1.add(x INT, y INT) RETURNS INT RETURN x + y")
return session


@pytest.fixture(scope="module")
def _employee_data() -> EmployeeData:
return [
Expand Down
Empty file.
Loading

0 comments on commit 4de9375

Please sign in to comment.