From b84c307b1c49879296b11bce95b2e65c6b09ecb1 Mon Sep 17 00:00:00 2001 From: Christian Daudt Date: Fri, 31 Mar 2023 09:52:17 -0700 Subject: [PATCH] fix (c/driver-manager): Protect against uninitialized AdbcError ADBC C API functions should initialize AdbcError struct passed into them instead of assuming that the caller did so. Given that these are "output-only" type parameters there is usually no expectation that they need to be zeroed coming into API calls. --- c/driver_manager/adbc_driver_manager.cc | 32 ++++++++++++++++++++ c/driver_manager/adbc_driver_manager_test.cc | 23 ++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/c/driver_manager/adbc_driver_manager.cc b/c/driver_manager/adbc_driver_manager.cc index c63560a40e..b80e470f3a 100644 --- a/c/driver_manager/adbc_driver_manager.cc +++ b/c/driver_manager/adbc_driver_manager.cc @@ -57,6 +57,11 @@ void GetWinError(std::string* buffer) { #endif // defined(_WIN32) +// Struct initializers +static void AdbcErrorInit(struct AdbcError* error) { + std::memset(error, 0, sizeof(*error)); +} + // Error handling void ReleaseError(struct AdbcError* error) { @@ -234,6 +239,7 @@ struct TempConnection { // Direct implementations of API methods AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { + AdbcErrorInit(error); // Allocate a temporary structure to store options pre-Init database->private_data = new TempDatabase(); database->private_driver = nullptr; @@ -242,6 +248,7 @@ AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, const char* value, struct AdbcError* error) { + AdbcErrorInit(error); if (database->private_driver) { return database->private_driver->DatabaseSetOption(database, key, value, error); } @@ -260,6 +267,7 @@ AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* AdbcStatusCode AdbcDriverManagerDatabaseSetInitFunc(struct AdbcDatabase* database, AdbcDriverInitFunc init_func, struct AdbcError* error) { + AdbcErrorInit(error); if (database->private_driver) { return ADBC_STATUS_INVALID_STATE; } @@ -270,6 +278,7 @@ AdbcStatusCode AdbcDriverManagerDatabaseSetInitFunc(struct AdbcDatabase* databas } AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) { + AdbcErrorInit(error); if (!database->private_data) { SetError(error, "Must call AdbcDatabaseNew first"); return ADBC_STATUS_INVALID_STATE; @@ -337,6 +346,7 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, struct AdbcError* error) { + AdbcErrorInit(error); if (!database->private_driver) { if (database->private_data) { TempDatabase* args = reinterpret_cast(database->private_data); @@ -358,6 +368,7 @@ AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, struct AdbcError* error) { + AdbcErrorInit(error); if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } @@ -368,6 +379,7 @@ AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error) { + AdbcErrorInit(error); if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } @@ -381,6 +393,7 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d const char* column_name, struct ArrowArrayStream* stream, struct AdbcError* error) { + AdbcErrorInit(error); if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } @@ -394,6 +407,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, const char* table_name, struct ArrowSchema* schema, struct AdbcError* error) { + AdbcErrorInit(error); if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } @@ -404,6 +418,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, struct ArrowArrayStream* stream, struct AdbcError* error) { + AdbcErrorInit(error); if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } @@ -413,6 +428,7 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, struct AdbcDatabase* database, struct AdbcError* error) { + AdbcErrorInit(error); if (!connection->private_data) { SetError(error, "Must call AdbcConnectionNew first"); return ADBC_STATUS_INVALID_STATE; @@ -439,6 +455,7 @@ AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, struct AdbcError* error) { + AdbcErrorInit(error); // Allocate a temporary structure to store options pre-Init, because // we don't get access to the database (and hence the driver // function table) until then @@ -452,6 +469,7 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, size_t serialized_length, struct ArrowArrayStream* out, struct AdbcError* error) { + AdbcErrorInit(error); if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } @@ -461,6 +479,7 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, struct AdbcError* error) { + AdbcErrorInit(error); if (!connection->private_driver) { if (connection->private_data) { TempConnection* args = reinterpret_cast(connection->private_data); @@ -477,6 +496,7 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, struct AdbcError* error) { + AdbcErrorInit(error); if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } @@ -485,6 +505,7 @@ AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, const char* value, struct AdbcError* error) { + AdbcErrorInit(error); if (!connection->private_data) { SetError(error, "AdbcConnectionSetOption: must AdbcConnectionNew first"); return ADBC_STATUS_INVALID_STATE; @@ -501,6 +522,7 @@ AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement, struct ArrowArray* values, struct ArrowSchema* schema, struct AdbcError* error) { + AdbcErrorInit(error); if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } @@ -510,6 +532,7 @@ AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement, AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, struct ArrowArrayStream* stream, struct AdbcError* error) { + AdbcErrorInit(error); if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } @@ -522,6 +545,7 @@ AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, struct AdbcPartitions* partitions, int64_t* rows_affected, struct AdbcError* error) { + AdbcErrorInit(error); if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } @@ -533,6 +557,7 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, struct ArrowArrayStream* out, int64_t* rows_affected, struct AdbcError* error) { + AdbcErrorInit(error); if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } @@ -543,6 +568,7 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcError* error) { + AdbcErrorInit(error); if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } @@ -552,6 +578,7 @@ AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, struct AdbcStatement* statement, struct AdbcError* error) { + AdbcErrorInit(error); if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } @@ -562,6 +589,7 @@ AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, struct AdbcError* error) { + AdbcErrorInit(error); if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } @@ -570,6 +598,7 @@ AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, struct AdbcError* error) { + AdbcErrorInit(error); if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } @@ -580,6 +609,7 @@ AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const char* key, const char* value, struct AdbcError* error) { + AdbcErrorInit(error); if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } @@ -588,6 +618,7 @@ AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const cha AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, const char* query, struct AdbcError* error) { + AdbcErrorInit(error); if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } @@ -597,6 +628,7 @@ AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, const uint8_t* plan, size_t length, struct AdbcError* error) { + AdbcErrorInit(error); if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } diff --git a/c/driver_manager/adbc_driver_manager_test.cc b/c/driver_manager/adbc_driver_manager_test.cc index 99fa477bfa..57c65256ab 100644 --- a/c/driver_manager/adbc_driver_manager_test.cc +++ b/c/driver_manager/adbc_driver_manager_test.cc @@ -86,6 +86,29 @@ TEST_F(DriverManager, DatabaseCustomInitFunc) { ASSERT_THAT(AdbcDatabaseRelease(&database, &error), IsOkStatus(&error)); } +TEST_F(DriverManager, UninitializedError) { + struct AdbcDatabase database; + struct AdbcError invalid_err; + std::memset(&database, 0, sizeof(database)); + + // Test out a few return error codes + std::memset(&invalid_err, 0xff, sizeof(invalid_err)); + ASSERT_THAT(AdbcDatabaseInit(&database, &invalid_err), + IsStatus(ADBC_STATUS_INVALID_STATE, &invalid_err)); + + std::memset(&invalid_err, 0xff, sizeof(invalid_err)); + ASSERT_THAT(AdbcDatabaseNew(&database, &error), IsOkStatus(&invalid_err)); + ASSERT_THAT( + AdbcDatabaseSetOption(&database, "driver", "adbc_driver_sqlite", &invalid_err), + IsOkStatus(&invalid_err)); + + ASSERT_THAT(AdbcDatabaseInit(&database, &invalid_err), IsOkStatus(&invalid_err)); + std::memset(&invalid_err, 0xff, sizeof(invalid_err)); + ASSERT_THAT( + AdbcDatabaseSetOption(&database, "notavalidkey", "notavalidvalue", &invalid_err), + IsStatus(ADBC_STATUS_NOT_IMPLEMENTED, &invalid_err)); +} + TEST_F(DriverManager, ConnectionOptions) { struct AdbcDatabase database; struct AdbcConnection connection;