Skip to content

Commit

Permalink
Add dest int test
Browse files Browse the repository at this point in the history
  • Loading branch information
rbiseck3 committed Jan 8, 2025
1 parent ffd643e commit 3056d8d
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 7 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ jobs:
DATABRICKS_CATALOG: ${{secrets.DATABRICKS_CATALOG}}
DATABRICKS_CLIENT_ID: ${{secrets.DATABRICKS_CLIENT_ID}}
DATABRICKS_CLIENT_SECRET: ${{secrets.DATABRICKS_CLIENT_SECRET}}
DATABRICKS_SERVER_HOSTNAME: ${{ secrets.DATABRICKS_SERVER_HOSTNAME }}
DATABRICKS_HTTP_PATH: ${{ secrets.DATABRICKS_HTTP_PATH }}
DATABRICKS_ACCESS_TOKEN: ${{ secrets.DATABRICKS_ACCESS_TOKEN }}
S3_INGEST_TEST_ACCESS_KEY: ${{ secrets.S3_INGEST_TEST_ACCESS_KEY }}
S3_INGEST_TEST_SECRET_KEY: ${{ secrets.S3_INGEST_TEST_SECRET_KEY }}
GCP_INGEST_SERVICE_KEY: ${{ secrets.GCP_INGEST_SERVICE_KEY }}
Expand Down
141 changes: 141 additions & 0 deletions test/integration/connectors/sql/test_databricks_delta_tables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import json
import os
import time
from contextlib import contextmanager
from pathlib import Path
from uuid import uuid4

import pytest
from databricks.sql import connect
from databricks.sql.client import Connection as DeltaTableConnection
from databricks.sql.client import Cursor as DeltaTableCursor
from pydantic import BaseModel, SecretStr

from test.integration.connectors.utils.constants import DESTINATION_TAG, env_setup_path
from test.integration.utils import requires_env
from unstructured_ingest.v2.interfaces import FileData, SourceIdentifiers
from unstructured_ingest.v2.logger import logger
from unstructured_ingest.v2.processes.connectors.sql.databricks_delta_tables import (
CONNECTOR_TYPE,
DatabrickDeltaTablesAccessConfig,
DatabrickDeltaTablesConnectionConfig,
DatabrickDeltaTablesUploader,
DatabrickDeltaTablesUploaderConfig,
DatabrickDeltaTablesUploadStager,
)

CATALOG = "utic-dev-tech-fixtures"


class EnvData(BaseModel):
server_hostname: str
http_path: str
access_token: SecretStr


def get_env_data() -> EnvData:
return EnvData(
server_hostname=os.environ["DATABRICKS_SERVER_HOSTNAME"],
http_path=os.environ["DATABRICKS_HTTP_PATH"],
access_token=os.environ["DATABRICKS_ACCESS_TOKEN"],
)


def get_destination_schema(new_table_name: str) -> str:
p = Path(env_setup_path / "sql" / "databricks_delta_tables" / "destination" / "schema.sql")
with p.open() as f:
data_lines = f.readlines()
data_lines[0] = data_lines[0].replace("elements", new_table_name)
data = "".join([line.strip() for line in data_lines])
return data


@contextmanager
def get_connection() -> DeltaTableConnection:
env_data = get_env_data()
with connect(
server_hostname=env_data.server_hostname,
http_path=env_data.http_path,
access_token=env_data.access_token.get_secret_value(),
) as connection:
yield connection


@contextmanager
def get_cursor() -> DeltaTableCursor:
with get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute(f"USE CATALOG '{CATALOG}'")
yield cursor


@pytest.fixture
def destination_table() -> str:
random_id = str(uuid4())[:8]
table_name = f"elements_{random_id}"
destination_schema = get_destination_schema(new_table_name=table_name)
with get_cursor() as cursor:
logger.info(f"creating table: {table_name}")
cursor.execute(f"DROP TABLE IF EXISTS {table_name}")
cursor.execute(destination_schema)

