From ad24938d0ae0de33291774f04ffe6383e1d5559d Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 6 Oct 2023 13:15:07 -0400
Subject: [PATCH] fix(c/driver/postgresql): only clear schema option if needed
(#1174)
Fixes #1109.
---
c/driver/postgresql/postgresql_test.cc | 49 +++++++++++++++++++
c/driver/postgresql/statement.cc | 4 +-
.../tests/test_dbapi.py | 15 ++++++
3 files changed, 67 insertions(+), 1 deletion(-)
diff --git a/c/driver/postgresql/postgresql_test.cc b/c/driver/postgresql/postgresql_test.cc
index 303cc0597a..932e685e4e 100644
--- a/c/driver/postgresql/postgresql_test.cc
+++ b/c/driver/postgresql/postgresql_test.cc
@@ -904,6 +904,55 @@ class PostgresStatementTest : public ::testing::Test,
};
ADBCV_TEST_STATEMENT(PostgresStatementTest)
+TEST_F(PostgresStatementTest, SqlIngestSchema) {
+ const std::string schema_name = "testschema";
+
+ ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
+
+ ASSERT_THAT(AdbcStatementSetSqlQuery(&statement,
+ "CREATE SCHEMA IF NOT EXISTS testschema", &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
+ IsOkStatus(&error));
+
+ std::string drop = "DROP TABLE IF EXISTS testschema.schematable";
+ ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, drop.c_str(), &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
+ IsOkStatus(&error));
+
+ {
+ adbc_validation::Handle schema;
+ adbc_validation::Handle batch;
+
+ ArrowSchemaInit(&schema.value);
+ ASSERT_THAT(ArrowSchemaSetTypeStruct(&schema.value, 1), adbc_validation::IsOkErrno());
+ ASSERT_THAT(ArrowSchemaSetType(schema->children[0], NANOARROW_TYPE_INT64),
+ adbc_validation::IsOkErrno());
+ ASSERT_THAT(ArrowSchemaSetName(schema->children[0], "ints"),
+ adbc_validation::IsOkErrno());
+
+ ASSERT_THAT((adbc_validation::MakeBatch(
+ &schema.value, &batch.value, static_cast(nullptr),
+ {-1, 0, 1, std::nullopt})),
+ adbc_validation::IsOkErrno());
+
+ ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE,
+ "schematable", &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_MODE,
+ ADBC_INGEST_OPTION_MODE_CREATE, &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_DB_SCHEMA,
+ schema_name.c_str(), &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementBind(&statement, &batch.value, &schema.value, &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
+ IsOkStatus(&error));
+ }
+}
+
TEST_F(PostgresStatementTest, SqlIngestTemporaryTable) {
ASSERT_THAT(quirks()->DropTempTable(&connection, "temptable", &error),
IsOkStatus(&error));
diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index 24086ce364..1f08fce16a 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -1300,14 +1300,16 @@ AdbcStatusCode PostgresStatement::SetOption(const char* key, const char* value,
prepared_ = false;
} else if (std::strcmp(key, ADBC_INGEST_OPTION_TEMPORARY) == 0) {
if (std::strcmp(value, ADBC_OPTION_VALUE_ENABLED) == 0) {
+ // https://github.com/apache/arrow-adbc/issues/1109: only clear the
+ // schema if enabling since Python always sets the flag explicitly
ingest_.temporary = true;
+ ingest_.db_schema.clear();
} else if (std::strcmp(value, ADBC_OPTION_VALUE_DISABLED) == 0) {
ingest_.temporary = false;
} else {
SetError(error, "[libpq] Invalid value '%s' for option '%s'", value, key);
return ADBC_STATUS_INVALID_ARGUMENT;
}
- ingest_.db_schema.clear();
prepared_ = false;
} else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 0) {
int64_t int_value = std::atol(value);
diff --git a/python/adbc_driver_postgresql/tests/test_dbapi.py b/python/adbc_driver_postgresql/tests/test_dbapi.py
index 1e31747269..2a132bd4a7 100644
--- a/python/adbc_driver_postgresql/tests/test_dbapi.py
+++ b/python/adbc_driver_postgresql/tests/test_dbapi.py
@@ -269,6 +269,21 @@ def test_ingest(postgres: dbapi.Connection) -> None:
cur.adbc_ingest("foo", table, catalog_name="main")
+def test_ingest_schema(postgres: dbapi.Connection) -> None:
+ table = pyarrow.Table.from_pydict({"numbers": [1, 2], "letters": ["a", "b"]})
+
+ with postgres.cursor() as cur:
+ cur.execute("CREATE SCHEMA IF NOT EXISTS testschema")
+ cur.execute("DROP TABLE IF EXISTS testschema.foo")
+
+ postgres.commit()
+
+ cur.adbc_ingest("foo", table, mode="create", db_schema_name="testschema")
+
+ cur.execute("SELECT * FROM testschema.foo ORDER BY numbers")
+ assert cur.fetch_arrow_table() == table
+
+
def test_ingest_temporary(postgres: dbapi.Connection) -> None:
table = pyarrow.Table.from_pydict(
{