Skip to content

Commit

Permalink
more defensive error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
paleolimbot committed Oct 21, 2024
1 parent 5248c34 commit 1c5ba60
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 11 deletions.
11 changes: 11 additions & 0 deletions c/driver/framework/base_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -455,11 +455,22 @@ class Driver {
}

auto error_obj = reinterpret_cast<Status*>(error->private_data);
if (!error_obj) {
return 0;
}
return error_obj->CDetailCount();
}

static AdbcErrorDetail CErrorGetDetail(const AdbcError* error, int index) {
if (error->vendor_code != ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) {
return {nullptr, nullptr, 0};
}

auto error_obj = reinterpret_cast<Status*>(error->private_data);
if (!error_obj) {
return {nullptr, nullptr, 0};
}

return error_obj->CDetail(index);
}

Expand Down
64 changes: 53 additions & 11 deletions c/driver_manager/adbc_driver_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,37 @@ void SetError(struct AdbcError* error, const std::string& message) {
error->release = ReleaseError;
}

// Copies src_error into error and releases src_error
void SetError(struct AdbcError* error, struct AdbcError* src_error) {
if (!error) return;
if (error->release) error->release(error);

if (src_error->message) {
size_t message_size = strlen(src_error->message);
error->message = new char[message_size];
std::memcpy(error->message, src_error->message, message_size);
error->message[message_size] = '\0';
} else {
error->message = nullptr;
}

error->release = ReleaseError;
if (src_error->release) {
src_error->release(src_error);
}
}

struct OwnedError {
struct AdbcError error {
ADBC_ERROR_INIT
};
~OwnedError() {
if (error.release) {
error.release(&error);
}
}
};

// Driver state

/// A driver DLL.
Expand Down Expand Up @@ -666,15 +697,15 @@ std::string AdbcDriverManagerDefaultEntrypoint(const std::string& driver) {

int AdbcErrorGetDetailCount(const struct AdbcError* error) {
if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA && error->private_data &&
error->private_driver) {
error->private_driver && error->private_driver->ErrorGetDetailCount) {
return error->private_driver->ErrorGetDetailCount(error);
}
return 0;
}

struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index) {
if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA && error->private_data &&
error->private_driver) {
error->private_driver && error->private_driver->ErrorGetDetail) {
return error->private_driver->ErrorGetDetail(error, index);
}
return {nullptr, nullptr, 0};
Expand Down Expand Up @@ -900,6 +931,7 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError*
status = AdbcLoadDriver(args->driver.c_str(), nullptr, ADBC_VERSION_1_1_0,
database->private_driver, error);
}

if (status != ADBC_STATUS_OK) {
// Restore private_data so it will be released by AdbcDatabaseRelease
database->private_data = args;
Expand All @@ -910,10 +942,18 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError*
database->private_driver = nullptr;
return status;
}
status = database->private_driver->DatabaseNew(database, error);

// Errors that occur during AdbcDatabaseXXX() refer to the driver via
// the private_driver member; however, after we return we have released
// the driver and inspecting the error might segfault. Here, we scope
// the driver-produced error to this function and make a copy if necessary.
OwnedError driver_error;

status = database->private_driver->DatabaseNew(database, &driver_error.error);
if (status != ADBC_STATUS_OK) {
if (database->private_driver->release) {
database->private_driver->release(database->private_driver, error);
SetError(error, &driver_error.error);
database->private_driver->release(database->private_driver, nullptr);
}
delete database->private_driver;
database->private_driver = nullptr;
Expand All @@ -927,33 +967,34 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError*

INIT_ERROR(error, database);
for (const auto& option : options) {
status = database->private_driver->DatabaseSetOption(database, option.first.c_str(),
option.second.c_str(), error);
status = database->private_driver->DatabaseSetOption(
database, option.first.c_str(), option.second.c_str(), &driver_error.error);
if (status != ADBC_STATUS_OK) break;
}
for (const auto& option : bytes_options) {
status = database->private_driver->DatabaseSetOptionBytes(
database, option.first.c_str(),
reinterpret_cast<const uint8_t*>(option.second.data()), option.second.size(),
error);
&driver_error.error);
if (status != ADBC_STATUS_OK) break;
}
for (const auto& option : int_options) {
status = database->private_driver->DatabaseSetOptionInt(
database, option.first.c_str(), option.second, error);
database, option.first.c_str(), option.second, &driver_error.error);
if (status != ADBC_STATUS_OK) break;
}
for (const auto& option : double_options) {
status = database->private_driver->DatabaseSetOptionDouble(
database, option.first.c_str(), option.second, error);
database, option.first.c_str(), option.second, &driver_error.error);
if (status != ADBC_STATUS_OK) break;
}

if (status != ADBC_STATUS_OK) {
// Release the database
std::ignore = database->private_driver->DatabaseRelease(database, error);
std::ignore = database->private_driver->DatabaseRelease(database, nullptr);
if (database->private_driver->release) {
database->private_driver->release(database->private_driver, error);
SetError(error, &driver_error.error);
database->private_driver->release(database->private_driver, nullptr);
}
delete database->private_driver;
database->private_driver = nullptr;
Expand All @@ -962,6 +1003,7 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError*
database->private_data = nullptr;
return status;
}

return database->private_driver->DatabaseInit(database, error);
}

Expand Down

0 comments on commit 1c5ba60

Please sign in to comment.