From fd6fa149ee9b5afca1924aa3d2f146cd73027b93 Mon Sep 17 00:00:00 2001 From: Dean Sherwin Date: Tue, 8 Oct 2024 17:22:33 +0200 Subject: [PATCH 01/12] basic databricks support working --- .gitignore | 3 + raster_loader/__init__.py | 4 + raster_loader/cli/databricks.py | 171 ++++++++++++++++++ raster_loader/errors.py | 8 + raster_loader/io/databricks.py | 297 ++++++++++++++++++++++++++++++++ setup.cfg | 2 + 6 files changed, 485 insertions(+) create mode 100644 raster_loader/cli/databricks.py create mode 100644 raster_loader/io/databricks.py diff --git a/.gitignore b/.gitignore index a0905f7..57511b3 100644 --- a/.gitignore +++ b/.gitignore @@ -178,3 +178,6 @@ carto_credentials.json .idea/codeStyles/codeStyleConfig.xml .idea/codeStyles/Project.xml .idea/.gitignore + +# Vim +*.swp diff --git a/raster_loader/__init__.py b/raster_loader/__init__.py index 36e7278..3b86888 100644 --- a/raster_loader/__init__.py +++ b/raster_loader/__init__.py @@ -6,9 +6,13 @@ from raster_loader.io.snowflake import ( SnowflakeConnection, ) +from raster_loader.io.databricks import ( + DatabricksConnection, +) __all__ = [ "__version__", "BigQueryConnection", "SnowflakeConnection", + "DatabricksConnection" ] diff --git a/raster_loader/cli/databricks.py b/raster_loader/cli/databricks.py new file mode 100644 index 0000000..f6b557e --- /dev/null +++ b/raster_loader/cli/databricks.py @@ -0,0 +1,171 @@ +import click +from functools import wraps, partial + +from raster_loader.utils import get_default_table_name +from raster_loader.io.databricks import DatabricksConnection + + +def catch_exception(func=None, *, handle=Exception): + if not func: + return partial(catch_exception, handle=handle) + + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except handle as e: + raise click.ClickException(str(e)) + + return wrapper + + +@click.group(context_settings=dict(help_option_names=["-h", "--help"])) +def databricks(args=None): + """ + Manage Databricks resources. + """ + pass + + +@databricks.command(help="Upload a raster file to Databricks.") +@click.option("--host", help="The Databricks host URL.", required=True) +@click.option("--token", help="The Databricks access token.", required=True) +@click.option("--cluster-id", help="The Databricks cluster ID.", required=True) # New option +@click.option( + "--file_path", help="The path to the raster file.", required=False, default=None +) +@click.option( + "--file_url", help="The URL to the raster file.", required=False, default=None +) +@click.option("--catalog", help="The name of the catalog.", required=True) +@click.option("--schema", help="The name of the schema.", required=True) +@click.option("--table", help="The name of the table.", default=None) +@click.option( + "--band", + help="Band(s) within raster to upload. " + "Could repeat --band to specify multiple bands.", + default=[1], + multiple=True, +) +@click.option( + "--band_name", + help="Column name(s) used to store band (Default: band_). " + "Could repeat --band_name to specify multiple bands column names. " + "List of column names HAVE to pair with --band list in the same order.", + default=[None], + multiple=True, +) +@click.option( + "--chunk_size", help="The number of blocks to upload in each chunk.", default=10000 +) +@click.option( + "--overwrite", + help="Overwrite existing data in the table if it already exists.", + default=False, + is_flag=True, +) +@click.option( + "--append", + help="Append records into a table if it already exists.", + default=False, + is_flag=True, +) +@click.option( + "--cleanup-on-failure", + help="Clean up resources if the upload fails. Useful for non-interactive scripts.", + default=False, + is_flag=True, +) +@catch_exception() +def upload( + host, + token, + cluster_id, # Accept cluster ID + file_path, + file_url, + catalog, + schema, + table, + band, + band_name, + chunk_size, + overwrite=False, + append=False, + cleanup_on_failure=False, +): + from raster_loader.io.common import ( + get_number_of_blocks, + print_band_information, + get_block_dims, + ) + import os + from urllib.parse import urlparse + + if file_path is None and file_url is None: + raise ValueError("Need either a --file_path or --file_url") + + if file_path and file_url: + raise ValueError("Only one of --file_path or --file_url must be provided.") + + is_local_file = file_path is not None + + # Check that band and band_name are the same length if band_name provided + if band_name != (None,): + if len(band) != len(band_name): + raise ValueError("Must supply the same number of band_names as bands") + else: + band_name = [None] * len(band) + + # Pair band and band_name in a list of tuples + bands_info = list(zip(band, band_name)) + + # Create default table name if not provided + if table is None: + table = get_default_table_name( + file_path if is_local_file else urlparse(file_url).path, band + ) + + connector = DatabricksConnection( + host=host, + token=token, + cluster_id=cluster_id, # Pass cluster_id to DatabricksConnection + catalog=catalog, + schema=schema, + ) + + source = file_path if is_local_file else file_url + + # Introspect raster file + num_blocks = get_number_of_blocks(source) + file_size_mb = 0 + if is_local_file: + file_size_mb = os.path.getsize(file_path) / 1024 / 1024 + + click.echo("Preparing to upload raster file to Databricks...") + click.echo(f"File Path: {source}") + click.echo(f"File Size: {file_size_mb} MB") + print_band_information(source) + click.echo(f"Source Band(s): {band}") + click.echo(f"Band Name(s): {band_name}") + click.echo(f"Number of Blocks: {num_blocks}") + click.echo(f"Block Dimensions: {get_block_dims(source)}") + click.echo(f"Catalog: {catalog}") + click.echo(f"Schema: {schema}") + click.echo(f"Table: {table}") + click.echo(f"Number of Records Per Batch: {chunk_size}") + + click.echo("Uploading Raster to Databricks") + + connector.upload_raster( + source, + table, + bands_info, + chunk_size, + overwrite=overwrite, + append=append, + cleanup_on_failure=cleanup_on_failure, + ) + + click.echo("Raster file uploaded to Databricks") + return 0 + diff --git a/raster_loader/errors.py b/raster_loader/errors.py index 5d3bcec..02f05cd 100644 --- a/raster_loader/errors.py +++ b/raster_loader/errors.py @@ -15,6 +15,13 @@ def import_error_snowflake(): # pragma: no cover ) raise ImportError(msg) +def import_error_databricks(): # pragma: no cover + msg = ( + "Databricks client is not installed.\n" + "Please install Databricks dependencies to use this function.\n" + 'run `pip install -U raster-loader"[databricks]"` to install from pypi.' + ) + raise ImportError(msg) class IncompatibleRasterException(Exception): def __init__(self): @@ -31,3 +38,4 @@ def __init__(self): def error_not_google_compatible(): # pragma: no cover raise IncompatibleRasterException() + diff --git a/raster_loader/io/databricks.py b/raster_loader/io/databricks.py new file mode 100644 index 0000000..bf98917 --- /dev/null +++ b/raster_loader/io/databricks.py @@ -0,0 +1,297 @@ +import json +import pandas as pd + +from typing import Iterable, List, Tuple + +from raster_loader.errors import ( + IncompatibleRasterException, + import_error_databricks, +) + +from raster_loader.utils import ask_yes_no_question, batched + +from raster_loader.io.common import ( + rasterio_metadata, + rasterio_windows_to_records, + get_number_of_blocks, + check_metadata_is_compatible, + update_metadata, +) +from raster_loader.io.datawarehouse import DataWarehouseConnection + +try: + from databricks.connect import DatabricksSession + from pyspark.sql.types import ( + StructType, + StructField, + StringType, + BinaryType, + IntegerType, + LongType, + ) +except ImportError: # pragma: no cover + _has_databricks = False +else: + _has_databricks = True + +class DatabricksConnection(DataWarehouseConnection): + def __init__(self, host, token, cluster_id, catalog, schema): + if not _has_databricks: + import_error_databricks() + + self.host = host + self.token = token + self.cluster_id = cluster_id + self.catalog = catalog + self.schema = schema + + self.client = self.get_connection() + + def get_connection(self): + # Initialize DatabricksSession + session = DatabricksSession.builder.remote(host=self.host, token=self.token, cluster_id=self.cluster_id).getOrCreate() + session.conf.set("spark.databricks.session.timeout", "6h") + return session + + def get_table_fqn(self, table): + return f"`{self.catalog}`.{self.schema}.{table}" + + def execute(self, sql): + # NOTE: if you get empty sql statement errors check runtime v databricks-connect version + # https://community.databricks.com/t5/data-engineering/parse-empty-statement-error-when-trying-to-use-spark-sql-via/td-p/80770 + return self.client.sql(sql) + + def execute_to_dataframe(self, sql): + df = self.execute(sql) + return df.toPandas() + + def create_schema_if_not_exists(self): + self.execute(f"CREATE SCHEMA IF NOT EXISTS `{self.catalog}`.{self.schema}") + + def create_table_if_not_exists(self, table): + self.execute( + f""" + CREATE TABLE IF NOT EXISTS `{self.catalog}`.{self.schema}.{table} ( + BLOCK BIGINT, + METADATA STRING, + {self.band_columns} + ) USING DELTA + """ + ) + + def band_rename_function(self, band_name: str): + return band_name.upper() + + def write_metadata( + self, + metadata, + append_records, + table, + ): + # Create a DataFrame with the metadata + schema = StructType( + [ + StructField("BLOCK", LongType(), True), + StructField("METADATA", StringType(), True), + ] + ) + + data = [(0, json.dumps(metadata))] + + metadata_df = self.client.createDataFrame(data, schema) + + # Write to table + fqn = self.get_table_fqn(table) + metadata_df.write.format("delta").mode("append").saveAsTable(fqn) + + return True + + def get_metadata(self, table): + fqn = self.get_table_fqn(table) + query = f""" + SELECT METADATA + FROM {fqn} + WHERE BLOCK = 0 + """ + result = self.execute_to_dataframe(query) + if result.empty: + return None + return json.loads(result.iloc[0]["METADATA"]) + + def check_if_table_exists(self, table): + sql = f""" + SELECT * + FROM `{self.catalog}`.INFORMATION_SCHEMA.TABLES + WHERE TABLE_SCHEMA = '{self.schema}' + AND TABLE_NAME = '{table}' + """ + df = self.execute(sql) + # If the count is greater than 0, the table exists + return df.count() > 0 + + def check_if_table_is_empty(self, table): + fqn = self.get_table_fqn(table) + df = self.client.table(fqn) + return df.count() == 0 + + def upload_records( + self, + records: Iterable, + table: str, + overwrite: bool, + ): + fqn = self.get_table_fqn(table) + records_list = [] + for record in records: + # Remove 'METADATA' from records, as it's handled separately + if "METADATA" in record: + del record["METADATA"] + records_list.append(record) + + data_df = pd.DataFrame(records_list) + spark_df = self.client.createDataFrame(data_df) + + if overwrite: + mode = "overwrite" + else: + mode = "append" + + spark_df.write.format("delta").mode(mode).saveAsTable(fqn) + + return True + + def upload_raster( + self, + file_path: str, + table: str, + bands_info: List[Tuple[int, str]] = None, + chunk_size: int = None, + overwrite: bool = False, + append: bool = False, + cleanup_on_failure: bool = False, + ) -> bool: + print("Loading raster file to Databricks...") + + bands_info = bands_info or [(1, None)] + + append_records = False + + try: + if ( + self.check_if_table_exists(table) + and not self.check_if_table_is_empty(table) + and not overwrite + ): + append_records = append or ask_yes_no_question( + f"Table `{self.catalog}`.{self.schema}.{table} already exists " + "and is not empty. Append records? [yes/no] " + ) + + if not append_records: + exit() + + # Prepare band columns + self.band_columns = ", ".join( + [ + f"{self.band_rename_function(band_name or f'band_{band}')} BINARY" + for band, band_name in bands_info + ] + ) + + # Create schema and table if not exists + self.create_schema_if_not_exists() + self.create_table_if_not_exists(table) + + metadata = rasterio_metadata(file_path, bands_info, self.band_rename_function) + + records_gen = rasterio_windows_to_records( + file_path, + self.band_rename_function, + bands_info, + ) + + total_blocks = get_number_of_blocks(file_path) + + if chunk_size is None: + ret = self.upload_records(records_gen, table, overwrite) + if not ret: + raise IOError("Error uploading to Databricks.") + else: + from tqdm.auto import tqdm + + print(f"Writing {total_blocks} blocks to Databricks...") + with tqdm(total=total_blocks) as pbar: + if total_blocks < chunk_size: + chunk_size = total_blocks + isFirstBatch = True + for records in batched(records_gen, chunk_size): + ret = self.upload_records( + records, table, overwrite and isFirstBatch + ) + pbar.update(len(records)) + if not ret: + raise IOError("Error uploading to Databricks.") + isFirstBatch = False + + print("Writing metadata to Databricks...") + if append_records: + old_metadata = self.get_metadata(table) + check_metadata_is_compatible(metadata, old_metadata) + update_metadata(metadata, old_metadata) + + self.write_metadata(metadata, append_records, table) + + except IncompatibleRasterException as e: + raise IOError(f"Error uploading to Databricks: {e.message}") + + except KeyboardInterrupt: + delete = cleanup_on_failure or ask_yes_no_question( + "Would you like to delete the partially uploaded table? [yes/no] " + ) + + if delete: + self.delete_table(table) + + raise KeyboardInterrupt + + except Exception as e: + delete = cleanup_on_failure or ask_yes_no_question( + ( + "Error uploading to Databricks. " + "Would you like to delete the partially uploaded table? [yes/no] " + ) + ) + + if delete: + self.delete_table(table) + + raise IOError(f"Error uploading to Databricks: {e}") + + print("Done.") + return True + + def delete_table(self, table): + fqn = self.get_table_fqn(table) + self.execute(f"DROP TABLE IF EXISTS {fqn}") + + def get_records(self, table: str, limit=10) -> pd.DataFrame: + fqn = self.get_table_fqn(table) + query = f"SELECT * FROM {fqn} LIMIT {limit}" + df = self.execute_to_dataframe(query) + return df + + def insert_in_table( + self, + rows: List[dict], + table: str, + ) -> bool: + fqn = self.get_table_fqn(table) + data_df = pd.DataFrame(rows) + spark_df = self.client.createDataFrame(data_df) + spark_df.write.format("delta").mode("append").saveAsTable(fqn) + return True + + + def quote_name(self, name): + return f"`{name}`" + diff --git a/setup.cfg b/setup.cfg index 26a05f1..e683fff 100644 --- a/setup.cfg +++ b/setup.cfg @@ -40,6 +40,7 @@ install_requires = shapely>=1.7.1 quadbin>=0.2.0 tqdm>=4.64.1 + databricks-connect==15.1.1 zip_safe = False [options.entry_points] @@ -48,6 +49,7 @@ console_scripts = raster_loader.cli = bigquery = raster_loader.cli.bigquery:bigquery snowflake = raster_loader.cli.snowflake:snowflake + databricks = raster_loader.cli.databricks:databricks info = raster_loader.cli.info:info [options.extras_require] From fdededce4a77fffe536534d18681c42e6e1115c2 Mon Sep 17 00:00:00 2001 From: Dean Sherwin Date: Wed, 9 Oct 2024 16:14:16 +0200 Subject: [PATCH 02/12] cleanup --- Makefile | 2 +- README.md | 17 ++++++++++++ docs/source/user_guide/cli.rst | 31 +++++++++++++++++++++- docs/source/user_guide/installation.rst | 3 ++- docs/source/user_guide/use_with_python.rst | 8 +++++- setup.cfg | 6 ++++- 6 files changed, 62 insertions(+), 5 deletions(-) diff --git a/Makefile b/Makefile index fa7ff8c..f698795 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ init: [ -d $(VENV) ] || python3 -m venv $(VENV) $(BIN)/pip install -r requirements-dev.txt $(BIN)/pre-commit install - $(BIN)/pip install -e .[snowflake,bigquery] + $(BIN)/pip install -e .[snowflake,bigquery,databricks] lint: $(BIN)/black raster_loader setup.py diff --git a/README.md b/README.md index a0b3c8b..5036f89 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ pip install -U raster-loader pip install -U raster-loader"[bigquery]" pip install -U raster-loader"[snowflake]" +pip install -U raster-loader"[databricks]" ``` ### Installing from source @@ -31,6 +32,22 @@ cd raster-loader pip install . ``` +### Installing for Development + +It is reccomended to use a virtualenv when developing. + +``` +python3 -m venv venv +source venv/bin/activate +pip install -e . +``` + +Doing `which carto` should return something like `/my/local/filesystem/raster-loader/venv/bin/carto` instead of the system-wide installation. + +The `-e` flag passed to the `pip install` program will set the project and its dependencies in development mode. Changes to the project files +will be reflected in the `carto` command immedietly without the need to re-run any setup steps. + + ## Usage There are two ways you can use Raster Loader: diff --git a/docs/source/user_guide/cli.rst b/docs/source/user_guide/cli.rst index c1e6bb5..4c048e7 100644 --- a/docs/source/user_guide/cli.rst +++ b/docs/source/user_guide/cli.rst @@ -42,10 +42,25 @@ Snowflake: To use the snowflake utilities, use the ``carto snowflake`` command. This command has several subcommands, which are described below. +Using the Raster Loader with Databricks +----------------------------------------- + +Before you can upload a raster file, you need to have set up the following in +Databricks: + +#. A databricks instance host. Eg. `https://dbc-abcde12345-678f.cloud.databricks.com` +#. A cluster id (cluser MUST BE turned on) +#. A Personal Access Token (PAT). See `Databricks PAT Docs `_. +#. A catalog +#. A schema (in the same catalog) + +To use the databricks utilities, use the ``carto databricks`` command. This command has +several subcommands, which are described below. + Uploading a raster layer ------------------------ -To upload a raster file, use the ``carto [bigquery|snowflake] upload`` command. +To upload a raster file, use the ``carto [bigquery|snowflake|databricks] upload`` command. The input raster must be a ``GoogleMapsCompatible`` raster. You can make your raster compatible by converting it with the following GDAL command: @@ -98,6 +113,20 @@ The same operation, performed with Snowflake, would be: Authentication parameters are explicitly required in this case for Snowflake, since they are not set up in the environment. +The same operation, performed with Databricks, would be: + +.. code-block:: bash + + carto databricks upload \ + --host 'https://dbc-12345abc-123f.cloud.databricks.com' \ + --token \ + --cluster-id '0123-456789-abc12345xyz' \ + --catalog 'main' \ + --schema default \ + --file_path \ + /path/to/my/raster/file/tif \ + --table mydatabrickstable + If no band is specified, the first band of the raster will be uploaded. If the ``--band`` flag is set, the specified band will be uploaded. For example, the following command uploads the second band of the raster: diff --git a/docs/source/user_guide/installation.rst b/docs/source/user_guide/installation.rst index 03c05a2..b18b234 100644 --- a/docs/source/user_guide/installation.rst +++ b/docs/source/user_guide/installation.rst @@ -22,13 +22,14 @@ To install from source: In most cases, it is recommended to install Raster Loader in a virtual environment. Use venv_ to create and manage your virtual environment. -The above will install the dependencies required to work with both BigQuery and Snowflake and. In case you only want to work with one of them, you can install the +The above will install the dependencies required to work with both BigQuery, Snowflake and Databricks. In case you only want to work with one of them, you can install the dependencies for each of them separately: .. code-block:: bash pip install -U raster-loader"[bigquery]" pip install -U raster-loader"[snowflake]" + pip install -U raster-loader"[databricks]" After installing the Raster Loader package, you will have access to the :ref:`carto CLI `. To make sure the installation was successful, run the diff --git a/docs/source/user_guide/use_with_python.rst b/docs/source/user_guide/use_with_python.rst index 1902973..758790b 100644 --- a/docs/source/user_guide/use_with_python.rst +++ b/docs/source/user_guide/use_with_python.rst @@ -18,6 +18,12 @@ For BigQuery, use ``BigQueryConnection``: from raster_loader import BigQueryConnection +For Databricks, use ``DatabricksConnection``: + +.. code-block:: python + + from raster_loader import DatabricksConnection + Then, create a connection object with the appropriate parameters. For Snowflake: @@ -48,7 +54,7 @@ For example: .. code-block:: python - connector.upload_raster( + connection.upload_raster( file_path = 'path/to/raster.tif', fqn = 'database.schema.tablename', ) diff --git a/setup.cfg b/setup.cfg index e683fff..198c7bc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,6 +10,7 @@ keywords = data warehouse bigquery snowflake + databricks author = CARTO url = https://github.com/cartodb/raster-loader license = BSD 3-Clause @@ -40,7 +41,6 @@ install_requires = shapely>=1.7.1 quadbin>=0.2.0 tqdm>=4.64.1 - databricks-connect==15.1.1 zip_safe = False [options.entry_points] @@ -62,9 +62,13 @@ bigquery = google-auth>=2.28.0 snowflake = snowflake-connector-python>=2.6.0 +databricks = + databricks-connect==15.1.1 + all = %(bigquery)s %(snowflake)s + %(databricks)s [flake8] max-line-length = 88 From fd4fce52d322972d555805f71350942038ed7cf7 Mon Sep 17 00:00:00 2001 From: Dean Sherwin Date: Wed, 9 Oct 2024 17:18:45 +0200 Subject: [PATCH 03/12] tests for databricks cli --- raster_loader/tests/databricks/__init__.py | 0 raster_loader/tests/databricks/test_cli.py | 243 +++++++++++++++++++++ 2 files changed, 243 insertions(+) create mode 100644 raster_loader/tests/databricks/__init__.py create mode 100644 raster_loader/tests/databricks/test_cli.py diff --git a/raster_loader/tests/databricks/__init__.py b/raster_loader/tests/databricks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/raster_loader/tests/databricks/test_cli.py b/raster_loader/tests/databricks/test_cli.py new file mode 100644 index 0000000..f3d703b --- /dev/null +++ b/raster_loader/tests/databricks/test_cli.py @@ -0,0 +1,243 @@ +import os +from unittest.mock import patch + +from click.testing import CliRunner +import pandas as pd + +from raster_loader.cli import main + + +here = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +fixtures = os.path.join(here, "fixtures") +tiff = os.path.join(fixtures, "mosaic_cog.tif") + + +@patch( + "raster_loader.io.databricks.DatabricksConnection.upload_raster", return_value=None +) +@patch("raster_loader.io.databricks.DatabricksConnection.__init__", return_value=None) +def test_databricks_upload(*args, **kwargs): + runner = CliRunner() + result = runner.invoke( + main, + [ + "databricks", + "upload", + "--file_path", + f"{tiff}", + "--catalog", + "catalog", + "--schema", + "schema", + "--table", + "table", + "--host", + "https://databricks-host", + "--token", + "token", + "--cluster-id", + "cluster-1234", + "--chunk_size", + 1, + "--band", + 1, + ], + ) + print(result.output) + assert result.exit_code == 0 + + +@patch( + "raster_loader.io.databricks.DatabricksConnection.upload_raster", return_value=None +) +@patch("raster_loader.io.databricks.DatabricksConnection.__init__", return_value=None) +def test_databricks_file_path_or_url_check(*args, **kwargs): + runner = CliRunner() + result = runner.invoke( + main, + [ + "databricks", + "upload", + "--catalog", + "catalog", + "--schema", + "schema", + "--host", + "https://databricks-host", + "--token", + "token", + "--cluster-id", + "cluster-1234", + "--chunk_size", + 1, + "--band", + 1, + ], + ) + assert result.exit_code == 1 + assert "Error: Need either a --file_path or --file_url" in result.output + + result = runner.invoke( + main, + [ + "databricks", + "upload", + "--file_path", + f"{tiff}", + "--file_url", + "http://example.com/raster.tif", + "--catalog", + "catalog", + "--schema", + "schema", + "--host", + "https://databricks-host", + "--token", + "token", + "--cluster-id", + "cluster-1234", + "--chunk_size", + 1, + "--band", + 1, + ], + ) + assert result.exit_code == 1 + assert "Only one of --file_path or --file_url must be provided" in result.output + + +@patch( + "raster_loader.io.databricks.DatabricksConnection.upload_raster", return_value=None +) +@patch("raster_loader.io.databricks.DatabricksConnection.__init__", return_value=None) +def test_databricks_upload_multiple_bands(*args, **kwargs): + runner = CliRunner() + result = runner.invoke( + main, + [ + "databricks", + "upload", + "--file_path", + f"{tiff}", + "--catalog", + "catalog", + "--schema", + "schema", + "--host", + "https://databricks-host", + "--token", + "token", + "--cluster-id", + "cluster-1234", + "--chunk_size", + 1, + "--band", + 1, + "--band", + 2, + ], + ) + assert result.exit_code == 0 + + +def test_databricks_fail_upload_multiple_bands_misaligned_with_band_names( + *args, **kwargs +): + runner = CliRunner() + result = runner.invoke( + main, + [ + "databricks", + "upload", + "--file_path", + f"{tiff}", + "--catalog", + "catalog", + "--schema", + "schema", + "--host", + "https://databricks-host", + "--token", + "token", + "--cluster-id", + "cluster-1234", + "--chunk_size", + 1, + "--band", + 1, + "--band_name", + "band_1", + "--band", + 2, + ], + ) + assert result.exit_code == 1 + assert "Error: Must supply the same number of band_names as bands" in result.output + + +@patch( + "raster_loader.io.databricks.DatabricksConnection.upload_raster", return_value=None +) +@patch("raster_loader.io.databricks.DatabricksConnection.__init__", return_value=None) +def test_databricks_upload_multiple_bands_aligned_with_band_names(*args, **kwargs): + runner = CliRunner() + result = runner.invoke( + main, + [ + "databricks", + "upload", + "--file_path", + f"{tiff}", + "--catalog", + "catalog", + "--schema", + "schema", + "--host", + "https://databricks-host", + "--token", + "token", + "--cluster-id", + "cluster-1234", + "--chunk_size", + 1, + "--band", + 1, + "--band_name", + "band_1", + "--band_name", + "band_2", + "--band", + 2, + ], + ) + assert result.exit_code == 0 + + +@patch( + "raster_loader.io.databricks.DatabricksConnection.upload_raster", return_value=None +) +@patch("raster_loader.io.databricks.DatabricksConnection.__init__", return_value=None) +def test_databricks_upload_no_table_name(*args, **kwargs): + runner = CliRunner() + result = runner.invoke( + main, + [ + "databricks", + "upload", + "--file_path", + f"{tiff}", + "--catalog", + "catalog", + "--schema", + "schema", + "--host", + "https://databricks-host", + "--token", + "token", + "--cluster-id", + "cluster-1234", + ], + ) + assert result.exit_code == 0 + assert "Table: mosaic_cog_band__1___" in result.output + From 5c0dc4a51f9ee6112a7e1a6711a5f4719e139dd6 Mon Sep 17 00:00:00 2001 From: Dean Sherwin Date: Thu, 10 Oct 2024 09:47:36 +0200 Subject: [PATCH 04/12] updatereadme --- README.md | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 5036f89..5f93686 100644 --- a/README.md +++ b/README.md @@ -32,21 +32,6 @@ cd raster-loader pip install . ``` -### Installing for Development - -It is reccomended to use a virtualenv when developing. - -``` -python3 -m venv venv -source venv/bin/activate -pip install -e . -``` - -Doing `which carto` should return something like `/my/local/filesystem/raster-loader/venv/bin/carto` instead of the system-wide installation. - -The `-e` flag passed to the `pip install` program will set the project and its dependencies in development mode. Changes to the project files -will be reflected in the `carto` command immedietly without the need to re-run any setup steps. - ## Usage @@ -167,6 +152,22 @@ project. [ROADMAP.md](ROADMAP.md) contains a list of features and improvements planned for future versions of Raster Loader. +### Installing for Development + +It is reccomended to use a virtualenv when developing. + +``` +python3 -m venv venv +source venv/bin/activate +pip install -e .[all] +``` + +Doing `which carto` should return something like `/my/local/filesystem/raster-loader/venv/bin/carto` instead of the system-wide installation. + +The `-e` flag passed to the `pip install` program will set the project and its dependencies in development mode. Changes to the project files +will be reflected in the `carto` command immedietly without the need to re-run any setup steps. + + ## Releasing ### 1. Create and merge a release PR updating the CHANGELOG From 8b0be62e58c87214762935b9a612d31508d0e9f8 Mon Sep 17 00:00:00 2001 From: Dean Sherwin Date: Thu, 10 Oct 2024 09:58:35 +0200 Subject: [PATCH 05/12] take schema and catalog off the DatabricksConnection class --- raster_loader/cli/databricks.py | 29 +++------ raster_loader/io/databricks.py | 105 +++++++++++--------------------- 2 files changed, 44 insertions(+), 90 deletions(-) diff --git a/raster_loader/cli/databricks.py b/raster_loader/cli/databricks.py index f6b557e..71f0e5c 100644 --- a/raster_loader/cli/databricks.py +++ b/raster_loader/cli/databricks.py @@ -31,19 +31,14 @@ def databricks(args=None): @click.option("--host", help="The Databricks host URL.", required=True) @click.option("--token", help="The Databricks access token.", required=True) @click.option("--cluster-id", help="The Databricks cluster ID.", required=True) # New option -@click.option( - "--file_path", help="The path to the raster file.", required=False, default=None -) -@click.option( - "--file_url", help="The URL to the raster file.", required=False, default=None -) +@click.option("--file_path", help="The path to the raster file.", required=False, default=None) +@click.option("--file_url", help="The URL to the raster file.", required=False, default=None) @click.option("--catalog", help="The name of the catalog.", required=True) @click.option("--schema", help="The name of the schema.", required=True) @click.option("--table", help="The name of the table.", default=None) @click.option( "--band", - help="Band(s) within raster to upload. " - "Could repeat --band to specify multiple bands.", + help="Band(s) within raster to upload. " "Could repeat --band to specify multiple bands.", default=[1], multiple=True, ) @@ -55,9 +50,7 @@ def databricks(args=None): default=[None], multiple=True, ) -@click.option( - "--chunk_size", help="The number of blocks to upload in each chunk.", default=10000 -) +@click.option("--chunk_size", help="The number of blocks to upload in each chunk.", default=10000) @click.option( "--overwrite", help="Overwrite existing data in the table if it already exists.", @@ -121,16 +114,10 @@ def upload( # Create default table name if not provided if table is None: - table = get_default_table_name( - file_path if is_local_file else urlparse(file_url).path, band - ) + table = get_default_table_name(file_path if is_local_file else urlparse(file_url).path, band) connector = DatabricksConnection( - host=host, - token=token, - cluster_id=cluster_id, # Pass cluster_id to DatabricksConnection - catalog=catalog, - schema=schema, + host=host, token=token, cluster_id=cluster_id # Pass cluster_id to DatabricksConnection ) source = file_path if is_local_file else file_url @@ -156,9 +143,10 @@ def upload( click.echo("Uploading Raster to Databricks") + fqn = f"`{catalog}`.{schema}.{table}" connector.upload_raster( source, - table, + fqn, bands_info, chunk_size, overwrite=overwrite, @@ -168,4 +156,3 @@ def upload( click.echo("Raster file uploaded to Databricks") return 0 - diff --git a/raster_loader/io/databricks.py b/raster_loader/io/databricks.py index bf98917..6e6396b 100644 --- a/raster_loader/io/databricks.py +++ b/raster_loader/io/databricks.py @@ -34,44 +34,41 @@ else: _has_databricks = True + class DatabricksConnection(DataWarehouseConnection): - def __init__(self, host, token, cluster_id, catalog, schema): + def __init__(self, host, token, cluster_id): if not _has_databricks: import_error_databricks() self.host = host self.token = token self.cluster_id = cluster_id - self.catalog = catalog - self.schema = schema self.client = self.get_connection() def get_connection(self): # Initialize DatabricksSession - session = DatabricksSession.builder.remote(host=self.host, token=self.token, cluster_id=self.cluster_id).getOrCreate() + session = DatabricksSession.builder.remote( + host=self.host, token=self.token, cluster_id=self.cluster_id + ).getOrCreate() session.conf.set("spark.databricks.session.timeout", "6h") return session - def get_table_fqn(self, table): - return f"`{self.catalog}`.{self.schema}.{table}" - def execute(self, sql): - # NOTE: if you get empty sql statement errors check runtime v databricks-connect version - # https://community.databricks.com/t5/data-engineering/parse-empty-statement-error-when-trying-to-use-spark-sql-via/td-p/80770 return self.client.sql(sql) def execute_to_dataframe(self, sql): df = self.execute(sql) return df.toPandas() - def create_schema_if_not_exists(self): - self.execute(f"CREATE SCHEMA IF NOT EXISTS `{self.catalog}`.{self.schema}") + def create_schema_if_not_exists(self, fqn): + schema_name = fqn.split(".")[1] # Extract schema from FQN + self.execute(f"CREATE SCHEMA IF NOT EXISTS {schema_name}") - def create_table_if_not_exists(self, table): + def create_table_if_not_exists(self, fqn): self.execute( f""" - CREATE TABLE IF NOT EXISTS `{self.catalog}`.{self.schema}.{table} ( + CREATE TABLE IF NOT EXISTS {fqn} ( BLOCK BIGINT, METADATA STRING, {self.band_columns} @@ -86,9 +83,8 @@ def write_metadata( self, metadata, append_records, - table, + fqn, ): - # Create a DataFrame with the metadata schema = StructType( [ StructField("BLOCK", LongType(), True), @@ -101,13 +97,11 @@ def write_metadata( metadata_df = self.client.createDataFrame(data, schema) # Write to table - fqn = self.get_table_fqn(table) metadata_df.write.format("delta").mode("append").saveAsTable(fqn) return True - def get_metadata(self, table): - fqn = self.get_table_fqn(table) + def get_metadata(self, fqn): query = f""" SELECT METADATA FROM {fqn} @@ -118,32 +112,28 @@ def get_metadata(self, table): return None return json.loads(result.iloc[0]["METADATA"]) - def check_if_table_exists(self, table): + def check_if_table_exists(self, fqn): + schema_name, table_name = fqn.split(".")[1:3] # Extract schema and table sql = f""" SELECT * - FROM `{self.catalog}`.INFORMATION_SCHEMA.TABLES - WHERE TABLE_SCHEMA = '{self.schema}' - AND TABLE_NAME = '{table}' + FROM {schema_name}.INFORMATION_SCHEMA.TABLES + WHERE TABLE_NAME = '{table_name}' """ df = self.execute(sql) - # If the count is greater than 0, the table exists return df.count() > 0 - def check_if_table_is_empty(self, table): - fqn = self.get_table_fqn(table) + def check_if_table_is_empty(self, fqn): df = self.client.table(fqn) return df.count() == 0 def upload_records( self, records: Iterable, - table: str, + fqn: str, overwrite: bool, ): - fqn = self.get_table_fqn(table) records_list = [] for record in records: - # Remove 'METADATA' from records, as it's handled separately if "METADATA" in record: del record["METADATA"] records_list.append(record) @@ -151,11 +141,7 @@ def upload_records( data_df = pd.DataFrame(records_list) spark_df = self.client.createDataFrame(data_df) - if overwrite: - mode = "overwrite" - else: - mode = "append" - + mode = "overwrite" if overwrite else "append" spark_df.write.format("delta").mode(mode).saveAsTable(fqn) return True @@ -163,7 +149,7 @@ def upload_records( def upload_raster( self, file_path: str, - table: str, + fqn: str, bands_info: List[Tuple[int, str]] = None, chunk_size: int = None, overwrite: bool = False, @@ -173,18 +159,12 @@ def upload_raster( print("Loading raster file to Databricks...") bands_info = bands_info or [(1, None)] - append_records = False - + try: - if ( - self.check_if_table_exists(table) - and not self.check_if_table_is_empty(table) - and not overwrite - ): + if self.check_if_table_exists(fqn) and not self.check_if_table_is_empty(fqn) and not overwrite: append_records = append or ask_yes_no_question( - f"Table `{self.catalog}`.{self.schema}.{table} already exists " - "and is not empty. Append records? [yes/no] " + f"Table {fqn} already exists and is not empty. Append records? [yes/no] " ) if not append_records: @@ -192,15 +172,12 @@ def upload_raster( # Prepare band columns self.band_columns = ", ".join( - [ - f"{self.band_rename_function(band_name or f'band_{band}')} BINARY" - for band, band_name in bands_info - ] + [f"{self.band_rename_function(band_name or f'band_{band}')} BINARY" for band, band_name in bands_info] ) # Create schema and table if not exists - self.create_schema_if_not_exists() - self.create_table_if_not_exists(table) + self.create_schema_if_not_exists(fqn) + self.create_table_if_not_exists(fqn) metadata = rasterio_metadata(file_path, bands_info, self.band_rename_function) @@ -213,7 +190,7 @@ def upload_raster( total_blocks = get_number_of_blocks(file_path) if chunk_size is None: - ret = self.upload_records(records_gen, table, overwrite) + ret = self.upload_records(records_gen, fqn, overwrite) if not ret: raise IOError("Error uploading to Databricks.") else: @@ -225,9 +202,7 @@ def upload_raster( chunk_size = total_blocks isFirstBatch = True for records in batched(records_gen, chunk_size): - ret = self.upload_records( - records, table, overwrite and isFirstBatch - ) + ret = self.upload_records(records, fqn, overwrite and isFirstBatch) pbar.update(len(records)) if not ret: raise IOError("Error uploading to Databricks.") @@ -235,11 +210,11 @@ def upload_raster( print("Writing metadata to Databricks...") if append_records: - old_metadata = self.get_metadata(table) + old_metadata = self.get_metadata(fqn) check_metadata_is_compatible(metadata, old_metadata) update_metadata(metadata, old_metadata) - self.write_metadata(metadata, append_records, table) + self.write_metadata(metadata, append_records, fqn) except IncompatibleRasterException as e: raise IOError(f"Error uploading to Databricks: {e.message}") @@ -250,32 +225,27 @@ def upload_raster( ) if delete: - self.delete_table(table) + self.delete_table(fqn) raise KeyboardInterrupt except Exception as e: delete = cleanup_on_failure or ask_yes_no_question( - ( - "Error uploading to Databricks. " - "Would you like to delete the partially uploaded table? [yes/no] " - ) + ("Error uploading to Databricks. " "Would you like to delete the partially uploaded table? [yes/no] ") ) if delete: - self.delete_table(table) + self.delete_table(fqn) raise IOError(f"Error uploading to Databricks: {e}") print("Done.") return True - def delete_table(self, table): - fqn = self.get_table_fqn(table) + def delete_table(self, fqn): self.execute(f"DROP TABLE IF EXISTS {fqn}") - def get_records(self, table: str, limit=10) -> pd.DataFrame: - fqn = self.get_table_fqn(table) + def get_records(self, fqn: str, limit=10) -> pd.DataFrame: query = f"SELECT * FROM {fqn} LIMIT {limit}" df = self.execute_to_dataframe(query) return df @@ -283,15 +253,12 @@ def get_records(self, table: str, limit=10) -> pd.DataFrame: def insert_in_table( self, rows: List[dict], - table: str, + fqn: str, ) -> bool: - fqn = self.get_table_fqn(table) data_df = pd.DataFrame(rows) spark_df = self.client.createDataFrame(data_df) spark_df.write.format("delta").mode("append").saveAsTable(fqn) return True - def quote_name(self, name): return f"`{name}`" - From ac17318683af59ad8cd7a2228125afc3f35c6017 Mon Sep 17 00:00:00 2001 From: Dean Sherwin Date: Thu, 10 Oct 2024 10:21:50 +0200 Subject: [PATCH 06/12] lint --- raster_loader/__init__.py | 2 +- raster_loader/cli/databricks.py | 2 +- raster_loader/errors.py | 3 ++- raster_loader/io/databricks.py | 4 +++- raster_loader/tests/databricks/test_cli.py | 1 - 5 files changed, 7 insertions(+), 5 deletions(-) diff --git a/raster_loader/__init__.py b/raster_loader/__init__.py index 3b86888..c85af99 100644 --- a/raster_loader/__init__.py +++ b/raster_loader/__init__.py @@ -14,5 +14,5 @@ "__version__", "BigQueryConnection", "SnowflakeConnection", - "DatabricksConnection" + "DatabricksConnection", ] diff --git a/raster_loader/cli/databricks.py b/raster_loader/cli/databricks.py index 71f0e5c..be76c47 100644 --- a/raster_loader/cli/databricks.py +++ b/raster_loader/cli/databricks.py @@ -155,4 +155,4 @@ def upload( ) click.echo("Raster file uploaded to Databricks") - return 0 + exit(0) diff --git a/raster_loader/errors.py b/raster_loader/errors.py index 02f05cd..138c391 100644 --- a/raster_loader/errors.py +++ b/raster_loader/errors.py @@ -15,6 +15,7 @@ def import_error_snowflake(): # pragma: no cover ) raise ImportError(msg) + def import_error_databricks(): # pragma: no cover msg = ( "Databricks client is not installed.\n" @@ -23,6 +24,7 @@ def import_error_databricks(): # pragma: no cover ) raise ImportError(msg) + class IncompatibleRasterException(Exception): def __init__(self): self.message = ( @@ -38,4 +40,3 @@ def __init__(self): def error_not_google_compatible(): # pragma: no cover raise IncompatibleRasterException() - diff --git a/raster_loader/io/databricks.py b/raster_loader/io/databricks.py index 6e6396b..1573b44 100644 --- a/raster_loader/io/databricks.py +++ b/raster_loader/io/databricks.py @@ -179,7 +179,9 @@ def upload_raster( self.create_schema_if_not_exists(fqn) self.create_table_if_not_exists(fqn) - metadata = rasterio_metadata(file_path, bands_info, self.band_rename_function) + metadata = rasterio_metadata( + file_path, bands_info, self.band_rename_function + ) records_gen = rasterio_windows_to_records( file_path, diff --git a/raster_loader/tests/databricks/test_cli.py b/raster_loader/tests/databricks/test_cli.py index f3d703b..a9957c9 100644 --- a/raster_loader/tests/databricks/test_cli.py +++ b/raster_loader/tests/databricks/test_cli.py @@ -240,4 +240,3 @@ def test_databricks_upload_no_table_name(*args, **kwargs): ) assert result.exit_code == 0 assert "Table: mosaic_cog_band__1___" in result.output - From 4a454f4b1dfe4e4b9bf9204a28c5145bfb598fcd Mon Sep 17 00:00:00 2001 From: Dean Sherwin Date: Thu, 10 Oct 2024 10:44:53 +0200 Subject: [PATCH 07/12] lint --- Makefile | 2 +- raster_loader/cli/databricks.py | 27 ++++++++++++++++------ raster_loader/io/databricks.py | 27 +++++++++++++++------- raster_loader/tests/databricks/test_cli.py | 1 - 4 files changed, 40 insertions(+), 17 deletions(-) diff --git a/Makefile b/Makefile index f698795..ce2b9bc 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ init: [ -d $(VENV) ] || python3 -m venv $(VENV) $(BIN)/pip install -r requirements-dev.txt $(BIN)/pre-commit install - $(BIN)/pip install -e .[snowflake,bigquery,databricks] + $(BIN)/pip install -e .[all] lint: $(BIN)/black raster_loader setup.py diff --git a/raster_loader/cli/databricks.py b/raster_loader/cli/databricks.py index be76c47..b96b718 100644 --- a/raster_loader/cli/databricks.py +++ b/raster_loader/cli/databricks.py @@ -30,15 +30,22 @@ def databricks(args=None): @databricks.command(help="Upload a raster file to Databricks.") @click.option("--host", help="The Databricks host URL.", required=True) @click.option("--token", help="The Databricks access token.", required=True) -@click.option("--cluster-id", help="The Databricks cluster ID.", required=True) # New option -@click.option("--file_path", help="The path to the raster file.", required=False, default=None) -@click.option("--file_url", help="The URL to the raster file.", required=False, default=None) +@click.option( + "--cluster-id", help="The Databricks cluster ID.", required=True +) # New option +@click.option( + "--file_path", help="The path to the raster file.", required=False, default=None +) +@click.option( + "--file_url", help="The URL to the raster file.", required=False, default=None +) @click.option("--catalog", help="The name of the catalog.", required=True) @click.option("--schema", help="The name of the schema.", required=True) @click.option("--table", help="The name of the table.", default=None) @click.option( "--band", - help="Band(s) within raster to upload. " "Could repeat --band to specify multiple bands.", + help="Band(s) within raster to upload. " + "Could repeat --band to specify multiple bands.", default=[1], multiple=True, ) @@ -50,7 +57,9 @@ def databricks(args=None): default=[None], multiple=True, ) -@click.option("--chunk_size", help="The number of blocks to upload in each chunk.", default=10000) +@click.option( + "--chunk_size", help="The number of blocks to upload in each chunk.", default=10000 +) @click.option( "--overwrite", help="Overwrite existing data in the table if it already exists.", @@ -114,10 +123,14 @@ def upload( # Create default table name if not provided if table is None: - table = get_default_table_name(file_path if is_local_file else urlparse(file_url).path, band) + table = get_default_table_name( + file_path if is_local_file else urlparse(file_url).path, band + ) connector = DatabricksConnection( - host=host, token=token, cluster_id=cluster_id # Pass cluster_id to DatabricksConnection + host=host, + token=token, + cluster_id=cluster_id, # Pass cluster_id to DatabricksConnection ) source = file_path if is_local_file else file_url diff --git a/raster_loader/io/databricks.py b/raster_loader/io/databricks.py index 1573b44..a5dee17 100644 --- a/raster_loader/io/databricks.py +++ b/raster_loader/io/databricks.py @@ -25,8 +25,6 @@ StructType, StructField, StringType, - BinaryType, - IntegerType, LongType, ) except ImportError: # pragma: no cover @@ -116,7 +114,7 @@ def check_if_table_exists(self, fqn): schema_name, table_name = fqn.split(".")[1:3] # Extract schema and table sql = f""" SELECT * - FROM {schema_name}.INFORMATION_SCHEMA.TABLES + FROM {schema_name}.INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '{table_name}' """ df = self.execute(sql) @@ -162,9 +160,14 @@ def upload_raster( append_records = False try: - if self.check_if_table_exists(fqn) and not self.check_if_table_is_empty(fqn) and not overwrite: + if ( + self.check_if_table_exists(fqn) + and not self.check_if_table_is_empty(fqn) + and not overwrite + ): append_records = append or ask_yes_no_question( - f"Table {fqn} already exists and is not empty. Append records? [yes/no] " + f"Table {fqn} already exists and is not empty. " + f"Append records? [yes/no] " ) if not append_records: @@ -172,7 +175,10 @@ def upload_raster( # Prepare band columns self.band_columns = ", ".join( - [f"{self.band_rename_function(band_name or f'band_{band}')} BINARY" for band, band_name in bands_info] + [ + f"{self.band_rename_function(band_name or f'band_{band}')} BINARY" + for band, band_name in bands_info + ] ) # Create schema and table if not exists @@ -204,7 +210,9 @@ def upload_raster( chunk_size = total_blocks isFirstBatch = True for records in batched(records_gen, chunk_size): - ret = self.upload_records(records, fqn, overwrite and isFirstBatch) + ret = self.upload_records( + records, fqn, overwrite and isFirstBatch + ) pbar.update(len(records)) if not ret: raise IOError("Error uploading to Databricks.") @@ -233,7 +241,10 @@ def upload_raster( except Exception as e: delete = cleanup_on_failure or ask_yes_no_question( - ("Error uploading to Databricks. " "Would you like to delete the partially uploaded table? [yes/no] ") + ( + "Error uploading to Databricks. " + "Would you like to delete the partially uploaded table? [yes/no] " + ) ) if delete: diff --git a/raster_loader/tests/databricks/test_cli.py b/raster_loader/tests/databricks/test_cli.py index a9957c9..25b79d5 100644 --- a/raster_loader/tests/databricks/test_cli.py +++ b/raster_loader/tests/databricks/test_cli.py @@ -2,7 +2,6 @@ from unittest.mock import patch from click.testing import CliRunner -import pandas as pd from raster_loader.cli import main From eb85b95d4d3ffed77bd0b7a266f042b28522c984 Mon Sep 17 00:00:00 2001 From: Dean Sherwin Date: Thu, 10 Oct 2024 11:43:33 +0200 Subject: [PATCH 08/12] Revert "take schema and catalog off the DatabricksConnection class" This reverts commit 8b0be62e58c87214762935b9a612d31508d0e9f8. --- raster_loader/cli/databricks.py | 5 +- raster_loader/io/databricks.py | 81 ++++++++++++++++++++------------- setup.cfg | 2 +- 3 files changed, 54 insertions(+), 34 deletions(-) diff --git a/raster_loader/cli/databricks.py b/raster_loader/cli/databricks.py index b96b718..33609dc 100644 --- a/raster_loader/cli/databricks.py +++ b/raster_loader/cli/databricks.py @@ -131,6 +131,8 @@ def upload( host=host, token=token, cluster_id=cluster_id, # Pass cluster_id to DatabricksConnection + catalog=catalog, + schema=schema, ) source = file_path if is_local_file else file_url @@ -156,10 +158,9 @@ def upload( click.echo("Uploading Raster to Databricks") - fqn = f"`{catalog}`.{schema}.{table}" connector.upload_raster( source, - fqn, + table, bands_info, chunk_size, overwrite=overwrite, diff --git a/raster_loader/io/databricks.py b/raster_loader/io/databricks.py index a5dee17..e3a620c 100644 --- a/raster_loader/io/databricks.py +++ b/raster_loader/io/databricks.py @@ -34,13 +34,15 @@ class DatabricksConnection(DataWarehouseConnection): - def __init__(self, host, token, cluster_id): + def __init__(self, host, token, cluster_id, catalog, schema): if not _has_databricks: import_error_databricks() self.host = host self.token = token self.cluster_id = cluster_id + self.catalog = catalog + self.schema = schema self.client = self.get_connection() @@ -52,6 +54,9 @@ def get_connection(self): session.conf.set("spark.databricks.session.timeout", "6h") return session + def get_table_fqn(self, table): + return f"`{self.catalog}`.{self.schema}.{table}" + def execute(self, sql): return self.client.sql(sql) @@ -59,14 +64,13 @@ def execute_to_dataframe(self, sql): df = self.execute(sql) return df.toPandas() - def create_schema_if_not_exists(self, fqn): - schema_name = fqn.split(".")[1] # Extract schema from FQN - self.execute(f"CREATE SCHEMA IF NOT EXISTS {schema_name}") + def create_schema_if_not_exists(self): + self.execute(f"CREATE SCHEMA IF NOT EXISTS `{self.catalog}`.{self.schema}") - def create_table_if_not_exists(self, fqn): + def create_table_if_not_exists(self, table): self.execute( f""" - CREATE TABLE IF NOT EXISTS {fqn} ( + CREATE TABLE IF NOT EXISTS `{self.catalog}`.{self.schema}.{table} ( BLOCK BIGINT, METADATA STRING, {self.band_columns} @@ -81,8 +85,9 @@ def write_metadata( self, metadata, append_records, - fqn, + table, ): + # Create a DataFrame with the metadata schema = StructType( [ StructField("BLOCK", LongType(), True), @@ -95,11 +100,13 @@ def write_metadata( metadata_df = self.client.createDataFrame(data, schema) # Write to table + fqn = self.get_table_fqn(table) metadata_df.write.format("delta").mode("append").saveAsTable(fqn) return True - def get_metadata(self, fqn): + def get_metadata(self, table): + fqn = self.get_table_fqn(table) query = f""" SELECT METADATA FROM {fqn} @@ -110,28 +117,32 @@ def get_metadata(self, fqn): return None return json.loads(result.iloc[0]["METADATA"]) - def check_if_table_exists(self, fqn): - schema_name, table_name = fqn.split(".")[1:3] # Extract schema and table + def check_if_table_exists(self, table): sql = f""" SELECT * - FROM {schema_name}.INFORMATION_SCHEMA.TABLES - WHERE TABLE_NAME = '{table_name}' + FROM `{self.catalog}`.INFORMATION_SCHEMA.TABLES + WHERE TABLE_SCHEMA = '{self.schema}' + AND TABLE_NAME = '{table}' """ df = self.execute(sql) + # If the count is greater than 0, the table exists return df.count() > 0 - def check_if_table_is_empty(self, fqn): + def check_if_table_is_empty(self, table): + fqn = self.get_table_fqn(table) df = self.client.table(fqn) return df.count() == 0 def upload_records( self, records: Iterable, - fqn: str, + table: str, overwrite: bool, ): + fqn = self.get_table_fqn(table) records_list = [] for record in records: + # Remove 'METADATA' from records, as it's handled separately if "METADATA" in record: del record["METADATA"] records_list.append(record) @@ -139,7 +150,11 @@ def upload_records( data_df = pd.DataFrame(records_list) spark_df = self.client.createDataFrame(data_df) - mode = "overwrite" if overwrite else "append" + if overwrite: + mode = "overwrite" + else: + mode = "append" + spark_df.write.format("delta").mode(mode).saveAsTable(fqn) return True @@ -147,7 +162,7 @@ def upload_records( def upload_raster( self, file_path: str, - fqn: str, + table: str, bands_info: List[Tuple[int, str]] = None, chunk_size: int = None, overwrite: bool = False, @@ -157,17 +172,18 @@ def upload_raster( print("Loading raster file to Databricks...") bands_info = bands_info or [(1, None)] + append_records = False try: if ( - self.check_if_table_exists(fqn) - and not self.check_if_table_is_empty(fqn) + self.check_if_table_exists(table) + and not self.check_if_table_is_empty(table) and not overwrite ): append_records = append or ask_yes_no_question( - f"Table {fqn} already exists and is not empty. " - f"Append records? [yes/no] " + f"Table `{self.catalog}`.{self.schema}.{table} already exists " + "and is not empty. Append records? [yes/no] " ) if not append_records: @@ -182,8 +198,8 @@ def upload_raster( ) # Create schema and table if not exists - self.create_schema_if_not_exists(fqn) - self.create_table_if_not_exists(fqn) + self.create_schema_if_not_exists() + self.create_table_if_not_exists(table) metadata = rasterio_metadata( file_path, bands_info, self.band_rename_function @@ -198,7 +214,7 @@ def upload_raster( total_blocks = get_number_of_blocks(file_path) if chunk_size is None: - ret = self.upload_records(records_gen, fqn, overwrite) + ret = self.upload_records(records_gen, table, overwrite) if not ret: raise IOError("Error uploading to Databricks.") else: @@ -211,7 +227,7 @@ def upload_raster( isFirstBatch = True for records in batched(records_gen, chunk_size): ret = self.upload_records( - records, fqn, overwrite and isFirstBatch + records, table, overwrite and isFirstBatch ) pbar.update(len(records)) if not ret: @@ -220,11 +236,11 @@ def upload_raster( print("Writing metadata to Databricks...") if append_records: - old_metadata = self.get_metadata(fqn) + old_metadata = self.get_metadata(table) check_metadata_is_compatible(metadata, old_metadata) update_metadata(metadata, old_metadata) - self.write_metadata(metadata, append_records, fqn) + self.write_metadata(metadata, append_records, table) except IncompatibleRasterException as e: raise IOError(f"Error uploading to Databricks: {e.message}") @@ -235,7 +251,7 @@ def upload_raster( ) if delete: - self.delete_table(fqn) + self.delete_table(table) raise KeyboardInterrupt @@ -248,17 +264,19 @@ def upload_raster( ) if delete: - self.delete_table(fqn) + self.delete_table(table) raise IOError(f"Error uploading to Databricks: {e}") print("Done.") return True - def delete_table(self, fqn): + def delete_table(self, table): + fqn = self.get_table_fqn(table) self.execute(f"DROP TABLE IF EXISTS {fqn}") - def get_records(self, fqn: str, limit=10) -> pd.DataFrame: + def get_records(self, table: str, limit=10) -> pd.DataFrame: + fqn = self.get_table_fqn(table) query = f"SELECT * FROM {fqn} LIMIT {limit}" df = self.execute_to_dataframe(query) return df @@ -266,8 +284,9 @@ def get_records(self, fqn: str, limit=10) -> pd.DataFrame: def insert_in_table( self, rows: List[dict], - fqn: str, + table: str, ) -> bool: + fqn = self.get_table_fqn(table) data_df = pd.DataFrame(rows) spark_df = self.client.createDataFrame(data_df) spark_df.write.format("delta").mode("append").saveAsTable(fqn) diff --git a/setup.cfg b/setup.cfg index 198c7bc..7cd5c7d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -63,7 +63,7 @@ bigquery = snowflake = snowflake-connector-python>=2.6.0 databricks = - databricks-connect==15.1.1 + databricks-connect==13.3.3 all = %(bigquery)s From ddf21e2df5a517850649cc2ec0927cd5580b2214 Mon Sep 17 00:00:00 2001 From: Dean Sherwin Date: Thu, 10 Oct 2024 11:51:41 +0200 Subject: [PATCH 09/12] satisty ci for databricks-connect python3.9 --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 7cd5c7d..0cee9ce 100644 --- a/setup.cfg +++ b/setup.cfg @@ -63,7 +63,7 @@ bigquery = snowflake = snowflake-connector-python>=2.6.0 databricks = - databricks-connect==13.3.3 + databricks-connect==13.0.1 all = %(bigquery)s From cd149a565d4554df4b4bcbd606aa7b452910208e Mon Sep 17 00:00:00 2001 From: Dean Sherwin Date: Thu, 10 Oct 2024 11:57:13 +0200 Subject: [PATCH 10/12] readme --- README.md | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 5f93686..15dfc3d 100644 --- a/README.md +++ b/README.md @@ -154,15 +154,12 @@ versions of Raster Loader. ### Installing for Development -It is reccomended to use a virtualenv when developing. - ``` -python3 -m venv venv -source venv/bin/activate -pip install -e .[all] +make init +source env/bin/activate ``` -Doing `which carto` should return something like `/my/local/filesystem/raster-loader/venv/bin/carto` instead of the system-wide installation. +Doing `which carto` should return something like `/my/local/filesystem/raster-loader/eenv/bin/carto` instead of the system-wide installation. The `-e` flag passed to the `pip install` program will set the project and its dependencies in development mode. Changes to the project files will be reflected in the `carto` command immedietly without the need to re-run any setup steps. From 2d161371322bac400b23d1d7a89e734f1e1e9e05 Mon Sep 17 00:00:00 2001 From: Dean Sherwin Date: Thu, 10 Oct 2024 14:41:01 +0200 Subject: [PATCH 11/12] Set databricks batch size to 400 by default because max RPC message size is 2GB --- raster_loader/cli/databricks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/raster_loader/cli/databricks.py b/raster_loader/cli/databricks.py index 33609dc..8492a46 100644 --- a/raster_loader/cli/databricks.py +++ b/raster_loader/cli/databricks.py @@ -58,7 +58,7 @@ def databricks(args=None): multiple=True, ) @click.option( - "--chunk_size", help="The number of blocks to upload in each chunk.", default=10000 + "--chunk_size", help="The number of blocks to upload in each chunk.", default=400 ) @click.option( "--overwrite", From 5e4bd59d951d00cbcfba695b5b2fd61227221e21 Mon Sep 17 00:00:00 2001 From: Dean Sherwin Date: Thu, 10 Oct 2024 15:24:44 +0200 Subject: [PATCH 12/12] include CRS and transform in raster metadata --- raster_loader/io/common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/raster_loader/io/common.py b/raster_loader/io/common.py index 6875061..5a08ad5 100644 --- a/raster_loader/io/common.py +++ b/raster_loader/io/common.py @@ -242,6 +242,8 @@ def rasterio_metadata( metadata["num_blocks"] = int(width * height / block_width / block_height) metadata["num_pixels"] = width * height metadata["pixel_resolution"] = pixel_resolution + metadata["crs"] = raster_crs + metadata["transform"] = raster_dataset.transform return metadata