Skip to content

Commit

Permalink
support azure app service mid endpoint
Browse files Browse the repository at this point in the history
Signed-off-by: Hitoshi Kamezaki <[email protected]>
  • Loading branch information
apc-kamezaki committed Oct 2, 2024
1 parent 69cd388 commit 5d09567
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 14 deletions.
49 changes: 36 additions & 13 deletions plugins/rest/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@ import (
"io"
"net/http"
"net/url"
"os"
"time"
)

var (
azureIMDSEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
defaultAPIVersion = "2018-02-01"
defaultResource = "https://storage.azure.com/"
timeout = 5 * time.Second
azureIMDSEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
defaultAPIVersion = "2018-02-01"
defaultResource = "https://storage.azure.com/"
timeout = 5 * time.Second
defaultAPIVersionForAppServiceMsi = "2019-08-01"
)

// azureManagedIdentitiesToken holds a token for managed identities for Azure resources
Expand All @@ -41,12 +43,13 @@ func (e *azureManagedIdentitiesError) Error() string {

// azureManagedIdentitiesAuthPlugin uses an azureManagedIdentitiesToken.AccessToken for bearer authorization
type azureManagedIdentitiesAuthPlugin struct {
Endpoint string `json:"endpoint"`
APIVersion string `json:"api_version"`
Resource string `json:"resource"`
ObjectID string `json:"object_id"`
ClientID string `json:"client_id"`
MiResID string `json:"mi_res_id"`
Endpoint string `json:"endpoint"`
APIVersion string `json:"api_version"`
Resource string `json:"resource"`
ObjectID string `json:"object_id"`
ClientID string `json:"client_id"`
MiResID string `json:"mi_res_id"`
UseAppServiceMsi bool `json:"use_app_service_msi,omitempty"`
}

func (ap *azureManagedIdentitiesAuthPlugin) NewClient(c Config) (*http.Client, error) {
Expand All @@ -55,15 +58,25 @@ func (ap *azureManagedIdentitiesAuthPlugin) NewClient(c Config) (*http.Client, e
}

if ap.Endpoint == "" {
ap.Endpoint = azureIMDSEndpoint
identityEndpoint := os.Getenv("IDENTITY_ENDPOINT")
if identityEndpoint != "" {
ap.UseAppServiceMsi = true
ap.Endpoint = identityEndpoint
} else {
ap.Endpoint = azureIMDSEndpoint
}
}

if ap.Resource == "" {
ap.Resource = defaultResource
}

if ap.APIVersion == "" {
ap.APIVersion = defaultAPIVersion
if ap.UseAppServiceMsi {
ap.APIVersion = defaultAPIVersionForAppServiceMsi
} else {
ap.APIVersion = defaultAPIVersion
}
}

t, err := DefaultTLSConfig(c)
Expand All @@ -78,6 +91,7 @@ func (ap *azureManagedIdentitiesAuthPlugin) Prepare(req *http.Request) error {
token, err := azureManagedIdentitiesTokenRequest(
ap.Endpoint, ap.APIVersion, ap.Resource,
ap.ObjectID, ap.ClientID, ap.MiResID,
ap.UseAppServiceMsi,
)
if err != nil {
return err
Expand All @@ -90,6 +104,7 @@ func (ap *azureManagedIdentitiesAuthPlugin) Prepare(req *http.Request) error {
// azureManagedIdentitiesTokenRequest fetches an azureManagedIdentitiesToken
func azureManagedIdentitiesTokenRequest(
endpoint, apiVersion, resource, objectID, clientID, miResID string,
useAppServiceMsi bool,
) (azureManagedIdentitiesToken, error) {
var token azureManagedIdentitiesToken
e := buildAzureManagedIdentitiesRequestPath(endpoint, apiVersion, resource, objectID, clientID, miResID)
Expand All @@ -98,7 +113,15 @@ func azureManagedIdentitiesTokenRequest(
if err != nil {
return token, err
}
request.Header.Add("Metadata", "true")
if useAppServiceMsi {
identityHeader := os.Getenv("IDENTITY_HEADER")
if identityHeader == "" {
return token, errors.New("azure managed identities auth: IDENTITY_HEADER env var not found")
}
request.Header.Add("x-identity-header", identityHeader)
} else {
request.Header.Add("Metadata", "true")
}

httpClient := http.Client{Timeout: timeout}
response, err := httpClient.Do(request)
Expand Down
59 changes: 58 additions & 1 deletion plugins/rest/azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ func assertParamsEqual(t *testing.T, expected url.Values, actual url.Values, lab
t.Errorf("%s: expected %s, got %s", label, expected.Encode(), actual.Encode())
}
}

func TestAzureManagedIdentitiesAuthPlugin_NewClient(t *testing.T) {
tests := []struct {
label string
Expand Down Expand Up @@ -79,6 +78,64 @@ func TestAzureManagedIdentitiesAuthPlugin_NewClient(t *testing.T) {
}
}

func TestAzureManagedIdentitiesAuthPluginForAppService_NewClient(t *testing.T) {
tests := []struct {
label string
endpoint string
apiVersion string
resource string
objectID string
clientID string
miResID string
}{
{
"test all defaults",
"", "", "", "", "", "",
},
{
"test no defaults",
"some_endpoint", "some_version", "some_resource", "some_oid", "some_cid", "some_miresid",
},
}

nonEmptyString := func(value string, defaultValue string) string {
if value == "" {
return defaultValue
}
return value
}

defaultIdentityEndpoint := "http://localhost:42356/msi/token"
defaultIdentityHeader := "IdentityHeader"
t.Setenv("IDENTITY_ENDPOINT", defaultIdentityEndpoint)
t.Setenv("IDENTITY_HEADER", defaultIdentityHeader)

for _, tt := range tests {
config := generateConfigString(tt.endpoint, tt.apiVersion, tt.resource, tt.objectID, tt.clientID, tt.miResID)

client, err := New([]byte(config), map[string]*keys.Config{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

ap := client.config.Credentials.AzureManagedIdentity
_, err = ap.NewClient(client.config)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

// We test that default values are set correctly in the azureManagedIdentitiesAuthPlugin
// Note that there is significant overlap between TestAzureManagedIdentitiesAuthPlugin_NewClient and TestAzureManagedIdentitiesAuthPlugin
// This is because the latter cannot test default endpoint setting, which we do here
assertStringsEqual(t, nonEmptyString(tt.endpoint, defaultIdentityEndpoint), ap.Endpoint, tt.label)
assertStringsEqual(t, nonEmptyString(tt.apiVersion, defaultAPIVersionForAppServiceMsi), ap.APIVersion, tt.label)
assertStringsEqual(t, nonEmptyString(tt.resource, defaultResource), ap.Resource, tt.label)
assertStringsEqual(t, tt.objectID, ap.ObjectID, tt.label)
assertStringsEqual(t, tt.clientID, ap.ClientID, tt.label)
assertStringsEqual(t, tt.miResID, ap.MiResID, tt.label)
}
}

func TestAzureManagedIdentitiesAuthPlugin(t *testing.T) {
tests := []struct {
label string
Expand Down

0 comments on commit 5d09567

Please sign in to comment.