From bd070da875e142458bef7433b5340773d088ad75 Mon Sep 17 00:00:00 2001 From: Jesse Whitehouse Date: Fri, 22 Sep 2023 17:32:21 -0400 Subject: [PATCH] Add e2e tests for parameterized queries Signed-off-by: Jesse Whitehouse --- tests/e2e/common/parameterized_query_tests.py | 185 ++++++++++++++++++ tests/e2e/test_driver.py | 3 +- 2 files changed, 187 insertions(+), 1 deletion(-) create mode 100644 tests/e2e/common/parameterized_query_tests.py diff --git a/tests/e2e/common/parameterized_query_tests.py b/tests/e2e/common/parameterized_query_tests.py new file mode 100644 index 00000000..6bda9fe9 --- /dev/null +++ b/tests/e2e/common/parameterized_query_tests.py @@ -0,0 +1,185 @@ +import pytest +import databricks.sql as sql +from databricks.sql import Error +import pytz + +from decimal import Decimal + +from databricks.sql.utils import DbSqlParameter, DbSqlType +from datetime import datetime + +from databricks.sql.client import Connection + + +class PySQLParameterizedQueryTestSuiteMixin: + """Namespace for tests of server-side parameterized queries""" + + def test_parameterized_query_named_and_inferred_e2e(self): + """Verify that named parameters passed to the database as a dict are returned of the correct type + All types are inferred. + """ + + conn: Connection + + query = """ + SELECT + :p_bool AS col_bool, + :p_int AS col_int, + :p_double AS col_double, + :p_date as col_date, + :p_timestamp as col_timestamp, + :p_str AS col_str + """ + + named_parameters = { + "p_bool": True, + "p_int": 1234, + "p_double": 3.14, + "p_date": datetime(2023, 9, 6, 3, 14, 27, 843, tzinfo=pytz.UTC), + "p_timestamp": datetime(2023, 9, 6, 3, 14, 27, 843, tzinfo=pytz.UTC), + "p_str": "Hello", + } + with self.connection() as conn: + cursor = conn.cursor() + cursor.execute(query, parameters=named_parameters) + result = cursor.fetchone() + + assert result.col_bool == True + assert result.col_int == 1234 + + # For equality, we use Decimal to quantize these values + assert Decimal(result.col_double).quantize(Decimal("0.00")) == Decimal( + 3.14 + ).quantize(Decimal("0.00")) + + assert result.col_date == datetime( + 2023, 9, 6, 3, 14, 27, 843, tzinfo=pytz.UTC + ) + assert result.col_timestamp == datetime( + 2023, 9, 6, 3, 14, 27, 843, tzinfo=pytz.UTC + ) + assert result.col_str == "Hello" + + def test_parameterized_query_named_dict_and_inferred_e2e(self): + """Verify that named parameters passed to the database as a list are returned of the correct type + All types are inferred. + """ + + conn: Connection + + query = """ + SELECT + :p_bool AS col_bool, + :p_int AS col_int, + :p_double AS col_double, + :p_date as col_date, + :p_timestamp as col_timestamp, + :p_str AS col_str + """ + + named_parameters = [ + DbSqlParameter( + name="p_bool", + value=True, + ), + DbSqlParameter( + name="p_int", + value=1234, + ), + DbSqlParameter( + name="p_double", + value=3.14, + ), + DbSqlParameter( + name="p_date", + value=datetime(2023, 9, 6, 3, 14, 27, 843, tzinfo=pytz.UTC), + ), + DbSqlParameter( + name="p_timestamp", + value=datetime(2023, 9, 6, 3, 14, 27, 843, tzinfo=pytz.UTC), + ), + DbSqlParameter(name="p_str", value="Hello"), + ] + + with self.connection() as conn: + cursor = conn.cursor() + cursor.execute(query, parameters=named_parameters) + result = cursor.fetchone() + + assert result.col_bool == True + assert result.col_int == 1234 + + # For equality, we use Decimal to quantize these values + assert Decimal(result.col_double).quantize(Decimal("0.00")) == Decimal( + 3.14 + ).quantize(Decimal("0.00")) + + assert result.col_date == datetime( + 2023, 9, 6, 3, 14, 27, 843, tzinfo=pytz.UTC + ) + assert result.col_timestamp == datetime( + 2023, 9, 6, 3, 14, 27, 843, tzinfo=pytz.UTC + ) + assert result.col_str == "Hello" + + def test_parameterized_query_named_dict_and_not_inferred_e2e(self): + """Verify that named parameters passed to the database are returned of the correct type + All types are explicitly set. + """ + + conn: Connection + + query = """ + SELECT + :p_bool AS col_bool, + :p_int AS col_int, + :p_double AS col_double, + :p_date as col_date, + :p_timestamp as col_timestamp, + :p_str AS col_str + """ + + named_parameters = [ + DbSqlParameter(name="p_bool", value=True, type=DbSqlType.BOOLEAN), + DbSqlParameter(name="p_int", value=1234, type=DbSqlType.INTEGER), + DbSqlParameter(name="p_double", value=3.14, type=DbSqlType.FLOAT), + DbSqlParameter( + name="p_date", + value=datetime(2023, 9, 6, 3, 14, 27, 843, tzinfo=pytz.UTC), + type=DbSqlType.DATE, + ), + DbSqlParameter( + name="p_timestamp", + value=datetime(2023, 9, 6, 3, 14, 27, 843, tzinfo=pytz.UTC), + type=DbSqlType.TIMESTAMP, + ), + DbSqlParameter(name="p_str", value="Hello", type=DbSqlType.STRING), + ] + + with self.connection() as conn: + cursor = conn.cursor() + cursor.execute(query, parameters=named_parameters) + result = cursor.fetchone() + + assert result.col_bool == True + assert result.col_int == 1234 + + # For equality, we use Decimal to quantize these values + assert Decimal(result.col_double).quantize(Decimal("0.00")) == Decimal( + 3.14 + ).quantize(Decimal("0.00")) + + # Observe that passing a datetime object with timezone information and the type set to DbSqlType.DATE + # strips away the time and tz info + assert ( + result.col_date + == datetime( + 2023, + 9, + 6, + ).date() + ) + assert result.col_timestamp == datetime( + 2023, 9, 6, 3, 14, 27, 843, tzinfo=pytz.UTC + ) + assert result.col_str == "Hello" diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 90bf5c3d..85b6024a 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -28,6 +28,7 @@ from tests.e2e.common.retry_test_mixins import Client429ResponseMixin, Client503ResponseMixin from tests.e2e.common.staging_ingestion_tests import PySQLStagingIngestionTestSuiteMixin from tests.e2e.common.retry_test_mixins import PySQLRetryTestsMixin +from tests.e2e.common.parameterized_query_tests import PySQLParameterizedQueryTestSuiteMixin log = logging.getLogger(__name__) @@ -142,7 +143,7 @@ def test_cloud_fetch(self): # Exclude Retry tests because they require specific setups, and LargeQueries too slow for core # tests class PySQLCoreTestSuite(SmokeTestMixin, CoreTestMixin, DecimalTestsMixin, TimestampTestsMixin, - PySQLTestCase, PySQLStagingIngestionTestSuiteMixin, PySQLRetryTestsMixin): + PySQLTestCase, PySQLStagingIngestionTestSuiteMixin, PySQLRetryTestsMixin, PySQLParameterizedQueryTestSuiteMixin): validate_row_value_type = True validate_result = True