yield table_name
with get_cursor() as cursor:
logger.info(f"dropping table: {table_name}")
cursor.execute(f"DROP TABLE IF EXISTS {table_name}")


def validate_destination(expected_num_elements: int, table_name: str, retries=30, interval=1):
with get_cursor() as cursor:
for i in range(retries):
cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
count = cursor.fetchone()[0]
if count == expected_num_elements:
break
logger.info(f"retry attempt {i}: expected {expected_num_elements} != count {count}")
time.sleep(interval)
assert (
count == expected_num_elements
), f"dest check failed: got {count}, expected {expected_num_elements}"


@pytest.mark.asyncio
@pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG, "sql")
@requires_env("DATABRICKS_SERVER_HOSTNAME", "DATABRICKS_HTTP_PATH", "DATABRICKS_ACCESS_TOKEN")
async def test_databricks_delta_tables_destination(
upload_file: Path, temp_dir: Path, destination_table: str
):
env_data = get_env_data()
mock_file_data = FileData(
identifier="mock file data",
connector_type=CONNECTOR_TYPE,
source_identifiers=SourceIdentifiers(filename=upload_file.name, fullpath=upload_file.name),
)
stager = DatabrickDeltaTablesUploadStager()
staged_path = stager.run(
elements_filepath=upload_file,
file_data=mock_file_data,
output_dir=temp_dir,
output_filename=upload_file.name,
)

assert staged_path.suffix == upload_file.suffix

uploader = DatabrickDeltaTablesUploader(
connection_config=DatabrickDeltaTablesConnectionConfig(
access_config=DatabrickDeltaTablesAccessConfig(
token=env_data.access_token.get_secret_value()
),
http_path=env_data.http_path,
server_hostname=env_data.server_hostname,
),
upload_config=DatabrickDeltaTablesUploaderConfig(
catalog=CATALOG, database="default", table_name=destination_table
),
)
with staged_path.open("r") as f:
staged_data = json.load(f)
expected_num_elements = len(staged_data)
uploader.precheck()
uploader.run(path=staged_path, file_data=mock_file_data)
validate_destination(expected_num_elements=expected_num_elements, table_name=destination_table)
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from contextlib import contextmanager
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generator, Optional

import numpy as np
Expand All @@ -25,7 +25,7 @@

if TYPE_CHECKING:
from databricks.sdk.core import oauth_service_principal
from databricks.sql import Connection as DeltaTableConnection
from databricks.sql.client import Connection as DeltaTableConnection
from databricks.sql.client import Cursor as DeltaTableCursor

CONNECTOR_TYPE = "databricks_delta_tables"
Expand Down Expand Up @@ -110,14 +110,12 @@ class DatabrickDeltaTablesUploadStager(SQLUploadStager):
class DatabrickDeltaTablesUploaderConfig(SQLUploaderConfig):
catalog: str = Field(description="Name of the catalog in the Databricks Unity Catalog service")
database: str = Field(description="Database name", default="default")
table: str = Field(description="Table name")
table_name: str = Field(description="Table name")


@dataclass
class DatabrickDeltaTablesUploader(SQLUploader):
upload_config: DatabrickDeltaTablesUploaderConfig = field(
default_factory=DatabrickDeltaTablesUploaderConfig
)
upload_config: DatabrickDeltaTablesUploaderConfig
connection_config: DatabrickDeltaTablesConnectionConfig
connector_type: str = CONNECTOR_TYPE

Expand Down Expand Up @@ -191,7 +189,7 @@ def upload_dataframe(self, df: pd.DataFrame, file_data: FileData) -> None:
f"writing a total of {len(df)} elements via"
f" document batches to destination"
f" table named {self.upload_config.table_name}"
f" with batch size {self.upload_config.batch_size}"
# f" with batch size {self.upload_config.batch_size}"
)
# TODO: currently variable binding not supporting for list types,
# update once that gets resolved in SDK
Expand Down

0 comments on commit 3056d8d

Please sign in to comment.