Skip to content

Commit

Permalink
fix (c/driver-manager): Protect against uninitialized AdbcError
Browse files Browse the repository at this point in the history
ADBC C API functions should initialize AdbcError struct passed into
them instead of assuming that the caller did so. Given that these
are "output-only" type parameters there is usually no expectation
that they need to be zeroed coming into API calls.
  • Loading branch information
cdaudt committed Mar 31, 2023
1 parent 2b8e429 commit b84c307
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
32 changes: 32 additions & 0 deletions c/driver_manager/adbc_driver_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ void GetWinError(std::string* buffer) {

#endif // defined(_WIN32)

// Struct initializers
static void AdbcErrorInit(struct AdbcError* error) {
std::memset(error, 0, sizeof(*error));
}

// Error handling

void ReleaseError(struct AdbcError* error) {
Expand Down Expand Up @@ -234,6 +239,7 @@ struct TempConnection {
// Direct implementations of API methods

AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) {
AdbcErrorInit(error);
// Allocate a temporary structure to store options pre-Init
database->private_data = new TempDatabase();
database->private_driver = nullptr;
Expand All @@ -242,6 +248,7 @@ AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError*

AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key,
const char* value, struct AdbcError* error) {
AdbcErrorInit(error);
if (database->private_driver) {
return database->private_driver->DatabaseSetOption(database, key, value, error);
}
Expand All @@ -260,6 +267,7 @@ AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char*
AdbcStatusCode AdbcDriverManagerDatabaseSetInitFunc(struct AdbcDatabase* database,
AdbcDriverInitFunc init_func,
struct AdbcError* error) {
AdbcErrorInit(error);
if (database->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
Expand All @@ -270,6 +278,7 @@ AdbcStatusCode AdbcDriverManagerDatabaseSetInitFunc(struct AdbcDatabase* databas
}

AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) {
AdbcErrorInit(error);
if (!database->private_data) {
SetError(error, "Must call AdbcDatabaseNew first");
return ADBC_STATUS_INVALID_STATE;
Expand Down Expand Up @@ -337,6 +346,7 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError*

AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database,
struct AdbcError* error) {
AdbcErrorInit(error);
if (!database->private_driver) {
if (database->private_data) {
TempDatabase* args = reinterpret_cast<TempDatabase*>(database->private_data);
Expand All @@ -358,6 +368,7 @@ AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database,

AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection,
struct AdbcError* error) {
AdbcErrorInit(error);
if (!connection->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
Expand All @@ -368,6 +379,7 @@ AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection,
uint32_t* info_codes, size_t info_codes_length,
struct ArrowArrayStream* out,
struct AdbcError* error) {
AdbcErrorInit(error);
if (!connection->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
Expand All @@ -381,6 +393,7 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d
const char* column_name,
struct ArrowArrayStream* stream,
struct AdbcError* error) {
AdbcErrorInit(error);
if (!connection->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
Expand All @@ -394,6 +407,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection,
const char* table_name,
struct ArrowSchema* schema,
struct AdbcError* error) {
AdbcErrorInit(error);
if (!connection->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
Expand All @@ -404,6 +418,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection,
AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection,
struct ArrowArrayStream* stream,
struct AdbcError* error) {
AdbcErrorInit(error);
if (!connection->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
Expand All @@ -413,6 +428,7 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection,
AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection,
struct AdbcDatabase* database,
struct AdbcError* error) {
AdbcErrorInit(error);
if (!connection->private_data) {
SetError(error, "Must call AdbcConnectionNew first");
return ADBC_STATUS_INVALID_STATE;
Expand All @@ -439,6 +455,7 @@ AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection,

AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection,
struct AdbcError* error) {
AdbcErrorInit(error);
// Allocate a temporary structure to store options pre-Init, because
// we don't get access to the database (and hence the driver
// function table) until then
Expand All @@ -452,6 +469,7 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection,
size_t serialized_length,
struct ArrowArrayStream* out,
struct AdbcError* error) {
AdbcErrorInit(error);
if (!connection->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
Expand All @@ -461,6 +479,7 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection,

AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection,
struct AdbcError* error) {
AdbcErrorInit(error);
if (!connection->private_driver) {
if (connection->private_data) {
TempConnection* args = reinterpret_cast<TempConnection*>(connection->private_data);
Expand All @@ -477,6 +496,7 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection,

AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection,
struct AdbcError* error) {
AdbcErrorInit(error);
if (!connection->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
Expand All @@ -485,6 +505,7 @@ AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection,

AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key,
const char* value, struct AdbcError* error) {
AdbcErrorInit(error);
if (!connection->private_data) {
SetError(error, "AdbcConnectionSetOption: must AdbcConnectionNew first");
return ADBC_STATUS_INVALID_STATE;
Expand All @@ -501,6 +522,7 @@ AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const
AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement,
struct ArrowArray* values, struct ArrowSchema* schema,
struct AdbcError* error) {
AdbcErrorInit(error);
if (!statement->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
Expand All @@ -510,6 +532,7 @@ AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement,
AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement,
struct ArrowArrayStream* stream,
struct AdbcError* error) {
AdbcErrorInit(error);
if (!statement->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
Expand All @@ -522,6 +545,7 @@ AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement,
struct AdbcPartitions* partitions,
int64_t* rows_affected,
struct AdbcError* error) {
AdbcErrorInit(error);
if (!statement->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
Expand All @@ -533,6 +557,7 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement,
struct ArrowArrayStream* out,
int64_t* rows_affected,
struct AdbcError* error) {
AdbcErrorInit(error);
if (!statement->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
Expand All @@ -543,6 +568,7 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement,
AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement,
struct ArrowSchema* schema,
struct AdbcError* error) {
AdbcErrorInit(error);
if (!statement->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
Expand All @@ -552,6 +578,7 @@ AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement,
AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection,
struct AdbcStatement* statement,
struct AdbcError* error) {
AdbcErrorInit(error);
if (!connection->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
Expand All @@ -562,6 +589,7 @@ AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection,

AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement,
struct AdbcError* error) {
AdbcErrorInit(error);
if (!statement->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
Expand All @@ -570,6 +598,7 @@ AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement,

AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement,
struct AdbcError* error) {
AdbcErrorInit(error);
if (!statement->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
Expand All @@ -580,6 +609,7 @@ AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement,

AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const char* key,
const char* value, struct AdbcError* error) {
AdbcErrorInit(error);
if (!statement->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
Expand All @@ -588,6 +618,7 @@ AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const cha

AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement,
const char* query, struct AdbcError* error) {
AdbcErrorInit(error);
if (!statement->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
Expand All @@ -597,6 +628,7 @@ AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement,
AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement,
const uint8_t* plan, size_t length,
struct AdbcError* error) {
AdbcErrorInit(error);
if (!statement->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
Expand Down
23 changes: 23 additions & 0 deletions c/driver_manager/adbc_driver_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,29 @@ TEST_F(DriverManager, DatabaseCustomInitFunc) {
ASSERT_THAT(AdbcDatabaseRelease(&database, &error), IsOkStatus(&error));
}

TEST_F(DriverManager, UninitializedError) {
struct AdbcDatabase database;
struct AdbcError invalid_err;
std::memset(&database, 0, sizeof(database));

// Test out a few return error codes
std::memset(&invalid_err, 0xff, sizeof(invalid_err));
ASSERT_THAT(AdbcDatabaseInit(&database, &invalid_err),
IsStatus(ADBC_STATUS_INVALID_STATE, &invalid_err));

std::memset(&invalid_err, 0xff, sizeof(invalid_err));
ASSERT_THAT(AdbcDatabaseNew(&database, &error), IsOkStatus(&invalid_err));
ASSERT_THAT(
AdbcDatabaseSetOption(&database, "driver", "adbc_driver_sqlite", &invalid_err),
IsOkStatus(&invalid_err));

ASSERT_THAT(AdbcDatabaseInit(&database, &invalid_err), IsOkStatus(&invalid_err));
std::memset(&invalid_err, 0xff, sizeof(invalid_err));
ASSERT_THAT(
AdbcDatabaseSetOption(&database, "notavalidkey", "notavalidvalue", &invalid_err),
IsStatus(ADBC_STATUS_NOT_IMPLEMENTED, &invalid_err));
}

TEST_F(DriverManager, ConnectionOptions) {
struct AdbcDatabase database;
struct AdbcConnection connection;
Expand Down

0 comments on commit b84c307

Please sign in to comment.