From c14ee701fd82ac0376f63f5c0437e36f1cfc4201 Mon Sep 17 00:00:00 2001 From: Quentin GODEAU Date: Tue, 26 Mar 2024 00:59:31 +0100 Subject: [PATCH 1/4] Add devicecode authentification --- CMakeLists.txt | 13 +- src/auth/azure_device_code_credential.cpp | 264 ++++++++++++++++++ src/auth/azure_device_code_function.cpp | 131 +++++++++ src/auth/azure_device_codes_context.cpp | 6 + src/azure_extension.cpp | 4 + src/azure_filesystem.cpp | 4 +- src/azure_secret.cpp | 63 ++++- src/azure_storage_account_client.cpp | 76 ++++- .../auth/azure_device_code_credential.hpp | 88 ++++++ .../auth/azure_device_code_function.hpp | 9 + .../auth/azure_device_codes_context.hpp | 13 + src/include/azure_secret.hpp | 14 +- src/include/azure_storage_account_client.hpp | 4 + 13 files changed, 658 insertions(+), 31 deletions(-) create mode 100644 src/auth/azure_device_code_credential.cpp create mode 100644 src/auth/azure_device_code_function.cpp create mode 100644 src/auth/azure_device_codes_context.cpp create mode 100644 src/include/auth/azure_device_code_credential.hpp create mode 100644 src/include/auth/azure_device_code_function.hpp create mode 100644 src/include/auth/azure_device_codes_context.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 341eb3c..be3f76c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,14 +11,17 @@ set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD_REQUIRED True) set(EXTENSION_SOURCES + src/auth/azure_device_code_credential.cpp + src/auth/azure_device_code_function.cpp + src/auth/azure_device_codes_context.cpp + src/azure_blob_filesystem.cpp + src/azure_dfs_filesystem.cpp src/azure_extension.cpp - src/azure_secret.cpp src/azure_filesystem.cpp + src/azure_parsed_url.cpp + src/azure_secret.cpp src/azure_storage_account_client.cpp - src/azure_blob_filesystem.cpp - src/azure_dfs_filesystem.cpp - src/http_state_policy.cpp - src/azure_parsed_url.cpp) + src/http_state_policy.cpp) add_library(${EXTENSION_NAME} STATIC ${EXTENSION_SOURCES}) set(PARAMETERS "-warnings") diff --git a/src/auth/azure_device_code_credential.cpp b/src/auth/azure_device_code_credential.cpp new file mode 100644 index 0000000..02d3c91 --- /dev/null +++ b/src/auth/azure_device_code_credential.cpp @@ -0,0 +1,264 @@ +#include "auth/azure_device_code_credential.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/string_util.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace duckdb { + +// TODO https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-auth-code-flow#refresh-the-access-token +// TODO replace AccessToken by this class ?? +struct RequestDeviceCodeResponse {}; + +struct HttpResponseError { + std::string error; + std::string error_description; + std::vector error_codes; + std::string timestamp; + std::string trace_id; + std::string correlation_id; + std::string error_uri; +}; + +static void ParseJson(const std::string &json_str, AzureDeviceCodeInfo *response) { + auto now = std::chrono::system_clock::now(); + try { + auto json = Azure::Core::Json::_internal::json::parse(json_str); + + response->device_code = json.at("device_code").get(); + response->user_code = json.at("user_code").get(); + response->verification_uri = json.at("verification_uri").get(); + response->expires_at = now + std::chrono::seconds(json.at("expires_in").get()); + response->interval = std::chrono::seconds(json.at("interval").get()); + response->message = json.at("message").get(); + } catch (const Azure::Core::Json::_internal::json::out_of_range &ex) { + throw IOException("[AzureDeviceCodeCredential] Failed to parse Azure response '%s'", ex.what()); + } catch (const Azure::Core::Json::_internal::json::exception &ex) { + throw IOException("[AzureDeviceCodeCredential] Failed to parse JSON Azure response '%s'", ex.what()); + } +} +static void ParseJson(const std::string &json_str, Azure::Core::Credentials::AccessToken *token) { + try { + auto json = Azure::Core::Json::_internal::json::parse(json_str); + + token->Token = json.at("access_token").get(); + token->ExpiresOn = Azure::DateTime(std::chrono::system_clock::now()) + + std::chrono::seconds(json.at("expires_in").get()); + } catch (const Azure::Core::Json::_internal::json::out_of_range &ex) { + throw IOException("[AzureDeviceCodeCredential] Failed to parse Azure response '%s'", ex.what()); + } catch (const Azure::Core::Json::_internal::json::exception &ex) { + throw IOException("[AzureDeviceCodeCredential] Failed to parse JSON Azure response '%s'", ex.what()); + } +} + +static bool TryParseJson(const std::string &json_str, HttpResponseError *error) { + try { + auto json = Azure::Core::Json::_internal::json::parse(json_str); + + error->error = json.at("error").get(); + error->error_description = json.at("error_description").get(); + error->error_codes = json.at("error_codes").get>(); + error->timestamp = json.at("timestamp").get(); + error->trace_id = json.at("trace_id").get(); + error->correlation_id = json.at("correlation_id").get(); + error->error_uri = json.at("error_uri").get(); + return true; + } catch (...) { + } + return false; +} + +static std::string EncodeScopes(const std::unordered_set &scopes) { + // The result + std::string result; + + // If the input isn't empty, append the first element. We do this so we + // don't need to introduce an if into the loop. + if (scopes.size() > 0) { + auto it = scopes.begin(); + const auto end = scopes.end(); + result = *it; + + // Append the remaining input components, after the first + while (++it != end) { + result += ' ' + *it; + } + } + return Azure::Core::Url::Encode(result); +} + +static std::string CacheScopeString(const std::vector &scopes) { + std::string result; + if (scopes.size() <= 1) { + for (const auto &scope : scopes) { + result += scope; + } + } else { + auto copy_scopes = scopes; + std::sort(copy_scopes.begin(), copy_scopes.end()); + for (const auto &scope : copy_scopes) { + result += scope; + } + } + return result; +} + +AzureDeviceCodeCredential::AzureDeviceCodeCredential(std::string tenant_id, std::string client_id, + std::unordered_set scopes, + const Azure::Core::Credentials::TokenCredentialOptions &options) + : AzureDeviceCodeCredential(std::move(tenant_id), std::move(client_id), std::move(scopes), options, nullptr) { +} + +AzureDeviceCodeCredential::AzureDeviceCodeCredential(std::string tenant_id, std::string client_id, + std::unordered_set scopes, + AzureDeviceCodeInfo device_code_info, + const Azure::Core::Credentials::TokenCredentialOptions &options) + : AzureDeviceCodeCredential(std::move(tenant_id), std::move(client_id), std::move(scopes), options, + make_uniq(std::move(device_code_info))) { +} + +AzureDeviceCodeCredential::AzureDeviceCodeCredential(std::string tenant_id, std::string client_id, + std::unordered_set scopes_p, + const Azure::Core::Credentials::TokenCredentialOptions &options, + std::unique_ptr device_code_info) + : Azure::Core::Credentials::TokenCredential("DeviceCodeCredential"), tenant_id(std::move(tenant_id)), + client_id(std::move(client_id)), scopes(std::move(scopes_p)), encoded_scopes(EncodeScopes(scopes)), + device_code_info(std::move(device_code_info)), http_pipeline(options, "identity", "DuckDB", {}, {}) + +{ +} + +AzureDeviceCodeInfo AzureDeviceCodeCredential::RequestDeviceCode() { + const std::string url = "https://login.microsoftonline.com/" + tenant_id + "/oauth2/v2.0/devicecode"; + const std::string body = "client_id=" + Azure::Core::Url::Encode(client_id) + "&scope=" + encoded_scopes; + Azure::Core::IO::MemoryBodyStream body_stream(reinterpret_cast(body.data()), body.size()); + + Azure::Core::Http::Request http_request(Azure::Core::Http::HttpMethod::Post, Azure::Core::Url(url), &body_stream); + http_request.SetHeader("Content-Type", "application/x-www-form-urlencoded"); + http_request.SetHeader("Content-Length", std::to_string(body.size())); + http_request.SetHeader("Accept", "application/json"); + + auto response = http_pipeline.Send(http_request, Azure::Core::Context()); + return HandleDeviceAuthorizationResponse(*response); +} + +AzureDeviceCodeInfo +AzureDeviceCodeCredential::HandleDeviceAuthorizationResponse(const Azure::Core::Http::RawResponse &response) { + const auto &response_body = response.GetBody(); + const auto response_body_str = std::string(response_body.begin(), response_body.end()); + if (response.GetStatusCode() == Azure::Core::Http::HttpStatusCode::Ok) { + AzureDeviceCodeInfo parsed_response; + ParseJson(std::string(response_body.begin(), response_body.end()), &parsed_response); + return parsed_response; + } else { + throw IOException("[AzureDeviceCodeCredential] Failed to retrieve devicecode HTTP code: %d, details: %s", + response.GetStatusCode(), response_body_str); + } +} + +Azure::Core::Credentials::AccessToken AzureDeviceCodeCredential::AuthenticatingUser() const { + // Check if it still possible to retrieve a token! + auto now = std::chrono::system_clock::now(); + if (now >= device_code_info->expires_at) { + throw IOException("[AzureDeviceCodeCredential] Your previous credential has already expired please " + "renew it by calling `SELECT * FROM azure_devicecode('')`;"); + } + + const std::string url = "https://login.microsoftonline.com/" + tenant_id + "/oauth2/v2.0/token"; + const std::string body = "grant_type=urn:ietf:params:oauth:grant-type:device_code" + "&client_id=" + + Azure::Core::Url::Encode(client_id) + "&device_code=" + device_code_info->device_code; + Azure::Core::IO::MemoryBodyStream body_stream(reinterpret_cast(body.data()), body.size()); + + Azure::Core::Http::Request http_request(Azure::Core::Http::HttpMethod::Post, Azure::Core::Url(url), &body_stream); + http_request.SetHeader("Content-Type", "application/x-www-form-urlencoded"); + http_request.SetHeader("Content-Length", std::to_string(body.size())); + http_request.SetHeader("Accept", "application/json"); + + while (true) { + auto response = http_pipeline.Send(http_request, Azure::Core::Context()); + const auto &response_body = response->GetBody(); + const auto response_body_str = std::string(response_body.begin(), response_body.end()); + + switch (response->GetStatusCode()) { + case Azure::Core::Http::HttpStatusCode::Ok: { + Azure::Core::Credentials::AccessToken token; + ParseJson(response_body_str, &token); + return token; + } break; + + default: { + HttpResponseError error; + TryParseJson(response_body_str, &error); + if ("authorization_pending" == error.error) { + // Wait before retry + std::this_thread::sleep_for(device_code_info->interval); + } else if ("authorization_declined" == error.error) { + throw IOException("[AzureDeviceCodeCredential] Failed to retrieve user token, end user denied the " + "authorization request. (error msg: %s)", + response_body_str); + } else if ("bad_verification_code" == error.error) { + throw IOException( + "[AzureDeviceCodeCredential] Failed to retrieve recognized device code. (error msg: %s)", + response_body_str); + } else if ("expired_token" == error.error) { + throw IOException( + "[AzureDeviceCodeCredential] Failed to retrieve user token already expired. (error msg: %s)", + response_body_str); + } else { + throw IOException("[AzureDeviceCodeCredential] Unexpected error: %s", response_body_str); + } + } break; + } + } +} + +Azure::Core::Credentials::AccessToken +AzureDeviceCodeCredential::GetToken(Azure::Core::Credentials::TokenRequestContext const &token_request_context, + Azure::Core::Context const &context) const { + using Azure::Core::_internal::StringExtensions; + + if (!device_code_info) { + throw IOException("[AzureDeviceCodeCredential] No device/user code register did you call `SELECT * FROM " + "azure_devicecode('')`;"); + } + + if (!token_request_context.TenantId.empty() && + !StringExtensions::LocaleInvariantCaseInsensitiveEqual(token_request_context.TenantId, tenant_id)) { + + throw IOException( + "[AzureDeviceCodeCredential] The current credential is not configured to acquire tokens for tenant '%s'.", + token_request_context.TenantId); + } + for (const auto &scope : token_request_context.Scopes) { + if (scopes.find(scope) == scopes.end()) { + throw IOException("[AzureDeviceCodeCredential] The required scope %s is not part of the requested scope, " + "please check secret defintion.", + scope); + } + } + auto request_scopes = token_request_context.Scopes; + std::sort(request_scopes.begin(), request_scopes.end()); + + return token_cache.GetToken(CacheScopeString(token_request_context.Scopes), token_request_context.TenantId, + token_request_context.MinimumExpiration, [&]() { return AuthenticatingUser(); }); +} + +} // namespace duckdb \ No newline at end of file diff --git a/src/auth/azure_device_code_function.cpp b/src/auth/azure_device_code_function.cpp new file mode 100644 index 0000000..c6a0d2d --- /dev/null +++ b/src/auth/azure_device_code_function.cpp @@ -0,0 +1,131 @@ +#include "auth/azure_device_code_function.hpp" +#include "auth/azure_device_codes_context.hpp" +#include "azure_storage_account_client.hpp" +#include "duckdb/catalog/catalog_transaction.hpp" +#include "duckdb/common/assert.hpp" +#include "duckdb/common/enums/vector_type.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/shared_ptr.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/execution/expression_executor_state.hpp" +#include "duckdb/function/function.hpp" +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/main/extension_util.hpp" +#include "duckdb/main/secret/secret.hpp" +#include "duckdb/main/secret/secret_manager.hpp" +#include +#include +#include + +namespace duckdb { + +struct AzureDeviceCodeBindData : public FunctionData { + const std::string secret_name; + + AzureDeviceCodeBindData(std::string secret_name) : secret_name(std::move(secret_name)) { + } + + duckdb::unique_ptr Copy() const override { + return make_uniq(secret_name); + } + + bool Equals(const FunctionData &other_p) const override { + if (&other_p == this) + return true; + auto &other = other_p.Cast(); + return other.secret_name == this->secret_name; + } +}; + +struct AzureDeviceCodeCompleted : public GlobalTableFunctionState { + AzureDeviceCodeCompleted() : completed(false) { + } + + bool completed; +}; + +static void AzureDeviceCodeImplementation(ClientContext &context, TableFunctionInput &data, DataChunk &output) { + auto &bind_data = data.bind_data->Cast(); + auto &global_data = data.global_state->Cast(); + + if (global_data.completed) { + return; + } + + auto transaction = CatalogTransaction::GetSystemCatalogTransaction(context); + auto secret = context.db->config.secret_manager->GetSecretByName(transaction, bind_data.secret_name); + if (!secret) { + throw InvalidInputException("azure_devicecode no secret found named %s", bind_data.secret_name); + } + + auto device_code_credential = CreateDeviceCodeCredential(ClientData::Get(context).file_opener.get(), + dynamic_cast(*secret->secret)); + auto device_code_info = device_code_credential->RequestDeviceCode(); + + auto &device_code_context = context.registered_state[AzureDeviceCodesClientContextState::CONTEXT_KEY]; + if (!device_code_context) { + device_code_context = make_shared(); + } + + D_ASSERT(reinterpret_cast(device_code_context.get()) != nullptr); + reinterpret_cast(*device_code_context) + .device_code_info_by_secret.insert(std::make_pair(bind_data.secret_name, device_code_info)); + + output.SetCapacity(1); + output.SetValue(0, 0, bind_data.secret_name); + output.SetValue(1, 0, device_code_info.user_code); + output.SetValue(2, 0, device_code_info.verification_uri); + output.SetValue(3, 0, device_code_info.message); + auto expires_at = std::chrono::duration_cast(device_code_info.expires_at.time_since_epoch()); + output.SetValue(4, 0, Value::TIMESTAMP(Timestamp::FromEpochSeconds(expires_at.count()))); + output.SetCardinality(1); + global_data.completed = true; +} + +static unique_ptr AzureDeviceCodeBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + if (input.inputs.empty()) { + throw BinderException("azure_devicecode takes at least one argument"); + } + if (input.inputs[0].IsNull()) { + throw BinderException("azure_devicecode first parameter cannot be NULL"); + } + + auto secret_name = StringValue::Get(input.inputs[0]); + + names.emplace_back("secret_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("user_code"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("verification_uri"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("message"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("expires_in"); + return_types.emplace_back(LogicalType::TIMESTAMP); + + return make_uniq(secret_name); +} + +unique_ptr AzureDeviceCodeInit(ClientContext &context, TableFunctionInitInput &input) { + return make_uniq(); +} + +void RegisterAzureDeviceCodeFunction(DatabaseInstance &instance) { + + TableFunction azure_devicecode("azure_devicecode", {LogicalType::VARCHAR}, AzureDeviceCodeImplementation, + AzureDeviceCodeBind, AzureDeviceCodeInit); + ExtensionUtil::RegisterFunction(instance, azure_devicecode); +} +} // namespace duckdb diff --git a/src/auth/azure_device_codes_context.cpp b/src/auth/azure_device_codes_context.cpp new file mode 100644 index 0000000..ff6945f --- /dev/null +++ b/src/auth/azure_device_codes_context.cpp @@ -0,0 +1,6 @@ +#include "auth/azure_device_codes_context.hpp" + +namespace duckdb { +const std::string AzureDeviceCodesClientContextState::CONTEXT_KEY = "auth/azure_device_codes_context"; + +} \ No newline at end of file diff --git a/src/azure_extension.cpp b/src/azure_extension.cpp index af3a084..bc44a56 100644 --- a/src/azure_extension.cpp +++ b/src/azure_extension.cpp @@ -1,6 +1,7 @@ #define DUCKDB_EXTENSION_MAIN #include "azure_extension.hpp" +#include "auth/azure_device_code_function.hpp" #include "azure_blob_filesystem.hpp" #include "azure_dfs_filesystem.hpp" #include "azure_secret.hpp" @@ -16,6 +17,9 @@ static void LoadInternal(DatabaseInstance &instance) { // Load Secret functions CreateAzureSecretFunctions::Register(instance); + // Load functions + RegisterAzureDeviceCodeFunction(instance); + // Load extension config auto &config = DBConfig::GetConfig(instance); config.AddExtensionOption("azure_storage_connection_string", diff --git a/src/azure_filesystem.cpp b/src/azure_filesystem.cpp index 4c2caed..2566b8c 100644 --- a/src/azure_filesystem.cpp +++ b/src/azure_filesystem.cpp @@ -42,8 +42,8 @@ void AzureStorageFileSystem::LoadFileInfo(AzureFileHandle &handle) { LoadRemoteFileInfo(handle); } catch (const Azure::Storage::StorageException &e) { throw IOException( - "AzureBlobStorageFileSystem open file '%s' failed with code'%s', Reason Phrase: '%s', Message: '%s'", - handle.path, e.ErrorCode, e.ReasonPhrase, e.Message); + "AzureBlobStorageFileSystem open file '%s' failed with code '%d', Reason Phrase: '%s', Message: '%s'", + handle.path, e.StatusCode, e.ReasonPhrase, e.Message); } catch (const std::exception &e) { throw IOException( "AzureBlobStorageFileSystem could not open file: '%s', unknown error occurred, this could mean " diff --git a/src/azure_secret.cpp b/src/azure_secret.cpp index 4fd9a8d..47813e6 100644 --- a/src/azure_secret.cpp +++ b/src/azure_secret.cpp @@ -1,6 +1,7 @@ #include "azure_secret.hpp" #include "azure_dfs_filesystem.hpp" #include "duckdb/common/types.hpp" +#include "duckdb/common/types/value.hpp" #include "duckdb/common/unique_ptr.hpp" #include "duckdb/main/extension_util.hpp" #include "duckdb/main/secret/secret.hpp" @@ -9,7 +10,6 @@ #include #include #include -#include namespace duckdb { constexpr auto COMMON_OPTIONS = { @@ -26,6 +26,24 @@ static void CopySecret(const std::string &key, const CreateSecretInput &input, K } } +template +static void CopySecret(const std::string &key, const CreateSecretInput &input, KeyValueSecret &result, + const T &default_value) { + auto val = input.options.find(key); + + if (val != input.options.end()) { + result.secret_map[key] = val->second; + } else { + result.secret_map[key] = Value(default_value); + } +} + +static void AddDefaultScopes(vector *scope) { + scope->push_back("azure://"); + scope->push_back("az://"); + scope->push_back(AzureDfsStorageFileSystem::PATH_PREFIX); +} + static void RedactCommonKeys(KeyValueSecret &result) { result.redact_keys.insert("proxy_password"); } @@ -33,9 +51,7 @@ static void RedactCommonKeys(KeyValueSecret &result) { static unique_ptr CreateAzureSecretFromConfig(ClientContext &context, CreateSecretInput &input) { auto scope = input.scope; if (scope.empty()) { - scope.push_back("azure://"); - scope.push_back("az://"); - scope.push_back(AzureDfsStorageFileSystem::PATH_PREFIX); + AddDefaultScopes(&scope); } auto result = make_uniq(scope, input.type, input.provider, input.name); @@ -58,9 +74,7 @@ static unique_ptr CreateAzureSecretFromConfig(ClientContext &context static unique_ptr CreateAzureSecretFromCredentialChain(ClientContext &context, CreateSecretInput &input) { auto scope = input.scope; if (scope.empty()) { - scope.push_back("azure://"); - scope.push_back("az://"); - scope.push_back(AzureDfsStorageFileSystem::PATH_PREFIX); + AddDefaultScopes(&scope); } auto result = make_uniq(scope, input.type, input.provider, input.name); @@ -82,9 +96,7 @@ static unique_ptr CreateAzureSecretFromCredentialChain(ClientContext static unique_ptr CreateAzureSecretFromServicePrincipal(ClientContext &context, CreateSecretInput &input) { auto scope = input.scope; if (scope.empty()) { - scope.push_back("azure://"); - scope.push_back("az://"); - scope.push_back(AzureDfsStorageFileSystem::PATH_PREFIX); + AddDefaultScopes(&scope); } auto result = make_uniq(scope, input.type, input.provider, input.name); @@ -108,6 +120,30 @@ static unique_ptr CreateAzureSecretFromServicePrincipal(ClientContex return std::move(result); } +static unique_ptr CreateAzureSecretFromDeviceCode(ClientContext &context, CreateSecretInput &input) { + auto scope = input.scope; + if (scope.empty()) { + AddDefaultScopes(&scope); + } + + auto result = make_uniq(scope, input.type, input.provider, input.name); + + // Manage common option that all secret type share + for (const auto *key : COMMON_OPTIONS) { + CopySecret(key, input, *result); + } + + // Manage specific secret option + CopySecret("tenant_id", input, *result); + CopySecret("client_id", input, *result); + CopySecret("oauth_scopes", input, *result, "https://storage.azure.com/.default"); + + // Redact sensible keys + RedactCommonKeys(*result); + + return std::move(result); +} + static void RegisterCommonSecretParameters(CreateSecretFunction &function) { // Register azure common parameters function.named_parameters["account_name"] = LogicalType::VARCHAR; @@ -149,6 +185,13 @@ void CreateAzureSecretFunctions::Register(DatabaseInstance &instance) { service_principal_function.named_parameters["client_certificate_path"] = LogicalType::VARCHAR; RegisterCommonSecretParameters(service_principal_function); ExtensionUtil::RegisterFunction(instance, service_principal_function); + + CreateSecretFunction device_code_function = {type, "device_code", CreateAzureSecretFromDeviceCode}; + device_code_function.named_parameters["tenant_id"] = LogicalType::VARCHAR; + device_code_function.named_parameters["client_id"] = LogicalType::VARCHAR; + device_code_function.named_parameters["oauth_scopes"] = LogicalType::VARCHAR; + RegisterCommonSecretParameters(device_code_function); + ExtensionUtil::RegisterFunction(instance, device_code_function); } } // namespace duckdb diff --git a/src/azure_storage_account_client.cpp b/src/azure_storage_account_client.cpp index e54ca93..c91fde3 100644 --- a/src/azure_storage_account_client.cpp +++ b/src/azure_storage_account_client.cpp @@ -1,10 +1,15 @@ #include "azure_storage_account_client.hpp" +#include "auth/azure_device_code_credential.hpp" +#include "auth/azure_device_codes_context.hpp" #include "duckdb/catalog/catalog_transaction.hpp" +#include "duckdb/common/assert.hpp" #include "duckdb/common/enums/statement_type.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/file_opener.hpp" +#include "duckdb/common/optional_ptr.hpp" #include "duckdb/common/string_util.hpp" +#include "duckdb/common/types/value.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/main/database.hpp" #include "duckdb/main/secret/secret.hpp" @@ -30,6 +35,8 @@ #include #include #include +#include +#include namespace duckdb { const static std::string DEFAULT_BLOB_ENDPOINT = "blob.core.windows.net"; @@ -195,6 +202,28 @@ CreateClientCredential(const KeyValueSecret &secret, transport_options); } +static std::shared_ptr +CreateDeviceCodeCredential(const KeyValueSecret &secret, + const Azure::Core::Http::Policies::TransportOptions &transport_options, + const optional_ptr &device_code_info = nullptr) { + constexpr bool error_on_missing = true; + auto tenant_id = secret.TryGetValue("tenant_id", error_on_missing).ToString(); + auto client_id = secret.TryGetValue("client_id", error_on_missing).ToString(); + auto oauth_scopes_value = secret.TryGetValue("oauth_scopes", error_on_missing).ToString(); + std::vector oauth_scopes = StringUtil::Split(oauth_scopes_value, ' '); + + if (device_code_info) { + return std::make_shared( + tenant_id, client_id, std::unordered_set(oauth_scopes.begin(), oauth_scopes.end()), + *device_code_info, ToTokenCredentialOptions(transport_options)); + + } else { + return std::make_shared( + tenant_id, client_id, std::unordered_set(oauth_scopes.begin(), oauth_scopes.end()), + ToTokenCredentialOptions(transport_options)); + } +} + static std::shared_ptr CreateCurlTransport(const std::string &proxy, const std::string &proxy_username, const std::string &proxy_password) { Azure::Core::Http::CurlTransportOptions curl_transport_options; @@ -403,7 +432,42 @@ GetDfsStorageAccountClientFromServicePrincipalProvider(FileOpener *opener, const auto account_url = azure_parsed_url.is_fully_qualified ? AccountUrl(azure_parsed_url) : AccountUrl(secret, DEFAULT_DFS_ENDPOINT); - ; + + auto dfs_options = ToDfsClientOptions(transport_options, GetHttpState(opener)); + return Azure::Storage::Files::DataLake::DataLakeServiceClient(account_url, token_credential, dfs_options); +} + +static Azure::Storage::Files::DataLake::DataLakeServiceClient +GetDfsStorageAccountClientFromDeviceCodeProvider(FileOpener *opener, const KeyValueSecret &secret, + const AzureParsedUrl &azure_parsed_url) { + auto context = opener->TryGetClientContext(); + if (!context) { + throw InternalException("Context cannot be null!"); + } + auto device_codes_info_context_it = context->registered_state.find(AzureDeviceCodesClientContextState::CONTEXT_KEY); + if (device_codes_info_context_it == context->registered_state.end()) { + throw InternalException( + "Not device code has been initialized did you run `SELECT * FROM azure_devicecode('%s');`", + secret.GetName()); + } + + D_ASSERT(dynamic_cast(device_codes_info_context_it->second.get()) != nullptr); + const auto &device_code_info_by_secret = + reinterpret_cast(*device_codes_info_context_it->second) + .device_code_info_by_secret; + auto device_code_info_it = device_code_info_by_secret.find(secret.GetName()); + if (device_code_info_it == device_code_info_by_secret.end()) { + throw InternalException( + "Not device code has been initialized did you run `SELECT * FROM azure_devicecode('%s');`", + secret.GetName()); + } + + auto transport_options = GetTransportOptions(opener, secret); + auto token_credential = CreateDeviceCodeCredential(secret, transport_options, &device_code_info_it->second); + + auto account_url = + azure_parsed_url.is_fully_qualified ? AccountUrl(azure_parsed_url) : AccountUrl(secret, DEFAULT_DFS_ENDPOINT); + auto dfs_options = ToDfsClientOptions(transport_options, GetHttpState(opener)); return Azure::Storage::Files::DataLake::DataLakeServiceClient(account_url, token_credential, dfs_options); } @@ -418,6 +482,8 @@ GetBlobStorageAccountClient(FileOpener *opener, const KeyValueSecret &secret, co return GetBlobStorageAccountClientFromCredentialChainProvider(opener, secret, azure_parsed_url); } else if (provider == "service_principal") { return GetBlobStorageAccountClientFromServicePrincipalProvider(opener, secret, azure_parsed_url); + } else if (provider == "device_code") { + // TODO implements: } throw InvalidInputException("Unsupported provider type %s for azure", provider); @@ -433,6 +499,8 @@ GetDfsStorageAccountClient(FileOpener *opener, const KeyValueSecret &secret, con return GetDfsStorageAccountClientFromCredentialChainProvider(opener, secret, azure_parsed_url); } else if (provider == "service_principal") { return GetDfsStorageAccountClientFromServicePrincipalProvider(opener, secret, azure_parsed_url); + } else if (provider == "device_code") { + return GetDfsStorageAccountClientFromDeviceCodeProvider(opener, secret, azure_parsed_url); } throw InvalidInputException("Unsupported provider type %s for azure", provider); @@ -505,6 +573,12 @@ const SecretMatch LookupSecret(FileOpener *opener, const std::string &path) { return {}; } +std::shared_ptr CreateDeviceCodeCredential(FileOpener *opener, + const KeyValueSecret &secret) { + auto transport_options = GetTransportOptions(opener, secret); + return CreateDeviceCodeCredential(secret, transport_options); +} + Azure::Storage::Blobs::BlobServiceClient ConnectToBlobStorageAccount(FileOpener *opener, const std::string &path, const AzureParsedUrl &azure_parsed_url) { diff --git a/src/include/auth/azure_device_code_credential.hpp b/src/include/auth/azure_device_code_credential.hpp new file mode 100644 index 0000000..47e0855 --- /dev/null +++ b/src/include/auth/azure_device_code_credential.hpp @@ -0,0 +1,88 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace duckdb { + +struct AzureDeviceCodeInfo { + // A long string used to verify the session between the client and the authorization server. + // The client uses this parameter to request the access token from the authorization server. + std::string device_code; + // A short string shown to the user used to identify the session on a secondary device. + std::string user_code; + // The URI the user should go to with the user_code in order to sign in. + std::string verification_uri; + // The number of seconds before the device_code and user_code expire. + std::chrono::system_clock::time_point expires_at; + // The number of seconds the client should wait between polling requests. + std::chrono::seconds interval; + // A human-readable string with instructions for the user. This can be localized by including a + // query parameter in the request of the form ?mkt=xx-XX, filling in the appropriate language + // culture code. + std::string message; +}; + +/** + * Implement the missing DeviceCodeCredential from the C++ SDK + * https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-device-code + * + * Note: The way this has been develop is also a hack on how the workflow (should) work. + * In theory the scopes shouldn't be an args of the constructor, they are given when a request + * call the #GetToken method and we should call a callback that would inform the user that they + * have to go to an URL and enter the user code. + * In our case it's hard to prompt the user because when queries are performed we do not known + * how DuckDB is really being use(cmd, lib...) + * So we split the way we obtains the user/device code and the token retrieval. + */ +class AzureDeviceCodeCredential final : public Azure::Core::Credentials::TokenCredential { +public: + explicit AzureDeviceCodeCredential(std::string tenant_id, std::string client_id, + std::unordered_set scopes, + Azure::Core::Credentials::TokenCredentialOptions const &options = + Azure::Core::Credentials::TokenCredentialOptions()); + + explicit AzureDeviceCodeCredential(std::string tenant_id, std::string client_id, + std::unordered_set scopes, AzureDeviceCodeInfo device_code, + const Azure::Core::Credentials::TokenCredentialOptions &options = + Azure::Core::Credentials::TokenCredentialOptions()); + Azure::Core::Credentials::AccessToken + GetToken(Azure::Core::Credentials::TokenRequestContext const &token_request_context, + Azure::Core::Context const &context) const override; + + /** + * Send a request to get the user & device code + */ + AzureDeviceCodeInfo RequestDeviceCode(); + +private: + explicit AzureDeviceCodeCredential(std::string tenant_id, std::string client_id, + std::unordered_set scopes, + const Azure::Core::Credentials::TokenCredentialOptions &options, + std::unique_ptr device_code_info); + + AzureDeviceCodeInfo HandleDeviceAuthorizationResponse(const Azure::Core::Http::RawResponse &response); + Azure::Core::Credentials::AccessToken AuthenticatingUser() const; + +private: + const std::string tenant_id; + const std::string client_id; + const std::unordered_set scopes; + const std::string encoded_scopes; + const std::unique_ptr device_code_info; + + Azure::Identity::_detail::TokenCache token_cache; + Azure::Core::Http::_internal::HttpPipeline http_pipeline; +}; + +} // namespace duckdb \ No newline at end of file diff --git a/src/include/auth/azure_device_code_function.hpp b/src/include/auth/azure_device_code_function.hpp new file mode 100644 index 0000000..c8377b0 --- /dev/null +++ b/src/include/auth/azure_device_code_function.hpp @@ -0,0 +1,9 @@ +#pragma once + +#include "duckdb/main/database.hpp" + +namespace duckdb { + +void RegisterAzureDeviceCodeFunction(DatabaseInstance &instance); + +} // namespace duckdb diff --git a/src/include/auth/azure_device_codes_context.hpp b/src/include/auth/azure_device_codes_context.hpp new file mode 100644 index 0000000..bab37da --- /dev/null +++ b/src/include/auth/azure_device_codes_context.hpp @@ -0,0 +1,13 @@ +#pragma once + +#include "auth/azure_device_code_credential.hpp" +#include "duckdb/main/client_context_state.hpp" +#include + +namespace duckdb { +class AzureDeviceCodesClientContextState final : public ClientContextState { +public: + const static std::string CONTEXT_KEY; + std::unordered_map device_code_info_by_secret; +}; +} // namespace duckdb \ No newline at end of file diff --git a/src/include/azure_secret.hpp b/src/include/azure_secret.hpp index 3dc39e9..e732e53 100644 --- a/src/include/azure_secret.hpp +++ b/src/include/azure_secret.hpp @@ -1,20 +1,8 @@ #pragma once -#include "azure_extension.hpp" -#include "duckdb.hpp" -#include - -#include -#include -#include -#include -#include -#include +#include "duckdb/main/database.hpp" namespace duckdb { -struct CreateSecretInput; -class CreateSecretFunction; - struct CreateAzureSecretFunctions { public: //! Register all CreateSecretFunctions diff --git a/src/include/azure_storage_account_client.hpp b/src/include/azure_storage_account_client.hpp index 2e22ee0..3ce2aa8 100644 --- a/src/include/azure_storage_account_client.hpp +++ b/src/include/azure_storage_account_client.hpp @@ -1,13 +1,17 @@ #pragma once +#include "auth/azure_device_code_credential.hpp" #include "azure_parsed_url.hpp" #include "duckdb/common/file_opener.hpp" +#include "duckdb/main/secret/secret.hpp" #include #include #include namespace duckdb { +std::shared_ptr CreateDeviceCodeCredential(FileOpener *opener, const KeyValueSecret &secret); + Azure::Storage::Blobs::BlobServiceClient ConnectToBlobStorageAccount(FileOpener *opener, const std::string &path, const AzureParsedUrl &azure_parsed_url); From da8adb2062b1706ca4f8efcb88a5ca5c33f021ff Mon Sep 17 00:00:00 2001 From: Quentin GODEAU Date: Tue, 26 Mar 2024 21:16:07 +0100 Subject: [PATCH 2/4] Rework credential creation --- src/auth/azure_device_code_credential.cpp | 72 +++-- src/auth/azure_device_code_function.cpp | 6 +- src/azure_storage_account_client.cpp | 268 +++++++----------- .../auth/azure_device_code_credential.hpp | 68 +++-- src/include/azure_storage_account_client.hpp | 3 +- 5 files changed, 180 insertions(+), 237 deletions(-) diff --git a/src/auth/azure_device_code_credential.cpp b/src/auth/azure_device_code_credential.cpp index 02d3c91..675c393 100644 --- a/src/auth/azure_device_code_credential.cpp +++ b/src/auth/azure_device_code_credential.cpp @@ -105,47 +105,33 @@ static std::string EncodeScopes(const std::unordered_set &scopes) { } static std::string CacheScopeString(const std::vector &scopes) { - std::string result; - if (scopes.size() <= 1) { - for (const auto &scope : scopes) { - result += scope; - } - } else { + switch (scopes.size()) { + case 0: + return ""; + + case 1: + return scopes[0]; + + default: { + std::string result; auto copy_scopes = scopes; std::sort(copy_scopes.begin(), copy_scopes.end()); for (const auto &scope : copy_scopes) { result += scope; } + return result; + } } - return result; -} - -AzureDeviceCodeCredential::AzureDeviceCodeCredential(std::string tenant_id, std::string client_id, - std::unordered_set scopes, - const Azure::Core::Credentials::TokenCredentialOptions &options) - : AzureDeviceCodeCredential(std::move(tenant_id), std::move(client_id), std::move(scopes), options, nullptr) { -} - -AzureDeviceCodeCredential::AzureDeviceCodeCredential(std::string tenant_id, std::string client_id, - std::unordered_set scopes, - AzureDeviceCodeInfo device_code_info, - const Azure::Core::Credentials::TokenCredentialOptions &options) - : AzureDeviceCodeCredential(std::move(tenant_id), std::move(client_id), std::move(scopes), options, - make_uniq(std::move(device_code_info))) { } -AzureDeviceCodeCredential::AzureDeviceCodeCredential(std::string tenant_id, std::string client_id, - std::unordered_set scopes_p, - const Azure::Core::Credentials::TokenCredentialOptions &options, - std::unique_ptr device_code_info) - : Azure::Core::Credentials::TokenCredential("DeviceCodeCredential"), tenant_id(std::move(tenant_id)), - client_id(std::move(client_id)), scopes(std::move(scopes_p)), encoded_scopes(EncodeScopes(scopes)), - device_code_info(std::move(device_code_info)), http_pipeline(options, "identity", "DuckDB", {}, {}) - -{ +AzureDeviceCodeCredentialRequester::AzureDeviceCodeCredentialRequester( + std::string tenant_id, std::string client_id, std::unordered_set scopes_p, + const Azure::Core::Credentials::TokenCredentialOptions &options) + : tenant_id(std::move(tenant_id)), client_id(std::move(client_id)), scopes(std::move(scopes_p)), + encoded_scopes(EncodeScopes(scopes)), http_pipeline(options, "identity", "DuckDB", {}, {}) { } -AzureDeviceCodeInfo AzureDeviceCodeCredential::RequestDeviceCode() { +AzureDeviceCodeInfo AzureDeviceCodeCredentialRequester::RequestDeviceCode() { const std::string url = "https://login.microsoftonline.com/" + tenant_id + "/oauth2/v2.0/devicecode"; const std::string body = "client_id=" + Azure::Core::Url::Encode(client_id) + "&scope=" + encoded_scopes; Azure::Core::IO::MemoryBodyStream body_stream(reinterpret_cast(body.data()), body.size()); @@ -160,7 +146,7 @@ AzureDeviceCodeInfo AzureDeviceCodeCredential::RequestDeviceCode() { } AzureDeviceCodeInfo -AzureDeviceCodeCredential::HandleDeviceAuthorizationResponse(const Azure::Core::Http::RawResponse &response) { +AzureDeviceCodeCredentialRequester::HandleDeviceAuthorizationResponse(const Azure::Core::Http::RawResponse &response) { const auto &response_body = response.GetBody(); const auto response_body_str = std::string(response_body.begin(), response_body.end()); if (response.GetStatusCode() == Azure::Core::Http::HttpStatusCode::Ok) { @@ -168,15 +154,25 @@ AzureDeviceCodeCredential::HandleDeviceAuthorizationResponse(const Azure::Core:: ParseJson(std::string(response_body.begin(), response_body.end()), &parsed_response); return parsed_response; } else { - throw IOException("[AzureDeviceCodeCredential] Failed to retrieve devicecode HTTP code: %d, details: %s", - response.GetStatusCode(), response_body_str); + throw IOException( + "[AzureDeviceCodeCredentialRequester] Failed to retrieve devicecode HTTP code: %d, details: %s", + response.GetStatusCode(), response_body_str); } } +AzureDeviceCodeCredential::AzureDeviceCodeCredential(std::string tenant_id, std::string client_id, + std::unordered_set scopes_p, + AzureDeviceCodeInfo device_code_info, + const Azure::Core::Credentials::TokenCredentialOptions &options) + : Azure::Core::Credentials::TokenCredential("DeviceCodeCredential"), tenant_id(std::move(tenant_id)), + client_id(std::move(client_id)), scopes(std::move(scopes_p)), device_code_info(std::move(device_code_info)), + http_pipeline(options, "identity", "DuckDB", {}, {}) { +} + Azure::Core::Credentials::AccessToken AzureDeviceCodeCredential::AuthenticatingUser() const { // Check if it still possible to retrieve a token! auto now = std::chrono::system_clock::now(); - if (now >= device_code_info->expires_at) { + if (now >= device_code_info.expires_at) { throw IOException("[AzureDeviceCodeCredential] Your previous credential has already expired please " "renew it by calling `SELECT * FROM azure_devicecode('')`;"); } @@ -184,7 +180,7 @@ Azure::Core::Credentials::AccessToken AzureDeviceCodeCredential::AuthenticatingU const std::string url = "https://login.microsoftonline.com/" + tenant_id + "/oauth2/v2.0/token"; const std::string body = "grant_type=urn:ietf:params:oauth:grant-type:device_code" "&client_id=" + - Azure::Core::Url::Encode(client_id) + "&device_code=" + device_code_info->device_code; + Azure::Core::Url::Encode(client_id) + "&device_code=" + device_code_info.device_code; Azure::Core::IO::MemoryBodyStream body_stream(reinterpret_cast(body.data()), body.size()); Azure::Core::Http::Request http_request(Azure::Core::Http::HttpMethod::Post, Azure::Core::Url(url), &body_stream); @@ -209,7 +205,7 @@ Azure::Core::Credentials::AccessToken AzureDeviceCodeCredential::AuthenticatingU TryParseJson(response_body_str, &error); if ("authorization_pending" == error.error) { // Wait before retry - std::this_thread::sleep_for(device_code_info->interval); + std::this_thread::sleep_for(device_code_info.interval); } else if ("authorization_declined" == error.error) { throw IOException("[AzureDeviceCodeCredential] Failed to retrieve user token, end user denied the " "authorization request. (error msg: %s)", @@ -235,7 +231,7 @@ AzureDeviceCodeCredential::GetToken(Azure::Core::Credentials::TokenRequestContex Azure::Core::Context const &context) const { using Azure::Core::_internal::StringExtensions; - if (!device_code_info) { + if (device_code_info.device_code.empty()) { throw IOException("[AzureDeviceCodeCredential] No device/user code register did you call `SELECT * FROM " "azure_devicecode('')`;"); } diff --git a/src/auth/azure_device_code_function.cpp b/src/auth/azure_device_code_function.cpp index c6a0d2d..fa36c46 100644 --- a/src/auth/azure_device_code_function.cpp +++ b/src/auth/azure_device_code_function.cpp @@ -65,9 +65,9 @@ static void AzureDeviceCodeImplementation(ClientContext &context, TableFunctionI throw InvalidInputException("azure_devicecode no secret found named %s", bind_data.secret_name); } - auto device_code_credential = CreateDeviceCodeCredential(ClientData::Get(context).file_opener.get(), - dynamic_cast(*secret->secret)); - auto device_code_info = device_code_credential->RequestDeviceCode(); + auto device_code_credential = CreateDeviceCodeCredentialRequester( + ClientData::Get(context).file_opener.get(), dynamic_cast(*secret->secret)); + auto device_code_info = device_code_credential.RequestDeviceCode(); auto &device_code_context = context.registered_state[AzureDeviceCodesClientContextState::CONTEXT_KEY]; if (!device_code_context) { diff --git a/src/azure_storage_account_client.cpp b/src/azure_storage_account_client.cpp index c91fde3..3a683d5 100644 --- a/src/azure_storage_account_client.cpp +++ b/src/azure_storage_account_client.cpp @@ -50,8 +50,14 @@ static std::string TryGetCurrentSetting(FileOpener *opener, const std::string &n return ""; } -static bool ConnectionStringMatchStorageAccountName(const std::string &connection_string, - const std::string &provided_storage_account) { +static bool TryMatchStorageAccountName(const std::string &connection_string, const AzureParsedUrl &azure_parsed_url) { + if (!azure_parsed_url.is_fully_qualified) { + // No way to check if it match + return true; + } + + const auto &provided_storage_account = azure_parsed_url.storage_account_name; + auto storage_account_name_pos = connection_string.find("AccountName="); if (storage_account_name_pos == std::string::npos) { throw InvalidInputException("A invalid connection string has been provided."); @@ -204,24 +210,38 @@ CreateClientCredential(const KeyValueSecret &secret, static std::shared_ptr CreateDeviceCodeCredential(const KeyValueSecret &secret, - const Azure::Core::Http::Policies::TransportOptions &transport_options, - const optional_ptr &device_code_info = nullptr) { + const Azure::Core::Http::Policies::TransportOptions &transport_options, FileOpener *opener) { + auto context = opener->TryGetClientContext(); + if (!context) { + throw InternalException("Context cannot be null!"); + } + auto device_codes_info_context_it = context->registered_state.find(AzureDeviceCodesClientContextState::CONTEXT_KEY); + if (device_codes_info_context_it == context->registered_state.end()) { + throw InternalException( + "Not device code has been initialized did you run `SELECT * FROM azure_devicecode('%s');`", + secret.GetName()); + } + + D_ASSERT(dynamic_cast(device_codes_info_context_it->second.get()) != nullptr); + const auto &device_code_info_by_secret = + reinterpret_cast(*device_codes_info_context_it->second) + .device_code_info_by_secret; + auto device_code_info_it = device_code_info_by_secret.find(secret.GetName()); + if (device_code_info_it == device_code_info_by_secret.end()) { + throw InternalException( + "Not device code has been initialized did you run `SELECT * FROM azure_devicecode('%s');`", + secret.GetName()); + } + constexpr bool error_on_missing = true; auto tenant_id = secret.TryGetValue("tenant_id", error_on_missing).ToString(); auto client_id = secret.TryGetValue("client_id", error_on_missing).ToString(); auto oauth_scopes_value = secret.TryGetValue("oauth_scopes", error_on_missing).ToString(); std::vector oauth_scopes = StringUtil::Split(oauth_scopes_value, ' '); - if (device_code_info) { - return std::make_shared( - tenant_id, client_id, std::unordered_set(oauth_scopes.begin(), oauth_scopes.end()), - *device_code_info, ToTokenCredentialOptions(transport_options)); - - } else { - return std::make_shared( - tenant_id, client_id, std::unordered_set(oauth_scopes.begin(), oauth_scopes.end()), - ToTokenCredentialOptions(transport_options)); - } + return std::make_shared( + tenant_id, client_id, std::unordered_set(oauth_scopes.begin(), oauth_scopes.end()), + device_code_info_it->second, ToTokenCredentialOptions(transport_options)); } static std::shared_ptr @@ -330,180 +350,91 @@ static Azure::Core::Http::Policies::TransportOptions GetTransportOptions(FileOpe return GetTransportOptions(transport_option_type, http_proxy, http_proxy_username, http_proxy_password); } -static Azure::Storage::Blobs::BlobServiceClient -GetBlobStorageAccountClientFromConfigProvider(FileOpener *opener, const KeyValueSecret &secret, - const AzureParsedUrl &azure_parsed_url) { - auto transport_options = GetTransportOptions(opener, secret); +std::shared_ptr +CreateAzureCredential(const KeyValueSecret &secret, + const Azure::Core::Http::Policies::TransportOptions &transport_options, FileOpener *opener) { + const auto &provider = secret.GetProvider(); - // If connection string, we're done heres - auto connection_string_val = secret.TryGetValue("connection_string"); - if (!connection_string_val.IsNull()) { - auto connection_string = connection_string_val.ToString(); - if (azure_parsed_url.is_fully_qualified && - !ConnectionStringMatchStorageAccountName(connection_string, azure_parsed_url.storage_account_name)) { - throw InvalidInputException("The provided connection string does not match the storage account named %s", - azure_parsed_url.storage_account_name); - } - - auto blob_options = ToBlobClientOptions(transport_options, GetHttpState(opener)); - return Azure::Storage::Blobs::BlobServiceClient::CreateFromConnectionString(connection_string, blob_options); + if (provider == "credential_chain") { + return CreateChainedTokenCredential(secret, transport_options); + } else if (provider == "service_principal") { + return CreateClientCredential(secret, transport_options); + } else if (provider == "device_code") { + return CreateDeviceCodeCredential(secret, transport_options, opener); } - // Default provider (config) with no connection string => public storage account - auto account_url = - azure_parsed_url.is_fully_qualified ? AccountUrl(azure_parsed_url) : AccountUrl(secret, DEFAULT_BLOB_ENDPOINT); - auto blob_options = ToBlobClientOptions(transport_options, GetHttpState(opener)); - return Azure::Storage::Blobs::BlobServiceClient(account_url, blob_options); + throw InvalidInputException("Unsupported provider type %s for azure", provider); } -static Azure::Storage::Files::DataLake::DataLakeServiceClient -GetDfsStorageAccountClientFromConfigProvider(FileOpener *opener, const KeyValueSecret &secret, - const AzureParsedUrl &azure_parsed_url) { +static Azure::Storage::Blobs::BlobServiceClient +GetBlobStorageAccountClient(FileOpener *opener, const KeyValueSecret &secret, const AzureParsedUrl &azure_parsed_url) { + auto &provider = secret.GetProvider(); auto transport_options = GetTransportOptions(opener, secret); + auto blob_options = ToBlobClientOptions(transport_options, GetHttpState(opener)); - // If connection string, we're done heres - auto connection_string_val = secret.TryGetValue("connection_string"); - if (!connection_string_val.IsNull()) { - auto connection_string = connection_string_val.ToString(); - if (azure_parsed_url.is_fully_qualified && - !ConnectionStringMatchStorageAccountName(connection_string, azure_parsed_url.storage_account_name)) { - throw InvalidInputException("The provided connection string does not match the storage account named %s", - azure_parsed_url.storage_account_name); + // default provider + if (provider == "config") { + // If connection string, we're done heres + auto connection_string_val = secret.TryGetValue("connection_string"); + if (!connection_string_val.IsNull()) { + auto connection_string = connection_string_val.ToString(); + if (!TryMatchStorageAccountName(connection_string, azure_parsed_url)) { + throw InvalidInputException( + "The provided connection string does not match the storage account named %s", + azure_parsed_url.storage_account_name); + } + + return Azure::Storage::Blobs::BlobServiceClient::CreateFromConnectionString(connection_string, + blob_options); } - auto dfs_options = ToDfsClientOptions(transport_options, GetHttpState(opener)); - return Azure::Storage::Files::DataLake::DataLakeServiceClient::CreateFromConnectionString(connection_string, - dfs_options); + // Default provider (config) with no connection string => public storage account + auto account_url = azure_parsed_url.is_fully_qualified ? AccountUrl(azure_parsed_url) + : AccountUrl(secret, DEFAULT_BLOB_ENDPOINT); + return Azure::Storage::Blobs::BlobServiceClient(account_url, blob_options); } - // Default provider (config) with no connection string => public storage account - auto account_url = - azure_parsed_url.is_fully_qualified ? AccountUrl(azure_parsed_url) : AccountUrl(secret, DEFAULT_DFS_ENDPOINT); - auto dfs_options = ToDfsClientOptions(transport_options, GetHttpState(opener)); - return Azure::Storage::Files::DataLake::DataLakeServiceClient(account_url, dfs_options); -} + // All other provider have token credential + auto credential = CreateAzureCredential(secret, transport_options, opener); -static Azure::Storage::Blobs::BlobServiceClient -GetBlobStorageAccountClientFromCredentialChainProvider(FileOpener *opener, const KeyValueSecret &secret, - const AzureParsedUrl &azure_parsed_url) { - auto transport_options = GetTransportOptions(opener, secret); - // Create credential chain - auto credential = CreateChainedTokenCredential(secret, transport_options); - - // Connect to storage account auto account_url = azure_parsed_url.is_fully_qualified ? AccountUrl(azure_parsed_url) : AccountUrl(secret, DEFAULT_BLOB_ENDPOINT); - auto blob_options = ToBlobClientOptions(transport_options, GetHttpState(opener)); return Azure::Storage::Blobs::BlobServiceClient(account_url, std::move(credential), blob_options); } static Azure::Storage::Files::DataLake::DataLakeServiceClient -GetDfsStorageAccountClientFromCredentialChainProvider(FileOpener *opener, const KeyValueSecret &secret, - const AzureParsedUrl &azure_parsed_url) { +GetDfsStorageAccountClient(FileOpener *opener, const KeyValueSecret &secret, const AzureParsedUrl &azure_parsed_url) { + auto &provider = secret.GetProvider(); auto transport_options = GetTransportOptions(opener, secret); - // Create credential chain - auto credential = CreateChainedTokenCredential(secret, transport_options); - - // Connect to storage account - auto account_url = - azure_parsed_url.is_fully_qualified ? AccountUrl(azure_parsed_url) : AccountUrl(secret, DEFAULT_DFS_ENDPOINT); auto dfs_options = ToDfsClientOptions(transport_options, GetHttpState(opener)); - return Azure::Storage::Files::DataLake::DataLakeServiceClient(account_url, std::move(credential), dfs_options); -} -static Azure::Storage::Blobs::BlobServiceClient -GetBlobStorageAccountClientFromServicePrincipalProvider(FileOpener *opener, const KeyValueSecret &secret, - const AzureParsedUrl &azure_parsed_url) { - auto transport_options = GetTransportOptions(opener, secret); - auto token_credential = CreateClientCredential(secret, transport_options); - - auto account_url = - azure_parsed_url.is_fully_qualified ? AccountUrl(azure_parsed_url) : AccountUrl(secret, DEFAULT_BLOB_ENDPOINT); - ; - auto blob_options = ToBlobClientOptions(transport_options, GetHttpState(opener)); - return Azure::Storage::Blobs::BlobServiceClient(account_url, token_credential, blob_options); -} - -static Azure::Storage::Files::DataLake::DataLakeServiceClient -GetDfsStorageAccountClientFromServicePrincipalProvider(FileOpener *opener, const KeyValueSecret &secret, - const AzureParsedUrl &azure_parsed_url) { - auto transport_options = GetTransportOptions(opener, secret); - auto token_credential = CreateClientCredential(secret, transport_options); - - auto account_url = - azure_parsed_url.is_fully_qualified ? AccountUrl(azure_parsed_url) : AccountUrl(secret, DEFAULT_DFS_ENDPOINT); - - auto dfs_options = ToDfsClientOptions(transport_options, GetHttpState(opener)); - return Azure::Storage::Files::DataLake::DataLakeServiceClient(account_url, token_credential, dfs_options); -} + // default provider + if (provider == "config") { + auto connection_string_val = secret.TryGetValue("connection_string"); + if (!connection_string_val.IsNull()) { + auto connection_string = connection_string_val.ToString(); + if (!TryMatchStorageAccountName(connection_string, azure_parsed_url)) { + throw InvalidInputException( + "The provided connection string does not match the storage account named %s", + azure_parsed_url.storage_account_name); + } -static Azure::Storage::Files::DataLake::DataLakeServiceClient -GetDfsStorageAccountClientFromDeviceCodeProvider(FileOpener *opener, const KeyValueSecret &secret, - const AzureParsedUrl &azure_parsed_url) { - auto context = opener->TryGetClientContext(); - if (!context) { - throw InternalException("Context cannot be null!"); - } - auto device_codes_info_context_it = context->registered_state.find(AzureDeviceCodesClientContextState::CONTEXT_KEY); - if (device_codes_info_context_it == context->registered_state.end()) { - throw InternalException( - "Not device code has been initialized did you run `SELECT * FROM azure_devicecode('%s');`", - secret.GetName()); - } + return Azure::Storage::Files::DataLake::DataLakeServiceClient::CreateFromConnectionString(connection_string, + dfs_options); + } - D_ASSERT(dynamic_cast(device_codes_info_context_it->second.get()) != nullptr); - const auto &device_code_info_by_secret = - reinterpret_cast(*device_codes_info_context_it->second) - .device_code_info_by_secret; - auto device_code_info_it = device_code_info_by_secret.find(secret.GetName()); - if (device_code_info_it == device_code_info_by_secret.end()) { - throw InternalException( - "Not device code has been initialized did you run `SELECT * FROM azure_devicecode('%s');`", - secret.GetName()); + // Default provider (config) with no connection string => public storage account + auto account_url = azure_parsed_url.is_fully_qualified ? AccountUrl(azure_parsed_url) + : AccountUrl(secret, DEFAULT_DFS_ENDPOINT); + return Azure::Storage::Files::DataLake::DataLakeServiceClient(account_url, dfs_options); } - auto transport_options = GetTransportOptions(opener, secret); - auto token_credential = CreateDeviceCodeCredential(secret, transport_options, &device_code_info_it->second); + // All other provider have token credential + auto credential = CreateAzureCredential(secret, transport_options, opener); auto account_url = azure_parsed_url.is_fully_qualified ? AccountUrl(azure_parsed_url) : AccountUrl(secret, DEFAULT_DFS_ENDPOINT); - - auto dfs_options = ToDfsClientOptions(transport_options, GetHttpState(opener)); - return Azure::Storage::Files::DataLake::DataLakeServiceClient(account_url, token_credential, dfs_options); -} - -static Azure::Storage::Blobs::BlobServiceClient -GetBlobStorageAccountClient(FileOpener *opener, const KeyValueSecret &secret, const AzureParsedUrl &azure_parsed_url) { - auto &provider = secret.GetProvider(); - // default provider - if (provider == "config") { - return GetBlobStorageAccountClientFromConfigProvider(opener, secret, azure_parsed_url); - } else if (provider == "credential_chain") { - return GetBlobStorageAccountClientFromCredentialChainProvider(opener, secret, azure_parsed_url); - } else if (provider == "service_principal") { - return GetBlobStorageAccountClientFromServicePrincipalProvider(opener, secret, azure_parsed_url); - } else if (provider == "device_code") { - // TODO implements: - } - - throw InvalidInputException("Unsupported provider type %s for azure", provider); -} - -static Azure::Storage::Files::DataLake::DataLakeServiceClient -GetDfsStorageAccountClient(FileOpener *opener, const KeyValueSecret &secret, const AzureParsedUrl &azure_parsed_url) { - auto &provider = secret.GetProvider(); - // default provider - if (provider == "config") { - return GetDfsStorageAccountClientFromConfigProvider(opener, secret, azure_parsed_url); - } else if (provider == "credential_chain") { - return GetDfsStorageAccountClientFromCredentialChainProvider(opener, secret, azure_parsed_url); - } else if (provider == "service_principal") { - return GetDfsStorageAccountClientFromServicePrincipalProvider(opener, secret, azure_parsed_url); - } else if (provider == "device_code") { - return GetDfsStorageAccountClientFromDeviceCodeProvider(opener, secret, azure_parsed_url); - } - - throw InvalidInputException("Unsupported provider type %s for azure", provider); + return Azure::Storage::Files::DataLake::DataLakeServiceClient(account_url, std::move(credential), dfs_options); } static Azure::Core::Http::Policies::TransportOptions GetTransportOptions(FileOpener *opener) { @@ -524,8 +455,7 @@ static Azure::Storage::Blobs::BlobServiceClient GetBlobStorageAccountClient(File auto blob_options = ToBlobClientOptions(transport_options, GetHttpState(opener)); auto connection_string = TryGetCurrentSetting(opener, "azure_storage_connection_string"); - if (!connection_string.empty() && - ConnectionStringMatchStorageAccountName(connection_string, provided_storage_account)) { + if (!connection_string.empty()) { return Azure::Storage::Blobs::BlobServiceClient::CreateFromConnectionString(connection_string, blob_options); } @@ -573,10 +503,18 @@ const SecretMatch LookupSecret(FileOpener *opener, const std::string &path) { return {}; } -std::shared_ptr CreateDeviceCodeCredential(FileOpener *opener, - const KeyValueSecret &secret) { +AzureDeviceCodeCredentialRequester CreateDeviceCodeCredentialRequester(FileOpener *opener, + const KeyValueSecret &secret) { auto transport_options = GetTransportOptions(opener, secret); - return CreateDeviceCodeCredential(secret, transport_options); + constexpr bool error_on_missing = true; + auto tenant_id = secret.TryGetValue("tenant_id", error_on_missing).ToString(); + auto client_id = secret.TryGetValue("client_id", error_on_missing).ToString(); + auto oauth_scopes_value = secret.TryGetValue("oauth_scopes", error_on_missing).ToString(); + std::vector oauth_scopes = StringUtil::Split(oauth_scopes_value, ' '); + + return AzureDeviceCodeCredentialRequester(tenant_id, client_id, + std::unordered_set(oauth_scopes.begin(), oauth_scopes.end()), + ToTokenCredentialOptions(transport_options)); } Azure::Storage::Blobs::BlobServiceClient ConnectToBlobStorageAccount(FileOpener *opener, const std::string &path, diff --git a/src/include/auth/azure_device_code_credential.hpp b/src/include/auth/azure_device_code_credential.hpp index 47e0855..4e4693c 100644 --- a/src/include/auth/azure_device_code_credential.hpp +++ b/src/include/auth/azure_device_code_credential.hpp @@ -13,6 +13,19 @@ #include #include +/** + * Implement the missing DeviceCodeCredential from the C++ SDK + * https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-device-code + * + * Note: The way this has been develop is also a hack on how the workflow (should) work. + * In theory the scopes shouldn't be an args of the constructor, they are given when a request + * call the #GetToken method and we should call a callback that would inform the user that they + * have to go to an URL and enter the user code. + * In our case it's hard to prompt the user because when queries are performed we do not known + * how DuckDB is really being use(cmd, lib...) + * So we split the way we obtains the user/device code and the token retrieval. + */ + namespace duckdb { struct AzureDeviceCodeInfo { @@ -33,25 +46,32 @@ struct AzureDeviceCodeInfo { std::string message; }; -/** - * Implement the missing DeviceCodeCredential from the C++ SDK - * https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-device-code - * - * Note: The way this has been develop is also a hack on how the workflow (should) work. - * In theory the scopes shouldn't be an args of the constructor, they are given when a request - * call the #GetToken method and we should call a callback that would inform the user that they - * have to go to an URL and enter the user code. - * In our case it's hard to prompt the user because when queries are performed we do not known - * how DuckDB is really being use(cmd, lib...) - * So we split the way we obtains the user/device code and the token retrieval. - */ -class AzureDeviceCodeCredential final : public Azure::Core::Credentials::TokenCredential { +class AzureDeviceCodeCredentialRequester final { public: - explicit AzureDeviceCodeCredential(std::string tenant_id, std::string client_id, - std::unordered_set scopes, - Azure::Core::Credentials::TokenCredentialOptions const &options = - Azure::Core::Credentials::TokenCredentialOptions()); + explicit AzureDeviceCodeCredentialRequester(std::string tenant_id, std::string client_id, + std::unordered_set scopes, + const Azure::Core::Credentials::TokenCredentialOptions &options = + Azure::Core::Credentials::TokenCredentialOptions()); + /** + * Send a request to get the user & device code + */ + AzureDeviceCodeInfo RequestDeviceCode(); + +private: + AzureDeviceCodeInfo HandleDeviceAuthorizationResponse(const Azure::Core::Http::RawResponse &response); + +private: + const std::string tenant_id; + const std::string client_id; + const std::unordered_set scopes; + const std::string encoded_scopes; + + Azure::Core::Http::_internal::HttpPipeline http_pipeline; +}; + +class AzureDeviceCodeCredential final : public Azure::Core::Credentials::TokenCredential { +public: explicit AzureDeviceCodeCredential(std::string tenant_id, std::string client_id, std::unordered_set scopes, AzureDeviceCodeInfo device_code, const Azure::Core::Credentials::TokenCredentialOptions &options = @@ -60,26 +80,14 @@ class AzureDeviceCodeCredential final : public Azure::Core::Credentials::TokenCr GetToken(Azure::Core::Credentials::TokenRequestContext const &token_request_context, Azure::Core::Context const &context) const override; - /** - * Send a request to get the user & device code - */ - AzureDeviceCodeInfo RequestDeviceCode(); - private: - explicit AzureDeviceCodeCredential(std::string tenant_id, std::string client_id, - std::unordered_set scopes, - const Azure::Core::Credentials::TokenCredentialOptions &options, - std::unique_ptr device_code_info); - - AzureDeviceCodeInfo HandleDeviceAuthorizationResponse(const Azure::Core::Http::RawResponse &response); Azure::Core::Credentials::AccessToken AuthenticatingUser() const; private: const std::string tenant_id; const std::string client_id; const std::unordered_set scopes; - const std::string encoded_scopes; - const std::unique_ptr device_code_info; + const AzureDeviceCodeInfo device_code_info; Azure::Identity::_detail::TokenCache token_cache; Azure::Core::Http::_internal::HttpPipeline http_pipeline; diff --git a/src/include/azure_storage_account_client.hpp b/src/include/azure_storage_account_client.hpp index 3ce2aa8..cce2004 100644 --- a/src/include/azure_storage_account_client.hpp +++ b/src/include/azure_storage_account_client.hpp @@ -10,7 +10,8 @@ namespace duckdb { -std::shared_ptr CreateDeviceCodeCredential(FileOpener *opener, const KeyValueSecret &secret); +AzureDeviceCodeCredentialRequester CreateDeviceCodeCredentialRequester(FileOpener *opener, + const KeyValueSecret &secret); Azure::Storage::Blobs::BlobServiceClient ConnectToBlobStorageAccount(FileOpener *opener, const std::string &path, const AzureParsedUrl &azure_parsed_url); From 76f05c1ccdb9f73662186495faff4192b29bcbca Mon Sep 17 00:00:00 2001 From: Quentin GODEAU Date: Wed, 27 Mar 2024 21:13:11 +0100 Subject: [PATCH 3/4] Add management of refresh token --- src/auth/azure_device_code_credential.cpp | 154 +++++++++++------- src/azure_secret.cpp | 2 +- .../auth/azure_device_code_credential.hpp | 31 +++- 3 files changed, 124 insertions(+), 63 deletions(-) diff --git a/src/auth/azure_device_code_credential.cpp b/src/auth/azure_device_code_credential.cpp index 675c393..754b704 100644 --- a/src/auth/azure_device_code_credential.cpp +++ b/src/auth/azure_device_code_credential.cpp @@ -7,14 +7,15 @@ #include #include #include -#include #include #include -#include #include #include +#include #include #include +#include +#include #include #include #include @@ -23,10 +24,6 @@ namespace duckdb { -// TODO https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-auth-code-flow#refresh-the-access-token -// TODO replace AccessToken by this class ?? -struct RequestDeviceCodeResponse {}; - struct HttpResponseError { std::string error; std::string error_description; @@ -54,19 +51,6 @@ static void ParseJson(const std::string &json_str, AzureDeviceCodeInfo *response throw IOException("[AzureDeviceCodeCredential] Failed to parse JSON Azure response '%s'", ex.what()); } } -static void ParseJson(const std::string &json_str, Azure::Core::Credentials::AccessToken *token) { - try { - auto json = Azure::Core::Json::_internal::json::parse(json_str); - - token->Token = json.at("access_token").get(); - token->ExpiresOn = Azure::DateTime(std::chrono::system_clock::now()) + - std::chrono::seconds(json.at("expires_in").get()); - } catch (const Azure::Core::Json::_internal::json::out_of_range &ex) { - throw IOException("[AzureDeviceCodeCredential] Failed to parse Azure response '%s'", ex.what()); - } catch (const Azure::Core::Json::_internal::json::exception &ex) { - throw IOException("[AzureDeviceCodeCredential] Failed to parse JSON Azure response '%s'", ex.what()); - } -} static bool TryParseJson(const std::string &json_str, HttpResponseError *error) { try { @@ -104,26 +88,6 @@ static std::string EncodeScopes(const std::unordered_set &scopes) { return Azure::Core::Url::Encode(result); } -static std::string CacheScopeString(const std::vector &scopes) { - switch (scopes.size()) { - case 0: - return ""; - - case 1: - return scopes[0]; - - default: { - std::string result; - auto copy_scopes = scopes; - std::sort(copy_scopes.begin(), copy_scopes.end()); - for (const auto &scope : copy_scopes) { - result += scope; - } - return result; - } - } -} - AzureDeviceCodeCredentialRequester::AzureDeviceCodeCredentialRequester( std::string tenant_id, std::string client_id, std::unordered_set scopes_p, const Azure::Core::Credentials::TokenCredentialOptions &options) @@ -132,7 +96,7 @@ AzureDeviceCodeCredentialRequester::AzureDeviceCodeCredentialRequester( } AzureDeviceCodeInfo AzureDeviceCodeCredentialRequester::RequestDeviceCode() { - const std::string url = "https://login.microsoftonline.com/" + tenant_id + "/oauth2/v2.0/devicecode"; + const std::string url = Azure::Identity::_detail::AadGlobalAuthority + tenant_id + "/oauth2/v2.0/devicecode"; const std::string body = "client_id=" + Azure::Core::Url::Encode(client_id) + "&scope=" + encoded_scopes; Azure::Core::IO::MemoryBodyStream body_stream(reinterpret_cast(body.data()), body.size()); @@ -165,11 +129,11 @@ AzureDeviceCodeCredential::AzureDeviceCodeCredential(std::string tenant_id, std: AzureDeviceCodeInfo device_code_info, const Azure::Core::Credentials::TokenCredentialOptions &options) : Azure::Core::Credentials::TokenCredential("DeviceCodeCredential"), tenant_id(std::move(tenant_id)), - client_id(std::move(client_id)), scopes(std::move(scopes_p)), device_code_info(std::move(device_code_info)), - http_pipeline(options, "identity", "DuckDB", {}, {}) { + client_id(std::move(client_id)), scopes(std::move(scopes_p)), encoded_scopes(EncodeScopes(scopes)), + device_code_info(std::move(device_code_info)), http_pipeline(options, "identity", "DuckDB", {}, {}) { } -Azure::Core::Credentials::AccessToken AzureDeviceCodeCredential::AuthenticatingUser() const { +AzureDeviceCodeCredential::OAuthAccessToken AzureDeviceCodeCredential::AuthenticatingUser() const { // Check if it still possible to retrieve a token! auto now = std::chrono::system_clock::now(); if (now >= device_code_info.expires_at) { @@ -177,10 +141,13 @@ Azure::Core::Credentials::AccessToken AzureDeviceCodeCredential::AuthenticatingU "renew it by calling `SELECT * FROM azure_devicecode('')`;"); } - const std::string url = "https://login.microsoftonline.com/" + tenant_id + "/oauth2/v2.0/token"; + const std::string url = Azure::Identity::_detail::AadGlobalAuthority + tenant_id + "/oauth2/v2.0/token"; + // clang-format off const std::string body = "grant_type=urn:ietf:params:oauth:grant-type:device_code" - "&client_id=" + - Azure::Core::Url::Encode(client_id) + "&device_code=" + device_code_info.device_code; + "&client_id=" + Azure::Core::Url::Encode(client_id) + + "&device_code=" + device_code_info.device_code; + // clang-format on + Azure::Core::IO::MemoryBodyStream body_stream(reinterpret_cast(body.data()), body.size()); Azure::Core::Http::Request http_request(Azure::Core::Http::HttpMethod::Post, Azure::Core::Url(url), &body_stream); @@ -193,9 +160,10 @@ Azure::Core::Credentials::AccessToken AzureDeviceCodeCredential::AuthenticatingU const auto &response_body = response->GetBody(); const auto response_body_str = std::string(response_body.begin(), response_body.end()); - switch (response->GetStatusCode()) { + const auto response_code = response->GetStatusCode(); + switch (response_code) { case Azure::Core::Http::HttpStatusCode::Ok: { - Azure::Core::Credentials::AccessToken token; + OAuthAccessToken token; ParseJson(response_body_str, &token); return token; } break; @@ -219,25 +187,79 @@ Azure::Core::Credentials::AccessToken AzureDeviceCodeCredential::AuthenticatingU "[AzureDeviceCodeCredential] Failed to retrieve user token already expired. (error msg: %s)", response_body_str); } else { - throw IOException("[AzureDeviceCodeCredential] Unexpected error: %s", response_body_str); + throw IOException("[AzureDeviceCodeCredential] Unexpected error (HTTP: %d): %s", response_code, + response_body_str); } } break; } } } +AzureDeviceCodeCredential::OAuthAccessToken AzureDeviceCodeCredential::RefreshToken() const { + const std::string url = Azure::Identity::_detail::AadGlobalAuthority + tenant_id + "/oauth2/v2.0/token"; + // clang-format off + const std::string body = "grant_type=refresh_token" + "&client_id=" + Azure::Core::Url::Encode(client_id) + + "&scope=" + encoded_scopes + + "&refresh_token=" + token.refresh_token; + // clang-format on + Azure::Core::IO::MemoryBodyStream body_stream(reinterpret_cast(body.data()), body.size()); + + Azure::Core::Http::Request http_request(Azure::Core::Http::HttpMethod::Post, Azure::Core::Url(url), &body_stream); + http_request.SetHeader("Content-Type", "application/x-www-form-urlencoded"); + http_request.SetHeader("Content-Length", std::to_string(body.size())); + http_request.SetHeader("Accept", "application/json"); + + auto response = http_pipeline.Send(http_request, Azure::Core::Context()); + const auto &response_body = response->GetBody(); + const auto response_body_str = std::string(response_body.begin(), response_body.end()); + const auto response_code = response->GetStatusCode(); + if (Azure::Core::Http::HttpStatusCode::Ok == response_code) { + OAuthAccessToken token; + ParseJson(response_body_str, &token); + return token; + } else { + throw IOException( + "[AzureDeviceCodeCredential] Failed to refresh token due to the following error (HTTP %d): %s", + response_code, response_body_str); + } +} + +void AzureDeviceCodeCredential::ParseJson(const std::string &json_str, OAuthAccessToken *token) { + try { + auto json = Azure::Core::Json::_internal::json::parse(json_str); + + // Mandatory + token->access_token = json.at("access_token").get(); + token->expires_at = Azure::DateTime(std::chrono::system_clock::now()) + + std::chrono::seconds(json.at("expires_in").get()); + + // Optional depending of the scopes + if (json.contains("refresh_token")) { + token->refresh_token = json.at("refresh_token").get(); + } + } catch (const Azure::Core::Json::_internal::json::out_of_range &ex) { + throw IOException("[AzureDeviceCodeCredential] Failed to parse Azure response '%s'", ex.what()); + } catch (const Azure::Core::Json::_internal::json::exception &ex) { + throw IOException("[AzureDeviceCodeCredential] Failed to parse JSON Azure response '%s'", ex.what()); + } +} + +bool AzureDeviceCodeCredential::IsFresh(const AzureDeviceCodeCredential::OAuthAccessToken &token, + Azure::DateTime::duration minimum_expiration, + std::chrono::system_clock::time_point now) { + return token.expires_at > (Azure::DateTime(now) + minimum_expiration); +} + Azure::Core::Credentials::AccessToken AzureDeviceCodeCredential::GetToken(Azure::Core::Credentials::TokenRequestContext const &token_request_context, Azure::Core::Context const &context) const { - using Azure::Core::_internal::StringExtensions; - if (device_code_info.device_code.empty()) { throw IOException("[AzureDeviceCodeCredential] No device/user code register did you call `SELECT * FROM " "azure_devicecode('')`;"); } - if (!token_request_context.TenantId.empty() && - !StringExtensions::LocaleInvariantCaseInsensitiveEqual(token_request_context.TenantId, tenant_id)) { + if (!token_request_context.TenantId.empty() && !StringUtil::CIEquals(token_request_context.TenantId, tenant_id)) { throw IOException( "[AzureDeviceCodeCredential] The current credential is not configured to acquire tokens for tenant '%s'.", @@ -250,11 +272,29 @@ AzureDeviceCodeCredential::GetToken(Azure::Core::Credentials::TokenRequestContex scope); } } - auto request_scopes = token_request_context.Scopes; - std::sort(request_scopes.begin(), request_scopes.end()); - return token_cache.GetToken(CacheScopeString(token_request_context.Scopes), token_request_context.TenantId, - token_request_context.MinimumExpiration, [&]() { return AuthenticatingUser(); }); + { + std::shared_lock read_lock(token_mutex); + if (!token.access_token.empty() && + IsFresh(token, token_request_context.MinimumExpiration, std::chrono::system_clock::now())) { + return Azure::Core::Credentials::AccessToken {token.access_token, token.expires_at}; + } + } + + { + std::unique_lock write_lock(token_mutex); + if (!token.access_token.empty() && + IsFresh(token, token_request_context.MinimumExpiration, std::chrono::system_clock::now())) { + return Azure::Core::Credentials::AccessToken {token.access_token, token.expires_at}; + } + + if (token.refresh_token.empty()) { + token = AuthenticatingUser(); + } else { + token = RefreshToken(); + } + return Azure::Core::Credentials::AccessToken {token.access_token, token.expires_at}; + } } } // namespace duckdb \ No newline at end of file diff --git a/src/azure_secret.cpp b/src/azure_secret.cpp index 47813e6..a57e8c0 100644 --- a/src/azure_secret.cpp +++ b/src/azure_secret.cpp @@ -136,7 +136,7 @@ static unique_ptr CreateAzureSecretFromDeviceCode(ClientContext &con // Manage specific secret option CopySecret("tenant_id", input, *result); CopySecret("client_id", input, *result); - CopySecret("oauth_scopes", input, *result, "https://storage.azure.com/.default"); + CopySecret("oauth_scopes", input, *result, "https://storage.azure.com/.default offline_access"); // Redact sensible keys RedactCommonKeys(*result); diff --git a/src/include/auth/azure_device_code_credential.hpp b/src/include/auth/azure_device_code_credential.hpp index 4e4693c..9f4c08f 100644 --- a/src/include/auth/azure_device_code_credential.hpp +++ b/src/include/auth/azure_device_code_credential.hpp @@ -2,14 +2,13 @@ #include #include +#include #include #include -#include -#include #include #include -#include #include +#include #include #include @@ -81,16 +80,38 @@ class AzureDeviceCodeCredential final : public Azure::Core::Credentials::TokenCr Azure::Core::Context const &context) const override; private: - Azure::Core::Credentials::AccessToken AuthenticatingUser() const; + /** + * When refresh token is set it seen to be valid for 90 days + * @see https://learn.microsoft.com/en-us/entra/identity-platform/refresh-tokens#token-lifetime + * @see https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-auth-code-flow#refresh-the-access-token + */ + struct OAuthAccessToken { + // Number of seconds the included access token is valid for. + Azure::DateTime expires_at; + // Issued for the scopes that were requested. + std::string access_token; + // Issued if the original scope parameter included offline_access. + std::string refresh_token; + }; + +private: + OAuthAccessToken AuthenticatingUser() const; + OAuthAccessToken RefreshToken() const; + static bool IsFresh(const OAuthAccessToken &token, Azure::DateTime::duration minimum_expiration, + std::chrono::system_clock::time_point now); + static void ParseJson(const std::string &json_str, OAuthAccessToken *token); private: const std::string tenant_id; const std::string client_id; const std::unordered_set scopes; + const std::string encoded_scopes; const AzureDeviceCodeInfo device_code_info; - Azure::Identity::_detail::TokenCache token_cache; Azure::Core::Http::_internal::HttpPipeline http_pipeline; + + mutable std::shared_timed_mutex token_mutex; + mutable OAuthAccessToken token; }; } // namespace duckdb \ No newline at end of file From 7232f99cc3397ebc5dd022f69c4ae74d528ca7ab Mon Sep 17 00:00:00 2001 From: Quentin GODEAU Date: Thu, 28 Mar 2024 07:32:48 +0100 Subject: [PATCH 4/4] Keep token cache in context to allow reuse in other queries --- CMakeLists.txt | 1 - src/auth/azure_device_code_credential.cpp | 37 +++++++------- src/auth/azure_device_code_function.cpp | 19 ++++---- src/auth/azure_device_codes_context.cpp | 6 --- src/azure_storage_account_client.cpp | 23 ++++----- .../auth/azure_device_code_context.hpp | 47 ++++++++++++++++++ .../auth/azure_device_code_credential.hpp | 48 +++++++++---------- .../auth/azure_device_codes_context.hpp | 13 ----- 8 files changed, 109 insertions(+), 85 deletions(-) delete mode 100644 src/auth/azure_device_codes_context.cpp create mode 100644 src/include/auth/azure_device_code_context.hpp delete mode 100644 src/include/auth/azure_device_codes_context.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index be3f76c..e90e21c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,7 +13,6 @@ set(CMAKE_CXX_STANDARD_REQUIRED True) set(EXTENSION_SOURCES src/auth/azure_device_code_credential.cpp src/auth/azure_device_code_function.cpp - src/auth/azure_device_codes_context.cpp src/azure_blob_filesystem.cpp src/azure_dfs_filesystem.cpp src/azure_extension.cpp diff --git a/src/auth/azure_device_code_credential.cpp b/src/auth/azure_device_code_credential.cpp index 754b704..1e683d8 100644 --- a/src/auth/azure_device_code_credential.cpp +++ b/src/auth/azure_device_code_credential.cpp @@ -1,4 +1,5 @@ #include "auth/azure_device_code_credential.hpp" +#include "auth/azure_device_code_context.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/helper.hpp" #include "duckdb/common/string_util.hpp" @@ -124,16 +125,16 @@ AzureDeviceCodeCredentialRequester::HandleDeviceAuthorizationResponse(const Azur } } -AzureDeviceCodeCredential::AzureDeviceCodeCredential(std::string tenant_id, std::string client_id, - std::unordered_set scopes_p, - AzureDeviceCodeInfo device_code_info, - const Azure::Core::Credentials::TokenCredentialOptions &options) +AzureDeviceCodeCredential::AzureDeviceCodeCredential( + std::string tenant_id, std::string client_id, std::unordered_set scopes_p, + std::shared_ptr device_code_context, + const Azure::Core::Credentials::TokenCredentialOptions &options) : Azure::Core::Credentials::TokenCredential("DeviceCodeCredential"), tenant_id(std::move(tenant_id)), client_id(std::move(client_id)), scopes(std::move(scopes_p)), encoded_scopes(EncodeScopes(scopes)), - device_code_info(std::move(device_code_info)), http_pipeline(options, "identity", "DuckDB", {}, {}) { + device_code_context(std::move(device_code_context)), http_pipeline(options, "identity", "DuckDB", {}, {}) { } -AzureDeviceCodeCredential::OAuthAccessToken AzureDeviceCodeCredential::AuthenticatingUser() const { +AzureOAuthAccessToken AzureDeviceCodeCredential::AuthenticatingUser(const AzureDeviceCodeInfo &device_code_info) const { // Check if it still possible to retrieve a token! auto now = std::chrono::system_clock::now(); if (now >= device_code_info.expires_at) { @@ -163,7 +164,7 @@ AzureDeviceCodeCredential::OAuthAccessToken AzureDeviceCodeCredential::Authentic const auto response_code = response->GetStatusCode(); switch (response_code) { case Azure::Core::Http::HttpStatusCode::Ok: { - OAuthAccessToken token; + AzureOAuthAccessToken token; ParseJson(response_body_str, &token); return token; } break; @@ -195,13 +196,13 @@ AzureDeviceCodeCredential::OAuthAccessToken AzureDeviceCodeCredential::Authentic } } -AzureDeviceCodeCredential::OAuthAccessToken AzureDeviceCodeCredential::RefreshToken() const { +AzureOAuthAccessToken AzureDeviceCodeCredential::RefreshToken(const std::string &refresh_token) const { const std::string url = Azure::Identity::_detail::AadGlobalAuthority + tenant_id + "/oauth2/v2.0/token"; // clang-format off const std::string body = "grant_type=refresh_token" "&client_id=" + Azure::Core::Url::Encode(client_id) + "&scope=" + encoded_scopes + - "&refresh_token=" + token.refresh_token; + "&refresh_token=" + refresh_token; // clang-format on Azure::Core::IO::MemoryBodyStream body_stream(reinterpret_cast(body.data()), body.size()); @@ -215,7 +216,7 @@ AzureDeviceCodeCredential::OAuthAccessToken AzureDeviceCodeCredential::RefreshTo const auto response_body_str = std::string(response_body.begin(), response_body.end()); const auto response_code = response->GetStatusCode(); if (Azure::Core::Http::HttpStatusCode::Ok == response_code) { - OAuthAccessToken token; + AzureOAuthAccessToken token; ParseJson(response_body_str, &token); return token; } else { @@ -225,7 +226,7 @@ AzureDeviceCodeCredential::OAuthAccessToken AzureDeviceCodeCredential::RefreshTo } } -void AzureDeviceCodeCredential::ParseJson(const std::string &json_str, OAuthAccessToken *token) { +void AzureDeviceCodeCredential::ParseJson(const std::string &json_str, AzureOAuthAccessToken *token) { try { auto json = Azure::Core::Json::_internal::json::parse(json_str); @@ -245,7 +246,7 @@ void AzureDeviceCodeCredential::ParseJson(const std::string &json_str, OAuthAcce } } -bool AzureDeviceCodeCredential::IsFresh(const AzureDeviceCodeCredential::OAuthAccessToken &token, +bool AzureDeviceCodeCredential::IsFresh(const AzureOAuthAccessToken &token, Azure::DateTime::duration minimum_expiration, std::chrono::system_clock::time_point now) { return token.expires_at > (Azure::DateTime(now) + minimum_expiration); @@ -254,7 +255,7 @@ bool AzureDeviceCodeCredential::IsFresh(const AzureDeviceCodeCredential::OAuthAc Azure::Core::Credentials::AccessToken AzureDeviceCodeCredential::GetToken(Azure::Core::Credentials::TokenRequestContext const &token_request_context, Azure::Core::Context const &context) const { - if (device_code_info.device_code.empty()) { + if (device_code_context->device_code_info.device_code.empty()) { throw IOException("[AzureDeviceCodeCredential] No device/user code register did you call `SELECT * FROM " "azure_devicecode('')`;"); } @@ -274,7 +275,8 @@ AzureDeviceCodeCredential::GetToken(Azure::Core::Credentials::TokenRequestContex } { - std::shared_lock read_lock(token_mutex); + std::shared_lock read_lock(*device_code_context); + auto &token = device_code_context->cache_token; if (!token.access_token.empty() && IsFresh(token, token_request_context.MinimumExpiration, std::chrono::system_clock::now())) { return Azure::Core::Credentials::AccessToken {token.access_token, token.expires_at}; @@ -282,16 +284,17 @@ AzureDeviceCodeCredential::GetToken(Azure::Core::Credentials::TokenRequestContex } { - std::unique_lock write_lock(token_mutex); + std::unique_lock write_lock(*device_code_context); + auto &token = device_code_context->cache_token; if (!token.access_token.empty() && IsFresh(token, token_request_context.MinimumExpiration, std::chrono::system_clock::now())) { return Azure::Core::Credentials::AccessToken {token.access_token, token.expires_at}; } if (token.refresh_token.empty()) { - token = AuthenticatingUser(); + token = AuthenticatingUser(device_code_context->device_code_info); } else { - token = RefreshToken(); + token = RefreshToken(token.refresh_token); } return Azure::Core::Credentials::AccessToken {token.access_token, token.expires_at}; } diff --git a/src/auth/azure_device_code_function.cpp b/src/auth/azure_device_code_function.cpp index fa36c46..e08a111 100644 --- a/src/auth/azure_device_code_function.cpp +++ b/src/auth/azure_device_code_function.cpp @@ -1,5 +1,5 @@ #include "auth/azure_device_code_function.hpp" -#include "auth/azure_device_codes_context.hpp" +#include "auth/azure_device_code_context.hpp" #include "azure_storage_account_client.hpp" #include "duckdb/catalog/catalog_transaction.hpp" #include "duckdb/common/assert.hpp" @@ -65,19 +65,18 @@ static void AzureDeviceCodeImplementation(ClientContext &context, TableFunctionI throw InvalidInputException("azure_devicecode no secret found named %s", bind_data.secret_name); } - auto device_code_credential = CreateDeviceCodeCredentialRequester( - ClientData::Get(context).file_opener.get(), dynamic_cast(*secret->secret)); + auto &kv_secret = dynamic_cast(*secret->secret); + auto device_code_credential = + CreateDeviceCodeCredentialRequester(ClientData::Get(context).file_opener.get(), kv_secret); auto device_code_info = device_code_credential.RequestDeviceCode(); - auto &device_code_context = context.registered_state[AzureDeviceCodesClientContextState::CONTEXT_KEY]; - if (!device_code_context) { - device_code_context = make_shared(); + const auto context_key = AzureDeviceCodeClientContextState::BuildContextKey(bind_data.secret_name); + auto device_code_context_it = context.registered_state.find(context_key); + if (device_code_context_it == context.registered_state.end()) { + auto device_code_context = make_shared(device_code_info); + context.registered_state.insert(std::make_pair(context_key, std::move(device_code_context))); } - D_ASSERT(reinterpret_cast(device_code_context.get()) != nullptr); - reinterpret_cast(*device_code_context) - .device_code_info_by_secret.insert(std::make_pair(bind_data.secret_name, device_code_info)); - output.SetCapacity(1); output.SetValue(0, 0, bind_data.secret_name); output.SetValue(1, 0, device_code_info.user_code); diff --git a/src/auth/azure_device_codes_context.cpp b/src/auth/azure_device_codes_context.cpp deleted file mode 100644 index ff6945f..0000000 --- a/src/auth/azure_device_codes_context.cpp +++ /dev/null @@ -1,6 +0,0 @@ -#include "auth/azure_device_codes_context.hpp" - -namespace duckdb { -const std::string AzureDeviceCodesClientContextState::CONTEXT_KEY = "auth/azure_device_codes_context"; - -} \ No newline at end of file diff --git a/src/azure_storage_account_client.cpp b/src/azure_storage_account_client.cpp index 3a683d5..f0ced4e 100644 --- a/src/azure_storage_account_client.cpp +++ b/src/azure_storage_account_client.cpp @@ -1,7 +1,7 @@ #include "azure_storage_account_client.hpp" #include "auth/azure_device_code_credential.hpp" -#include "auth/azure_device_codes_context.hpp" +#include "auth/azure_device_code_context.hpp" #include "duckdb/catalog/catalog_transaction.hpp" #include "duckdb/common/assert.hpp" #include "duckdb/common/enums/statement_type.hpp" @@ -215,23 +215,18 @@ CreateDeviceCodeCredential(const KeyValueSecret &secret, if (!context) { throw InternalException("Context cannot be null!"); } - auto device_codes_info_context_it = context->registered_state.find(AzureDeviceCodesClientContextState::CONTEXT_KEY); - if (device_codes_info_context_it == context->registered_state.end()) { + auto context_key = AzureDeviceCodeClientContextState::BuildContextKey(secret.GetName()); + auto device_code_info_context_it = context->registered_state.find(context_key); + if (device_code_info_context_it == context->registered_state.end()) { throw InternalException( "Not device code has been initialized did you run `SELECT * FROM azure_devicecode('%s');`", secret.GetName()); } - D_ASSERT(dynamic_cast(device_codes_info_context_it->second.get()) != nullptr); - const auto &device_code_info_by_secret = - reinterpret_cast(*device_codes_info_context_it->second) - .device_code_info_by_secret; - auto device_code_info_it = device_code_info_by_secret.find(secret.GetName()); - if (device_code_info_it == device_code_info_by_secret.end()) { - throw InternalException( - "Not device code has been initialized did you run `SELECT * FROM azure_devicecode('%s');`", - secret.GetName()); - } + D_ASSERT(dynamic_cast(device_code_info_context_it->second.get()) != nullptr); + auto device_code_info_context = std::shared_ptr( + device_code_info_context_it->second, + reinterpret_cast(device_code_info_context_it->second.get())); constexpr bool error_on_missing = true; auto tenant_id = secret.TryGetValue("tenant_id", error_on_missing).ToString(); @@ -241,7 +236,7 @@ CreateDeviceCodeCredential(const KeyValueSecret &secret, return std::make_shared( tenant_id, client_id, std::unordered_set(oauth_scopes.begin(), oauth_scopes.end()), - device_code_info_it->second, ToTokenCredentialOptions(transport_options)); + std::move(device_code_info_context), ToTokenCredentialOptions(transport_options)); } static std::shared_ptr diff --git a/src/include/auth/azure_device_code_context.hpp b/src/include/auth/azure_device_code_context.hpp new file mode 100644 index 0000000..958b602 --- /dev/null +++ b/src/include/auth/azure_device_code_context.hpp @@ -0,0 +1,47 @@ +#pragma once + +#include "auth/azure_device_code_credential.hpp" +#include "duckdb/main/client_context_state.hpp" +#include +#include +#include + +namespace duckdb { +class AzureDeviceCodeClientContextState final : public ClientContextState { +public: + const AzureDeviceCodeInfo device_code_info; + // Access to this attributes should always be protected by firstly acquiring the lock. + AzureOAuthAccessToken cache_token; + + AzureDeviceCodeClientContextState(AzureDeviceCodeInfo device_code_info) + : device_code_info(std::move(device_code_info)) { + } + + static std::string BuildContextKey(const std::string &secret_name) { + return "azure:device_codes:" + secret_name; + } + +public: // mutex API + void lock() { + cache_token_mutex.lock(); + } + bool try_lock() { + return cache_token_mutex.try_lock(); + } + void unlock() { + cache_token_mutex.unlock(); + } + void lock_shared() { + cache_token_mutex.lock_shared(); + } + bool try_lock_shared() { + return cache_token_mutex.try_lock_shared(); + } + void unlock_shared() { + cache_token_mutex.unlock_shared(); + } + +private: + std::shared_timed_mutex cache_token_mutex; +}; +} // namespace duckdb \ No newline at end of file diff --git a/src/include/auth/azure_device_code_credential.hpp b/src/include/auth/azure_device_code_credential.hpp index 9f4c08f..de1d025 100644 --- a/src/include/auth/azure_device_code_credential.hpp +++ b/src/include/auth/azure_device_code_credential.hpp @@ -27,6 +27,8 @@ namespace duckdb { +class AzureDeviceCodeClientContextState; + struct AzureDeviceCodeInfo { // A long string used to verify the session between the client and the authorization server. // The client uses this parameter to request the access token from the authorization server. @@ -45,6 +47,20 @@ struct AzureDeviceCodeInfo { std::string message; }; +/** + * When refresh token is set it seen to be valid for 90 days + * @see https://learn.microsoft.com/en-us/entra/identity-platform/refresh-tokens#token-lifetime + * @see https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-auth-code-flow#refresh-the-access-token + */ +struct AzureOAuthAccessToken { + // Number of seconds the included access token is valid for. + Azure::DateTime expires_at; + // Issued for the scopes that were requested. + std::string access_token; + // Issued if the original scope parameter included offline_access. + std::string refresh_token; +}; + class AzureDeviceCodeCredentialRequester final { public: explicit AzureDeviceCodeCredentialRequester(std::string tenant_id, std::string client_id, @@ -72,7 +88,8 @@ class AzureDeviceCodeCredentialRequester final { class AzureDeviceCodeCredential final : public Azure::Core::Credentials::TokenCredential { public: explicit AzureDeviceCodeCredential(std::string tenant_id, std::string client_id, - std::unordered_set scopes, AzureDeviceCodeInfo device_code, + std::unordered_set scopes, + std::shared_ptr device_code_context, const Azure::Core::Credentials::TokenCredentialOptions &options = Azure::Core::Credentials::TokenCredentialOptions()); Azure::Core::Credentials::AccessToken @@ -80,38 +97,21 @@ class AzureDeviceCodeCredential final : public Azure::Core::Credentials::TokenCr Azure::Core::Context const &context) const override; private: - /** - * When refresh token is set it seen to be valid for 90 days - * @see https://learn.microsoft.com/en-us/entra/identity-platform/refresh-tokens#token-lifetime - * @see https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-auth-code-flow#refresh-the-access-token - */ - struct OAuthAccessToken { - // Number of seconds the included access token is valid for. - Azure::DateTime expires_at; - // Issued for the scopes that were requested. - std::string access_token; - // Issued if the original scope parameter included offline_access. - std::string refresh_token; - }; - -private: - OAuthAccessToken AuthenticatingUser() const; - OAuthAccessToken RefreshToken() const; - static bool IsFresh(const OAuthAccessToken &token, Azure::DateTime::duration minimum_expiration, + AzureOAuthAccessToken AuthenticatingUser(const AzureDeviceCodeInfo &device_code_info) const; + AzureOAuthAccessToken RefreshToken(const std::string &refresh_token) const; + static bool IsFresh(const AzureOAuthAccessToken &token, Azure::DateTime::duration minimum_expiration, std::chrono::system_clock::time_point now); - static void ParseJson(const std::string &json_str, OAuthAccessToken *token); + static void ParseJson(const std::string &json_str, AzureOAuthAccessToken *token); private: const std::string tenant_id; const std::string client_id; const std::unordered_set scopes; const std::string encoded_scopes; - const AzureDeviceCodeInfo device_code_info; - Azure::Core::Http::_internal::HttpPipeline http_pipeline; + const std::shared_ptr device_code_context; - mutable std::shared_timed_mutex token_mutex; - mutable OAuthAccessToken token; + Azure::Core::Http::_internal::HttpPipeline http_pipeline; }; } // namespace duckdb \ No newline at end of file diff --git a/src/include/auth/azure_device_codes_context.hpp b/src/include/auth/azure_device_codes_context.hpp deleted file mode 100644 index bab37da..0000000 --- a/src/include/auth/azure_device_codes_context.hpp +++ /dev/null @@ -1,13 +0,0 @@ -#pragma once - -#include "auth/azure_device_code_credential.hpp" -#include "duckdb/main/client_context_state.hpp" -#include - -namespace duckdb { -class AzureDeviceCodesClientContextState final : public ClientContextState { -public: - const static std::string CONTEXT_KEY; - std::unordered_map device_code_info_by_secret; -}; -} // namespace duckdb \ No newline at end of file