Skip to content

Commit

Permalink
fix: Azure fragment store with multiple tenants (#354)
Browse files Browse the repository at this point in the history
* fix: Azure fragment store with multiple tenants.

I knew that the service credentials would be scoped to tenant IDs, but I did not realize that the `service.Client`s and `UserDelegationCredential`s would also be tenant-scoped. This was causing problems where the credentials of whatever tenant was first seen would be used to sign _all_ `SignGet` URLs, resulting in signature mismatch errors when the blob was owned by a different tenant.

This updates the cache logic to keep track of both the service clients and the user delegation credentials by tenant ID, which should solve the problem.
  • Loading branch information
jshearer authored Dec 8, 2023
1 parent 0fa83a6 commit 79690d5
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 70 deletions.
232 changes: 174 additions & 58 deletions broker/fragment/store_azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package fragment

import (
"context"
"errors"
"fmt"
"io"
"net/url"
Expand All @@ -22,9 +23,9 @@ import (
pb "go.gazette.dev/core/broker/protocol"
)

// AzureStoreConfig configures a Fragment store of the "azure://" or "azure-ad://" scheme.
// azureStoreConfig configures a Fragment store of the "azure://" or "azure-ad://" scheme.
// It is initialized from parsed URL parametrs of the pb.FragmentStore
type AzureStoreConfig struct {
type azureStoreConfig struct {
accountTenantID string // The tenant ID that owns the storage account that we're writing into
// NOTE: This is not the tenant ID that owns the servie principal
storageAccountName string // Storage accounts in Azure are the equivalent to a "bucket" in S3
Expand All @@ -34,20 +35,27 @@ type AzureStoreConfig struct {
RewriterConfig
}

func (cfg *AzureStoreConfig) serviceUrl() string {
func (cfg *azureStoreConfig) serviceUrl() string {
return fmt.Sprintf("https://%s.blob.core.windows.net", cfg.storageAccountName)
}

func (cfg *AzureStoreConfig) containerURL() string {
func (cfg *azureStoreConfig) containerURL() string {
return fmt.Sprintf("%s/%s", cfg.serviceUrl(), cfg.containerName)
}

type udcAndExp struct {
udc *service.UserDelegationCredential
exp *time.Time
}

type azureBackend struct {
clients map[string]pipeline.Pipeline
svcClient service.Client
clientMu sync.Mutex
udc *service.UserDelegationCredential
udcExp *time.Time
// This is a cache of configured Pipelines for each tenant. These do not expire
pipelines map[string]pipeline.Pipeline
// This is a cache of Azure storage clients for each tenant. These do not expire
clients map[string]*service.Client
mu sync.Mutex
// This is a cache of URL-signing credentials for each tenant. These DO expire
udcs map[string]udcAndExp
}

func (a *azureBackend) Provider() string {
Expand All @@ -56,14 +64,14 @@ func (a *azureBackend) Provider() string {

// See here for an example of how to use the Azure client libraries to create signatures:
// https://github.com/Azure/azure-sdk-for-go/blob/main/sdk/storage/azblob/service/examples_test.go#L285
func (a *azureBackend) SignGet(ep *url.URL, fragment pb.Fragment, d time.Duration) (string, error) {
cfg, _, err := a.azureClient(ep)
func (a *azureBackend) SignGet(endpoint *url.URL, fragment pb.Fragment, d time.Duration) (string, error) {
cfg, err := parseAzureEndpoint(endpoint)
if err != nil {
return "", err
}
blobName := cfg.rewritePath(cfg.prefix, fragment.ContentPath())

udc, err := a.getUserDelegationCredential()
udc, err := a.getUserDelegationCredential(endpoint)
if err != nil {
return "", err
}
Expand All @@ -73,20 +81,27 @@ func (a *azureBackend) SignGet(ep *url.URL, fragment pb.Fragment, d time.Duratio
ExpiryTime: time.Now().UTC().Add(d), // Timestamps are expected in UTC https://docs.microsoft.com/en-us/rest/api/storageservices/create-service-sas#service-sas-example
ContainerName: cfg.containerName,
BlobName: blobName,

// To produce a container SAS (as opposed to a blob SAS), assign to Permissions using
// ContainerSASPermissions and make sure the BlobName field is "" (the default).
Permissions: to.Ptr(sas.ContainerPermissions{Read: true, Add: true, Write: true}).String(),
// These are the permissions granted to the signed URLs
Permissions: to.Ptr(sas.BlobPermissions{Read: true}).String(),
}.SignWithUserDelegation(udc)

if err != nil {
return "", err
}

log.WithFields(log.Fields{
"tenantId": cfg.accountTenantID,
"storageAccountName": cfg.storageAccountName,
"containerName": cfg.containerName,
"blobName": blobName,
"expiryTime": sasQueryParams.ExpiryTime(),
}).Debug("Signed get request")

return fmt.Sprintf("%s/%s?%s", cfg.containerURL(), blobName, sasQueryParams.Encode()), nil
}

func (a *azureBackend) Exists(ctx context.Context, ep *url.URL, fragment pb.Fragment) (bool, error) {
cfg, client, err := a.azureClient(ep)
cfg, client, err := a.getAzurePipeline(ep)
if err != nil {
return false, err
}
Expand All @@ -108,7 +123,7 @@ func (a *azureBackend) Exists(ctx context.Context, ep *url.URL, fragment pb.Frag
}

func (a *azureBackend) Open(ctx context.Context, ep *url.URL, fragment pb.Fragment) (io.ReadCloser, error) {
cfg, client, err := a.azureClient(ep)
cfg, client, err := a.getAzurePipeline(ep)
if err != nil {
return nil, err
}
Expand All @@ -124,7 +139,7 @@ func (a *azureBackend) Open(ctx context.Context, ep *url.URL, fragment pb.Fragme
}

func (a *azureBackend) Persist(ctx context.Context, ep *url.URL, spool Spool) error {
cfg, client, err := a.azureClient(ep)
cfg, client, err := a.getAzurePipeline(ep)
if err != nil {
return err
}
Expand All @@ -147,7 +162,7 @@ func (a *azureBackend) Persist(ctx context.Context, ep *url.URL, spool Spool) er
}

func (a *azureBackend) List(ctx context.Context, store pb.FragmentStore, ep *url.URL, journal pb.Journal, callback func(pb.Fragment)) error {
cfg, client, err := a.azureClient(ep)
cfg, client, err := a.getAzurePipeline(ep)
if err != nil {
return err
}
Expand Down Expand Up @@ -188,7 +203,7 @@ func (a *azureBackend) List(ctx context.Context, store pb.FragmentStore, ep *url
}

func (a *azureBackend) Remove(ctx context.Context, fragment pb.Fragment) error {
cfg, client, err := a.azureClient(fragment.BackingStore.URL())
cfg, client, err := a.getAzurePipeline(fragment.BackingStore.URL())
if err != nil {
return err
}
Expand Down Expand Up @@ -222,47 +237,127 @@ func getAzureStorageCredential(coreCredential azcore.TokenCredential, tenant str
return credential, nil
}

func (a *azureBackend) azureClient(ep *url.URL) (cfg AzureStoreConfig, client pipeline.Pipeline, err error) {
if err = parseStoreArgs(ep, &cfg); err != nil {
func parseAzureEndpoint(endpoint *url.URL) (cfg azureStoreConfig, err error) {
if err = parseStoreArgs(endpoint, &cfg); err != nil {
return
}

// Omit leading slash from URI. Note that FragmentStore already
// enforces that URL Paths end in '/'.
var splitPath = strings.Split(ep.Path[1:], "/")
var splitPath = strings.Split(endpoint.Path[1:], "/")

if ep.Scheme == "azure" {
if endpoint.Scheme == "azure" {
// Since only one non-ad "Shared Key" credential can be injected via
// environment variables, we should only keep around one client for
// all `azure://` requests.
cfg.accountTenantID = "AZURE_SHARED_KEY"
cfg.storageAccountName = os.Getenv("AZURE_ACCOUNT_NAME")
cfg.containerName, cfg.prefix = ep.Host, ep.Path[1:]
} else if ep.Scheme == "azure-ad" {
cfg.accountTenantID, cfg.storageAccountName, cfg.containerName, cfg.prefix = ep.Host, splitPath[0], splitPath[1], strings.Join(splitPath[2:], "/")
cfg.containerName, cfg.prefix = endpoint.Host, endpoint.Path[1:]
} else if endpoint.Scheme == "azure-ad" {
cfg.accountTenantID, cfg.storageAccountName, cfg.containerName, cfg.prefix = endpoint.Host, splitPath[0], splitPath[1], strings.Join(splitPath[2:], "/")
}

a.clientMu.Lock()
defer a.clientMu.Unlock()
return cfg, nil
}

if a.clients[cfg.accountTenantID] != nil {
client = a.clients[cfg.accountTenantID]
return
}
func (a *azureBackend) getAzureServiceClient(endpoint *url.URL) (client *service.Client, err error) {
var cfg azureStoreConfig

var credentials azblob.Credential
if cfg, err = parseAzureEndpoint(endpoint); err != nil {
return nil, err
}

if ep.Scheme == "azure" {
if endpoint.Scheme == "azure" {
var accountName = os.Getenv("AZURE_ACCOUNT_NAME")
var accountKey = os.Getenv("AZURE_ACCOUNT_KEY")

a.mu.Lock()
client, ok := a.clients[accountName]
a.mu.Unlock()

if ok {
log.WithFields(log.Fields{
"storageAccountName": accountName,
}).Info("Re-using cached azure:// service client")
return client, nil
}

sharedKeyCred, err := service.NewSharedKeyCredential(accountName, accountKey)
if err != nil {
return cfg, nil, err
return nil, err
}
serviceClient, err := service.NewClientWithSharedKeyCredential(cfg.serviceUrl(), sharedKeyCred, &service.ClientOptions{})
if err != nil {
return cfg, nil, err
return nil, err
}
a.svcClient = *serviceClient

a.mu.Lock()
a.clients[accountName] = serviceClient
a.mu.Lock()
return serviceClient, nil
} else if endpoint.Scheme == "azure-ad" {
// Link to the Azure docs describing what fields are required for active directory auth
// https://learn.microsoft.com/en-us/azure/developer/go/azure-sdk-authentication-service-principal?tabs=azure-cli#-option-1-authenticate-with-a-secret
var clientId = os.Getenv("AZURE_CLIENT_ID")
var clientSecret = os.Getenv("AZURE_CLIENT_SECRET")

a.mu.Lock()
client, ok := a.clients[cfg.accountTenantID]
a.mu.Unlock()

if ok {
log.WithFields(log.Fields{
"accountTenantId": cfg.accountTenantID,
}).Info("Re-using cached azure-ad:// service client")
return client, nil
}

identityCreds, err := azidentity.NewClientSecretCredential(
cfg.accountTenantID,
clientId,
clientSecret,
&azidentity.ClientSecretCredentialOptions{
AdditionallyAllowedTenants: []string{cfg.accountTenantID},
DisableInstanceDiscovery: true,
},
)
if err != nil {
return nil, err
}

serviceClient, err := service.NewClient(cfg.serviceUrl(), identityCreds, &service.ClientOptions{})
if err != nil {
return nil, err
}

a.mu.Lock()
a.clients[cfg.accountTenantID] = serviceClient
a.mu.Unlock()

return serviceClient, nil
}
return nil, errors.New("unrecognized URI scheme")
}

func (a *azureBackend) getAzurePipeline(ep *url.URL) (cfg azureStoreConfig, client pipeline.Pipeline, err error) {
if cfg, err = parseAzureEndpoint(ep); err != nil {
return
}

a.mu.Lock()
client = a.pipelines[cfg.accountTenantID]
a.mu.Unlock()

if client != nil {
return
}

var credentials azblob.Credential

if ep.Scheme == "azure" {
var accountName = os.Getenv("AZURE_ACCOUNT_NAME")
var accountKey = os.Getenv("AZURE_ACCOUNT_KEY")

// Create an azblob credential that we can pass to `NewPipeline`
credentials, err = azblob.NewSharedKeyCredential(accountName, accountKey)
if err != nil {
Expand All @@ -287,34 +382,29 @@ func (a *azureBackend) azureClient(ep *url.URL) (cfg AzureStoreConfig, client pi
return cfg, nil, err
}

serviceClient, err := service.NewClient(cfg.serviceUrl(), identityCreds, &service.ClientOptions{})
if err != nil {
return cfg, nil, err
}
a.svcClient = *serviceClient

credentials, err = getAzureStorageCredential(identityCreds, cfg.accountTenantID)
if err != nil {
return cfg, nil, err
}
}

client = azblob.NewPipeline(credentials, azblob.PipelineOptions{})
if a.clients == nil {
a.clients = make(map[string]pipeline.Pipeline)
}
a.clients[cfg.accountTenantID] = client

a.mu.Lock()
a.pipelines[cfg.accountTenantID] = client
a.mu.Unlock()

log.WithFields(log.Fields{
"tenant": cfg.accountTenantID,
"storageAccountName": cfg.storageAccountName,
"storageContainerName": cfg.containerName,
"pathPrefix": cfg.prefix,
}).Info("constructed new Azure Storage client")
}).Info("constructed new Azure Storage pipeline client")

return cfg, client, nil
}

func (a *azureBackend) buildBlobURL(cfg AzureStoreConfig, client pipeline.Pipeline, path string) (*azblob.BlockBlobURL, error) {
func (a *azureBackend) buildBlobURL(cfg azureStoreConfig, client pipeline.Pipeline, path string) (*azblob.BlockBlobURL, error) {
u, err := url.Parse(fmt.Sprint(cfg.containerURL(), "/", cfg.rewritePath(cfg.prefix, path)))
if err != nil {
return nil, err
Expand All @@ -324,7 +414,15 @@ func (a *azureBackend) buildBlobURL(cfg AzureStoreConfig, client pipeline.Pipeli
}

// Cache UserDelegationCredentials and refresh them when needed
func (a *azureBackend) getUserDelegationCredential() (*service.UserDelegationCredential, error) {
func (a *azureBackend) getUserDelegationCredential(endpoint *url.URL) (*service.UserDelegationCredential, error) {
var cfg, err = parseAzureEndpoint(endpoint)
if err != nil {
return nil, err
}
a.mu.Lock()
var udc, hasCachedUdc = a.udcs[cfg.accountTenantID]
a.mu.Unlock()

// https://learn.microsoft.com/en-us/azure/storage/blobs/storage-blob-user-delegation-sas-create-cli#use-azure-ad-credentials-to-secure-a-sas
// According to the above docs, signed URLs generated with a UDC are invalid after
// that UDC expires. In addition, a UDC can live up to 7 days. So let's ensure that
Expand All @@ -333,23 +431,41 @@ func (a *azureBackend) getUserDelegationCredential() (*service.UserDelegationCre
// ----| NOW |------|NOW+5Day|-----| udcExp |---- No need to refresh
// ----| NOW |-----| udcExp |-----|NOW+5Day|---- Need to refresh
// ----|udcExp|-----| NOW | ------------------ Need to refresh
if a.udc == nil || (a.udcExp != nil && a.udcExp.Before(time.Now().Add(time.Hour*24*5))) {
if !hasCachedUdc || (udc.exp != nil && udc.exp.Before(time.Now().Add(time.Hour*24*5))) {
// Generate UDCs that expire 6 days from now, and refresh them after they
// have less than 5 days left until they expire.
var startTime = time.Now().Add(time.Second * -10)
var expTime = time.Now().Add(time.Hour * 24 * 6)
var info = service.KeyInfo{
Start: to.Ptr(time.Now().Add(time.Second * -10).UTC().Format(sas.TimeFormat)),
Start: to.Ptr(startTime.UTC().Format(sas.TimeFormat)),
Expiry: to.Ptr(expTime.UTC().Format(sas.TimeFormat)),
}

udc, err := a.svcClient.GetUserDelegationCredential(context.Background(), info, nil)
var serviceClient, err = a.getAzureServiceClient(endpoint)
if err != nil {
return nil, err
}

cred, err := serviceClient.GetUserDelegationCredential(context.Background(), info, nil)
if err != nil {
return nil, err
}

a.udc = udc
a.udcExp = &expTime
log.WithFields(log.Fields{
"newExpiration": expTime,
"newStart": startTime.String(),
"service.KeyInfo": info,
"tenant": cfg.accountTenantID,
}).Info("Refreshing Azure Storage UDC")

udc = udcAndExp{
udc: cred,
exp: &expTime,
}
a.mu.Lock()
a.udcs[cfg.accountTenantID] = udc
a.mu.Unlock()
}

return a.udc, nil
return udc.udc, nil
}
Loading

0 comments on commit 79690d5

Please sign in to comment.