From a17176cb574109a545c4cedd18bdb5e5f798c047 Mon Sep 17 00:00:00 2001
From: David Li
Date: Mon, 26 Jun 2023 12:22:15 -0400
Subject: [PATCH] feat(c/driver/postgresql): implement ADBC 1.1.0 features
- ADBC_INFO_DRIVER_ADBC_VERSION
- StatementExecuteSchema (#318)
- ADBC_CONNECTION_OPTION_CURRENT_{CATALOG, DB_SCHEMA) (#319)
---
c/driver/common/utils.c | 13 ++
c/driver/common/utils.h | 3 +
c/driver/postgresql/connection.cc | 58 ++++++++-
c/driver/postgresql/connection.h | 12 ++
c/driver/postgresql/database.h | 12 ++
c/driver/postgresql/postgres_copy_reader.h | 8 +-
c/driver/postgresql/postgresql.cc | 135 ++++++++++++++++++++-
c/driver/postgresql/postgresql_test.cc | 27 +++--
c/driver/postgresql/statement.cc | 105 +++++++++-------
c/driver/postgresql/statement.h | 14 +++
c/driver/sqlite/sqlite.c | 6 +
c/validation/adbc_validation.cc | 123 +++++++++++++++++++
c/validation/adbc_validation.h | 18 +++
13 files changed, 477 insertions(+), 57 deletions(-)
diff --git a/c/driver/common/utils.c b/c/driver/common/utils.c
index dfac14f5e4..eb01bc18a4 100644
--- a/c/driver/common/utils.c
+++ b/c/driver/common/utils.c
@@ -244,6 +244,19 @@ AdbcStatusCode AdbcConnectionGetInfoAppendString(struct ArrowArray* array,
return ADBC_STATUS_OK;
}
+AdbcStatusCode AdbcConnectionGetInfoAppendInt(struct ArrowArray* array,
+ uint32_t info_code, int64_t info_value,
+ struct AdbcError* error) {
+ CHECK_NA(INTERNAL, ArrowArrayAppendUInt(array->children[0], info_code), error);
+ // Append to type variant
+ CHECK_NA(INTERNAL, ArrowArrayAppendInt(array->children[1]->children[2], info_value),
+ error);
+ // Append type code/offset
+ CHECK_NA(INTERNAL, ArrowArrayFinishUnionElement(array->children[1], /*type_id=*/2),
+ error);
+ return ADBC_STATUS_OK;
+}
+
AdbcStatusCode AdbcInitConnectionObjectsSchema(struct ArrowSchema* schema,
struct AdbcError* error) {
ArrowSchemaInit(schema);
diff --git a/c/driver/common/utils.h b/c/driver/common/utils.h
index 5735bb945f..381c7b05ee 100644
--- a/c/driver/common/utils.h
+++ b/c/driver/common/utils.h
@@ -117,6 +117,9 @@ AdbcStatusCode AdbcConnectionGetInfoAppendString(struct ArrowArray* array,
uint32_t info_code,
const char* info_value,
struct AdbcError* error);
+AdbcStatusCode AdbcConnectionGetInfoAppendInt(struct ArrowArray* array,
+ uint32_t info_code, int64_t info_value,
+ struct AdbcError* error);
AdbcStatusCode AdbcInitConnectionObjectsSchema(struct ArrowSchema* schema,
struct AdbcError* error);
diff --git a/c/driver/postgresql/connection.cc b/c/driver/postgresql/connection.cc
index 1b256b6761..e5435c1911 100644
--- a/c/driver/postgresql/connection.cc
+++ b/c/driver/postgresql/connection.cc
@@ -36,8 +36,9 @@
namespace {
static const uint32_t kSupportedInfoCodes[] = {
- ADBC_INFO_VENDOR_NAME, ADBC_INFO_VENDOR_VERSION, ADBC_INFO_DRIVER_NAME,
- ADBC_INFO_DRIVER_VERSION, ADBC_INFO_DRIVER_ARROW_VERSION,
+ ADBC_INFO_VENDOR_NAME, ADBC_INFO_VENDOR_VERSION,
+ ADBC_INFO_DRIVER_NAME, ADBC_INFO_DRIVER_VERSION,
+ ADBC_INFO_DRIVER_ARROW_VERSION, ADBC_INFO_DRIVER_ADBC_VERSION,
};
static const std::unordered_map kPgTableTypes = {
@@ -771,6 +772,10 @@ AdbcStatusCode PostgresConnectionGetInfoImpl(const uint32_t* info_codes,
RAISE_ADBC(AdbcConnectionGetInfoAppendString(array, info_codes[i],
NANOARROW_VERSION, error));
break;
+ case ADBC_INFO_DRIVER_ADBC_VERSION:
+ RAISE_ADBC(AdbcConnectionGetInfoAppendInt(array, info_codes[i],
+ ADBC_VERSION_1_1_0, error));
+ break;
default:
// Ignore
continue;
@@ -831,6 +836,55 @@ AdbcStatusCode PostgresConnection::GetObjects(
return BatchToArrayStream(&array, &schema, out, error);
}
+AdbcStatusCode PostgresConnection::GetOption(const char* option, char* value,
+ size_t* length, struct AdbcError* error) {
+ std::string output;
+ if (std::strcmp(option, ADBC_CONNECTION_OPTION_CURRENT_CATALOG) == 0) {
+ PqResultHelper result_helper{conn_, "SELECT CURRENT_CATALOG", {}, error};
+ RAISE_ADBC(result_helper.Prepare());
+ RAISE_ADBC(result_helper.Execute());
+ auto it = result_helper.begin();
+ if (it == result_helper.end()) {
+ SetError(error, "[libpq] PostgreSQL returned no rows for 'SELECT CURRENT_CATALOG'");
+ return ADBC_STATUS_INTERNAL;
+ }
+ output = (*it)[0].data;
+ } else if (std::strcmp(option, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA) == 0) {
+ PqResultHelper result_helper{conn_, "SELECT CURRENT_SCHEMA", {}, error};
+ RAISE_ADBC(result_helper.Prepare());
+ RAISE_ADBC(result_helper.Execute());
+ auto it = result_helper.begin();
+ if (it == result_helper.end()) {
+ SetError(error, "[libpq] PostgreSQL returned no rows for 'SELECT CURRENT_SCHEMA'");
+ return ADBC_STATUS_INTERNAL;
+ }
+ output = (*it)[0].data;
+ } else if (std::strcmp(option, ADBC_CONNECTION_OPTION_AUTOCOMMIT) == 0) {
+ output = autocommit_ ? ADBC_OPTION_VALUE_ENABLED : ADBC_OPTION_VALUE_DISABLED;
+ } else {
+ return ADBC_STATUS_NOT_FOUND;
+ }
+
+ if (output.size() + 1 <= *length) {
+ std::memcpy(value, output.c_str(), output.size() + 1);
+ }
+ *length = output.size() + 1;
+ return ADBC_STATUS_OK;
+}
+AdbcStatusCode PostgresConnection::GetOptionBytes(const char* option, uint8_t* value,
+ size_t* length,
+ struct AdbcError* error) {
+ return ADBC_STATUS_NOT_FOUND;
+}
+AdbcStatusCode PostgresConnection::GetOptionInt(const char* option, int64_t* value,
+ struct AdbcError* error) {
+ return ADBC_STATUS_NOT_FOUND;
+}
+AdbcStatusCode PostgresConnection::GetOptionDouble(const char* option, double* value,
+ struct AdbcError* error) {
+ return ADBC_STATUS_NOT_FOUND;
+}
+
AdbcStatusCode PostgresConnection::GetTableSchema(const char* catalog,
const char* db_schema,
const char* table_name,
diff --git a/c/driver/postgresql/connection.h b/c/driver/postgresql/connection.h
index 74315ee053..c39d742a20 100644
--- a/c/driver/postgresql/connection.h
+++ b/c/driver/postgresql/connection.h
@@ -40,6 +40,14 @@ class PostgresConnection {
const char* table_name, const char** table_types,
const char* column_name, struct ArrowArrayStream* out,
struct AdbcError* error);
+ AdbcStatusCode GetOption(const char* option, char* value, size_t* length,
+ struct AdbcError* error);
+ AdbcStatusCode GetOptionBytes(const char* option, uint8_t* value, size_t* length,
+ struct AdbcError* error);
+ AdbcStatusCode GetOptionInt(const char* option, int64_t* value,
+ struct AdbcError* error);
+ AdbcStatusCode GetOptionDouble(const char* option, double* value,
+ struct AdbcError* error);
AdbcStatusCode GetTableSchema(const char* catalog, const char* db_schema,
const char* table_name, struct ArrowSchema* schema,
struct AdbcError* error);
@@ -49,6 +57,10 @@ class PostgresConnection {
AdbcStatusCode Release(struct AdbcError* error);
AdbcStatusCode Rollback(struct AdbcError* error);
AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error);
+ AdbcStatusCode SetOptionBytes(const char* key, const uint8_t* value, size_t length,
+ struct AdbcError* error);
+ AdbcStatusCode SetOptionInt(const char* key, int64_t value, struct AdbcError* error);
+ AdbcStatusCode SetOptionDouble(const char* key, double value, struct AdbcError* error);
PGconn* conn() const { return conn_; }
const std::shared_ptr& type_resolver() const {
diff --git a/c/driver/postgresql/database.h b/c/driver/postgresql/database.h
index f10464787a..233182cc00 100644
--- a/c/driver/postgresql/database.h
+++ b/c/driver/postgresql/database.h
@@ -36,7 +36,19 @@ class PostgresDatabase {
AdbcStatusCode Init(struct AdbcError* error);
AdbcStatusCode Release(struct AdbcError* error);
+ AdbcStatusCode GetOption(const char* option, char* value, size_t* length,
+ struct AdbcError* error);
+ AdbcStatusCode GetOptionBytes(const char* option, uint8_t* value, size_t* length,
+ struct AdbcError* error);
+ AdbcStatusCode GetOptionInt(const char* option, int64_t* value,
+ struct AdbcError* error);
+ AdbcStatusCode GetOptionDouble(const char* option, double* value,
+ struct AdbcError* error);
AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error);
+ AdbcStatusCode SetOptionBytes(const char* key, const uint8_t* value, size_t length,
+ struct AdbcError* error);
+ AdbcStatusCode SetOptionInt(const char* key, int64_t value, struct AdbcError* error);
+ AdbcStatusCode SetOptionDouble(const char* key, double value, struct AdbcError* error);
// Internal implementation
diff --git a/c/driver/postgresql/postgres_copy_reader.h b/c/driver/postgresql/postgres_copy_reader.h
index 18d1fbd48c..e6844aaa44 100644
--- a/c/driver/postgresql/postgres_copy_reader.h
+++ b/c/driver/postgresql/postgres_copy_reader.h
@@ -673,12 +673,13 @@ static inline ArrowErrorCode MakeCopyFieldReader(const PostgresType& pg_type,
class PostgresCopyStreamReader {
public:
- ArrowErrorCode Init(const PostgresType& pg_type) {
+ ArrowErrorCode Init(PostgresType pg_type) {
if (pg_type.type_id() != PostgresTypeId::kRecord) {
return EINVAL;
}
- root_reader_.Init(pg_type);
+ pg_type_ = std::move(pg_type);
+ root_reader_.Init(pg_type_);
array_size_approx_bytes_ = 0;
return NANOARROW_OK;
}
@@ -802,7 +803,10 @@ class PostgresCopyStreamReader {
return NANOARROW_OK;
}
+ const PostgresType& pg_type() const { return pg_type_; }
+
private:
+ PostgresType pg_type_;
PostgresCopyFieldTupleReader root_reader_;
nanoarrow::UniqueSchema schema_;
nanoarrow::UniqueArray array_;
diff --git a/c/driver/postgresql/postgresql.cc b/c/driver/postgresql/postgresql.cc
index 95e6c8b881..48e884936f 100644
--- a/c/driver/postgresql/postgresql.cc
+++ b/c/driver/postgresql/postgresql.cc
@@ -142,6 +142,42 @@ AdbcStatusCode PostgresConnectionGetObjects(
table_types, column_name, stream, error);
}
+AdbcStatusCode PostgresConnectionGetOption(struct AdbcConnection* connection,
+ const char* key, char* value, size_t* length,
+ struct AdbcError* error) {
+ if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto ptr =
+ reinterpret_cast*>(connection->private_data);
+ return (*ptr)->GetOption(key, value, length, error);
+}
+
+AdbcStatusCode PostgresConnectionGetOptionBytes(struct AdbcConnection* connection,
+ const char* key, uint8_t* value,
+ size_t* length, struct AdbcError* error) {
+ if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto ptr =
+ reinterpret_cast*>(connection->private_data);
+ return (*ptr)->GetOptionBytes(key, value, length, error);
+}
+
+AdbcStatusCode PostgresConnectionGetOptionInt(struct AdbcConnection* connection,
+ const char* key, int64_t* value,
+ struct AdbcError* error) {
+ if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto ptr =
+ reinterpret_cast*>(connection->private_data);
+ return (*ptr)->GetOptionInt(key, value, error);
+}
+
+AdbcStatusCode PostgresConnectionGetOptionDouble(struct AdbcConnection* connection,
+ const char* key, double* value,
+ struct AdbcError* error) {
+ if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto ptr =
+ reinterpret_cast*>(connection->private_data);
+ return (*ptr)->GetOptionDouble(key, value, error);
+}
+
AdbcStatusCode PostgresConnectionGetTableSchema(
struct AdbcConnection* connection, const char* catalog, const char* db_schema,
const char* table_name, struct ArrowSchema* schema, struct AdbcError* error) {
@@ -213,6 +249,33 @@ AdbcStatusCode PostgresConnectionSetOption(struct AdbcConnection* connection,
return (*ptr)->SetOption(key, value, error);
}
+AdbcStatusCode PostgresConnectionSetOptionBytes(struct AdbcConnection* connection,
+ const char* key, const uint8_t* value,
+ size_t length, struct AdbcError* error) {
+ if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto ptr =
+ reinterpret_cast*>(connection->private_data);
+ return (*ptr)->SetOptionBytes(key, value, length, error);
+}
+
+AdbcStatusCode PostgresConnectionSetOptionInt(struct AdbcConnection* connection,
+ const char* key, int64_t value,
+ struct AdbcError* error) {
+ if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto ptr =
+ reinterpret_cast*>(connection->private_data);
+ return (*ptr)->SetOptionInt(key, value, error);
+}
+
+AdbcStatusCode PostgresConnectionSetOptionDouble(struct AdbcConnection* connection,
+ const char* key, double value,
+ struct AdbcError* error) {
+ if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto ptr =
+ reinterpret_cast*>(connection->private_data);
+ return (*ptr)->SetOptionDouble(key, value, error);
+}
+
} // namespace
AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection,
struct AdbcError* error) {
@@ -237,6 +300,30 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d
table_types, column_name, stream, error);
}
+AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key,
+ char* value, size_t* length,
+ struct AdbcError* error) {
+ return PostgresConnectionGetOption(connection, key, value, length, error);
+}
+
+AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection,
+ const char* key, uint8_t* value,
+ size_t* length, struct AdbcError* error) {
+ return PostgresConnectionGetOptionBytes(connection, key, value, length, error);
+}
+
+AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection,
+ const char* key, int64_t* value,
+ struct AdbcError* error) {
+ return PostgresConnectionGetOptionInt(connection, key, value, error);
+}
+
+AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection,
+ const char* key, double* value,
+ struct AdbcError* error) {
+ return PostgresConnectionGetOptionDouble(connection, key, value, error);
+}
+
AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection,
const char* catalog, const char* db_schema,
const char* table_name,
@@ -287,6 +374,24 @@ AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const
return PostgresConnectionSetOption(connection, key, value, error);
}
+AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection,
+ const char* key, const uint8_t* value,
+ size_t length, struct AdbcError* error) {
+ return PostgresConnectionSetOptionBytes(connection, key, value, length, error);
+}
+
+AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection,
+ const char* key, int64_t value, size_t length,
+ struct AdbcError* error) {
+ return PostgresConnectionSetOptionInt(connection, key, value, error);
+}
+
+AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection,
+ const char* key, double value, size_t length,
+ struct AdbcError* error) {
+ return PostgresConnectionSetOptionDouble(connection, key, value, error);
+}
+
// ---------------------------------------------------------------------
// AdbcStatement
@@ -329,6 +434,15 @@ AdbcStatusCode PostgresStatementExecuteQuery(struct AdbcStatement* statement,
return (*ptr)->ExecuteQuery(output, rows_affected, error);
}
+AdbcStatusCode PostgresStatementExecuteSchema(struct AdbcStatement* statement,
+ struct ArrowSchema* schema,
+ struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr =
+ reinterpret_cast*>(statement->private_data);
+ return (*ptr)->ExecuteSchema(schema, error);
+}
+
AdbcStatusCode PostgresStatementGetPartitionDesc(struct AdbcStatement* statement,
uint8_t* partition_desc,
struct AdbcError* error) {
@@ -423,6 +537,11 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement,
return PostgresStatementExecuteQuery(statement, output, rows_affected, error);
}
+AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement,
+ ArrowSchema* schema, struct AdbcError* error) {
+ return PostgresStatementExecuteSchema(statement, schema, error);
+}
+
AdbcStatusCode AdbcStatementGetPartitionDesc(struct AdbcStatement* statement,
uint8_t* partition_desc,
struct AdbcError* error) {
@@ -474,7 +593,20 @@ AdbcStatusCode AdbcDriverInit(int version, void* raw_driver, struct AdbcError* e
if (!raw_driver) return ADBC_STATUS_INVALID_ARGUMENT;
auto* driver = reinterpret_cast(raw_driver);
- std::memset(driver, 0, ADBC_DRIVER_1_0_0_SIZE);
+ if (version >= ADBC_VERSION_1_1_0) {
+ std::memset(driver, 0, ADBC_DRIVER_1_1_0_SIZE);
+ driver->StatementExecuteSchema = PostgresStatementExecuteSchema;
+ driver->ConnectionGetOption = PostgresConnectionGetOption;
+ driver->ConnectionGetOptionBytes = PostgresConnectionGetOptionBytes;
+ driver->ConnectionGetOptionInt = PostgresConnectionGetOptionInt;
+ driver->ConnectionGetOptionDouble = PostgresConnectionGetOptionDouble;
+ driver->ConnectionSetOptionBytes = PostgresConnectionSetOptionBytes;
+ driver->ConnectionSetOptionInt = PostgresConnectionSetOptionInt;
+ driver->ConnectionSetOptionDouble = PostgresConnectionSetOptionDouble;
+ } else {
+ std::memset(driver, 0, ADBC_DRIVER_1_0_0_SIZE);
+ }
+
driver->DatabaseInit = PostgresDatabaseInit;
driver->DatabaseNew = PostgresDatabaseNew;
driver->DatabaseRelease = PostgresDatabaseRelease;
@@ -502,6 +634,7 @@ AdbcStatusCode AdbcDriverInit(int version, void* raw_driver, struct AdbcError* e
driver->StatementRelease = PostgresStatementRelease;
driver->StatementSetOption = PostgresStatementSetOption;
driver->StatementSetSqlQuery = PostgresStatementSetSqlQuery;
+
return ADBC_STATUS_OK;
}
}
diff --git a/c/driver/postgresql/postgresql_test.cc b/c/driver/postgresql/postgresql_test.cc
index 153d8eb221..a4af67add1 100644
--- a/c/driver/postgresql/postgresql_test.cc
+++ b/c/driver/postgresql/postgresql_test.cc
@@ -103,6 +103,8 @@ class PostgresQuirks : public adbc_validation::DriverQuirks {
std::string catalog() const override { return "postgres"; }
std::string db_schema() const override { return "public"; }
+
+ bool supports_execute_schema() const override { return true; }
};
class PostgresDatabaseTest : public ::testing::Test,
@@ -134,10 +136,8 @@ TEST_F(PostgresConnectionTest, GetInfoMetadata) {
adbc_validation::StreamReader reader;
std::vector info = {
- ADBC_INFO_DRIVER_NAME,
- ADBC_INFO_DRIVER_VERSION,
- ADBC_INFO_VENDOR_NAME,
- ADBC_INFO_VENDOR_VERSION,
+ ADBC_INFO_DRIVER_NAME, ADBC_INFO_DRIVER_VERSION, ADBC_INFO_DRIVER_ADBC_VERSION,
+ ADBC_INFO_VENDOR_NAME, ADBC_INFO_VENDOR_VERSION,
};
ASSERT_THAT(AdbcConnectionGetInfo(&connection, info.data(), info.size(),
&reader.stream.value, &error),
@@ -153,29 +153,30 @@ TEST_F(PostgresConnectionTest, GetInfoMetadata) {
ASSERT_FALSE(ArrowArrayViewIsNull(reader.array_view->children[0], row));
const uint32_t code =
reader.array_view->children[0]->buffer_views[1].data.as_uint32[row];
+ const uint32_t offset =
+ reader.array_view->children[1]->buffer_views[1].data.as_int32[row];
seen.push_back(code);
- int str_child_index = 0;
- struct ArrowArrayView* str_child =
- reader.array_view->children[1]->children[str_child_index];
+ struct ArrowArrayView* str_child = reader.array_view->children[1]->children[0];
+ struct ArrowArrayView* int_child = reader.array_view->children[1]->children[2];
switch (code) {
case ADBC_INFO_DRIVER_NAME: {
- ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, 0);
+ ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, offset);
EXPECT_EQ("ADBC PostgreSQL Driver", std::string(val.data, val.size_bytes));
break;
}
case ADBC_INFO_DRIVER_VERSION: {
- ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, 1);
+ ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, offset);
EXPECT_EQ("(unknown)", std::string(val.data, val.size_bytes));
break;
}
case ADBC_INFO_VENDOR_NAME: {
- ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, 2);
+ ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, offset);
EXPECT_EQ("PostgreSQL", std::string(val.data, val.size_bytes));
break;
}
case ADBC_INFO_VENDOR_VERSION: {
- ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, 3);
+ ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, offset);
#ifdef __WIN32
const char* pater = "\\d\\d\\d\\d\\d\\d";
#else
@@ -185,6 +186,10 @@ TEST_F(PostgresConnectionTest, GetInfoMetadata) {
::testing::MatchesRegex(pater));
break;
}
+ case ADBC_INFO_DRIVER_ADBC_VERSION: {
+ EXPECT_EQ(ADBC_VERSION_1_1_0, ArrowArrayViewGetIntUnsafe(int_child, offset));
+ break;
+ }
default:
// Ignored
break;
diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index 4cae15b631..e566dd5d69 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -867,50 +867,12 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream,
// 1. Prepare the query to get the schema
{
- // TODO: we should pipeline here and assume this will succeed
- PGresult* result = PQprepare(connection_->conn(), /*stmtName=*/"", query_.c_str(),
- /*nParams=*/0, nullptr);
- if (PQresultStatus(result) != PGRES_COMMAND_OK) {
- SetError(error,
- "[libpq] Failed to execute query: could not infer schema: failed to "
- "prepare query: %s\nQuery was:%s",
- PQerrorMessage(connection_->conn()), query_.c_str());
- PQclear(result);
- return ADBC_STATUS_IO;
- }
- PQclear(result);
- result = PQdescribePrepared(connection_->conn(), /*stmtName=*/"");
- if (PQresultStatus(result) != PGRES_COMMAND_OK) {
- SetError(error,
- "[libpq] Failed to execute query: could not infer schema: failed to "
- "describe prepared statement: %s\nQuery was:%s",
- PQerrorMessage(connection_->conn()), query_.c_str());
- PQclear(result);
- return ADBC_STATUS_IO;
- }
-
- // Resolve the information from the PGresult into a PostgresType
- PostgresType root_type;
- AdbcStatusCode status =
- ResolvePostgresType(*type_resolver_, result, &root_type, error);
- PQclear(result);
- if (status != ADBC_STATUS_OK) return status;
-
- // Initialize the copy reader and infer the output schema (i.e., error for
- // unsupported types before issuing the COPY query)
- reader_.copy_reader_.reset(new PostgresCopyStreamReader());
- reader_.copy_reader_->Init(root_type);
- struct ArrowError na_error;
- int na_res = reader_.copy_reader_->InferOutputSchema(&na_error);
- if (na_res != NANOARROW_OK) {
- SetError(error, "[libpq] Failed to infer output schema: %s", na_error.message);
- return na_res;
- }
+ RAISE_ADBC(SetupReader(error));
// If the caller did not request a result set or if there are no
// inferred output columns (e.g. a CREATE or UPDATE), then don't
// use COPY (which would fail anyways)
- if (!stream || root_type.n_children() == 0) {
+ if (!stream || reader_.copy_reader_->pg_type().n_children() == 0) {
RAISE_ADBC(ExecuteUpdateQuery(rows_affected, error));
if (stream) {
struct ArrowSchema schema;
@@ -924,7 +886,8 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream,
// This resolves the reader specific to each PostgresType -> ArrowSchema
// conversion. It is unlikely that this will fail given that we have just
// inferred these conversions ourselves.
- na_res = reader_.copy_reader_->InitFieldReaders(&na_error);
+ struct ArrowError na_error;
+ int na_res = reader_.copy_reader_->InitFieldReaders(&na_error);
if (na_res != NANOARROW_OK) {
SetError(error, "[libpq] Failed to initialize field readers: %s", na_error.message);
return na_res;
@@ -953,6 +916,23 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream,
return ADBC_STATUS_OK;
}
+AdbcStatusCode PostgresStatement::ExecuteSchema(struct ArrowSchema* schema,
+ struct AdbcError* error) {
+ ClearResult();
+ if (query_.empty()) {
+ SetError(error, "%s", "[libpq] Must SetSqlQuery before ExecuteQuery");
+ return ADBC_STATUS_INVALID_STATE;
+ } else if (bind_.release) {
+ // TODO: if we have parameters, bind them (since they can affect the output schema)
+ SetError(error, "[libpq] ExecuteSchema with parameters is not implemented");
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+ }
+
+ RAISE_ADBC(SetupReader(error));
+ CHECK_NA(INTERNAL, reader_.copy_reader_->GetSchema(schema), error);
+ return ADBC_STATUS_OK;
+}
+
AdbcStatusCode PostgresStatement::ExecuteUpdateBulk(int64_t* rows_affected,
struct AdbcError* error) {
if (!bind_.release) {
@@ -1070,6 +1050,49 @@ AdbcStatusCode PostgresStatement::SetOption(const char* key, const char* value,
return ADBC_STATUS_OK;
}
+AdbcStatusCode PostgresStatement::SetupReader(struct AdbcError* error) {
+ // TODO: we should pipeline here and assume this will succeed
+ PGresult* result = PQprepare(connection_->conn(), /*stmtName=*/"", query_.c_str(),
+ /*nParams=*/0, nullptr);
+ if (PQresultStatus(result) != PGRES_COMMAND_OK) {
+ SetError(error,
+ "[libpq] Failed to execute query: could not infer schema: failed to "
+ "prepare query: %s\nQuery was:%s",
+ PQerrorMessage(connection_->conn()), query_.c_str());
+ PQclear(result);
+ return ADBC_STATUS_IO;
+ }
+ PQclear(result);
+ result = PQdescribePrepared(connection_->conn(), /*stmtName=*/"");
+ if (PQresultStatus(result) != PGRES_COMMAND_OK) {
+ SetError(error,
+ "[libpq] Failed to execute query: could not infer schema: failed to "
+ "describe prepared statement: %s\nQuery was:%s",
+ PQerrorMessage(connection_->conn()), query_.c_str());
+ PQclear(result);
+ return ADBC_STATUS_IO;
+ }
+
+ // Resolve the information from the PGresult into a PostgresType
+ PostgresType root_type;
+ AdbcStatusCode status = ResolvePostgresType(*type_resolver_, result, &root_type, error);
+ PQclear(result);
+ if (status != ADBC_STATUS_OK) return status;
+
+ // Initialize the copy reader and infer the output schema (i.e., error for
+ // unsupported types before issuing the COPY query)
+ reader_.copy_reader_.reset(new PostgresCopyStreamReader());
+ reader_.copy_reader_->Init(root_type);
+ struct ArrowError na_error;
+ int na_res = reader_.copy_reader_->InferOutputSchema(&na_error);
+ if (na_res != NANOARROW_OK) {
+ SetError(error, "[libpq] Failed to infer output schema: (%d) %s: %s", na_res,
+ std::strerror(na_res), na_error.message);
+ return ADBC_STATUS_INTERNAL;
+ }
+ return ADBC_STATUS_OK;
+}
+
void PostgresStatement::ClearResult() {
// TODO: we may want to synchronize here for safety
reader_.Release();
diff --git a/c/driver/postgresql/statement.h b/c/driver/postgresql/statement.h
index 62af2457d5..c72446d9b7 100644
--- a/c/driver/postgresql/statement.h
+++ b/c/driver/postgresql/statement.h
@@ -102,11 +102,24 @@ class PostgresStatement {
AdbcStatusCode Bind(struct ArrowArrayStream* stream, struct AdbcError* error);
AdbcStatusCode ExecuteQuery(struct ArrowArrayStream* stream, int64_t* rows_affected,
struct AdbcError* error);
+ AdbcStatusCode ExecuteSchema(struct ArrowSchema* schema, struct AdbcError* error);
+ AdbcStatusCode GetOption(const char* option, char* value, size_t* length,
+ struct AdbcError* error);
+ AdbcStatusCode GetOptionBytes(const char* option, uint8_t* value, size_t* length,
+ struct AdbcError* error);
+ AdbcStatusCode GetOptionInt(const char* option, int64_t* value,
+ struct AdbcError* error);
+ AdbcStatusCode GetOptionDouble(const char* option, double* value,
+ struct AdbcError* error);
AdbcStatusCode GetParameterSchema(struct ArrowSchema* schema, struct AdbcError* error);
AdbcStatusCode New(struct AdbcConnection* connection, struct AdbcError* error);
AdbcStatusCode Prepare(struct AdbcError* error);
AdbcStatusCode Release(struct AdbcError* error);
AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error);
+ AdbcStatusCode SetOptionBytes(const char* key, const uint8_t* value, size_t length,
+ struct AdbcError* error);
+ AdbcStatusCode SetOptionInt(const char* key, int64_t value, struct AdbcError* error);
+ AdbcStatusCode SetOptionDouble(const char* key, double value, struct AdbcError* error);
AdbcStatusCode SetSqlQuery(const char* query, struct AdbcError* error);
// ---------------------------------------------------------------------
@@ -122,6 +135,7 @@ class PostgresStatement {
AdbcStatusCode ExecutePreparedStatement(struct ArrowArrayStream* stream,
int64_t* rows_affected,
struct AdbcError* error);
+ AdbcStatusCode SetupReader(struct AdbcError* error);
private:
std::shared_ptr type_resolver_;
diff --git a/c/driver/sqlite/sqlite.c b/c/driver/sqlite/sqlite.c
index 359c594dab..d622ceb054 100644
--- a/c/driver/sqlite/sqlite.c
+++ b/c/driver/sqlite/sqlite.c
@@ -1480,6 +1480,12 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement,
return SqliteStatementExecuteQuery(statement, out, rows_affected, error);
}
+AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement,
+ struct ArrowSchema* schema,
+ struct AdbcError* error) {
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+}
+
AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement,
struct AdbcError* error) {
return SqliteStatementPrepare(statement, error);
diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc
index 803f14c023..358afce9a7 100644
--- a/c/validation/adbc_validation.cc
+++ b/c/validation/adbc_validation.cc
@@ -33,6 +33,7 @@
#include
#include
#include
+#include
#include "adbc_validation_util.h"
@@ -247,6 +248,38 @@ void ConnectionTest::TestAutocommitToggle() {
//------------------------------------------------------------
// Tests of metadata
+std::optional ConnectionGetOption(struct AdbcConnection* connection,
+ std::string_view option,
+ struct AdbcError* error) {
+ char buffer[128];
+ size_t buffer_size = sizeof(buffer);
+ AdbcStatusCode status =
+ AdbcConnectionGetOption(connection, option.data(), buffer, &buffer_size, error);
+ EXPECT_THAT(status, IsOkStatus(error));
+ if (status != ADBC_STATUS_OK) return std::nullopt;
+ EXPECT_GT(buffer_size, 0);
+ if (buffer_size == 0) return std::nullopt;
+ return std::string(buffer, buffer_size - 1);
+}
+
+void ConnectionTest::TestMetadataCurrentCatalog() {
+ ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error));
+ ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error));
+
+ auto catalog =
+ ConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_CURRENT_CATALOG, &error);
+ ASSERT_THAT(catalog, ::testing::Optional(quirks()->catalog()));
+}
+
+void ConnectionTest::TestMetadataCurrentDbSchema() {
+ ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error));
+ ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error));
+
+ auto db_schema =
+ ConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA, &error);
+ ASSERT_THAT(db_schema, ::testing::Optional(quirks()->db_schema()));
+}
+
void ConnectionTest::TestMetadataGetInfo() {
ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error));
@@ -2042,6 +2075,11 @@ void StatementTest::TestTransactions() {
ASSERT_THAT(quirks()->DropTable(&connection, "bulk_ingest", &error),
IsOkStatus(&error));
+ auto autocommit =
+ ConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_AUTOCOMMIT, &error);
+ ASSERT_THAT(autocommit,
+ ::testing::Optional(::testing::StrEq(ADBC_OPTION_VALUE_ENABLED)));
+
Handle connection2;
ASSERT_THAT(AdbcConnectionNew(&connection2.value, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcConnectionInit(&connection2.value, &database, &error),
@@ -2051,6 +2089,11 @@ void StatementTest::TestTransactions() {
ADBC_OPTION_VALUE_DISABLED, &error),
IsOkStatus(&error));
+ autocommit =
+ ConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_AUTOCOMMIT, &error);
+ ASSERT_THAT(autocommit,
+ ::testing::Optional(::testing::StrEq(ADBC_OPTION_VALUE_DISABLED)));
+
// Uncommitted change
ASSERT_NO_FATAL_FAILURE(IngestSampleTable(&connection, &error));
@@ -2126,6 +2169,86 @@ void StatementTest::TestTransactions() {
}
}
+void StatementTest::TestSqlSchemaInts() {
+ if (!quirks()->supports_execute_schema()) {
+ GTEST_SKIP() << "Not supported";
+ }
+
+ ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42", &error),
+ IsOkStatus(&error));
+
+ nanoarrow::UniqueSchema schema;
+ ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error),
+ IsOkStatus(&error));
+
+ ASSERT_EQ(1, schema->n_children);
+ ASSERT_THAT(schema->children[0]->format, ::testing::AnyOfArray({
+ ::testing::StrEq("i"), // int32
+ ::testing::StrEq("l"), // int64
+ }));
+
+ ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
+}
+
+void StatementTest::TestSqlSchemaFloats() {
+ if (!quirks()->supports_execute_schema()) {
+ GTEST_SKIP() << "Not supported";
+ }
+
+ ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT CAST(1.5 AS FLOAT)", &error),
+ IsOkStatus(&error));
+
+ nanoarrow::UniqueSchema schema;
+ ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error),
+ IsOkStatus(&error));
+
+ ASSERT_EQ(1, schema->n_children);
+ ASSERT_THAT(schema->children[0]->format, ::testing::AnyOfArray({
+ ::testing::StrEq("f"), // float32
+ ::testing::StrEq("g"), // float64
+ }));
+
+ ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
+}
+
+void StatementTest::TestSqlSchemaStrings() {
+ if (!quirks()->supports_execute_schema()) {
+ GTEST_SKIP() << "Not supported";
+ }
+
+ ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 'hi'", &error),
+ IsOkStatus(&error));
+
+ nanoarrow::UniqueSchema schema;
+ ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error),
+ IsOkStatus(&error));
+
+ ASSERT_EQ(1, schema->n_children);
+ ASSERT_THAT(schema->children[0]->format, ::testing::AnyOfArray({
+ ::testing::StrEq("u"), // string
+ ::testing::StrEq("U"), // large_string
+ }));
+
+ ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
+}
+
+void StatementTest::TestSqlSchemaErrors() {
+ if (!quirks()->supports_execute_schema()) {
+ GTEST_SKIP() << "Not supported";
+ }
+
+ ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
+
+ nanoarrow::UniqueSchema schema;
+ ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error),
+ IsStatus(ADBC_STATUS_INVALID_STATE, &error));
+
+ ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
+}
+
void StatementTest::TestConcurrentStatements() {
Handle statement1;
Handle statement2;
diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index 2a5883c00f..eca65e6f8b 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -89,6 +89,9 @@ class DriverQuirks {
/// single connection
virtual bool supports_concurrent_statements() const { return false; }
+ /// \brief Whether AdbcStatementExecuteSchema should work
+ virtual bool supports_execute_schema() const { return false; }
+
/// \brief Whether AdbcStatementExecutePartitions should work
virtual bool supports_partitioned_data() const { return false; }
@@ -157,6 +160,9 @@ class ConnectionTest {
void TestAutocommitToggle();
+ void TestMetadataCurrentCatalog();
+ void TestMetadataCurrentDbSchema();
+
void TestMetadataGetInfo();
void TestMetadataGetTableSchema();
void TestMetadataGetTableTypes();
@@ -183,6 +189,8 @@ class ConnectionTest {
TEST_F(FIXTURE, Concurrent) { TestConcurrent(); } \
TEST_F(FIXTURE, AutocommitDefault) { TestAutocommitDefault(); } \
TEST_F(FIXTURE, AutocommitToggle) { TestAutocommitToggle(); } \
+ TEST_F(FIXTURE, MetadataCurrentCatalog) { TestMetadataCurrentCatalog(); } \
+ TEST_F(FIXTURE, MetadataCurrentDbSchema) { TestMetadataCurrentDbSchema(); } \
TEST_F(FIXTURE, MetadataGetInfo) { TestMetadataGetInfo(); } \
TEST_F(FIXTURE, MetadataGetTableSchema) { TestMetadataGetTableSchema(); } \
TEST_F(FIXTURE, MetadataGetTableTypes) { TestMetadataGetTableTypes(); } \
@@ -257,6 +265,12 @@ class StatementTest {
void TestSqlQueryErrors();
+ void TestSqlSchemaInts();
+ void TestSqlSchemaFloats();
+ void TestSqlSchemaStrings();
+
+ void TestSqlSchemaErrors();
+
void TestTransactions();
void TestConcurrentStatements();
@@ -316,6 +330,10 @@ class StatementTest {
TEST_F(FIXTURE, SqlQueryFloats) { TestSqlQueryFloats(); } \
TEST_F(FIXTURE, SqlQueryStrings) { TestSqlQueryStrings(); } \
TEST_F(FIXTURE, SqlQueryErrors) { TestSqlQueryErrors(); } \
+ TEST_F(FIXTURE, SqlSchemaInts) { TestSqlSchemaInts(); } \
+ TEST_F(FIXTURE, SqlSchemaFloats) { TestSqlSchemaFloats(); } \
+ TEST_F(FIXTURE, SqlSchemaStrings) { TestSqlSchemaStrings(); } \
+ TEST_F(FIXTURE, SqlSchemaErrors) { TestSqlSchemaErrors(); } \
TEST_F(FIXTURE, Transactions) { TestTransactions(); } \
TEST_F(FIXTURE, ConcurrentStatements) { TestConcurrentStatements(); } \
TEST_F(FIXTURE, ResultInvalidation) { TestResultInvalidation(); }