diff --git a/r/adbcdrivermanager/src/driver_base.h b/r/adbcdrivermanager/src/driver_base.h index 465ef4dead..df591a1311 100644 --- a/r/adbcdrivermanager/src/driver_base.h +++ b/r/adbcdrivermanager/src/driver_base.h @@ -29,17 +29,17 @@ class Error { int DetailCount() const { return details_.size(); } - AdbcErrorDetail Detail(int index) { - const auto detail = details_[index]; + AdbcErrorDetail Detail(int index) const { + const auto& detail = details_[index]; return {detail.first.c_str(), reinterpret_cast(detail.second.data()), detail.second.size()}; } - void ToAdbc(AdbcError* adbc_error) { + void ToAdbc(AdbcError* adbc_error, AdbcDriver* driver = nullptr) { auto error_owned_by_adbc_error = new Error(message_, details_); adbc_error->message = const_cast(error_owned_by_adbc_error->message_.c_str()); adbc_error->private_data = error_owned_by_adbc_error; - adbc_error->private_driver = nullptr; + adbc_error->private_driver = driver; adbc_error->vendor_code = ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA; for (size_t i = 0; i < 5; i++) { adbc_error->sqlstate[i] = error_owned_by_adbc_error->sql_state_[i]; @@ -148,6 +148,8 @@ class Option { // This class handles option setting and getting. class ObjectBase { public: + ObjectBase() : driver_(nullptr) {} + virtual ~ObjectBase() {} virtual bool OptionKeySupported(const std::string& key, const Option& value) const { @@ -158,11 +160,6 @@ class ObjectBase { virtual AdbcStatusCode Release(AdbcError* error) { return ADBC_STATUS_OK; } - bool HasOption(const std::string& key) { - auto result = options_.find(key); - return result != options_.end(); - } - const Option& GetOption(const std::string& key, const Option& default_value = Option()) const { auto result = options_.find(key); @@ -183,12 +180,15 @@ class ObjectBase { } private: + AdbcDriver* driver_; std::unordered_map options_; // Let the Driver use these to expose C callables wrapping option setters/getters template friend class Driver; + void set_driver(AdbcDriver* driver) { driver_ = driver; } + template AdbcStatusCode CSetOption(const char* key, T value, AdbcError* error) { Option option(value); @@ -247,28 +247,28 @@ class ObjectBase { } } - static void InitErrorNotFound(const char* key, AdbcError* error) { + void InitErrorNotFound(const char* key, AdbcError* error) const { std::stringstream msg_builder; msg_builder << "Option not found for key '" << key << "'"; Error cpperror(msg_builder.str()); cpperror.AddDetail("adbc.r.option_key", key); - cpperror.ToAdbc(error); + cpperror.ToAdbc(error, driver_); } - static void InitErrorWrongType(const char* key, AdbcError* error) { + void InitErrorWrongType(const char* key, AdbcError* error) const { std::stringstream msg_builder; msg_builder << "Wrong type requested for option key '" << key << "'"; Error cpperror(msg_builder.str()); cpperror.AddDetail("adbc.r.option_key", key); - cpperror.ToAdbc(error); + cpperror.ToAdbc(error, driver_); } - static void InitErrorOptionNotSupported(const char* key, AdbcError* error) { + void InitErrorOptionNotSupported(const char* key, AdbcError* error) const { std::stringstream msg_builder; msg_builder << "Option '" << key << "' is not supported"; Error cpperror(msg_builder.str()); cpperror.AddDetail("adbc.r.option_key", key); - cpperror.ToAdbc(error); + cpperror.ToAdbc(error, driver_); } }; @@ -293,7 +293,7 @@ class StatementObjectBase : public ObjectBase { }; template -class Driver final { +class Driver { public: static AdbcStatusCode Init(int version, void* raw_driver, AdbcError* error) { if (version != ADBC_VERSION_1_1_0) return ADBC_STATUS_NOT_IMPLEMENTED; @@ -460,13 +460,15 @@ class Driver final { // Database trampolines static AdbcStatusCode CDatabaseInit(AdbcDatabase* database, AdbcError* error) { auto private_data = reinterpret_cast(database->private_data); - return private_data->Init(database->private_driver, error); + private_data->set_driver(database->private_driver); + return private_data->Init(database->private_driver->private_data, error); } // Connection trampolines static AdbcStatusCode CConnectionInit(AdbcConnection* connection, AdbcDatabase* database, AdbcError* error) { auto private_data = reinterpret_cast(connection->private_data); + private_data->set_driver(connection->private_driver); return private_data->Init(database->private_data, error); } @@ -474,6 +476,7 @@ class Driver final { static AdbcStatusCode CStatementNew(AdbcConnection* connection, AdbcStatement* statement, AdbcError* error) { auto private_data = new StatementT(); + private_data->set_driver(connection->private_driver); AdbcStatusCode status = private_data->Init(connection->private_data, error); if (status != ADBC_STATUS_OK) { delete private_data; diff --git a/r/adbcdrivermanager/tests/testthat/test-error.R b/r/adbcdrivermanager/tests/testthat/test-error.R index 15a9c21e61..15e9f4e92a 100644 --- a/r/adbcdrivermanager/tests/testthat/test-error.R +++ b/r/adbcdrivermanager/tests/testthat/test-error.R @@ -46,13 +46,16 @@ test_that("stop_for_error() gives a custom error class with extra info", { had_error <- FALSE tryCatch({ db <- adbc_database_init(adbc_driver_void()) - adbc_database_release(db) - adbc_database_release(db) + adbc_database_get_option(db, "this option does not exist") }, adbc_status = function(e) { had_error <<- TRUE expect_s3_class(e, "adbc_status") - expect_s3_class(e, "adbc_status_invalid_state") - expect_identical(e$error$status, 6L) + expect_s3_class(e, "adbc_status_not_found") + expect_identical(e$error$status, 3L) + expect_identical( + e$error$detail[["adbc.r.option_key"]], + charToRaw("this option does not exist") + ) }) expect_true(had_error)