From a17176cb574109a545c4cedd18bdb5e5f798c047 Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 26 Jun 2023 12:22:15 -0400 Subject: [PATCH] feat(c/driver/postgresql): implement ADBC 1.1.0 features - ADBC_INFO_DRIVER_ADBC_VERSION - StatementExecuteSchema (#318) - ADBC_CONNECTION_OPTION_CURRENT_{CATALOG, DB_SCHEMA) (#319) --- c/driver/common/utils.c | 13 ++ c/driver/common/utils.h | 3 + c/driver/postgresql/connection.cc | 58 ++++++++- c/driver/postgresql/connection.h | 12 ++ c/driver/postgresql/database.h | 12 ++ c/driver/postgresql/postgres_copy_reader.h | 8 +- c/driver/postgresql/postgresql.cc | 135 ++++++++++++++++++++- c/driver/postgresql/postgresql_test.cc | 27 +++-- c/driver/postgresql/statement.cc | 105 +++++++++------- c/driver/postgresql/statement.h | 14 +++ c/driver/sqlite/sqlite.c | 6 + c/validation/adbc_validation.cc | 123 +++++++++++++++++++ c/validation/adbc_validation.h | 18 +++ 13 files changed, 477 insertions(+), 57 deletions(-) diff --git a/c/driver/common/utils.c b/c/driver/common/utils.c index dfac14f5e4..eb01bc18a4 100644 --- a/c/driver/common/utils.c +++ b/c/driver/common/utils.c @@ -244,6 +244,19 @@ AdbcStatusCode AdbcConnectionGetInfoAppendString(struct ArrowArray* array, return ADBC_STATUS_OK; } +AdbcStatusCode AdbcConnectionGetInfoAppendInt(struct ArrowArray* array, + uint32_t info_code, int64_t info_value, + struct AdbcError* error) { + CHECK_NA(INTERNAL, ArrowArrayAppendUInt(array->children[0], info_code), error); + // Append to type variant + CHECK_NA(INTERNAL, ArrowArrayAppendInt(array->children[1]->children[2], info_value), + error); + // Append type code/offset + CHECK_NA(INTERNAL, ArrowArrayFinishUnionElement(array->children[1], /*type_id=*/2), + error); + return ADBC_STATUS_OK; +} + AdbcStatusCode AdbcInitConnectionObjectsSchema(struct ArrowSchema* schema, struct AdbcError* error) { ArrowSchemaInit(schema); diff --git a/c/driver/common/utils.h b/c/driver/common/utils.h index 5735bb945f..381c7b05ee 100644 --- a/c/driver/common/utils.h +++ b/c/driver/common/utils.h @@ -117,6 +117,9 @@ AdbcStatusCode AdbcConnectionGetInfoAppendString(struct ArrowArray* array, uint32_t info_code, const char* info_value, struct AdbcError* error); +AdbcStatusCode AdbcConnectionGetInfoAppendInt(struct ArrowArray* array, + uint32_t info_code, int64_t info_value, + struct AdbcError* error); AdbcStatusCode AdbcInitConnectionObjectsSchema(struct ArrowSchema* schema, struct AdbcError* error); diff --git a/c/driver/postgresql/connection.cc b/c/driver/postgresql/connection.cc index 1b256b6761..e5435c1911 100644 --- a/c/driver/postgresql/connection.cc +++ b/c/driver/postgresql/connection.cc @@ -36,8 +36,9 @@ namespace { static const uint32_t kSupportedInfoCodes[] = { - ADBC_INFO_VENDOR_NAME, ADBC_INFO_VENDOR_VERSION, ADBC_INFO_DRIVER_NAME, - ADBC_INFO_DRIVER_VERSION, ADBC_INFO_DRIVER_ARROW_VERSION, + ADBC_INFO_VENDOR_NAME, ADBC_INFO_VENDOR_VERSION, + ADBC_INFO_DRIVER_NAME, ADBC_INFO_DRIVER_VERSION, + ADBC_INFO_DRIVER_ARROW_VERSION, ADBC_INFO_DRIVER_ADBC_VERSION, }; static const std::unordered_map kPgTableTypes = { @@ -771,6 +772,10 @@ AdbcStatusCode PostgresConnectionGetInfoImpl(const uint32_t* info_codes, RAISE_ADBC(AdbcConnectionGetInfoAppendString(array, info_codes[i], NANOARROW_VERSION, error)); break; + case ADBC_INFO_DRIVER_ADBC_VERSION: + RAISE_ADBC(AdbcConnectionGetInfoAppendInt(array, info_codes[i], + ADBC_VERSION_1_1_0, error)); + break; default: // Ignore continue; @@ -831,6 +836,55 @@ AdbcStatusCode PostgresConnection::GetObjects( return BatchToArrayStream(&array, &schema, out, error); } +AdbcStatusCode PostgresConnection::GetOption(const char* option, char* value, + size_t* length, struct AdbcError* error) { + std::string output; + if (std::strcmp(option, ADBC_CONNECTION_OPTION_CURRENT_CATALOG) == 0) { + PqResultHelper result_helper{conn_, "SELECT CURRENT_CATALOG", {}, error}; + RAISE_ADBC(result_helper.Prepare()); + RAISE_ADBC(result_helper.Execute()); + auto it = result_helper.begin(); + if (it == result_helper.end()) { + SetError(error, "[libpq] PostgreSQL returned no rows for 'SELECT CURRENT_CATALOG'"); + return ADBC_STATUS_INTERNAL; + } + output = (*it)[0].data; + } else if (std::strcmp(option, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA) == 0) { + PqResultHelper result_helper{conn_, "SELECT CURRENT_SCHEMA", {}, error}; + RAISE_ADBC(result_helper.Prepare()); + RAISE_ADBC(result_helper.Execute()); + auto it = result_helper.begin(); + if (it == result_helper.end()) { + SetError(error, "[libpq] PostgreSQL returned no rows for 'SELECT CURRENT_SCHEMA'"); + return ADBC_STATUS_INTERNAL; + } + output = (*it)[0].data; + } else if (std::strcmp(option, ADBC_CONNECTION_OPTION_AUTOCOMMIT) == 0) { + output = autocommit_ ? ADBC_OPTION_VALUE_ENABLED : ADBC_OPTION_VALUE_DISABLED; + } else { + return ADBC_STATUS_NOT_FOUND; + } + + if (output.size() + 1 <= *length) { + std::memcpy(value, output.c_str(), output.size() + 1); + } + *length = output.size() + 1; + return ADBC_STATUS_OK; +} +AdbcStatusCode PostgresConnection::GetOptionBytes(const char* option, uint8_t* value, + size_t* length, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} +AdbcStatusCode PostgresConnection::GetOptionInt(const char* option, int64_t* value, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} +AdbcStatusCode PostgresConnection::GetOptionDouble(const char* option, double* value, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode PostgresConnection::GetTableSchema(const char* catalog, const char* db_schema, const char* table_name, diff --git a/c/driver/postgresql/connection.h b/c/driver/postgresql/connection.h index 74315ee053..c39d742a20 100644 --- a/c/driver/postgresql/connection.h +++ b/c/driver/postgresql/connection.h @@ -40,6 +40,14 @@ class PostgresConnection { const char* table_name, const char** table_types, const char* column_name, struct ArrowArrayStream* out, struct AdbcError* error); + AdbcStatusCode GetOption(const char* option, char* value, size_t* length, + struct AdbcError* error); + AdbcStatusCode GetOptionBytes(const char* option, uint8_t* value, size_t* length, + struct AdbcError* error); + AdbcStatusCode GetOptionInt(const char* option, int64_t* value, + struct AdbcError* error); + AdbcStatusCode GetOptionDouble(const char* option, double* value, + struct AdbcError* error); AdbcStatusCode GetTableSchema(const char* catalog, const char* db_schema, const char* table_name, struct ArrowSchema* schema, struct AdbcError* error); @@ -49,6 +57,10 @@ class PostgresConnection { AdbcStatusCode Release(struct AdbcError* error); AdbcStatusCode Rollback(struct AdbcError* error); AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error); + AdbcStatusCode SetOptionBytes(const char* key, const uint8_t* value, size_t length, + struct AdbcError* error); + AdbcStatusCode SetOptionInt(const char* key, int64_t value, struct AdbcError* error); + AdbcStatusCode SetOptionDouble(const char* key, double value, struct AdbcError* error); PGconn* conn() const { return conn_; } const std::shared_ptr& type_resolver() const { diff --git a/c/driver/postgresql/database.h b/c/driver/postgresql/database.h index f10464787a..233182cc00 100644 --- a/c/driver/postgresql/database.h +++ b/c/driver/postgresql/database.h @@ -36,7 +36,19 @@ class PostgresDatabase { AdbcStatusCode Init(struct AdbcError* error); AdbcStatusCode Release(struct AdbcError* error); + AdbcStatusCode GetOption(const char* option, char* value, size_t* length, + struct AdbcError* error); + AdbcStatusCode GetOptionBytes(const char* option, uint8_t* value, size_t* length, + struct AdbcError* error); + AdbcStatusCode GetOptionInt(const char* option, int64_t* value, + struct AdbcError* error); + AdbcStatusCode GetOptionDouble(const char* option, double* value, + struct AdbcError* error); AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error); + AdbcStatusCode SetOptionBytes(const char* key, const uint8_t* value, size_t length, + struct AdbcError* error); + AdbcStatusCode SetOptionInt(const char* key, int64_t value, struct AdbcError* error); + AdbcStatusCode SetOptionDouble(const char* key, double value, struct AdbcError* error); // Internal implementation diff --git a/c/driver/postgresql/postgres_copy_reader.h b/c/driver/postgresql/postgres_copy_reader.h index 18d1fbd48c..e6844aaa44 100644 --- a/c/driver/postgresql/postgres_copy_reader.h +++ b/c/driver/postgresql/postgres_copy_reader.h @@ -673,12 +673,13 @@ static inline ArrowErrorCode MakeCopyFieldReader(const PostgresType& pg_type, class PostgresCopyStreamReader { public: - ArrowErrorCode Init(const PostgresType& pg_type) { + ArrowErrorCode Init(PostgresType pg_type) { if (pg_type.type_id() != PostgresTypeId::kRecord) { return EINVAL; } - root_reader_.Init(pg_type); + pg_type_ = std::move(pg_type); + root_reader_.Init(pg_type_); array_size_approx_bytes_ = 0; return NANOARROW_OK; } @@ -802,7 +803,10 @@ class PostgresCopyStreamReader { return NANOARROW_OK; } + const PostgresType& pg_type() const { return pg_type_; } + private: + PostgresType pg_type_; PostgresCopyFieldTupleReader root_reader_; nanoarrow::UniqueSchema schema_; nanoarrow::UniqueArray array_; diff --git a/c/driver/postgresql/postgresql.cc b/c/driver/postgresql/postgresql.cc index 95e6c8b881..48e884936f 100644 --- a/c/driver/postgresql/postgresql.cc +++ b/c/driver/postgresql/postgresql.cc @@ -142,6 +142,42 @@ AdbcStatusCode PostgresConnectionGetObjects( table_types, column_name, stream, error); } +AdbcStatusCode PostgresConnectionGetOption(struct AdbcConnection* connection, + const char* key, char* value, size_t* length, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->GetOption(key, value, length, error); +} + +AdbcStatusCode PostgresConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->GetOptionBytes(key, value, length, error); +} + +AdbcStatusCode PostgresConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->GetOptionInt(key, value, error); +} + +AdbcStatusCode PostgresConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->GetOptionDouble(key, value, error); +} + AdbcStatusCode PostgresConnectionGetTableSchema( struct AdbcConnection* connection, const char* catalog, const char* db_schema, const char* table_name, struct ArrowSchema* schema, struct AdbcError* error) { @@ -213,6 +249,33 @@ AdbcStatusCode PostgresConnectionSetOption(struct AdbcConnection* connection, return (*ptr)->SetOption(key, value, error); } +AdbcStatusCode PostgresConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->SetOptionBytes(key, value, length, error); +} + +AdbcStatusCode PostgresConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->SetOptionInt(key, value, error); +} + +AdbcStatusCode PostgresConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->SetOptionDouble(key, value, error); +} + } // namespace AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, struct AdbcError* error) { @@ -237,6 +300,30 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d table_types, column_name, stream, error); } +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return PostgresConnectionGetOption(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return PostgresConnectionGetOptionBytes(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + return PostgresConnectionGetOptionInt(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + return PostgresConnectionGetOptionDouble(connection, key, value, error); +} + AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, const char* catalog, const char* db_schema, const char* table_name, @@ -287,6 +374,24 @@ AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const return PostgresConnectionSetOption(connection, key, value, error); } +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return PostgresConnectionSetOptionBytes(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, size_t length, + struct AdbcError* error) { + return PostgresConnectionSetOptionInt(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, size_t length, + struct AdbcError* error) { + return PostgresConnectionSetOptionDouble(connection, key, value, error); +} + // --------------------------------------------------------------------- // AdbcStatement @@ -329,6 +434,15 @@ AdbcStatusCode PostgresStatementExecuteQuery(struct AdbcStatement* statement, return (*ptr)->ExecuteQuery(output, rows_affected, error); } +AdbcStatusCode PostgresStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->ExecuteSchema(schema, error); +} + AdbcStatusCode PostgresStatementGetPartitionDesc(struct AdbcStatement* statement, uint8_t* partition_desc, struct AdbcError* error) { @@ -423,6 +537,11 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, return PostgresStatementExecuteQuery(statement, output, rows_affected, error); } +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + ArrowSchema* schema, struct AdbcError* error) { + return PostgresStatementExecuteSchema(statement, schema, error); +} + AdbcStatusCode AdbcStatementGetPartitionDesc(struct AdbcStatement* statement, uint8_t* partition_desc, struct AdbcError* error) { @@ -474,7 +593,20 @@ AdbcStatusCode AdbcDriverInit(int version, void* raw_driver, struct AdbcError* e if (!raw_driver) return ADBC_STATUS_INVALID_ARGUMENT; auto* driver = reinterpret_cast(raw_driver); - std::memset(driver, 0, ADBC_DRIVER_1_0_0_SIZE); + if (version >= ADBC_VERSION_1_1_0) { + std::memset(driver, 0, ADBC_DRIVER_1_1_0_SIZE); + driver->StatementExecuteSchema = PostgresStatementExecuteSchema; + driver->ConnectionGetOption = PostgresConnectionGetOption; + driver->ConnectionGetOptionBytes = PostgresConnectionGetOptionBytes; + driver->ConnectionGetOptionInt = PostgresConnectionGetOptionInt; + driver->ConnectionGetOptionDouble = PostgresConnectionGetOptionDouble; + driver->ConnectionSetOptionBytes = PostgresConnectionSetOptionBytes; + driver->ConnectionSetOptionInt = PostgresConnectionSetOptionInt; + driver->ConnectionSetOptionDouble = PostgresConnectionSetOptionDouble; + } else { + std::memset(driver, 0, ADBC_DRIVER_1_0_0_SIZE); + } + driver->DatabaseInit = PostgresDatabaseInit; driver->DatabaseNew = PostgresDatabaseNew; driver->DatabaseRelease = PostgresDatabaseRelease; @@ -502,6 +634,7 @@ AdbcStatusCode AdbcDriverInit(int version, void* raw_driver, struct AdbcError* e driver->StatementRelease = PostgresStatementRelease; driver->StatementSetOption = PostgresStatementSetOption; driver->StatementSetSqlQuery = PostgresStatementSetSqlQuery; + return ADBC_STATUS_OK; } } diff --git a/c/driver/postgresql/postgresql_test.cc b/c/driver/postgresql/postgresql_test.cc index 153d8eb221..a4af67add1 100644 --- a/c/driver/postgresql/postgresql_test.cc +++ b/c/driver/postgresql/postgresql_test.cc @@ -103,6 +103,8 @@ class PostgresQuirks : public adbc_validation::DriverQuirks { std::string catalog() const override { return "postgres"; } std::string db_schema() const override { return "public"; } + + bool supports_execute_schema() const override { return true; } }; class PostgresDatabaseTest : public ::testing::Test, @@ -134,10 +136,8 @@ TEST_F(PostgresConnectionTest, GetInfoMetadata) { adbc_validation::StreamReader reader; std::vector info = { - ADBC_INFO_DRIVER_NAME, - ADBC_INFO_DRIVER_VERSION, - ADBC_INFO_VENDOR_NAME, - ADBC_INFO_VENDOR_VERSION, + ADBC_INFO_DRIVER_NAME, ADBC_INFO_DRIVER_VERSION, ADBC_INFO_DRIVER_ADBC_VERSION, + ADBC_INFO_VENDOR_NAME, ADBC_INFO_VENDOR_VERSION, }; ASSERT_THAT(AdbcConnectionGetInfo(&connection, info.data(), info.size(), &reader.stream.value, &error), @@ -153,29 +153,30 @@ TEST_F(PostgresConnectionTest, GetInfoMetadata) { ASSERT_FALSE(ArrowArrayViewIsNull(reader.array_view->children[0], row)); const uint32_t code = reader.array_view->children[0]->buffer_views[1].data.as_uint32[row]; + const uint32_t offset = + reader.array_view->children[1]->buffer_views[1].data.as_int32[row]; seen.push_back(code); - int str_child_index = 0; - struct ArrowArrayView* str_child = - reader.array_view->children[1]->children[str_child_index]; + struct ArrowArrayView* str_child = reader.array_view->children[1]->children[0]; + struct ArrowArrayView* int_child = reader.array_view->children[1]->children[2]; switch (code) { case ADBC_INFO_DRIVER_NAME: { - ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, 0); + ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, offset); EXPECT_EQ("ADBC PostgreSQL Driver", std::string(val.data, val.size_bytes)); break; } case ADBC_INFO_DRIVER_VERSION: { - ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, 1); + ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, offset); EXPECT_EQ("(unknown)", std::string(val.data, val.size_bytes)); break; } case ADBC_INFO_VENDOR_NAME: { - ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, 2); + ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, offset); EXPECT_EQ("PostgreSQL", std::string(val.data, val.size_bytes)); break; } case ADBC_INFO_VENDOR_VERSION: { - ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, 3); + ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, offset); #ifdef __WIN32 const char* pater = "\\d\\d\\d\\d\\d\\d"; #else @@ -185,6 +186,10 @@ TEST_F(PostgresConnectionTest, GetInfoMetadata) { ::testing::MatchesRegex(pater)); break; } + case ADBC_INFO_DRIVER_ADBC_VERSION: { + EXPECT_EQ(ADBC_VERSION_1_1_0, ArrowArrayViewGetIntUnsafe(int_child, offset)); + break; + } default: // Ignored break; diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc index 4cae15b631..e566dd5d69 100644 --- a/c/driver/postgresql/statement.cc +++ b/c/driver/postgresql/statement.cc @@ -867,50 +867,12 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream, // 1. Prepare the query to get the schema { - // TODO: we should pipeline here and assume this will succeed - PGresult* result = PQprepare(connection_->conn(), /*stmtName=*/"", query_.c_str(), - /*nParams=*/0, nullptr); - if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error, - "[libpq] Failed to execute query: could not infer schema: failed to " - "prepare query: %s\nQuery was:%s", - PQerrorMessage(connection_->conn()), query_.c_str()); - PQclear(result); - return ADBC_STATUS_IO; - } - PQclear(result); - result = PQdescribePrepared(connection_->conn(), /*stmtName=*/""); - if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error, - "[libpq] Failed to execute query: could not infer schema: failed to " - "describe prepared statement: %s\nQuery was:%s", - PQerrorMessage(connection_->conn()), query_.c_str()); - PQclear(result); - return ADBC_STATUS_IO; - } - - // Resolve the information from the PGresult into a PostgresType - PostgresType root_type; - AdbcStatusCode status = - ResolvePostgresType(*type_resolver_, result, &root_type, error); - PQclear(result); - if (status != ADBC_STATUS_OK) return status; - - // Initialize the copy reader and infer the output schema (i.e., error for - // unsupported types before issuing the COPY query) - reader_.copy_reader_.reset(new PostgresCopyStreamReader()); - reader_.copy_reader_->Init(root_type); - struct ArrowError na_error; - int na_res = reader_.copy_reader_->InferOutputSchema(&na_error); - if (na_res != NANOARROW_OK) { - SetError(error, "[libpq] Failed to infer output schema: %s", na_error.message); - return na_res; - } + RAISE_ADBC(SetupReader(error)); // If the caller did not request a result set or if there are no // inferred output columns (e.g. a CREATE or UPDATE), then don't // use COPY (which would fail anyways) - if (!stream || root_type.n_children() == 0) { + if (!stream || reader_.copy_reader_->pg_type().n_children() == 0) { RAISE_ADBC(ExecuteUpdateQuery(rows_affected, error)); if (stream) { struct ArrowSchema schema; @@ -924,7 +886,8 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream, // This resolves the reader specific to each PostgresType -> ArrowSchema // conversion. It is unlikely that this will fail given that we have just // inferred these conversions ourselves. - na_res = reader_.copy_reader_->InitFieldReaders(&na_error); + struct ArrowError na_error; + int na_res = reader_.copy_reader_->InitFieldReaders(&na_error); if (na_res != NANOARROW_OK) { SetError(error, "[libpq] Failed to initialize field readers: %s", na_error.message); return na_res; @@ -953,6 +916,23 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream, return ADBC_STATUS_OK; } +AdbcStatusCode PostgresStatement::ExecuteSchema(struct ArrowSchema* schema, + struct AdbcError* error) { + ClearResult(); + if (query_.empty()) { + SetError(error, "%s", "[libpq] Must SetSqlQuery before ExecuteQuery"); + return ADBC_STATUS_INVALID_STATE; + } else if (bind_.release) { + // TODO: if we have parameters, bind them (since they can affect the output schema) + SetError(error, "[libpq] ExecuteSchema with parameters is not implemented"); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + RAISE_ADBC(SetupReader(error)); + CHECK_NA(INTERNAL, reader_.copy_reader_->GetSchema(schema), error); + return ADBC_STATUS_OK; +} + AdbcStatusCode PostgresStatement::ExecuteUpdateBulk(int64_t* rows_affected, struct AdbcError* error) { if (!bind_.release) { @@ -1070,6 +1050,49 @@ AdbcStatusCode PostgresStatement::SetOption(const char* key, const char* value, return ADBC_STATUS_OK; } +AdbcStatusCode PostgresStatement::SetupReader(struct AdbcError* error) { + // TODO: we should pipeline here and assume this will succeed + PGresult* result = PQprepare(connection_->conn(), /*stmtName=*/"", query_.c_str(), + /*nParams=*/0, nullptr); + if (PQresultStatus(result) != PGRES_COMMAND_OK) { + SetError(error, + "[libpq] Failed to execute query: could not infer schema: failed to " + "prepare query: %s\nQuery was:%s", + PQerrorMessage(connection_->conn()), query_.c_str()); + PQclear(result); + return ADBC_STATUS_IO; + } + PQclear(result); + result = PQdescribePrepared(connection_->conn(), /*stmtName=*/""); + if (PQresultStatus(result) != PGRES_COMMAND_OK) { + SetError(error, + "[libpq] Failed to execute query: could not infer schema: failed to " + "describe prepared statement: %s\nQuery was:%s", + PQerrorMessage(connection_->conn()), query_.c_str()); + PQclear(result); + return ADBC_STATUS_IO; + } + + // Resolve the information from the PGresult into a PostgresType + PostgresType root_type; + AdbcStatusCode status = ResolvePostgresType(*type_resolver_, result, &root_type, error); + PQclear(result); + if (status != ADBC_STATUS_OK) return status; + + // Initialize the copy reader and infer the output schema (i.e., error for + // unsupported types before issuing the COPY query) + reader_.copy_reader_.reset(new PostgresCopyStreamReader()); + reader_.copy_reader_->Init(root_type); + struct ArrowError na_error; + int na_res = reader_.copy_reader_->InferOutputSchema(&na_error); + if (na_res != NANOARROW_OK) { + SetError(error, "[libpq] Failed to infer output schema: (%d) %s: %s", na_res, + std::strerror(na_res), na_error.message); + return ADBC_STATUS_INTERNAL; + } + return ADBC_STATUS_OK; +} + void PostgresStatement::ClearResult() { // TODO: we may want to synchronize here for safety reader_.Release(); diff --git a/c/driver/postgresql/statement.h b/c/driver/postgresql/statement.h index 62af2457d5..c72446d9b7 100644 --- a/c/driver/postgresql/statement.h +++ b/c/driver/postgresql/statement.h @@ -102,11 +102,24 @@ class PostgresStatement { AdbcStatusCode Bind(struct ArrowArrayStream* stream, struct AdbcError* error); AdbcStatusCode ExecuteQuery(struct ArrowArrayStream* stream, int64_t* rows_affected, struct AdbcError* error); + AdbcStatusCode ExecuteSchema(struct ArrowSchema* schema, struct AdbcError* error); + AdbcStatusCode GetOption(const char* option, char* value, size_t* length, + struct AdbcError* error); + AdbcStatusCode GetOptionBytes(const char* option, uint8_t* value, size_t* length, + struct AdbcError* error); + AdbcStatusCode GetOptionInt(const char* option, int64_t* value, + struct AdbcError* error); + AdbcStatusCode GetOptionDouble(const char* option, double* value, + struct AdbcError* error); AdbcStatusCode GetParameterSchema(struct ArrowSchema* schema, struct AdbcError* error); AdbcStatusCode New(struct AdbcConnection* connection, struct AdbcError* error); AdbcStatusCode Prepare(struct AdbcError* error); AdbcStatusCode Release(struct AdbcError* error); AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error); + AdbcStatusCode SetOptionBytes(const char* key, const uint8_t* value, size_t length, + struct AdbcError* error); + AdbcStatusCode SetOptionInt(const char* key, int64_t value, struct AdbcError* error); + AdbcStatusCode SetOptionDouble(const char* key, double value, struct AdbcError* error); AdbcStatusCode SetSqlQuery(const char* query, struct AdbcError* error); // --------------------------------------------------------------------- @@ -122,6 +135,7 @@ class PostgresStatement { AdbcStatusCode ExecutePreparedStatement(struct ArrowArrayStream* stream, int64_t* rows_affected, struct AdbcError* error); + AdbcStatusCode SetupReader(struct AdbcError* error); private: std::shared_ptr type_resolver_; diff --git a/c/driver/sqlite/sqlite.c b/c/driver/sqlite/sqlite.c index 359c594dab..d622ceb054 100644 --- a/c/driver/sqlite/sqlite.c +++ b/c/driver/sqlite/sqlite.c @@ -1480,6 +1480,12 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, return SqliteStatementExecuteQuery(statement, out, rows_affected, error); } +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, struct AdbcError* error) { return SqliteStatementPrepare(statement, error); diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc index 803f14c023..358afce9a7 100644 --- a/c/validation/adbc_validation.cc +++ b/c/validation/adbc_validation.cc @@ -33,6 +33,7 @@ #include #include #include +#include #include "adbc_validation_util.h" @@ -247,6 +248,38 @@ void ConnectionTest::TestAutocommitToggle() { //------------------------------------------------------------ // Tests of metadata +std::optional ConnectionGetOption(struct AdbcConnection* connection, + std::string_view option, + struct AdbcError* error) { + char buffer[128]; + size_t buffer_size = sizeof(buffer); + AdbcStatusCode status = + AdbcConnectionGetOption(connection, option.data(), buffer, &buffer_size, error); + EXPECT_THAT(status, IsOkStatus(error)); + if (status != ADBC_STATUS_OK) return std::nullopt; + EXPECT_GT(buffer_size, 0); + if (buffer_size == 0) return std::nullopt; + return std::string(buffer, buffer_size - 1); +} + +void ConnectionTest::TestMetadataCurrentCatalog() { + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + + auto catalog = + ConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_CURRENT_CATALOG, &error); + ASSERT_THAT(catalog, ::testing::Optional(quirks()->catalog())); +} + +void ConnectionTest::TestMetadataCurrentDbSchema() { + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + + auto db_schema = + ConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA, &error); + ASSERT_THAT(db_schema, ::testing::Optional(quirks()->db_schema())); +} + void ConnectionTest::TestMetadataGetInfo() { ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); @@ -2042,6 +2075,11 @@ void StatementTest::TestTransactions() { ASSERT_THAT(quirks()->DropTable(&connection, "bulk_ingest", &error), IsOkStatus(&error)); + auto autocommit = + ConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_AUTOCOMMIT, &error); + ASSERT_THAT(autocommit, + ::testing::Optional(::testing::StrEq(ADBC_OPTION_VALUE_ENABLED))); + Handle connection2; ASSERT_THAT(AdbcConnectionNew(&connection2.value, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection2.value, &database, &error), @@ -2051,6 +2089,11 @@ void StatementTest::TestTransactions() { ADBC_OPTION_VALUE_DISABLED, &error), IsOkStatus(&error)); + autocommit = + ConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_AUTOCOMMIT, &error); + ASSERT_THAT(autocommit, + ::testing::Optional(::testing::StrEq(ADBC_OPTION_VALUE_DISABLED))); + // Uncommitted change ASSERT_NO_FATAL_FAILURE(IngestSampleTable(&connection, &error)); @@ -2126,6 +2169,86 @@ void StatementTest::TestTransactions() { } } +void StatementTest::TestSqlSchemaInts() { + if (!quirks()->supports_execute_schema()) { + GTEST_SKIP() << "Not supported"; + } + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42", &error), + IsOkStatus(&error)); + + nanoarrow::UniqueSchema schema; + ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error), + IsOkStatus(&error)); + + ASSERT_EQ(1, schema->n_children); + ASSERT_THAT(schema->children[0]->format, ::testing::AnyOfArray({ + ::testing::StrEq("i"), // int32 + ::testing::StrEq("l"), // int64 + })); + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + +void StatementTest::TestSqlSchemaFloats() { + if (!quirks()->supports_execute_schema()) { + GTEST_SKIP() << "Not supported"; + } + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT CAST(1.5 AS FLOAT)", &error), + IsOkStatus(&error)); + + nanoarrow::UniqueSchema schema; + ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error), + IsOkStatus(&error)); + + ASSERT_EQ(1, schema->n_children); + ASSERT_THAT(schema->children[0]->format, ::testing::AnyOfArray({ + ::testing::StrEq("f"), // float32 + ::testing::StrEq("g"), // float64 + })); + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + +void StatementTest::TestSqlSchemaStrings() { + if (!quirks()->supports_execute_schema()) { + GTEST_SKIP() << "Not supported"; + } + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 'hi'", &error), + IsOkStatus(&error)); + + nanoarrow::UniqueSchema schema; + ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error), + IsOkStatus(&error)); + + ASSERT_EQ(1, schema->n_children); + ASSERT_THAT(schema->children[0]->format, ::testing::AnyOfArray({ + ::testing::StrEq("u"), // string + ::testing::StrEq("U"), // large_string + })); + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + +void StatementTest::TestSqlSchemaErrors() { + if (!quirks()->supports_execute_schema()) { + GTEST_SKIP() << "Not supported"; + } + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + + nanoarrow::UniqueSchema schema; + ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error), + IsStatus(ADBC_STATUS_INVALID_STATE, &error)); + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + void StatementTest::TestConcurrentStatements() { Handle statement1; Handle statement2; diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h index 2a5883c00f..eca65e6f8b 100644 --- a/c/validation/adbc_validation.h +++ b/c/validation/adbc_validation.h @@ -89,6 +89,9 @@ class DriverQuirks { /// single connection virtual bool supports_concurrent_statements() const { return false; } + /// \brief Whether AdbcStatementExecuteSchema should work + virtual bool supports_execute_schema() const { return false; } + /// \brief Whether AdbcStatementExecutePartitions should work virtual bool supports_partitioned_data() const { return false; } @@ -157,6 +160,9 @@ class ConnectionTest { void TestAutocommitToggle(); + void TestMetadataCurrentCatalog(); + void TestMetadataCurrentDbSchema(); + void TestMetadataGetInfo(); void TestMetadataGetTableSchema(); void TestMetadataGetTableTypes(); @@ -183,6 +189,8 @@ class ConnectionTest { TEST_F(FIXTURE, Concurrent) { TestConcurrent(); } \ TEST_F(FIXTURE, AutocommitDefault) { TestAutocommitDefault(); } \ TEST_F(FIXTURE, AutocommitToggle) { TestAutocommitToggle(); } \ + TEST_F(FIXTURE, MetadataCurrentCatalog) { TestMetadataCurrentCatalog(); } \ + TEST_F(FIXTURE, MetadataCurrentDbSchema) { TestMetadataCurrentDbSchema(); } \ TEST_F(FIXTURE, MetadataGetInfo) { TestMetadataGetInfo(); } \ TEST_F(FIXTURE, MetadataGetTableSchema) { TestMetadataGetTableSchema(); } \ TEST_F(FIXTURE, MetadataGetTableTypes) { TestMetadataGetTableTypes(); } \ @@ -257,6 +265,12 @@ class StatementTest { void TestSqlQueryErrors(); + void TestSqlSchemaInts(); + void TestSqlSchemaFloats(); + void TestSqlSchemaStrings(); + + void TestSqlSchemaErrors(); + void TestTransactions(); void TestConcurrentStatements(); @@ -316,6 +330,10 @@ class StatementTest { TEST_F(FIXTURE, SqlQueryFloats) { TestSqlQueryFloats(); } \ TEST_F(FIXTURE, SqlQueryStrings) { TestSqlQueryStrings(); } \ TEST_F(FIXTURE, SqlQueryErrors) { TestSqlQueryErrors(); } \ + TEST_F(FIXTURE, SqlSchemaInts) { TestSqlSchemaInts(); } \ + TEST_F(FIXTURE, SqlSchemaFloats) { TestSqlSchemaFloats(); } \ + TEST_F(FIXTURE, SqlSchemaStrings) { TestSqlSchemaStrings(); } \ + TEST_F(FIXTURE, SqlSchemaErrors) { TestSqlSchemaErrors(); } \ TEST_F(FIXTURE, Transactions) { TestTransactions(); } \ TEST_F(FIXTURE, ConcurrentStatements) { TestConcurrentStatements(); } \ TEST_F(FIXTURE, ResultInvalidation) { TestResultInvalidation(); }