Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate PerRPCCreds from Oauth creds to eliminate undefined behavior #603

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions go/pkg/balancer/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
load("@io_bazel_rules_go//go:def.bzl", "go_library")

go_library(
name = "balancer",
srcs = [
"roundrobin.go",
],
srcs = ["roundrobin.go"],
importpath = "github.com/bazelbuild/remote-apis-sdks/go/pkg/balancer",
visibility = ["//visibility:public"],
deps = [
"@org_golang_google_grpc//:go_default_library",
],
)

20 changes: 4 additions & 16 deletions go/pkg/credshelper/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,20 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")

go_library(
name = "credshelper",
srcs = [
"credshelper.go",
],
srcs = ["credshelper.go"],
importpath = "github.com/bazelbuild/remote-apis-sdks/go/pkg/credshelper",
visibility = ["//visibility:public"],
deps = [
"//go/api/credshelper",
"//go/pkg/digest",
"//go/pkg/digest",
"@com_github_golang_glog//:glog",
"@com_github_hectane_go_acl//:go-acl",
"@org_golang_google_grpc//credentials",
"@org_golang_google_grpc//credentials/oauth",
"@org_golang_google_protobuf//encoding/prototext",
"@org_golang_google_protobuf//types/known/timestamppb",
"@org_golang_x_oauth2//:oauth2",
"@org_golang_x_oauth2//google",
],
)

go_test(
name = "credshelper_test",
srcs = [
"credshelper_test.go",
],
srcs = ["credshelper_test.go"],
embed = [":credshelper"],
deps = [
"@com_github_google_go_cmp//cmp",
"@org_golang_x_oauth2//:oauth2",
],
)
89 changes: 61 additions & 28 deletions go/pkg/credshelper/credshelper.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (

log "github.com/golang/glog"
"golang.org/x/oauth2"
"google.golang.org/grpc/credentials"
grpcOauth "google.golang.org/grpc/credentials/oauth"
)

