From 79690d5b892a97011902e55dabeeb429db698c76 Mon Sep 17 00:00:00 2001 From: Joseph Shearer Date: Fri, 8 Dec 2023 12:40:09 -0500 Subject: [PATCH] fix: Azure fragment store with multiple tenants (#354) * fix: Azure fragment store with multiple tenants. I knew that the service credentials would be scoped to tenant IDs, but I did not realize that the `service.Client`s and `UserDelegationCredential`s would also be tenant-scoped. This was causing problems where the credentials of whatever tenant was first seen would be used to sign _all_ `SignGet` URLs, resulting in signature mismatch errors when the blob was owned by a different tenant. This updates the cache logic to keep track of both the service clients and the user delegation credentials by tenant ID, which should solve the problem. --- broker/fragment/store_azure.go | 232 ++++++++++++++++++++++++--------- broker/fragment/stores.go | 29 +++-- 2 files changed, 191 insertions(+), 70 deletions(-) diff --git a/broker/fragment/store_azure.go b/broker/fragment/store_azure.go index cc55005f..bf6af4d4 100644 --- a/broker/fragment/store_azure.go +++ b/broker/fragment/store_azure.go @@ -2,6 +2,7 @@ package fragment import ( "context" + "errors" "fmt" "io" "net/url" @@ -22,9 +23,9 @@ import ( pb "go.gazette.dev/core/broker/protocol" ) -// AzureStoreConfig configures a Fragment store of the "azure://" or "azure-ad://" scheme. +// azureStoreConfig configures a Fragment store of the "azure://" or "azure-ad://" scheme. // It is initialized from parsed URL parametrs of the pb.FragmentStore -type AzureStoreConfig struct { +type azureStoreConfig struct { accountTenantID string // The tenant ID that owns the storage account that we're writing into // NOTE: This is not the tenant ID that owns the servie principal storageAccountName string // Storage accounts in Azure are the equivalent to a "bucket" in S3 @@ -34,20 +35,27 @@ type AzureStoreConfig struct { RewriterConfig } -func (cfg *AzureStoreConfig) serviceUrl() string { +func (cfg *azureStoreConfig) serviceUrl() string { return fmt.Sprintf("https://%s.blob.core.windows.net", cfg.storageAccountName) } -func (cfg *AzureStoreConfig) containerURL() string { +func (cfg *azureStoreConfig) containerURL() string { return fmt.Sprintf("%s/%s", cfg.serviceUrl(), cfg.containerName) } +type udcAndExp struct { + udc *service.UserDelegationCredential + exp *time.Time +} + type azureBackend struct { - clients map[string]pipeline.Pipeline - svcClient service.Client - clientMu sync.Mutex - udc *service.UserDelegationCredential - udcExp *time.Time + // This is a cache of configured Pipelines for each tenant. These do not expire + pipelines map[string]pipeline.Pipeline + // This is a cache of Azure storage clients for each tenant. These do not expire + clients map[string]*service.Client + mu sync.Mutex + // This is a cache of URL-signing credentials for each tenant. These DO expire + udcs map[string]udcAndExp } func (a *azureBackend) Provider() string { @@ -56,14 +64,14 @@ func (a *azureBackend) Provider() string { // See here for an example of how to use the Azure client libraries to create signatures: // https://github.com/Azure/azure-sdk-for-go/blob/main/sdk/storage/azblob/service/examples_test.go#L285 -func (a *azureBackend) SignGet(ep *url.URL, fragment pb.Fragment, d time.Duration) (string, error) { - cfg, _, err := a.azureClient(ep) +func (a *azureBackend) SignGet(endpoint *url.URL, fragment pb.Fragment, d time.Duration) (string, error) { + cfg, err := parseAzureEndpoint(endpoint) if err != nil { return "", err } blobName := cfg.rewritePath(cfg.prefix, fragment.ContentPath()) - udc, err := a.getUserDelegationCredential() + udc, err := a.getUserDelegationCredential(endpoint) if err != nil { return "", err } @@ -73,20 +81,27 @@ func (a *azureBackend) SignGet(ep *url.URL, fragment pb.Fragment, d time.Duratio ExpiryTime: time.Now().UTC().Add(d), // Timestamps are expected in UTC https://docs.microsoft.com/en-us/rest/api/storageservices/create-service-sas#service-sas-example ContainerName: cfg.containerName, BlobName: blobName, - - // To produce a container SAS (as opposed to a blob SAS), assign to Permissions using - // ContainerSASPermissions and make sure the BlobName field is "" (the default). - Permissions: to.Ptr(sas.ContainerPermissions{Read: true, Add: true, Write: true}).String(), + // These are the permissions granted to the signed URLs + Permissions: to.Ptr(sas.BlobPermissions{Read: true}).String(), }.SignWithUserDelegation(udc) if err != nil { return "", err } + + log.WithFields(log.Fields{ + "tenantId": cfg.accountTenantID, + "storageAccountName": cfg.storageAccountName, + "containerName": cfg.containerName, + "blobName": blobName, + "expiryTime": sasQueryParams.ExpiryTime(), + }).Debug("Signed get request") + return fmt.Sprintf("%s/%s?%s", cfg.containerURL(), blobName, sasQueryParams.Encode()), nil } func (a *azureBackend) Exists(ctx context.Context, ep *url.URL, fragment pb.Fragment) (bool, error) { - cfg, client, err := a.azureClient(ep) + cfg, client, err := a.getAzurePipeline(ep) if err != nil { return false, err } @@ -108,7 +123,7 @@ func (a *azureBackend) Exists(ctx context.Context, ep *url.URL, fragment pb.Frag } func (a *azureBackend) Open(ctx context.Context, ep *url.URL, fragment pb.Fragment) (io.ReadCloser, error) { - cfg, client, err := a.azureClient(ep) + cfg, client, err := a.getAzurePipeline(ep) if err != nil { return nil, err } @@ -124,7 +139,7 @@ func (a *azureBackend) Open(ctx context.Context, ep *url.URL, fragment pb.Fragme } func (a *azureBackend) Persist(ctx context.Context, ep *url.URL, spool Spool) error { - cfg, client, err := a.azureClient(ep) + cfg, client, err := a.getAzurePipeline(ep) if err != nil { return err } @@ -147,7 +162,7 @@ func (a *azureBackend) Persist(ctx context.Context, ep *url.URL, spool Spool) er } func (a *azureBackend) List(ctx context.Context, store pb.FragmentStore, ep *url.URL, journal pb.Journal, callback func(pb.Fragment)) error { - cfg, client, err := a.azureClient(ep) + cfg, client, err := a.getAzurePipeline(ep) if err != nil { return err } @@ -188,7 +203,7 @@ func (a *azureBackend) List(ctx context.Context, store pb.FragmentStore, ep *url } func (a *azureBackend) Remove(ctx context.Context, fragment pb.Fragment) error { - cfg, client, err := a.azureClient(fragment.BackingStore.URL()) + cfg, client, err := a.getAzurePipeline(fragment.BackingStore.URL()) if err != nil { return err } @@ -222,47 +237,127 @@ func getAzureStorageCredential(coreCredential azcore.TokenCredential, tenant str return credential, nil } -func (a *azureBackend) azureClient(ep *url.URL) (cfg AzureStoreConfig, client pipeline.Pipeline, err error) { - if err = parseStoreArgs(ep, &cfg); err != nil { +func parseAzureEndpoint(endpoint *url.URL) (cfg azureStoreConfig, err error) { + if err = parseStoreArgs(endpoint, &cfg); err != nil { return } + // Omit leading slash from URI. Note that FragmentStore already // enforces that URL Paths end in '/'. - var splitPath = strings.Split(ep.Path[1:], "/") + var splitPath = strings.Split(endpoint.Path[1:], "/") - if ep.Scheme == "azure" { + if endpoint.Scheme == "azure" { // Since only one non-ad "Shared Key" credential can be injected via // environment variables, we should only keep around one client for // all `azure://` requests. cfg.accountTenantID = "AZURE_SHARED_KEY" cfg.storageAccountName = os.Getenv("AZURE_ACCOUNT_NAME") - cfg.containerName, cfg.prefix = ep.Host, ep.Path[1:] - } else if ep.Scheme == "azure-ad" { - cfg.accountTenantID, cfg.storageAccountName, cfg.containerName, cfg.prefix = ep.Host, splitPath[0], splitPath[1], strings.Join(splitPath[2:], "/") + cfg.containerName, cfg.prefix = endpoint.Host, endpoint.Path[1:] + } else if endpoint.Scheme == "azure-ad" { + cfg.accountTenantID, cfg.storageAccountName, cfg.containerName, cfg.prefix = endpoint.Host, splitPath[0], splitPath[1], strings.Join(splitPath[2:], "/") } - a.clientMu.Lock() - defer a.clientMu.Unlock() + return cfg, nil +} - if a.clients[cfg.accountTenantID] != nil { - client = a.clients[cfg.accountTenantID] - return - } +func (a *azureBackend) getAzureServiceClient(endpoint *url.URL) (client *service.Client, err error) { + var cfg azureStoreConfig - var credentials azblob.Credential + if cfg, err = parseAzureEndpoint(endpoint); err != nil { + return nil, err + } - if ep.Scheme == "azure" { + if endpoint.Scheme == "azure" { var accountName = os.Getenv("AZURE_ACCOUNT_NAME") var accountKey = os.Getenv("AZURE_ACCOUNT_KEY") + + a.mu.Lock() + client, ok := a.clients[accountName] + a.mu.Unlock() + + if ok { + log.WithFields(log.Fields{ + "storageAccountName": accountName, + }).Info("Re-using cached azure:// service client") + return client, nil + } + sharedKeyCred, err := service.NewSharedKeyCredential(accountName, accountKey) if err != nil { - return cfg, nil, err + return nil, err } serviceClient, err := service.NewClientWithSharedKeyCredential(cfg.serviceUrl(), sharedKeyCred, &service.ClientOptions{}) if err != nil { - return cfg, nil, err + return nil, err } - a.svcClient = *serviceClient + + a.mu.Lock() + a.clients[accountName] = serviceClient + a.mu.Lock() + return serviceClient, nil + } else if endpoint.Scheme == "azure-ad" { + // Link to the Azure docs describing what fields are required for active directory auth + // https://learn.microsoft.com/en-us/azure/developer/go/azure-sdk-authentication-service-principal?tabs=azure-cli#-option-1-authenticate-with-a-secret + var clientId = os.Getenv("AZURE_CLIENT_ID") + var clientSecret = os.Getenv("AZURE_CLIENT_SECRET") + + a.mu.Lock() + client, ok := a.clients[cfg.accountTenantID] + a.mu.Unlock() + + if ok { + log.WithFields(log.Fields{ + "accountTenantId": cfg.accountTenantID, + }).Info("Re-using cached azure-ad:// service client") + return client, nil + } + + identityCreds, err := azidentity.NewClientSecretCredential( + cfg.accountTenantID, + clientId, + clientSecret, + &azidentity.ClientSecretCredentialOptions{ + AdditionallyAllowedTenants: []string{cfg.accountTenantID}, + DisableInstanceDiscovery: true, + }, + ) + if err != nil { + return nil, err + } + + serviceClient, err := service.NewClient(cfg.serviceUrl(), identityCreds, &service.ClientOptions{}) + if err != nil { + return nil, err + } + + a.mu.Lock() + a.clients[cfg.accountTenantID] = serviceClient + a.mu.Unlock() + + return serviceClient, nil + } + return nil, errors.New("unrecognized URI scheme") +} + +func (a *azureBackend) getAzurePipeline(ep *url.URL) (cfg azureStoreConfig, client pipeline.Pipeline, err error) { + if cfg, err = parseAzureEndpoint(ep); err != nil { + return + } + + a.mu.Lock() + client = a.pipelines[cfg.accountTenantID] + a.mu.Unlock() + + if client != nil { + return + } + + var credentials azblob.Credential + + if ep.Scheme == "azure" { + var accountName = os.Getenv("AZURE_ACCOUNT_NAME") + var accountKey = os.Getenv("AZURE_ACCOUNT_KEY") + // Create an azblob credential that we can pass to `NewPipeline` credentials, err = azblob.NewSharedKeyCredential(accountName, accountKey) if err != nil { @@ -287,12 +382,6 @@ func (a *azureBackend) azureClient(ep *url.URL) (cfg AzureStoreConfig, client pi return cfg, nil, err } - serviceClient, err := service.NewClient(cfg.serviceUrl(), identityCreds, &service.ClientOptions{}) - if err != nil { - return cfg, nil, err - } - a.svcClient = *serviceClient - credentials, err = getAzureStorageCredential(identityCreds, cfg.accountTenantID) if err != nil { return cfg, nil, err @@ -300,21 +389,22 @@ func (a *azureBackend) azureClient(ep *url.URL) (cfg AzureStoreConfig, client pi } client = azblob.NewPipeline(credentials, azblob.PipelineOptions{}) - if a.clients == nil { - a.clients = make(map[string]pipeline.Pipeline) - } - a.clients[cfg.accountTenantID] = client + + a.mu.Lock() + a.pipelines[cfg.accountTenantID] = client + a.mu.Unlock() log.WithFields(log.Fields{ + "tenant": cfg.accountTenantID, "storageAccountName": cfg.storageAccountName, "storageContainerName": cfg.containerName, "pathPrefix": cfg.prefix, - }).Info("constructed new Azure Storage client") + }).Info("constructed new Azure Storage pipeline client") return cfg, client, nil } -func (a *azureBackend) buildBlobURL(cfg AzureStoreConfig, client pipeline.Pipeline, path string) (*azblob.BlockBlobURL, error) { +func (a *azureBackend) buildBlobURL(cfg azureStoreConfig, client pipeline.Pipeline, path string) (*azblob.BlockBlobURL, error) { u, err := url.Parse(fmt.Sprint(cfg.containerURL(), "/", cfg.rewritePath(cfg.prefix, path))) if err != nil { return nil, err @@ -324,7 +414,15 @@ func (a *azureBackend) buildBlobURL(cfg AzureStoreConfig, client pipeline.Pipeli } // Cache UserDelegationCredentials and refresh them when needed -func (a *azureBackend) getUserDelegationCredential() (*service.UserDelegationCredential, error) { +func (a *azureBackend) getUserDelegationCredential(endpoint *url.URL) (*service.UserDelegationCredential, error) { + var cfg, err = parseAzureEndpoint(endpoint) + if err != nil { + return nil, err + } + a.mu.Lock() + var udc, hasCachedUdc = a.udcs[cfg.accountTenantID] + a.mu.Unlock() + // https://learn.microsoft.com/en-us/azure/storage/blobs/storage-blob-user-delegation-sas-create-cli#use-azure-ad-credentials-to-secure-a-sas // According to the above docs, signed URLs generated with a UDC are invalid after // that UDC expires. In addition, a UDC can live up to 7 days. So let's ensure that @@ -333,23 +431,41 @@ func (a *azureBackend) getUserDelegationCredential() (*service.UserDelegationCre // ----| NOW |------|NOW+5Day|-----| udcExp |---- No need to refresh // ----| NOW |-----| udcExp |-----|NOW+5Day|---- Need to refresh // ----|udcExp|-----| NOW | ------------------ Need to refresh - if a.udc == nil || (a.udcExp != nil && a.udcExp.Before(time.Now().Add(time.Hour*24*5))) { + if !hasCachedUdc || (udc.exp != nil && udc.exp.Before(time.Now().Add(time.Hour*24*5))) { // Generate UDCs that expire 6 days from now, and refresh them after they // have less than 5 days left until they expire. + var startTime = time.Now().Add(time.Second * -10) var expTime = time.Now().Add(time.Hour * 24 * 6) var info = service.KeyInfo{ - Start: to.Ptr(time.Now().Add(time.Second * -10).UTC().Format(sas.TimeFormat)), + Start: to.Ptr(startTime.UTC().Format(sas.TimeFormat)), Expiry: to.Ptr(expTime.UTC().Format(sas.TimeFormat)), } - udc, err := a.svcClient.GetUserDelegationCredential(context.Background(), info, nil) + var serviceClient, err = a.getAzureServiceClient(endpoint) + if err != nil { + return nil, err + } + + cred, err := serviceClient.GetUserDelegationCredential(context.Background(), info, nil) if err != nil { return nil, err } - a.udc = udc - a.udcExp = &expTime + log.WithFields(log.Fields{ + "newExpiration": expTime, + "newStart": startTime.String(), + "service.KeyInfo": info, + "tenant": cfg.accountTenantID, + }).Info("Refreshing Azure Storage UDC") + + udc = udcAndExp{ + udc: cred, + exp: &expTime, + } + a.mu.Lock() + a.udcs[cfg.accountTenantID] = udc + a.mu.Unlock() } - return a.udc, nil + return udc.udc, nil } diff --git a/broker/fragment/stores.go b/broker/fragment/stores.go index dfda6034..455b63fe 100644 --- a/broker/fragment/stores.go +++ b/broker/fragment/stores.go @@ -10,6 +10,8 @@ import ( "text/template" "time" + "github.com/Azure/azure-pipeline-go/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/service" "github.com/gorilla/schema" "github.com/pkg/errors" pb "go.gazette.dev/core/broker/protocol" @@ -35,10 +37,14 @@ var sharedStores = struct { azure *azureBackend fs *fsBackend }{ - s3: newS3Backend(), - gcs: &gcsBackend{}, - azure: &azureBackend{}, - fs: &fsBackend{}, + s3: newS3Backend(), + gcs: &gcsBackend{}, + azure: &azureBackend{ + pipelines: make(map[string]pipeline.Pipeline), + clients: make(map[string]*service.Client), + udcs: make(map[string]udcAndExp), + }, + fs: &fsBackend{}, } func getBackend(scheme string) backend { @@ -194,14 +200,13 @@ func evalPathPostfix(spool Spool, spec *pb.JournalSpec) (string, error) { // in the implementation of new journal naming taxonomies which don't disrupt // journal fragments that are already written. // -// var cfg = RewriterConfig{ -// Replace: "/old-path/page-views/ -// Find: "/bar/v1/page-views/", -// } -// // Remaps journal name => fragment store URL: -// // "/foo/bar/v1/page-views/part-000" => "s3://my-bucket/foo/old-path/page-views/part-000" // Matched. -// // "/foo/bar/v2/page-views/part-000" => "s3://my-bucket/foo/bar/v2/page-views/part-000" // Not matched. -// +// var cfg = RewriterConfig{ +// Replace: "/old-path/page-views/ +// Find: "/bar/v1/page-views/", +// } +// // Remaps journal name => fragment store URL: +// // "/foo/bar/v1/page-views/part-000" => "s3://my-bucket/foo/old-path/page-views/part-000" // Matched. +// // "/foo/bar/v2/page-views/part-000" => "s3://my-bucket/foo/bar/v2/page-views/part-000" // Not matched. type RewriterConfig struct { // Find is the string to replace in the unmodified journal name. Find string