diff --git a/cmd/aws/aws.go b/cmd/aws/aws.go index 17fc5db..2480bd8 100644 --- a/cmd/aws/aws.go +++ b/cmd/aws/aws.go @@ -15,17 +15,18 @@ package aws import ( + "context" "encoding/json" "errors" "fmt" "strings" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/secretsmanager" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/secretsmanager" + "github.com/aws/aws-sdk-go-v2/service/secretsmanager/types" "github.com/boxboat/dockcmd/cmd/common" "github.com/patrickmn/go-cache" ) @@ -34,8 +35,9 @@ const latestVersion = "AWSCURRENT" type SecretsClient struct { common.SecretClient - secretsManagerClient *secretsmanager.SecretsManager + secretsManagerClient *secretsmanager.Client secretCache *cache.Cache + ctx context.Context } type SecretsClientOpt interface { @@ -43,6 +45,7 @@ type SecretsClientOpt interface { } type secretsClientOpts struct { + ctx context.Context region string profile string accessKeyID string @@ -89,6 +92,13 @@ func UseChainCredentials() SecretsClientOpt { }) } +func WithContext(ctx context.Context) SecretsClientOpt { + return secretClientOptFn(func(opts *secretsClientOpts) error { + opts.ctx = ctx + return nil + }) +} + func (opt secretClientOptFn) configureSecretsClient(opts *secretsClientOpts) error { return opt(opts) } @@ -103,28 +113,40 @@ func NewSecretsClient(opts ...SecretsClientOpt) (*SecretsClient, error) { } } + if o.ctx == nil { + o.ctx = context.Background() + } + client := &SecretsClient{ secretCache: cache.New(o.cacheTTL, o.cacheTTL), + ctx: o.ctx, } - sess, err := session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - }) - if err != nil { - return nil, err - } + var cfg aws.Config + var err error - var creds = sess.Config.Credentials if !o.useChainCredentials { if o.accessKeyID == "" || o.secretAccessKey == "" { return nil, errors.New("no aws credentials provided") } - creds = credentials.NewStaticCredentials(o.accessKeyID, o.secretAccessKey, "") + cfg, err = config.LoadDefaultConfig( + o.ctx, + config.WithRegion(o.region), + config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(o.accessKeyID, o.secretAccessKey, ""))) + if err != nil { + return nil, err + } + } else { + cfg, err = config.LoadDefaultConfig( + o.ctx, + config.WithRegion(o.region), + config.WithSharedConfigProfile(o.profile)) + if err != nil { + return nil, err + } } - client.secretsManagerClient = secretsmanager.New( - sess, - aws.NewConfig().WithRegion(o.region).WithCredentials(creds)) + client.secretsManagerClient = secretsmanager.NewFromConfig(cfg) return client, nil } @@ -148,43 +170,28 @@ func (c *SecretsClient) getSecret(secretName string) (string, string, error) { } common.Logger.Debugf("retrieving [%s] from AWS Secrets Manager", adjustedSecretName) - result, err := c.secretsManagerClient.GetSecretValue(input) + result, err := c.secretsManagerClient.GetSecretValue(c.ctx, input) if err != nil { var errorMessage string - if aerr, ok := err.(awserr.Error); ok { - switch aerr.Code() { - case secretsmanager.ErrCodeDecryptionFailure: - // Secrets Manager can't decrypt the protected secret text using the provided KMS key. - errorMessage = fmt.Sprintf("secret{%s}: %v %v", adjustedSecretName, secretsmanager.ErrCodeDecryptionFailure, aerr.Error()) - break - - case secretsmanager.ErrCodeInternalServiceError: - // An error occurred on the server side. - errorMessage = fmt.Sprintf("secret{%s}: %v %v", adjustedSecretName, secretsmanager.ErrCodeInternalServiceError, aerr.Error()) - break - - case secretsmanager.ErrCodeInvalidParameterException: - // You provided an invalid value for a parameter. - errorMessage = fmt.Sprintf("secret{%s}: %v %v", adjustedSecretName, secretsmanager.ErrCodeInvalidParameterException, aerr.Error()) - break - - case secretsmanager.ErrCodeInvalidRequestException: - // You provided a parameter value that is not valid for the current state of the resource. - errorMessage = fmt.Sprintf("secret{%s}: %v %v", adjustedSecretName, secretsmanager.ErrCodeInvalidRequestException, aerr.Error()) - break - - case secretsmanager.ErrCodeResourceNotFoundException: - // We can't find the resource that you asked for. - errorMessage = fmt.Sprintf("secret{%s}: %v %v", adjustedSecretName, secretsmanager.ErrCodeResourceNotFoundException, aerr.Error()) - break - - default: - errorMessage = fmt.Sprintf("secret{%s[%s]}: %v", adjustedSecretName, aerr.Error()) - break - } + var decryptionFailure *types.DecryptionFailure + var internalServer *types.InternalServiceError + var invalidParameter *types.InvalidParameterException + var invalidRequest *types.InvalidRequestException + var notFound *types.ResourceNotFoundException + + if errors.As(err, &decryptionFailure) { + errorMessage = fmt.Sprintf("secret{%s}: %s %s", adjustedSecretName, decryptionFailure.ErrorCode(), decryptionFailure.ErrorMessage()) + } else if errors.As(err, &internalServer) { + errorMessage = fmt.Sprintf("secret{%s}: %s %s", adjustedSecretName, internalServer.ErrorCode(), internalServer.ErrorMessage()) + } else if errors.As(err, &invalidParameter) { + errorMessage = fmt.Sprintf("secret{%s}: %s %s", adjustedSecretName, invalidParameter.ErrorCode(), invalidParameter.ErrorMessage()) + } else if errors.As(err, &invalidRequest) { + errorMessage = fmt.Sprintf("secret{%s}: %s %s", adjustedSecretName, invalidRequest.ErrorCode(), invalidRequest.ErrorMessage()) + } else if errors.As(err, ¬Found) { + errorMessage = fmt.Sprintf("secret{%s}: %s %s", adjustedSecretName, notFound.ErrorCode(), notFound.ErrorMessage()) } else { - errorMessage = fmt.Sprintln(err.Error()) + errorMessage = fmt.Sprintf("secret{%s[%s]}: %v", adjustedSecretName, err) } return adjustedSecretName, "", errors.New(errorMessage) }