Expand Down Expand Up @@ -77,6 +78,16 @@ func (r *reusableCmd) Digest() digest.Digest {
// Credentials provides auth functionalities using an external credentials helper
type Credentials struct {
tokenSource *grpcOauth.TokenSource
perRPCCreds *perRPCCredentials
credsHelperCmd *reusableCmd
}

// perRPCCredentials fullfills the grpc.Credentials.PerRPCCredentials interface
// to provde auth functionalities with headers
type perRPCCredentials struct {
headers map[string]string
expiry time.Time
headersLock sync.RWMutex
credsHelperCmd *reusableCmd
}

Expand All @@ -86,9 +97,6 @@ type Credentials struct {
// oauth2.TokenSource and credentials.PerRPCCredentials interfaces.
type externalTokenSource struct {
credsHelperCmd *reusableCmd
headers map[string]string
expiry time.Time
headersLock sync.RWMutex
}

// TokenSource returns a token source for this credentials instance.
Expand All @@ -99,6 +107,20 @@ func (c *Credentials) TokenSource() *grpcOauth.TokenSource {
return c.tokenSource
}

// PerRPCCreds returns a perRPCCredentials for this credentials instance.
func (c *Credentials) PerRPCCreds() credentials.PerRPCCredentials {
if c == nil {
return nil
}
// If no perRPCCreds exist for this Credentials object, then
// grpcOauth.TokenSource will do since it implements the same interface
// and some credentials helpers may only provide a token without headers
if c.perRPCCreds == nil {
return c.TokenSource()
}
return c.perRPCCreds
}

// Token retrieves an oauth2 token from the external tokensource.
func (ts *externalTokenSource) Token() (*oauth2.Token, error) {
if ts == nil {
Expand All @@ -108,27 +130,30 @@ func (ts *externalTokenSource) Token() (*oauth2.Token, error) {
if err != nil {
return nil, err
}
if credsOut.tk.AccessToken == "" {
return nil, fmt.Errorf("no token was printed by the credentials helper")
}
log.Infof("'%s' credentials refreshed at %v, expires at %v", ts.credsHelperCmd, time.Now(), credsOut.tk.Expiry)
return credsOut.tk, err
}

// GetRequestMetadata gets the current request metadata, refreshing tokens if required.
func (ts *externalTokenSource) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
ts.headersLock.RLock()
defer ts.headersLock.RUnlock()
if ts.expiry.Before(nowFn().Add(-expiryBuffer)) {
credsOut, err := runCredsHelperCmd(ts.credsHelperCmd)
func (p *perRPCCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
p.headersLock.RLock()
defer p.headersLock.RUnlock()
if p.expiry.Before(nowFn().Add(-expiryBuffer)) {
credsOut, err := runCredsHelperCmd(p.credsHelperCmd)
if err != nil {
return nil, err
}
ts.expiry = credsOut.tk.Expiry
ts.headers = credsOut.hdrs
p.expiry = credsOut.tk.Expiry
p.headers = credsOut.hdrs
}
return ts.headers, nil
return p.headers, nil
}

// RequireTransportSecurity indicates whether the credentials require transport security.
func (ts *externalTokenSource) RequireTransportSecurity() bool {
func (p *perRPCCredentials) RequireTransportSecurity() bool {
return true
}

Expand All @@ -149,28 +174,36 @@ func NewExternalCredentials(credshelper string, credshelperArgs []string) (*Cred
c := &Credentials{
credsHelperCmd: credsHelperCmd,
}
baseTS := &externalTokenSource{
credsHelperCmd: credsHelperCmd,
if len(credsOut.hdrs) != 0 {
c.perRPCCreds = &perRPCCredentials{
headers: credsOut.hdrs,
expiry: credsOut.tk.Expiry,
credsHelperCmd: credsHelperCmd,
}
}
c.tokenSource = &grpcOauth.TokenSource{
// Wrap the base token source with a ReuseTokenSource so that we only
// generate new credentials when the current one is about to expire.
// This is needed because retrieving the token is expensive and some
// token providers have per hour rate limits.
TokenSource: oauth2.ReuseTokenSourceWithExpiry(
credsOut.tk,
baseTS,
// Refresh tokens a bit early to be safe
expiryBuffer,
),
if credsOut.tk.AccessToken != "" {
baseTS := &externalTokenSource{
credsHelperCmd: credsHelperCmd,
}
c.tokenSource = &grpcOauth.TokenSource{
// Wrap the base token source with a ReuseTokenSource so that we only
// generate new credentials when the current one is about to expire.
// This is needed because retrieving the token is expensive and some
// token providers have per hour rate limits.
TokenSource: oauth2.ReuseTokenSourceWithExpiry(
credsOut.tk,
baseTS,
// Refresh tokens a bit early to be safe
expiryBuffer,
),
}
}
return c, nil
}

type credshelperOutput struct {
hdrs map[string]string
tk *oauth2.Token
rexp time.Time
}

func runCredsHelperCmd(credsHelperCmd *reusableCmd) (*credshelperOutput, error) {
Expand Down Expand Up @@ -203,8 +236,8 @@ func parseTokenExpiryFromOutput(out string) (*credshelperOutput, error) {
if err := json.Unmarshal([]byte(out), &jsonOut); err != nil {
return nil, fmt.Errorf("error while decoding credshelper output:%v", err)
}
if jsonOut.Token == "" {
return nil, fmt.Errorf("no token was printed by the credentials helper")
if jsonOut.Token == "" && len(jsonOut.Headers) == 0 {
return nil, fmt.Errorf("both token and headers are empty, invalid credentials")
}
credsOut.tk = &oauth2.Token{AccessToken: jsonOut.Token}
credsOut.hdrs = jsonOut.Headers
Expand Down
13 changes: 6 additions & 7 deletions go/pkg/credshelper/credshelper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ func TestNewExternalCredentials(t *testing.T) {
credshelperOut: fmt.Sprintf(`{"token":"%v","expiry":""}`, testToken),
}, {
name: "No Token",
wantErr: true,
credshelperOut: `{"headers":{"hdr":"val"},"token":"","expiry":""}`,
}, {
name: "Credshelper Command Passed - No Expiry",
Expand Down Expand Up @@ -168,7 +167,7 @@ func TestNewExternalCredentials(t *testing.T) {
if test.wantErr && err == nil {
t.Fatalf("NewExternalCredentials did not return an error.")
}
if !test.wantErr {
if !test.wantErr && test.name != "No Token" {
if err != nil {
t.Fatalf("NewExternalCredentials returned an error: %v", err)
}
Expand Down Expand Up @@ -247,24 +246,24 @@ func TestGetRequestMetadata(t *testing.T) {
credshelperArgs = []string{test.credshelperOut}
}
credsHelperCmd := newReusableCmd(credshelper, credshelperArgs)
exTs := externalTokenSource{
p := perRPCCredentials{
credsHelperCmd: credsHelperCmd,
expiry: test.tsExp,
headers: test.tsHeaders,
headersLock: sync.RWMutex{},
}
hdrs, err := exTs.GetRequestMetadata(context.Background(), "uri")
hdrs, err := p.GetRequestMetadata(context.Background(), "uri")
if test.wantErr && err == nil {
t.Fatalf("GetRequestMetadata did not return an error.")
}
if !test.wantErr {
if err != nil {
t.Fatalf("GetRequestMetadata returned an error: %v", err)
}
if !reflect.DeepEqual(hdrs, exTs.headers) {
t.Errorf("GetRequestMetadata did not update headers in the tokensource: returned hdrs: %v, tokensource headers: %v", hdrs, exTs.headers)
if !reflect.DeepEqual(hdrs, p.headers) {
t.Errorf("GetRequestMetadata did not update headers in the tokensource: returned hdrs: %v, tokensource headers: %v", hdrs, p.headers)
}
if !exp.Equal(exTs.expiry) {
if !exp.Equal(p.expiry) {
t.Errorf("GetRequestMetadata did not update expiry in the tokensource")
}
if !test.wantExpired && !reflect.DeepEqual(hdrs, testHdrs) {
Expand Down
2 changes: 1 addition & 1 deletion go/pkg/flags/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func NewClientFromFlags(ctx context.Context, opts ...client.Opt) (*client.Client
if err != nil {
return nil, fmt.Errorf("credentials helper failed. Please try again or use another method of authentication:%v", err)
}
perRPCCreds = &client.PerRPCCreds{Creds: creds.TokenSource()}
perRPCCreds = &client.PerRPCCreds{Creds: creds.PerRPCCreds()}
}
opts = tOpts

Expand Down
Loading