Skip to content

Commit

Permalink
plugins/rest/azure: Support managed identity for App Servivce / Conta…
Browse files Browse the repository at this point in the history
…iner Apps

IDENTITY_ENDPOINT and IDENTITY_HEADER envirnnment variables are
provided on Azure App Service for getting the token.
We can detect these variables and switch the endpoint
and header value from IMDS.

Fixes: open-policy-agent#7085
Signed-off-by: Hitoshi Kamezaki <[email protected]>
  • Loading branch information
apc-kamezaki committed Oct 4, 2024
1 parent 69cd388 commit 815d1cb
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 14 deletions.
49 changes: 36 additions & 13 deletions plugins/rest/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@ import (
"io"
"net/http"
"net/url"
"os"
"time"
)

var (
azureIMDSEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
defaultAPIVersion = "2018-02-01"
defaultResource = "https://storage.azure.com/"
timeout = 5 * time.Second
azureIMDSEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
defaultAPIVersion = "2018-02-01"
defaultResource = "https://storage.azure.com/"
timeout = 5 * time.Second
defaultAPIVersionForAppServiceMsi = "2019-08-01"
)

// azureManagedIdentitiesToken holds a token for managed identities for Azure resources
Expand All @@ -41,12 +43,13 @@ func (e *azureManagedIdentitiesError) Error() string {

// azureManagedIdentitiesAuthPlugin uses an azureManagedIdentitiesToken.AccessToken for bearer authorization
type azureManagedIdentitiesAuthPlugin struct {
Endpoint string `json:"endpoint"`
APIVersion string `json:"api_version"`
Resource string `json:"resource"`
ObjectID string `json:"object_id"`
ClientID string `json:"client_id"`
MiResID string `json:"mi_res_id"`
Endpoint string `json:"endpoint"`
APIVersion string `json:"api_version"`
Resource string `json:"resource"`
ObjectID string `json:"object_id"`
ClientID string `json:"client_id"`
MiResID string `json:"mi_res_id"`
UseAppServiceMsi bool `json:"use_app_service_msi,omitempty"`
}

func (ap *azureManagedIdentitiesAuthPlugin) NewClient(c Config) (*http.Client, error) {
Expand All @@ -55,15 +58,25 @@ func (ap *azureManagedIdentitiesAuthPlugin) NewClient(c Config) (*http.Client, e
}

if ap.Endpoint == "" {
ap.Endpoint = azureIMDSEndpoint
identityEndpoint := os.Getenv("IDENTITY_ENDPOINT")
if identityEndpoint != "" {
ap.UseAppServiceMsi = true
ap.Endpoint = identityEndpoint
} else {
ap.Endpoint = azureIMDSEndpoint
}
}

if ap.Resource == "" {
ap.Resource = defaultResource
}

if ap.APIVersion == "" {
ap.APIVersion = defaultAPIVersion
if ap.UseAppServiceMsi {
ap.APIVersion = defaultAPIVersionForAppServiceMsi
} else {
ap.APIVersion = defaultAPIVersion
}
}

t, err := DefaultTLSConfig(c)
Expand All @@ -78,6 +91,7 @@ func (ap *azureManagedIdentitiesAuthPlugin) Prepare(req *http.Request) error {
token, err := azureManagedIdentitiesTokenRequest(
ap.Endpoint, ap.APIVersion, ap.Resource,
ap.ObjectID, ap.ClientID, ap.MiResID,
ap.UseAppServiceMsi,
)
if err != nil {
return err
Expand All @@ -90,6 +104,7 @@ func (ap *azureManagedIdentitiesAuthPlugin) Prepare(req *http.Request) error {
// azureManagedIdentitiesTokenRequest fetches an azureManagedIdentitiesToken
func azureManagedIdentitiesTokenRequest(
endpoint, apiVersion, resource, objectID, clientID, miResID string,
useAppServiceMsi bool,
) (azureManagedIdentitiesToken, error) {
var token azureManagedIdentitiesToken
e := buildAzureManagedIdentitiesRequestPath(endpoint, apiVersion, resource, objectID, clientID, miResID)
Expand All @@ -98,7 +113,15 @@ func azureManagedIdentitiesTokenRequest(
if err != nil {
return token, err
}
request.Header.Add("Metadata", "true")
if useAppServiceMsi {
identityHeader := os.Getenv("IDENTITY_HEADER")
if identityHeader == "" {
return token, errors.New("azure managed identities auth: IDENTITY_HEADER env var not found")
}
request.Header.Add("x-identity-header", identityHeader)
} else {
request.Header.Add("Metadata", "true")
}

httpClient := http.Client{Timeout: timeout}
response, err := httpClient.Do(request)
Expand Down
59 changes: 58 additions & 1 deletion plugins/rest/azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ func assertParamsEqual(t *testing.T, expected url.Values, actual url.Values, lab
t.Errorf("%s: expected %s, got %s", label, expected.Encode(), actual.Encode())
}
}

func TestAzureManagedIdentitiesAuthPlugin_NewClient(t *testing.T) {
tests := []struct {
label string
Expand Down Expand Up @@ -79,6 +78,64 @@ func TestAzureManagedIdentitiesAuthPlugin_NewClient(t *testing.T) {
}
}

func TestAzureManagedIdentitiesAuthPluginForAppService_NewClient(t *testing.T) {
tests := []struct {
label string
endpoint string
apiVersion string
resource string
objectID string
clientID string
miResID string
}{
{
"test all defaults",
"", "", "", "", "", "",
},
{
"test no defaults",
"some_endpoint", "some_version", "some_resource", "some_oid", "some_cid", "some_miresid",
},
}

nonEmptyString := func(value string, defaultValue string) string {
if value == "" {
return defaultValue
}
return value
}

defaultIdentityEndpoint := "http://localhost:42356/msi/token"
defaultIdentityHeader := "IdentityHeader"
t.Setenv("IDENTITY_ENDPOINT", defaultIdentityEndpoint)
t.Setenv("IDENTITY_HEADER", defaultIdentityHeader)

for _, tt := range tests {
config := generateConfigString(tt.endpoint, tt.apiVersion, tt.resource, tt.objectID, tt.clientID, tt.miResID)

client, err := New([]byte(config), map[string]*keys.Config{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

ap := client.config.Credentials.AzureManagedIdentity
_, err = ap.NewClient(client.config)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

// We test that default values are set correctly in the azureManagedIdentitiesAuthPlugin
// Note that there is significant overlap between TestAzureManagedIdentitiesAuthPlugin_NewClient and TestAzureManagedIdentitiesAuthPlugin
// This is because the latter cannot test default endpoint setting, which we do here
assertStringsEqual(t, nonEmptyString(tt.endpoint, defaultIdentityEndpoint), ap.Endpoint, tt.label)
assertStringsEqual(t, nonEmptyString(tt.apiVersion, defaultAPIVersionForAppServiceMsi), ap.APIVersion, tt.label)
assertStringsEqual(t, nonEmptyString(tt.resource, defaultResource), ap.Resource, tt.label)
assertStringsEqual(t, tt.objectID, ap.ObjectID, tt.label)
assertStringsEqual(t, tt.clientID, ap.ClientID, tt.label)
assertStringsEqual(t, tt.miResID, ap.MiResID, tt.label)
}
}

func TestAzureManagedIdentitiesAuthPlugin(t *testing.T) {
tests := []struct {
label string
Expand Down

0 comments on commit 815d1cb

Please sign in to comment.