From a57535d98573c35f636c194e74fa32be0f02383a Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 17 Oct 2024 17:25:12 -0500 Subject: [PATCH 1/9] start client --- c/driver/framework/base_client.h | 128 +++++++++++++++++++++++++++++++ c/driver/framework/status.h | 1 - 2 files changed, 128 insertions(+), 1 deletion(-) create mode 100644 c/driver/framework/base_client.h diff --git a/c/driver/framework/base_client.h b/c/driver/framework/base_client.h new file mode 100644 index 0000000000..d55ec6a6b0 --- /dev/null +++ b/c/driver/framework/base_client.h @@ -0,0 +1,128 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "driver/framework/status.h" + +namespace adbc::client { + +class BaseDriver; +class BaseDatabase; +class BaseConnection; +class BaseStatement; + +using adbc::driver::Result; +using adbc::driver::Status; + +#define WRAP_CALL(func, ...) (driver_->func(&statement_, __VA_ARGS__, &error_)) +#define WRAP_CALL0(func) (driver_->func(&statement_, &error_)) + +class BaseObject { + protected: + BaseObject(AdbcDriver* driver) : driver_(driver) {} + AdbcDriver* driver_{nullptr}; + AdbcError error_{ADBC_ERROR_INIT}; +}; + +class BaseStatement : BaseObject { + friend class BaseConnection; + + public: + BaseStatement(AdbcDriver* driver) : BaseObject(driver) {} + ~BaseStatement() { + if (!statement_.private_data) { + return; + } + + AdbcStatusCode code = WRAP_CALL0(StatementRelease); + if (code != ADBC_STATUS_OK) { + // TODO: Register with connection or context + } + } + + Status SetSqlQuery(const std::string& query) { + AdbcStatusCode code = WRAP_CALL(StatementSetSqlQuery, query.c_str()); + return Status::FromAdbc(code, error_); + } + + Result ExecuteQuery(ArrowArrayStream* stream) { + int64_t affected_rows = -1; + AdbcStatusCode code = WRAP_CALL(StatementExecuteQuery, stream, &affected_rows); + return Status::FromAdbc(code, error_); + } + + private: + AdbcStatement statement_; +}; + +#undef WRAP_CALL +#undef WRAP_CALL0 +#define WRAP_CALL(func, ...) (driver_->func(&connection_, __VA_ARGS__, &error_)) +#define WRAP_CALL0(func) (driver_->func(&connection_, &error_)) + +class BaseConnection : BaseObject { + public: + BaseConnection(AdbcDriver* driver) : BaseObject(driver) {} + ~BaseConnection() { + AdbcStatusCode code = driver_->ConnectionRelease(&connection_, &error_); + if (code != ADBC_STATUS_OK) { + // TODO: Register with database or context + } + } + + Result NewStatement() { + BaseStatement out(driver_); + AdbcStatusCode code = driver_->StatementNew(&connection_, &out.statement_, &error_); + if (code != ADBC_STATUS_OK) { + return Status::FromAdbc(code, error_); + } + + return out; + } + + Status Cancel(const std::string& query) { + AdbcStatusCode code = WRAP_CALL0(ConnectionCancel); + return Status::FromAdbc(code, error_); + } + + Status GetInfo(const uint32_t* info_codes, size_t n_info_codes, ArrowArrayStream* out) { + AdbcStatusCode code = WRAP_CALL(ConnectionGetInfo, info_codes, n_info_codes, out); + return Status::FromAdbc(code, error_); + } + + private: + AdbcConnection connection_; +}; + +#undef WRAP_CALL +#undef WRAP_CALL0 + +class BaseDriver { + public: + Status Init(AdbcDriverInitFunc init_func) { + return Status::FromAdbc(init_func(ADBC_VERSION_1_1_0, &driver_, &error_), error_); + } + + private: + AdbcDriver driver_{}; + AdbcError error_{ADBC_ERROR_INIT}; +}; + +} // namespace adbc::client \ No newline at end of file diff --git a/c/driver/framework/status.h b/c/driver/framework/status.h index cfdca6ebbe..677cef15c0 100644 --- a/c/driver/framework/status.h +++ b/c/driver/framework/status.h @@ -111,7 +111,6 @@ class Status { } static Status FromAdbc(AdbcStatusCode code, AdbcError& error) { - // not really meant to be used, just something we have for now while porting if (code == ADBC_STATUS_OK) { if (error.release) { error.release(&error); From b725c67d38ac72a1eeaa4b1293bf71a7ba574806 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 23 Oct 2024 00:21:36 -0500 Subject: [PATCH 2/9] heiarchy --- c/driver/framework/base_client.h | 231 ++++++++++++++++++++++++++----- 1 file changed, 193 insertions(+), 38 deletions(-) diff --git a/c/driver/framework/base_client.h b/c/driver/framework/base_client.h index d55ec6a6b0..d0c1c4dd9d 100644 --- a/c/driver/framework/base_client.h +++ b/c/driver/framework/base_client.h @@ -17,6 +17,9 @@ #pragma once +#include +#include + #include #include "driver/framework/status.h" @@ -31,98 +34,250 @@ class BaseStatement; using adbc::driver::Result; using adbc::driver::Status; -#define WRAP_CALL(func, ...) (driver_->func(&statement_, __VA_ARGS__, &error_)) -#define WRAP_CALL0(func) (driver_->func(&statement_, &error_)) +using SharedDriver = std::shared_ptr; +using SharedDatabase = std::shared_ptr; +using SharedConnection = std::shared_ptr; +using SharedStatement = std::shared_ptr; + +class Context { + public: + enum class LogLevel { kInfo, kWarn }; + + virtual void OnUnreleasableStatement(SharedConnection connection, + AdbcStatement* statement, const AdbcError* error) { + Log(LogLevel::kWarn, "leaking unreleasable statement"); + } + + virtual void OnUnreleasableConnection(SharedDatabase database, + AdbcConnection* connection, + const AdbcError* error) { + Log(LogLevel::kWarn, "leaking unreleasable connection"); + } + + virtual void OnUnreleaseableDatabase(SharedDriver driver, AdbcDatabase* database, + const AdbcError* error) { + Log(LogLevel::kWarn, "leaking unreleasable database"); + } + + virtual void OnUnreleaseableDriver(AdbcDriver* driver, const AdbcError* error) { + Log(LogLevel::kWarn, "leaking unreleasable driver"); + } + + virtual void Log(LogLevel level, std::string_view message) {} + + static std::shared_ptr Default() { return std::make_unique(); } +}; + +using SharedContext = std::shared_ptr; + +class BaseDriver { + public: + BaseDriver(SharedContext context = std::make_unique()) : context_(context) {} + Status Init(AdbcDriverInitFunc init_func) { + return Status::FromAdbc(init_func(ADBC_VERSION_1_1_0, &driver_, &error_), error_); + } + + AdbcDriver* driver() { return &driver_; } + + Context* context() { return context_.get(); } + + private: + SharedContext context_; + AdbcDriver driver_{}; + AdbcError error_{ADBC_ERROR_INIT}; +}; class BaseObject { + public: + const SharedDriver& GetSharedDriver() { return driver_; } + protected: - BaseObject(AdbcDriver* driver) : driver_(driver) {} - AdbcDriver* driver_{nullptr}; + SharedDriver driver_; AdbcError error_{ADBC_ERROR_INIT}; + + void NewBase(SharedDriver driver) { driver_ = driver; } + void ReleaseBase() { driver_.reset(); } + AdbcDriver* driver() { return driver_->driver(); } + Context* context() { return driver_->context(); } }; -class BaseStatement : BaseObject { - friend class BaseConnection; +#define WRAP_CALL(func, ...) (driver()->func(&database_, __VA_ARGS__, &error_)) +#define WRAP_CALL0(func) (driver()->func(&database_, &error_)) +class BaseDatabase : public BaseObject { public: - BaseStatement(AdbcDriver* driver) : BaseObject(driver) {} - ~BaseStatement() { - if (!statement_.private_data) { - return; + ~BaseDatabase() { + if (driver_ && database_.private_data) { + AdbcStatusCode code = WRAP_CALL0(DatabaseRelease); + if (code != ADBC_STATUS_OK) { + context()->OnUnreleaseableDatabase(driver_, &database_, &error_); + } } + } - AdbcStatusCode code = WRAP_CALL0(StatementRelease); - if (code != ADBC_STATUS_OK) { - // TODO: Register with connection or context - } + AdbcDatabase* database() { return &database_; } + + Status New(SharedDriver parent) { + NewBase(std::move(parent)); + AdbcStatusCode code = WRAP_CALL0(DatabaseNew); + return Status::FromAdbc(code, error_); } - Status SetSqlQuery(const std::string& query) { - AdbcStatusCode code = WRAP_CALL(StatementSetSqlQuery, query.c_str()); + Status Init() { + UNWRAP_STATUS(CheckValid()); + AdbcStatusCode code = WRAP_CALL0(DatabaseInit); return Status::FromAdbc(code, error_); } - Result ExecuteQuery(ArrowArrayStream* stream) { - int64_t affected_rows = -1; - AdbcStatusCode code = WRAP_CALL(StatementExecuteQuery, stream, &affected_rows); + Status Release() { + AdbcStatusCode code = WRAP_CALL0(DatabaseRelease); + if (code == ADBC_STATUS_OK) { + ReleaseBase(); + std::memset(&database_, 0, sizeof(database_)); + } + return Status::FromAdbc(code, error_); } private: - AdbcStatement statement_; + AdbcDatabase database_{}; + + Status CheckValid() { + if (driver_) { + return Status::Ok(); + } else { + return driver::status::InvalidState("BaseDatabase is not valid"); + } + } }; #undef WRAP_CALL #undef WRAP_CALL0 -#define WRAP_CALL(func, ...) (driver_->func(&connection_, __VA_ARGS__, &error_)) -#define WRAP_CALL0(func) (driver_->func(&connection_, &error_)) -class BaseConnection : BaseObject { +#define WRAP_CALL(func, ...) (driver()->func(&connection_, __VA_ARGS__, &error_)) +#define WRAP_CALL0(func) (driver()->func(&connection_, &error_)) + +class BaseConnection : public BaseObject { public: - BaseConnection(AdbcDriver* driver) : BaseObject(driver) {} ~BaseConnection() { - AdbcStatusCode code = driver_->ConnectionRelease(&connection_, &error_); + AdbcStatusCode code = WRAP_CALL0(ConnectionRelease); if (code != ADBC_STATUS_OK) { - // TODO: Register with database or context + context()->OnUnreleasableConnection(database_, &connection_, &error_); } } - Result NewStatement() { - BaseStatement out(driver_); - AdbcStatusCode code = driver_->StatementNew(&connection_, &out.statement_, &error_); - if (code != ADBC_STATUS_OK) { - return Status::FromAdbc(code, error_); + AdbcConnection* connection() { return &connection_; } + + Status New(SharedDatabase database) { + NewBase(database->GetSharedDriver()); + AdbcStatusCode code = WRAP_CALL0(ConnectionNew); + return Status::FromAdbc(code, error_); + } + + Status Init() { + UNWRAP_STATUS(CheckValid()); + AdbcStatusCode code = WRAP_CALL(ConnectionInit, database_->database()); + return Status::FromAdbc(code, error_); + } + + Status Release() { + AdbcStatusCode code = WRAP_CALL0(ConnectionRelease); + if (code == ADBC_STATUS_OK) { + ReleaseBase(); + database_.reset(); + std::memset(&connection_, 0, sizeof(connection_)); } - return out; + return Status::FromAdbc(code, error_); } Status Cancel(const std::string& query) { + UNWRAP_STATUS(CheckValid()); AdbcStatusCode code = WRAP_CALL0(ConnectionCancel); return Status::FromAdbc(code, error_); } Status GetInfo(const uint32_t* info_codes, size_t n_info_codes, ArrowArrayStream* out) { + UNWRAP_STATUS(CheckValid()); AdbcStatusCode code = WRAP_CALL(ConnectionGetInfo, info_codes, n_info_codes, out); return Status::FromAdbc(code, error_); } private: - AdbcConnection connection_; + SharedDatabase database_; + AdbcConnection connection_{}; + + Status CheckValid() { + if (driver_ && database_ && connection_.private_data) { + return Status::Ok(); + } else { + return driver::status::InvalidState("BaseConnection is not valid"); + } + } }; #undef WRAP_CALL #undef WRAP_CALL0 -class BaseDriver { +#define WRAP_CALL(func, ...) (driver_->driver()->func(&statement_, __VA_ARGS__, &error_)) +#define WRAP_CALL0(func) (driver_->driver()->func(&statement_, &error_)) + +class BaseStatement : BaseObject { public: - Status Init(AdbcDriverInitFunc init_func) { - return Status::FromAdbc(init_func(ADBC_VERSION_1_1_0, &driver_, &error_), error_); + ~BaseStatement() { + if (driver_ && statement_.private_data) { + AdbcStatusCode code = WRAP_CALL0(StatementRelease); + if (code != ADBC_STATUS_OK) { + context()->OnUnreleasableStatement(connection_, &statement_, &error_); + } + } + } + + Status New(SharedConnection connection) { + NewBase(connection->GetSharedDriver()); + AdbcStatusCode code = + driver()->StatementNew(connection->connection(), &statement_, &error_); + return Status::FromAdbc(code, error_); + } + + Status Release() { + AdbcStatusCode code = WRAP_CALL0(StatementRelease); + if (code == ADBC_STATUS_OK) { + ReleaseBase(); + connection_.reset(); + std::memset(&statement_, 0, sizeof(statement_)); + } + + return Status::FromAdbc(code, error_); + } + + Status SetSqlQuery(const std::string& query) { + UNWRAP_STATUS(CheckValid()); + AdbcStatusCode code = WRAP_CALL(StatementSetSqlQuery, query.c_str()); + return Status::FromAdbc(code, error_); + } + + Result ExecuteQuery(ArrowArrayStream* stream) { + UNWRAP_STATUS(CheckValid()); + int64_t affected_rows = -1; + AdbcStatusCode code = WRAP_CALL(StatementExecuteQuery, stream, &affected_rows); + return Status::FromAdbc(code, error_); } private: - AdbcDriver driver_{}; - AdbcError error_{ADBC_ERROR_INIT}; + SharedConnection connection_; + AdbcStatement statement_{}; + + Status CheckValid() { + if (driver_ && connection_ && statement_.private_data) { + return Status::Ok(); + } else { + return driver::status::InvalidState("BaseStatement is not valid"); + } + } }; +#undef WRAP_CALL +#undef WRAP_CALL0 + } // namespace adbc::client \ No newline at end of file From 3b7a1320c83ca041101857b7a74b7fb68661648b Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 23 Oct 2024 01:15:00 -0500 Subject: [PATCH 3/9] a go at the higher level --- c/driver/framework/base_client.h | 198 ++++++++++++++++++++++++++----- 1 file changed, 170 insertions(+), 28 deletions(-) diff --git a/c/driver/framework/base_client.h b/c/driver/framework/base_client.h index d0c1c4dd9d..32ce3052bb 100644 --- a/c/driver/framework/base_client.h +++ b/c/driver/framework/base_client.h @@ -17,7 +17,6 @@ #pragma once -#include #include #include @@ -26,36 +25,34 @@ namespace adbc::client { +namespace internal { class BaseDriver; class BaseDatabase; class BaseConnection; class BaseStatement; +} // namespace internal using adbc::driver::Result; using adbc::driver::Status; -using SharedDriver = std::shared_ptr; -using SharedDatabase = std::shared_ptr; -using SharedConnection = std::shared_ptr; -using SharedStatement = std::shared_ptr; - class Context { public: enum class LogLevel { kInfo, kWarn }; - virtual void OnUnreleasableStatement(SharedConnection connection, - AdbcStatement* statement, const AdbcError* error) { + virtual void OnUnreleasableStatement( + std::shared_ptr connection, AdbcStatement* statement, + const AdbcError* error) { Log(LogLevel::kWarn, "leaking unreleasable statement"); } - virtual void OnUnreleasableConnection(SharedDatabase database, + virtual void OnUnreleasableConnection(std::shared_ptr database, AdbcConnection* connection, const AdbcError* error) { Log(LogLevel::kWarn, "leaking unreleasable connection"); } - virtual void OnUnreleaseableDatabase(SharedDriver driver, AdbcDatabase* database, - const AdbcError* error) { + virtual void OnUnreleaseableDatabase(std::shared_ptr driver, + AdbcDatabase* database, const AdbcError* error) { Log(LogLevel::kWarn, "leaking unreleasable database"); } @@ -68,11 +65,12 @@ class Context { static std::shared_ptr Default() { return std::make_unique(); } }; -using SharedContext = std::shared_ptr; +namespace internal { class BaseDriver { public: - BaseDriver(SharedContext context = std::make_unique()) : context_(context) {} + BaseDriver(std::shared_ptr context = std::make_unique()) + : context_(context) {} Status Init(AdbcDriverInitFunc init_func) { return Status::FromAdbc(init_func(ADBC_VERSION_1_1_0, &driver_, &error_), error_); } @@ -82,20 +80,20 @@ class BaseDriver { Context* context() { return context_.get(); } private: - SharedContext context_; + std::shared_ptr context_; AdbcDriver driver_{}; AdbcError error_{ADBC_ERROR_INIT}; }; class BaseObject { public: - const SharedDriver& GetSharedDriver() { return driver_; } + const std::shared_ptr& GetSharedDriver() { return driver_; } protected: - SharedDriver driver_; + std::shared_ptr driver_; AdbcError error_{ADBC_ERROR_INIT}; - void NewBase(SharedDriver driver) { driver_ = driver; } + void NewBase(std::shared_ptr driver) { driver_ = driver; } void ReleaseBase() { driver_.reset(); } AdbcDriver* driver() { return driver_->driver(); } Context* context() { return driver_->context(); } @@ -117,7 +115,7 @@ class BaseDatabase : public BaseObject { AdbcDatabase* database() { return &database_; } - Status New(SharedDriver parent) { + Status New(std::shared_ptr parent) { NewBase(std::move(parent)); AdbcStatusCode code = WRAP_CALL0(DatabaseNew); return Status::FromAdbc(code, error_); @@ -168,7 +166,7 @@ class BaseConnection : public BaseObject { AdbcConnection* connection() { return &connection_; } - Status New(SharedDatabase database) { + Status New(std::shared_ptr database) { NewBase(database->GetSharedDriver()); AdbcStatusCode code = WRAP_CALL0(ConnectionNew); return Status::FromAdbc(code, error_); @@ -191,7 +189,7 @@ class BaseConnection : public BaseObject { return Status::FromAdbc(code, error_); } - Status Cancel(const std::string& query) { + Status Cancel() { UNWRAP_STATUS(CheckValid()); AdbcStatusCode code = WRAP_CALL0(ConnectionCancel); return Status::FromAdbc(code, error_); @@ -204,7 +202,7 @@ class BaseConnection : public BaseObject { } private: - SharedDatabase database_; + std::shared_ptr database_; AdbcConnection connection_{}; Status CheckValid() { @@ -233,7 +231,7 @@ class BaseStatement : BaseObject { } } - Status New(SharedConnection connection) { + Status New(std::shared_ptr connection) { NewBase(connection->GetSharedDriver()); AdbcStatusCode code = driver()->StatementNew(connection->connection(), &statement_, &error_); @@ -251,21 +249,20 @@ class BaseStatement : BaseObject { return Status::FromAdbc(code, error_); } - Status SetSqlQuery(const std::string& query) { + Status SetSqlQuery(const char* query) { UNWRAP_STATUS(CheckValid()); - AdbcStatusCode code = WRAP_CALL(StatementSetSqlQuery, query.c_str()); + AdbcStatusCode code = WRAP_CALL(StatementSetSqlQuery, query); return Status::FromAdbc(code, error_); } - Result ExecuteQuery(ArrowArrayStream* stream) { + Status ExecuteQuery(ArrowArrayStream* stream, int64_t* affected_rows) { UNWRAP_STATUS(CheckValid()); - int64_t affected_rows = -1; - AdbcStatusCode code = WRAP_CALL(StatementExecuteQuery, stream, &affected_rows); + AdbcStatusCode code = WRAP_CALL(StatementExecuteQuery, stream, affected_rows); return Status::FromAdbc(code, error_); } private: - SharedConnection connection_; + std::shared_ptr connection_; AdbcStatement statement_{}; Status CheckValid() { @@ -280,4 +277,149 @@ class BaseStatement : BaseObject { #undef WRAP_CALL #undef WRAP_CALL0 +} // namespace internal + +template +class Stream { + public: + explicit Stream(Parent parent) : parent_(parent) {} + Stream(Stream&& rhs) : Stream(rhs.get()) { + parent_ = std::move(rhs.parent_); + std::memcpy(&stream_, &rhs.stream_, sizeof(ArrowArrayStream)); + std::memset(rhs.stream_, 0, sizeof(ArrowArrayStream)); + rows_affected_ = rhs.rows_affected_; + } + + Stream& operator=(Stream&& rhs) { + parent_ = std::move(rhs.parent_); + std::memcpy(&stream_, &rhs.stream_, sizeof(ArrowArrayStream)); + std::memset(rhs.stream_, 0, sizeof(ArrowArrayStream)); + rows_affected_ = rhs.rows_affected_; + return *this; + } + + Stream(const Stream& rhs) = delete; + + ArrowArrayStream* stream() { return &stream_; } + + int64_t* mutable_rows_affected() { return &rows_affected_; } + + ~Stream() { + if (stream_.release) { + stream_.release(&stream_); + } + } + + void Export(ArrowArrayStream* out) { + Stream* instance = new Stream(); + instance->parent_ = std::move(parent_); + std::memcpy(&instance->stream_, &stream_, sizeof(ArrowArrayStream)); + std::memset(stream_, 0, sizeof(ArrowArrayStream)); + instance->rows_affected_ = rows_affected_; + + out->get_schema = &CGetSchema; + out->get_next = &CGetNext; + out->get_last_error = &CGetLastError; + out->release = &CRelease; + out->private_data = instance; + } + + private: + Parent parent_; + ArrowArrayStream stream_{}; + int64_t rows_affected_{-1}; + + static int CGetSchema(ArrowArrayStream* stream, ArrowSchema* schema) { + return reinterpret_cast(stream->private_data)->GetSchema(schema); + } + + static int CGetNext(ArrowArrayStream* stream, ArrowArray* array) { + return reinterpret_cast(stream->private_data)->GetNext(array); + } + + static const char* CGetLastError(ArrowArrayStream* stream) { + return reinterpret_cast(stream->private_data)->GetLastError(); + } + + static void CRelease(ArrowArrayStream* stream) { + delete reinterpret_cast(stream->private_data); + stream->release = nullptr; + stream->private_data = nullptr; + } +}; + +using ConnectionStream = Stream>; +using StatementStream = Stream>; + +class Statement { + public: + Statement(std::shared_ptr base) : base_(base) {} + + Status SetSqlQuery(const std::string& query) { + return base_->SetSqlQuery(query.c_str()); + } + + Result ExecuteQuery() { + StatementStream out(base_); + UNWRAP_STATUS(base_->ExecuteQuery(out.stream(), out.mutable_rows_affected())); + return out; + } + + private: + std::shared_ptr base_; +}; + +class Connection { + public: + Connection(std::shared_ptr base) : base_(base) {} + + Result NewStatement() { + auto child = std::make_shared(); + UNWRAP_STATUS(child->New(base_)); + return Statement(child); + } + + Status Cancel() { return base_->Cancel(); } + + Result GetInfo(const std::vector& info_codes = {}) { + ConnectionStream out(base_); + UNWRAP_STATUS(base_->GetInfo(info_codes.data(), info_codes.size(), out.stream())); + return out; + } + + private: + std::shared_ptr base_; +}; + +class Database { + public: + Database(std::shared_ptr base) : base_(base) {} + + Result NewConnection() { + auto child = std::make_shared(); + UNWRAP_STATUS(child->New(base_)); + UNWRAP_STATUS(child->Init()); + return Connection(child); + } + + private: + std::shared_ptr base_; +}; + +class Driver { + Driver() : base_(std::make_shared()) {} + + Status Init(AdbcDriverInitFunc init_func) { return base_->Init(init_func); } + + Result NewDatabase() { + auto child = std::make_shared(); + UNWRAP_STATUS(child->New(base_)); + UNWRAP_STATUS(child->Init()); + return Database(child); + } + + private: + std::shared_ptr base_; +}; + } // namespace adbc::client \ No newline at end of file From c7fd5c7af3c46097c39fdc8ec3f009a1f1233d36 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 23 Oct 2024 01:15:51 -0500 Subject: [PATCH 4/9] rename --- c/driver/framework/{base_client.h => client.h} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename c/driver/framework/{base_client.h => client.h} (100%) diff --git a/c/driver/framework/base_client.h b/c/driver/framework/client.h similarity index 100% rename from c/driver/framework/base_client.h rename to c/driver/framework/client.h From e8445e2c0e7bdbbb1cefb28c84ada0ae42a764a6 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 23 Oct 2024 13:42:16 -0500 Subject: [PATCH 5/9] heiarchy validity checking --- c/driver/framework/client.h | 185 +++++++++++++++++++++++++++++------- c/driver/framework/status.h | 8 ++ 2 files changed, 158 insertions(+), 35 deletions(-) diff --git a/c/driver/framework/client.h b/c/driver/framework/client.h index 32ce3052bb..16ab6838cf 100644 --- a/c/driver/framework/client.h +++ b/c/driver/framework/client.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include @@ -60,6 +61,27 @@ class Context { Log(LogLevel::kWarn, "leaking unreleasable driver"); } + virtual void OnDeleteHandleWithoutClose( + const std::shared_ptr& statement) { + Log(LogLevel::kWarn, + "Leaking Statement handle; AdbcStatement will be auto-released when all child " + "readers are released. Use Statement::Release() to avoid this message."); + } + + virtual void OnDeleteHandleWithoutClose( + const std::shared_ptr& connection) { + Log(LogLevel::kWarn, + "Leaking Connection handle; AdbcConnection will be auto-released when all child " + "readers are released. Use Connection::Release() to avoid this message."); + } + + virtual void OnDeleteHandleWithoutClose( + const std::shared_ptr& database) { + Log(LogLevel::kWarn, + "Leaking Database handle; AdbcDatabase will be auto-released when all child " + "readers are released. Use Database::Release() to avoid this message."); + } + virtual void Log(LogLevel level, std::string_view message) {} static std::shared_ptr Default() { return std::make_unique(); } @@ -69,16 +91,32 @@ namespace internal { class BaseDriver { public: - BaseDriver(std::shared_ptr context = std::make_unique()) + explicit BaseDriver(std::shared_ptr context = std::make_unique()) : context_(context) {} - Status Init(AdbcDriverInitFunc init_func) { + + BaseDriver(const BaseDriver& rhs) = delete; + + Status Load(AdbcDriverInitFunc init_func) { return Status::FromAdbc(init_func(ADBC_VERSION_1_1_0, &driver_, &error_), error_); } + Status Unload() { + AdbcStatusCode code = driver_.release(&driver_, &error_); + return Status::FromAdbc(code, error_); + } + AdbcDriver* driver() { return &driver_; } Context* context() { return context_.get(); } + Status CheckValid() { + if (!driver_.release) { + return Status::InvalidState("Driver is released"); + } else { + return Status::Ok(); + } + } + private: std::shared_ptr context_; AdbcDriver driver_{}; @@ -104,6 +142,7 @@ class BaseObject { class BaseDatabase : public BaseObject { public: + BaseDatabase(const BaseDatabase& rhs) = delete; ~BaseDatabase() { if (driver_ && database_.private_data) { AdbcStatusCode code = WRAP_CALL0(DatabaseRelease); @@ -137,16 +176,16 @@ class BaseDatabase : public BaseObject { return Status::FromAdbc(code, error_); } - private: - AdbcDatabase database_{}; - Status CheckValid() { - if (driver_) { - return Status::Ok(); - } else { - return driver::status::InvalidState("BaseDatabase is not valid"); + if (!driver_ || !database_.private_data) { + return Status::InvalidState("BaseDatabase is released"); } + + return driver_->CheckValid(); } + + private: + AdbcDatabase database_{}; }; #undef WRAP_CALL @@ -157,6 +196,7 @@ class BaseDatabase : public BaseObject { class BaseConnection : public BaseObject { public: + BaseConnection(const BaseConnection& rhs) = delete; ~BaseConnection() { AdbcStatusCode code = WRAP_CALL0(ConnectionRelease); if (code != ADBC_STATUS_OK) { @@ -201,17 +241,17 @@ class BaseConnection : public BaseObject { return Status::FromAdbc(code, error_); } - private: - std::shared_ptr database_; - AdbcConnection connection_{}; - Status CheckValid() { - if (driver_ && database_ && connection_.private_data) { - return Status::Ok(); - } else { - return driver::status::InvalidState("BaseConnection is not valid"); + if (!driver_ || !database_ || !connection_.private_data) { + return Status::InvalidState("BaseConnection is released"); } + + return database_->CheckValid(); } + + private: + std::shared_ptr database_; + AdbcConnection connection_{}; }; #undef WRAP_CALL @@ -220,8 +260,9 @@ class BaseConnection : public BaseObject { #define WRAP_CALL(func, ...) (driver_->driver()->func(&statement_, __VA_ARGS__, &error_)) #define WRAP_CALL0(func) (driver_->driver()->func(&statement_, &error_)) -class BaseStatement : BaseObject { +class BaseStatement : public BaseObject { public: + BaseStatement(const BaseStatement& rhs) = delete; ~BaseStatement() { if (driver_ && statement_.private_data) { AdbcStatusCode code = WRAP_CALL0(StatementRelease); @@ -261,17 +302,17 @@ class BaseStatement : BaseObject { return Status::FromAdbc(code, error_); } - private: - std::shared_ptr connection_; - AdbcStatement statement_{}; - Status CheckValid() { - if (driver_ && connection_ && statement_.private_data) { - return Status::Ok(); - } else { - return driver::status::InvalidState("BaseStatement is not valid"); + if (!driver_ || !connection_ || !statement_.private_data) { + return Status::InvalidState("BaseStatement is released"); } + + return connection_->CheckValid(); } + + private: + std::shared_ptr connection_; + AdbcStatement statement_{}; }; #undef WRAP_CALL @@ -302,6 +343,8 @@ class Stream { ArrowArrayStream* stream() { return &stream_; } + int64_t rows_affected() { return rows_affected_; } + int64_t* mutable_rows_affected() { return &rows_affected_; } ~Stream() { @@ -328,17 +371,37 @@ class Stream { Parent parent_; ArrowArrayStream stream_{}; int64_t rows_affected_{-1}; + // For the specific case of a stream whose parent is no longer valid, + // this lets us save the error message and return a const char* from + // get_last_error(). + Status last_status_; static int CGetSchema(ArrowArrayStream* stream, ArrowSchema* schema) { - return reinterpret_cast(stream->private_data)->GetSchema(schema); + auto private_data = reinterpret_cast(stream->private_data); + if (!private_data->parent_->CheckValid().ok()) { + return EADDRNOTAVAIL; + } + + return private_data->GetSchema(schema); } static int CGetNext(ArrowArrayStream* stream, ArrowArray* array) { - return reinterpret_cast(stream->private_data)->GetNext(array); + auto private_data = reinterpret_cast(stream->private_data); + if (!private_data->parent_->CheckValid().ok()) { + return EADDRNOTAVAIL; + } + + return private_data->GetNext(array); } static const char* CGetLastError(ArrowArrayStream* stream) { - return reinterpret_cast(stream->private_data)->GetLastError(); + auto private_data = reinterpret_cast(stream->private_data); + private_data->last_status_ = private_data->CheckValid(); + if (!private_data->last_status_.ok()) { + return private_data->last_status_.message(); + } + + return private_data->GetLastError(); } static void CRelease(ArrowArrayStream* stream) { @@ -348,12 +411,25 @@ class Stream { } }; +class Driver; +class Database; +class Connection; +class Statement; using ConnectionStream = Stream>; using StatementStream = Stream>; class Statement { public: - Statement(std::shared_ptr base) : base_(base) {} + Statement(const Statement& rhs) = delete; + Statement(const Statement&& rhs) : base_(std::move(rhs.base_)) {} + Statement& operator=(Statement&& rhs) { + base_ = std::move(rhs.base_); + return *this; + } + + ~Statement() { base_->GetSharedDriver()->context()->OnDeleteHandleWithoutClose(base_); } + + Status Release() { return base_->Release(); } Status SetSqlQuery(const std::string& query) { return base_->SetSqlQuery(query.c_str()); @@ -367,11 +443,25 @@ class Statement { private: std::shared_ptr base_; + + friend class Connection; + Statement(std::shared_ptr base) : base_(base) {} }; class Connection { public: - Connection(std::shared_ptr base) : base_(base) {} + Connection(const Connection& rhs) = delete; + Connection(const Connection&& rhs) : base_(std::move(rhs.base_)) {} + Connection& operator=(Connection&& rhs) { + base_ = std::move(rhs.base_); + return *this; + } + + ~Connection() { + base_->GetSharedDriver()->context()->OnDeleteHandleWithoutClose(base_); + } + + Status Release() { return base_->Release(); } Result NewStatement() { auto child = std::make_shared(); @@ -389,11 +479,23 @@ class Connection { private: std::shared_ptr base_; + + friend class Database; + Connection(std::shared_ptr base) : base_(base) {} }; class Database { public: - Database(std::shared_ptr base) : base_(base) {} + Database(const Database& rhs) = delete; + Database(const Database&& rhs) : base_(std::move(rhs.base_)) {} + Database& operator=(Database&& rhs) { + base_ = std::move(rhs.base_); + return *this; + } + + ~Database() { base_->GetSharedDriver()->context()->OnDeleteHandleWithoutClose(base_); } + + Status Release() { return base_->Release(); } Result NewConnection() { auto child = std::make_shared(); @@ -404,12 +506,25 @@ class Database { private: std::shared_ptr base_; + + friend class Driver; + Database(std::shared_ptr base) : base_(base) {} }; class Driver { - Driver() : base_(std::make_shared()) {} + explicit Driver(std::shared_ptr context = Context::Default()) + : base_(std::make_shared()) {} + + Driver(const Driver& rhs) = delete; + Driver(const Driver&& rhs) : base_(std::move(rhs.base_)) {} + Driver& operator=(Driver&& rhs) { + base_ = std::move(rhs.base_); + return *this; + } + + Status Load(AdbcDriverInitFunc init_func) { return base_->Load(init_func); } - Status Init(AdbcDriverInitFunc init_func) { return base_->Init(init_func); } + Status Unload() { return base_->Unload(); } Result NewDatabase() { auto child = std::make_shared(); @@ -422,4 +537,4 @@ class Driver { std::shared_ptr base_; }; -} // namespace adbc::client \ No newline at end of file +} // namespace adbc::client diff --git a/c/driver/framework/status.h b/c/driver/framework/status.h index 677cef15c0..8736cc565d 100644 --- a/c/driver/framework/status.h +++ b/c/driver/framework/status.h @@ -63,6 +63,14 @@ class Status { /// \brief Check if this is an error or not. bool ok() const { return impl_ == nullptr; } + const char* message() { + if (!impl_) { + return ""; + } else { + return impl_->message.c_str(); + } + } + /// \brief Add another error detail. void AddDetail(std::string key, std::string value) { assert(impl_ != nullptr); From 524bf9d222ced7c1011f08bdd96096e95589e3d5 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 23 Oct 2024 14:54:08 -0500 Subject: [PATCH 6/9] even more lifecycle --- c/driver/framework/base_driver_test.cc | 77 ++++++++++++++++++++ c/driver/framework/client.h | 98 ++++++++++++++++++++++---- c/driver/framework/status.h | 10 ++- 3 files changed, 171 insertions(+), 14 deletions(-) diff --git a/c/driver/framework/base_driver_test.cc b/c/driver/framework/base_driver_test.cc index 1d8d61f60f..ad94c009ac 100644 --- a/c/driver/framework/base_driver_test.cc +++ b/c/driver/framework/base_driver_test.cc @@ -20,6 +20,7 @@ #include #include "driver/framework/base_driver.h" +#include "driver/framework/client.h" #include "driver/framework/connection.h" #include "driver/framework/database.h" #include "driver/framework/statement.h" @@ -235,3 +236,79 @@ TEST(TestDriverBase, TestVoidDriverMethods) { ADBC_STATUS_INVALID_ARGUMENT); EXPECT_EQ(driver.StatementCancel(&statement, nullptr), ADBC_STATUS_NOT_IMPLEMENTED); } + +class TestContext : public adbc::client::Context { + void Log(LogLevel level, std::string_view message) override { + GTEST_FAIL() << "Unexpected TestContext log message: " << message; + } +}; + +TEST(TestDriverBase, TestVoidDriverMethodsClient) { + using adbc::client::Connection; + using adbc::client::Database; + using adbc::client::Driver; + using adbc::client::Statement; + + Driver driver(std::make_shared()); + ASSERT_TRUE(driver.Load(VoidDriverInitFunc).ok()); + + auto maybe_database = driver.NewDatabase(); + ASSERT_TRUE(maybe_database.has_value()); + Database database = std::move(maybe_database.value()); + + // TODO: Test database methods + + auto maybe_connection = database.NewConnection(); + ASSERT_TRUE(maybe_connection.has_value()) << maybe_connection.status().message(); + Connection connection = std::move(maybe_connection.value()); + + // TODO: Test connection methods + + // EXPECT_EQ(driver.ConnectionCommit(&connection, nullptr), ADBC_STATUS_INVALID_STATE); + // EXPECT_EQ(driver.ConnectionGetInfo(&connection, nullptr, 0, nullptr, nullptr), + // ADBC_STATUS_INVALID_ARGUMENT); + // EXPECT_EQ(driver.ConnectionGetObjects(&connection, 0, nullptr, nullptr, 0, nullptr, + // nullptr, nullptr, nullptr), + // ADBC_STATUS_NOT_IMPLEMENTED); + // EXPECT_EQ(driver.ConnectionGetTableSchema(&connection, nullptr, nullptr, nullptr, + // nullptr, nullptr), + // ADBC_STATUS_INVALID_ARGUMENT); + // EXPECT_EQ(driver.ConnectionGetTableTypes(&connection, nullptr, nullptr), + // ADBC_STATUS_INVALID_ARGUMENT); + // EXPECT_EQ(driver.ConnectionReadPartition(&connection, nullptr, 0, nullptr, nullptr), + // ADBC_STATUS_NOT_IMPLEMENTED); + // EXPECT_EQ(driver.ConnectionRollback(&connection, nullptr), + // ADBC_STATUS_INVALID_STATE); EXPECT_EQ(driver.ConnectionCancel(&connection, nullptr), + // ADBC_STATUS_NOT_IMPLEMENTED); EXPECT_EQ(driver.ConnectionGetStatistics(&connection, + // nullptr, nullptr, nullptr, 0, + // nullptr, nullptr), + // ADBC_STATUS_NOT_IMPLEMENTED); + // EXPECT_EQ(driver.ConnectionGetStatisticNames(&connection, nullptr, nullptr), + // ADBC_STATUS_NOT_IMPLEMENTED); + + auto maybe_statement = connection.NewStatement(); + ASSERT_TRUE(maybe_statement.has_value()); + Statement statement = std::move(maybe_statement.value()); + + // TODO: Test statement methods + // EXPECT_EQ(driver.StatementExecuteQuery(&statement, nullptr, nullptr, nullptr), + // ADBC_STATUS_INVALID_STATE); + // EXPECT_EQ(driver.StatementExecuteSchema(&statement, nullptr, nullptr), + // ADBC_STATUS_NOT_IMPLEMENTED); + // EXPECT_EQ(driver.StatementPrepare(&statement, nullptr), ADBC_STATUS_INVALID_STATE); + // EXPECT_EQ(driver.StatementSetSqlQuery(&statement, "", nullptr), ADBC_STATUS_OK); + // EXPECT_EQ(driver.StatementSetSubstraitPlan(&statement, nullptr, 0, nullptr), + // ADBC_STATUS_NOT_IMPLEMENTED); + // EXPECT_EQ(driver.StatementBind(&statement, nullptr, nullptr, nullptr), + // ADBC_STATUS_INVALID_ARGUMENT); + // EXPECT_EQ(driver.StatementBindStream(&statement, nullptr, nullptr), + // ADBC_STATUS_INVALID_ARGUMENT); + // EXPECT_EQ(driver.StatementCancel(&statement, nullptr), ADBC_STATUS_NOT_IMPLEMENTED); + + ASSERT_EQ(statement.SetSqlQuery("").code(), ADBC_STATUS_OK); + + ASSERT_EQ(statement.Release().code(), ADBC_STATUS_OK); + ASSERT_EQ(connection.Release().code(), ADBC_STATUS_OK); + ASSERT_EQ(statement.Release().code(), ADBC_STATUS_OK); + ASSERT_EQ(driver.Unload().code(), ADBC_STATUS_OK); +} diff --git a/c/driver/framework/client.h b/c/driver/framework/client.h index 16ab6838cf..6b085d00e0 100644 --- a/c/driver/framework/client.h +++ b/c/driver/framework/client.h @@ -40,6 +40,8 @@ class Context { public: enum class LogLevel { kInfo, kWarn }; + virtual ~Context() = default; + virtual void OnUnreleasableStatement( std::shared_ptr connection, AdbcStatement* statement, const AdbcError* error) { @@ -125,6 +127,7 @@ class BaseDriver { class BaseObject { public: + BaseObject() = default; const std::shared_ptr& GetSharedDriver() { return driver_; } protected: @@ -142,6 +145,7 @@ class BaseObject { class BaseDatabase : public BaseObject { public: + BaseDatabase() = default; BaseDatabase(const BaseDatabase& rhs) = delete; ~BaseDatabase() { if (driver_ && database_.private_data) { @@ -167,6 +171,7 @@ class BaseDatabase : public BaseObject { } Status Release() { + UNWRAP_STATUS(CheckValid()); AdbcStatusCode code = WRAP_CALL0(DatabaseRelease); if (code == ADBC_STATUS_OK) { ReleaseBase(); @@ -196,6 +201,7 @@ class BaseDatabase : public BaseObject { class BaseConnection : public BaseObject { public: + BaseConnection() = default; BaseConnection(const BaseConnection& rhs) = delete; ~BaseConnection() { AdbcStatusCode code = WRAP_CALL0(ConnectionRelease); @@ -207,8 +213,13 @@ class BaseConnection : public BaseObject { AdbcConnection* connection() { return &connection_; } Status New(std::shared_ptr database) { + UNWRAP_STATUS(database->CheckValid()); NewBase(database->GetSharedDriver()); AdbcStatusCode code = WRAP_CALL0(ConnectionNew); + if (code == ADBC_STATUS_OK) { + database_ = database; + } + return Status::FromAdbc(code, error_); } @@ -219,6 +230,7 @@ class BaseConnection : public BaseObject { } Status Release() { + UNWRAP_STATUS(CheckValid()); AdbcStatusCode code = WRAP_CALL0(ConnectionRelease); if (code == ADBC_STATUS_OK) { ReleaseBase(); @@ -262,6 +274,7 @@ class BaseConnection : public BaseObject { class BaseStatement : public BaseObject { public: + BaseStatement() = default; BaseStatement(const BaseStatement& rhs) = delete; ~BaseStatement() { if (driver_ && statement_.private_data) { @@ -276,10 +289,15 @@ class BaseStatement : public BaseObject { NewBase(connection->GetSharedDriver()); AdbcStatusCode code = driver()->StatementNew(connection->connection(), &statement_, &error_); + if (code == ADBC_STATUS_OK) { + connection_ = connection; + } + return Status::FromAdbc(code, error_); } Status Release() { + UNWRAP_STATUS(CheckValid()); AdbcStatusCode code = WRAP_CALL0(StatementRelease); if (code == ADBC_STATUS_OK) { ReleaseBase(); @@ -324,17 +342,16 @@ template class Stream { public: explicit Stream(Parent parent) : parent_(parent) {} - Stream(Stream&& rhs) : Stream(rhs.get()) { - parent_ = std::move(rhs.parent_); + Stream(Stream&& rhs) : Stream(std::move(rhs.parent_)) { std::memcpy(&stream_, &rhs.stream_, sizeof(ArrowArrayStream)); - std::memset(rhs.stream_, 0, sizeof(ArrowArrayStream)); + std::memset(&rhs.stream_, 0, sizeof(ArrowArrayStream)); rows_affected_ = rhs.rows_affected_; } Stream& operator=(Stream&& rhs) { parent_ = std::move(rhs.parent_); std::memcpy(&stream_, &rhs.stream_, sizeof(ArrowArrayStream)); - std::memset(rhs.stream_, 0, sizeof(ArrowArrayStream)); + std::memset(&rhs.stream_, 0, sizeof(ArrowArrayStream)); rows_affected_ = rhs.rows_affected_; return *this; } @@ -357,7 +374,7 @@ class Stream { Stream* instance = new Stream(); instance->parent_ = std::move(parent_); std::memcpy(&instance->stream_, &stream_, sizeof(ArrowArrayStream)); - std::memset(stream_, 0, sizeof(ArrowArrayStream)); + std::memset(&stream_, 0, sizeof(ArrowArrayStream)); instance->rows_affected_ = rows_affected_; out->get_schema = &CGetSchema; @@ -427,15 +444,24 @@ class Statement { return *this; } - ~Statement() { base_->GetSharedDriver()->context()->OnDeleteHandleWithoutClose(base_); } + ~Statement() { + if (base_) { + base_->GetSharedDriver()->context()->OnDeleteHandleWithoutClose(base_); + } + } - Status Release() { return base_->Release(); } + Status Release() { + UNWRAP_STATUS(CheckValid()); + return base_->Release(); + } Status SetSqlQuery(const std::string& query) { + UNWRAP_STATUS(CheckValid()); return base_->SetSqlQuery(query.c_str()); } Result ExecuteQuery() { + UNWRAP_STATUS(CheckValid()); StatementStream out(base_); UNWRAP_STATUS(base_->ExecuteQuery(out.stream(), out.mutable_rows_affected())); return out; @@ -446,6 +472,14 @@ class Statement { friend class Connection; Statement(std::shared_ptr base) : base_(base) {} + + Status CheckValid() { + if (!base_) { + return Status::InvalidState("Statement handle has been released"); + } else { + return Status::Ok(); + } + } }; class Connection { @@ -458,20 +492,30 @@ class Connection { } ~Connection() { - base_->GetSharedDriver()->context()->OnDeleteHandleWithoutClose(base_); + if (base_) { + base_->GetSharedDriver()->context()->OnDeleteHandleWithoutClose(base_); + } } - Status Release() { return base_->Release(); } + Status Release() { + UNWRAP_STATUS(CheckValid()); + return base_->Release(); + } Result NewStatement() { + UNWRAP_STATUS(CheckValid()); auto child = std::make_shared(); UNWRAP_STATUS(child->New(base_)); return Statement(child); } - Status Cancel() { return base_->Cancel(); } + Status Cancel() { + UNWRAP_STATUS(CheckValid()); + return base_->Cancel(); + } Result GetInfo(const std::vector& info_codes = {}) { + UNWRAP_STATUS(CheckValid()); ConnectionStream out(base_); UNWRAP_STATUS(base_->GetInfo(info_codes.data(), info_codes.size(), out.stream())); return out; @@ -482,6 +526,14 @@ class Connection { friend class Database; Connection(std::shared_ptr base) : base_(base) {} + + Status CheckValid() { + if (!base_) { + return Status::InvalidState("Connection handle has been released"); + } else { + return Status::Ok(); + } + } }; class Database { @@ -493,11 +545,22 @@ class Database { return *this; } - ~Database() { base_->GetSharedDriver()->context()->OnDeleteHandleWithoutClose(base_); } + ~Database() { + if (base_) { + base_->GetSharedDriver()->context()->OnDeleteHandleWithoutClose(base_); + } + } - Status Release() { return base_->Release(); } + Status Release() { + UNWRAP_STATUS(CheckValid()); + UNWRAP_STATUS(base_->Release()); + base_.reset(); + return Status::Ok(); + } Result NewConnection() { + UNWRAP_STATUS(CheckValid()); + auto child = std::make_shared(); UNWRAP_STATUS(child->New(base_)); UNWRAP_STATUS(child->Init()); @@ -509,11 +572,20 @@ class Database { friend class Driver; Database(std::shared_ptr base) : base_(base) {} + + Status CheckValid() { + if (!base_) { + return Status::InvalidState("Database handle has been released"); + } else { + return Status::Ok(); + } + } }; class Driver { + public: explicit Driver(std::shared_ptr context = Context::Default()) - : base_(std::make_shared()) {} + : base_(std::make_shared(std::move(context))) {} Driver(const Driver& rhs) = delete; Driver(const Driver&& rhs) : base_(std::move(rhs.base_)) {} diff --git a/c/driver/framework/status.h b/c/driver/framework/status.h index 8736cc565d..b9e1da85db 100644 --- a/c/driver/framework/status.h +++ b/c/driver/framework/status.h @@ -63,7 +63,7 @@ class Status { /// \brief Check if this is an error or not. bool ok() const { return impl_ == nullptr; } - const char* message() { + const char* message() const { if (!impl_) { return ""; } else { @@ -71,6 +71,14 @@ class Status { } } + AdbcStatusCode code() const { + if (ok()) { + return ADBC_STATUS_OK; + } else { + return impl_->code; + } + } + /// \brief Add another error detail. void AddDetail(std::string key, std::string value) { assert(impl_ != nullptr); From 7c9cbf620ab9d3f823afa4d764360400dc0f6660 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 23 Oct 2024 15:27:17 -0500 Subject: [PATCH 7/9] better lifecycle --- c/driver/framework/client.h | 43 ++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/c/driver/framework/client.h b/c/driver/framework/client.h index 6b085d00e0..7ead8d2cbf 100644 --- a/c/driver/framework/client.h +++ b/c/driver/framework/client.h @@ -204,9 +204,11 @@ class BaseConnection : public BaseObject { BaseConnection() = default; BaseConnection(const BaseConnection& rhs) = delete; ~BaseConnection() { - AdbcStatusCode code = WRAP_CALL0(ConnectionRelease); - if (code != ADBC_STATUS_OK) { - context()->OnUnreleasableConnection(database_, &connection_, &error_); + if (driver_ && connection_.private_data) { + AdbcStatusCode code = WRAP_CALL0(ConnectionRelease); + if (code != ADBC_STATUS_OK) { + context()->OnUnreleasableConnection(database_, &connection_, &error_); + } } } @@ -437,22 +439,25 @@ using StatementStream = Stream>; class Statement { public: + Statement& operator=(const Statement&) = delete; Statement(const Statement& rhs) = delete; Statement(const Statement&& rhs) : base_(std::move(rhs.base_)) {} - Statement& operator=(Statement&& rhs) { + Statement& operator=(const Statement&& rhs) { base_ = std::move(rhs.base_); return *this; } ~Statement() { - if (base_) { + if (base_ && base_->GetSharedDriver()) { base_->GetSharedDriver()->context()->OnDeleteHandleWithoutClose(base_); } } Status Release() { UNWRAP_STATUS(CheckValid()); - return base_->Release(); + UNWRAP_STATUS(base_->Release()); + base_.reset(); + return Status::Ok(); } Status SetSqlQuery(const std::string& query) { @@ -471,7 +476,7 @@ class Statement { std::shared_ptr base_; friend class Connection; - Statement(std::shared_ptr base) : base_(base) {} + Statement(std::shared_ptr base) : base_(std::move(base)) {} Status CheckValid() { if (!base_) { @@ -484,29 +489,32 @@ class Statement { class Connection { public: + Connection& operator=(const Connection&) = delete; Connection(const Connection& rhs) = delete; Connection(const Connection&& rhs) : base_(std::move(rhs.base_)) {} - Connection& operator=(Connection&& rhs) { + Connection& operator=(const Connection&& rhs) { base_ = std::move(rhs.base_); return *this; } ~Connection() { - if (base_) { + if (base_ && base_->GetSharedDriver()) { base_->GetSharedDriver()->context()->OnDeleteHandleWithoutClose(base_); } } Status Release() { UNWRAP_STATUS(CheckValid()); - return base_->Release(); + UNWRAP_STATUS(base_->Release()); + base_.reset(); + return Status::Ok(); } Result NewStatement() { UNWRAP_STATUS(CheckValid()); auto child = std::make_shared(); UNWRAP_STATUS(child->New(base_)); - return Statement(child); + return Statement(std::move(child)); } Status Cancel() { @@ -525,7 +533,7 @@ class Connection { std::shared_ptr base_; friend class Database; - Connection(std::shared_ptr base) : base_(base) {} + Connection(std::shared_ptr base) : base_(std::move(base)) {} Status CheckValid() { if (!base_) { @@ -538,15 +546,16 @@ class Connection { class Database { public: + Database& operator=(const Database&) = delete; Database(const Database& rhs) = delete; Database(const Database&& rhs) : base_(std::move(rhs.base_)) {} - Database& operator=(Database&& rhs) { + Database& operator=(const Database&& rhs) { base_ = std::move(rhs.base_); return *this; } ~Database() { - if (base_) { + if (base_ && base_->GetSharedDriver()) { base_->GetSharedDriver()->context()->OnDeleteHandleWithoutClose(base_); } } @@ -564,14 +573,14 @@ class Database { auto child = std::make_shared(); UNWRAP_STATUS(child->New(base_)); UNWRAP_STATUS(child->Init()); - return Connection(child); + return Connection(std::move(child)); } private: std::shared_ptr base_; friend class Driver; - Database(std::shared_ptr base) : base_(base) {} + Database(std::shared_ptr base) : base_(std::move(base)) {} Status CheckValid() { if (!base_) { @@ -602,7 +611,7 @@ class Driver { auto child = std::make_shared(); UNWRAP_STATUS(child->New(base_)); UNWRAP_STATUS(child->Init()); - return Database(child); + return Database(std::move(child)); } private: From 9cc2bef24b4c51f2f3a2c24b253cb164a1de2fda Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 23 Oct 2024 15:36:16 -0500 Subject: [PATCH 8/9] test passed! --- c/driver/framework/base_driver_test.cc | 2 +- c/driver/framework/client.h | 20 +++++++++++--------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/c/driver/framework/base_driver_test.cc b/c/driver/framework/base_driver_test.cc index ad94c009ac..3117705628 100644 --- a/c/driver/framework/base_driver_test.cc +++ b/c/driver/framework/base_driver_test.cc @@ -309,6 +309,6 @@ TEST(TestDriverBase, TestVoidDriverMethodsClient) { ASSERT_EQ(statement.Release().code(), ADBC_STATUS_OK); ASSERT_EQ(connection.Release().code(), ADBC_STATUS_OK); - ASSERT_EQ(statement.Release().code(), ADBC_STATUS_OK); + ASSERT_EQ(database.Release().code(), ADBC_STATUS_OK); ASSERT_EQ(driver.Unload().code(), ADBC_STATUS_OK); } diff --git a/c/driver/framework/client.h b/c/driver/framework/client.h index 7ead8d2cbf..05dc7c4c3a 100644 --- a/c/driver/framework/client.h +++ b/c/driver/framework/client.h @@ -344,6 +344,10 @@ template class Stream { public: explicit Stream(Parent parent) : parent_(parent) {} + + Stream& operator=(const Stream& rhs) = delete; + Stream(const Stream& rhs) = delete; + Stream(Stream&& rhs) : Stream(std::move(rhs.parent_)) { std::memcpy(&stream_, &rhs.stream_, sizeof(ArrowArrayStream)); std::memset(&rhs.stream_, 0, sizeof(ArrowArrayStream)); @@ -358,8 +362,6 @@ class Stream { return *this; } - Stream(const Stream& rhs) = delete; - ArrowArrayStream* stream() { return &stream_; } int64_t rows_affected() { return rows_affected_; } @@ -441,8 +443,8 @@ class Statement { public: Statement& operator=(const Statement&) = delete; Statement(const Statement& rhs) = delete; - Statement(const Statement&& rhs) : base_(std::move(rhs.base_)) {} - Statement& operator=(const Statement&& rhs) { + Statement(Statement&& rhs) : base_(std::move(rhs.base_)) {} + Statement& operator=(Statement&& rhs) { base_ = std::move(rhs.base_); return *this; } @@ -491,8 +493,8 @@ class Connection { public: Connection& operator=(const Connection&) = delete; Connection(const Connection& rhs) = delete; - Connection(const Connection&& rhs) : base_(std::move(rhs.base_)) {} - Connection& operator=(const Connection&& rhs) { + Connection(Connection&& rhs) : base_(std::move(rhs.base_)) {} + Connection& operator=(Connection&& rhs) { base_ = std::move(rhs.base_); return *this; } @@ -548,8 +550,8 @@ class Database { public: Database& operator=(const Database&) = delete; Database(const Database& rhs) = delete; - Database(const Database&& rhs) : base_(std::move(rhs.base_)) {} - Database& operator=(const Database&& rhs) { + Database(Database&& rhs) : base_(std::move(rhs.base_)) {} + Database& operator=(Database&& rhs) { base_ = std::move(rhs.base_); return *this; } @@ -597,7 +599,7 @@ class Driver { : base_(std::make_shared(std::move(context))) {} Driver(const Driver& rhs) = delete; - Driver(const Driver&& rhs) : base_(std::move(rhs.base_)) {} + Driver(Driver&& rhs) : base_(std::move(rhs.base_)) {} Driver& operator=(Driver&& rhs) { base_ = std::move(rhs.base_); return *this; From 1b32a398b2ac4deba6e1300ec7fbb81cdc56cc69 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 23 Oct 2024 15:41:15 -0500 Subject: [PATCH 9/9] lint --- c/driver/framework/client.h | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/c/driver/framework/client.h b/c/driver/framework/client.h index 05dc7c4c3a..cc2124768b 100644 --- a/c/driver/framework/client.h +++ b/c/driver/framework/client.h @@ -19,6 +19,9 @@ #include #include +#include +#include +#include #include @@ -478,7 +481,8 @@ class Statement { std::shared_ptr base_; friend class Connection; - Statement(std::shared_ptr base) : base_(std::move(base)) {} + explicit Statement(std::shared_ptr base) + : base_(std::move(base)) {} Status CheckValid() { if (!base_) { @@ -535,7 +539,8 @@ class Connection { std::shared_ptr base_; friend class Database; - Connection(std::shared_ptr base) : base_(std::move(base)) {} + explicit Connection(std::shared_ptr base) + : base_(std::move(base)) {} Status CheckValid() { if (!base_) { @@ -582,7 +587,8 @@ class Database { std::shared_ptr base_; friend class Driver; - Database(std::shared_ptr base) : base_(std::move(base)) {} + explicit Database(std::shared_ptr base) + : base_(std::move(base)) {} Status CheckValid() { if (!base_) {