From ad24938d0ae0de33291774f04ffe6383e1d5559d Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 6 Oct 2023 13:15:07 -0400 Subject: [PATCH] fix(c/driver/postgresql): only clear schema option if needed (#1174) Fixes #1109. --- c/driver/postgresql/postgresql_test.cc | 49 +++++++++++++++++++ c/driver/postgresql/statement.cc | 4 +- .../tests/test_dbapi.py | 15 ++++++ 3 files changed, 67 insertions(+), 1 deletion(-) diff --git a/c/driver/postgresql/postgresql_test.cc b/c/driver/postgresql/postgresql_test.cc index 303cc0597a..932e685e4e 100644 --- a/c/driver/postgresql/postgresql_test.cc +++ b/c/driver/postgresql/postgresql_test.cc @@ -904,6 +904,55 @@ class PostgresStatementTest : public ::testing::Test, }; ADBCV_TEST_STATEMENT(PostgresStatementTest) +TEST_F(PostgresStatementTest, SqlIngestSchema) { + const std::string schema_name = "testschema"; + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, + "CREATE SCHEMA IF NOT EXISTS testschema", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + IsOkStatus(&error)); + + std::string drop = "DROP TABLE IF EXISTS testschema.schematable"; + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, drop.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + IsOkStatus(&error)); + + { + adbc_validation::Handle schema; + adbc_validation::Handle batch; + + ArrowSchemaInit(&schema.value); + ASSERT_THAT(ArrowSchemaSetTypeStruct(&schema.value, 1), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowSchemaSetType(schema->children[0], NANOARROW_TYPE_INT64), + adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowSchemaSetName(schema->children[0], "ints"), + adbc_validation::IsOkErrno()); + + ASSERT_THAT((adbc_validation::MakeBatch( + &schema.value, &batch.value, static_cast(nullptr), + {-1, 0, 1, std::nullopt})), + adbc_validation::IsOkErrno()); + + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, + "schematable", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_CREATE, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_DB_SCHEMA, + schema_name.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, &batch.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + IsOkStatus(&error)); + } +} + TEST_F(PostgresStatementTest, SqlIngestTemporaryTable) { ASSERT_THAT(quirks()->DropTempTable(&connection, "temptable", &error), IsOkStatus(&error)); diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc index 24086ce364..1f08fce16a 100644 --- a/c/driver/postgresql/statement.cc +++ b/c/driver/postgresql/statement.cc @@ -1300,14 +1300,16 @@ AdbcStatusCode PostgresStatement::SetOption(const char* key, const char* value, prepared_ = false; } else if (std::strcmp(key, ADBC_INGEST_OPTION_TEMPORARY) == 0) { if (std::strcmp(value, ADBC_OPTION_VALUE_ENABLED) == 0) { + // https://github.com/apache/arrow-adbc/issues/1109: only clear the + // schema if enabling since Python always sets the flag explicitly ingest_.temporary = true; + ingest_.db_schema.clear(); } else if (std::strcmp(value, ADBC_OPTION_VALUE_DISABLED) == 0) { ingest_.temporary = false; } else { SetError(error, "[libpq] Invalid value '%s' for option '%s'", value, key); return ADBC_STATUS_INVALID_ARGUMENT; } - ingest_.db_schema.clear(); prepared_ = false; } else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 0) { int64_t int_value = std::atol(value); diff --git a/python/adbc_driver_postgresql/tests/test_dbapi.py b/python/adbc_driver_postgresql/tests/test_dbapi.py index 1e31747269..2a132bd4a7 100644 --- a/python/adbc_driver_postgresql/tests/test_dbapi.py +++ b/python/adbc_driver_postgresql/tests/test_dbapi.py @@ -269,6 +269,21 @@ def test_ingest(postgres: dbapi.Connection) -> None: cur.adbc_ingest("foo", table, catalog_name="main") +def test_ingest_schema(postgres: dbapi.Connection) -> None: + table = pyarrow.Table.from_pydict({"numbers": [1, 2], "letters": ["a", "b"]}) + + with postgres.cursor() as cur: + cur.execute("CREATE SCHEMA IF NOT EXISTS testschema") + cur.execute("DROP TABLE IF EXISTS testschema.foo") + + postgres.commit() + + cur.adbc_ingest("foo", table, mode="create", db_schema_name="testschema") + + cur.execute("SELECT * FROM testschema.foo ORDER BY numbers") + assert cur.fetch_arrow_table() == table + + def test_ingest_temporary(postgres: dbapi.Connection) -> None: table = pyarrow.Table.from_pydict( {