Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support adding gRPC metadata to outgoing RPCs #601

Merged
merged 3 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions go/pkg/client/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ go_library(
"@org_golang_google_grpc//codes",
"@org_golang_google_grpc//credentials",
"@org_golang_google_grpc//credentials/oauth",
"@org_golang_google_grpc//metadata",
"@org_golang_google_grpc//status",
"@org_golang_google_protobuf//encoding/prototext",
"@org_golang_google_protobuf//encoding/protowire",
Expand Down Expand Up @@ -82,6 +83,7 @@ go_test(
"@go_googleapis//google/rpc:status_go_proto",
"@org_golang_google_grpc//:go_default_library",
"@org_golang_google_grpc//codes",
"@org_golang_google_grpc//metadata",
"@org_golang_google_grpc//status",
"@org_golang_google_protobuf//proto",
"@org_golang_google_protobuf//testing/protocmp",
Expand Down
33 changes: 33 additions & 0 deletions go/pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"time"

"errors"

"github.com/bazelbuild/remote-apis-sdks/go/pkg/actas"
"github.com/bazelbuild/remote-apis-sdks/go/pkg/balancer"
"github.com/bazelbuild/remote-apis-sdks/go/pkg/chunker"
Expand All @@ -26,6 +27,7 @@
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/oauth"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

// Redundant imports are required for the google3 mirror. Aliases should not be changed.
Expand Down Expand Up @@ -189,6 +191,7 @@
casDownloaders *semaphore.Weighted
casDownloadRequests chan *downloadRequest
rpcTimeouts RPCTimeouts
remoteHeaders map[string]string

Check failure on line 194 in go/pkg/client/client.go

View workflow job for this annotation

GitHub Actions / lint

field `remoteHeaders` is unused (unused)
creds credentials.PerRPCCredentials
uploadOnce sync.Once
downloadOnce sync.Once
Expand Down Expand Up @@ -551,6 +554,10 @@
//
// If this is specified, TLSClientAuthCert must also be specified.
TLSClientAuthKey string

// RemoteHeaders specifies additional gRPC metadata headers to be passed with
// each RPC. These headers are not meant to be used for authentication.
RemoteHeaders map[string][]string
}

func createTLSConfig(params DialParams) (*tls.Config, error) {
Expand Down Expand Up @@ -651,6 +658,32 @@
opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
}

if len(params.RemoteHeaders) > 0 {
md := metadata.MD(params.RemoteHeaders)
opts = append(
opts,
grpc.WithChainUnaryInterceptor(func(
ctx context.Context,
method string,
req, reply any,
cc *grpc.ClientConn,
invoker grpc.UnaryInvoker,
opts ...grpc.CallOption) error {
ctx = metadata.NewOutgoingContext(ctx, md)
return invoker(ctx, method, req, reply, cc, opts...)
}),
grpc.WithChainStreamInterceptor(func(
ctx context.Context,
desc *grpc.StreamDesc,
cc *grpc.ClientConn,
method string,
streamer grpc.Streamer,
opts ...grpc.CallOption) (grpc.ClientStream, error) {
ctx = metadata.NewOutgoingContext(ctx, md)
return streamer(ctx, desc, cc, method, opts...)
}))
}

return opts, authUsed, nil
}

Expand Down
104 changes: 104 additions & 0 deletions go/pkg/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,21 @@
import (
"context"
"errors"
"io"
"net"
"os"
"path"
"testing"

"github.com/bazelbuild/remote-apis-sdks/go/pkg/digest"
repb "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2"
svpb "github.com/bazelbuild/remote-apis/build/bazel/semver"
"github.com/google/go-cmp/cmp"
bspb "google.golang.org/genproto/googleapis/bytestream"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

const (
Expand Down Expand Up @@ -266,3 +275,98 @@
})
}
}

func TestRemoteHeaders(t *testing.T) {
one := []byte{1}
oneDigest := digest.NewFromBlob(one)
want := map[string][]string{"x-test": {"test123"}}
checkHeaders := func(t *testing.T, got metadata.MD) {
t.Helper()
for k, wantV := range want {
if gotV, ok := got[k]; !ok {
t.Errorf("header %s not seen in server metadata", k)
} else if len(gotV) != 1 {
t.Errorf("header %s seen %d times", k, len(wantV))
} else if diff := cmp.Diff(gotV, wantV); diff != "" {
t.Errorf("got header %s value %q; want %q; diff (-got, +want) %s", k, gotV, wantV, diff)
}
}
}

ctx := context.Background()
listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("Cannot listen: %v", err)
}
defer listener.Close()
server := grpc.NewServer()
fake := &fakeByteStreamForRemoteHeaders{}
bspb.RegisterByteStreamServer(server, fake)
repb.RegisterCapabilitiesServer(server, &fakeCapabilitiesForRemoteHeaders{})
go server.Serve(listener)

Check failure on line 306 in go/pkg/client/client_test.go

View workflow job for this annotation

GitHub Actions / lint

Error return value of `server.Serve` is not checked (errcheck)
defer server.Stop()

client, err := NewClient(ctx, "instance", DialParams{
Service: listener.Addr().String(),
NoSecurity: true,
RemoteHeaders: want,
})
if err != nil {
t.Fatalf("Cannot create client: %v", err)
}
defer client.Close()

t.Run("unary", func(t *testing.T) {
if _, err := client.WriteBlob(ctx, one); err != nil {
t.Fatalf("Writing blob: %v", err)
}
checkHeaders(t, fake.writeHeaders)
})

t.Run("stream", func(t *testing.T) {
if _, _, err := client.ReadBlob(ctx, oneDigest); err != nil {
t.Fatalf("Reading blob: %v", err)
}
checkHeaders(t, fake.readHeaders)
})
}

type fakeByteStreamForRemoteHeaders struct {
bspb.UnimplementedByteStreamServer
readHeaders, writeHeaders metadata.MD
}

func (bs *fakeByteStreamForRemoteHeaders) Read(req *bspb.ReadRequest, stream bspb.ByteStream_ReadServer) error {
md, ok := metadata.FromIncomingContext(stream.Context())
if !ok {
return status.Error(codes.InvalidArgument, "metadata not found")

Check failure on line 342 in go/pkg/client/client_test.go

View workflow job for this annotation

GitHub Actions / lint

error returned from external package is unwrapped: sig: func google.golang.org/grpc/status.Error(c google.golang.org/grpc/codes.Code, msg string) error (wrapcheck)
}
bs.readHeaders = md
stream.Send(&bspb.ReadResponse{Data: []byte{1}})

Check failure on line 345 in go/pkg/client/client_test.go

View workflow job for this annotation

GitHub Actions / lint

Error return value of `stream.Send` is not checked (errcheck)
return nil
}

func (bs *fakeByteStreamForRemoteHeaders) Write(stream bspb.ByteStream_WriteServer) error {
md, ok := metadata.FromIncomingContext(stream.Context())
if !ok {
return status.Error(codes.InvalidArgument, "metadata not found")

Check failure on line 352 in go/pkg/client/client_test.go

View workflow job for this annotation

GitHub Actions / lint

error returned from external package is unwrapped: sig: func google.golang.org/grpc/status.Error(c google.golang.org/grpc/codes.Code, msg string) error (wrapcheck)
}
bs.writeHeaders = md
for {
_, err := stream.Recv()
if err == io.EOF {
break
} else if err != nil {
return err

Check failure on line 360 in go/pkg/client/client_test.go

View workflow job for this annotation

GitHub Actions / lint

error returned from interface method should be wrapped: sig: func (google.golang.org/genproto/googleapis/bytestream.ByteStream_WriteServer).Recv() (*google.golang.org/genproto/googleapis/bytestream.WriteRequest, error) (wrapcheck)
}
}
return stream.SendAndClose(&bspb.WriteResponse{})

Check failure on line 363 in go/pkg/client/client_test.go

View workflow job for this annotation

GitHub Actions / lint

error returned from interface method should be wrapped: sig: func (google.golang.org/genproto/googleapis/bytestream.ByteStream_WriteServer).SendAndClose(*google.golang.org/genproto/googleapis/bytestream.WriteResponse) error (wrapcheck)
}

type fakeCapabilitiesForRemoteHeaders struct {
repb.UnimplementedCapabilitiesServer
}

func (cap *fakeCapabilitiesForRemoteHeaders) GetCapabilities(ctx context.Context, req *repb.GetCapabilitiesRequest) (*repb.ServerCapabilities, error) {
return &repb.ServerCapabilities{}, nil
}
5 changes: 5 additions & 0 deletions go/pkg/flags/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ var (
// Instance gives the instance of remote execution to test (in
// projects/[PROJECT_ID]/instances/[INSTANCE_NAME] format for Google RBE).
Instance = flag.String("instance", "", "The instance ID to target when calling remote execution via gRPC (e.g., projects/$PROJECT/instances/default_instance for Google RBE).")
// RemoteHeaders stores additional headers to pass with each RPC.
RemoteHeaders map[string][]string
// CASConcurrency specifies the maximum number of concurrent upload & download RPCs that can be in flight.
CASConcurrency = flag.Int("cas_concurrency", client.DefaultCASConcurrency, "Num concurrent upload / download RPCs that the SDK is allowed to do.")
// MaxConcurrentRequests denotes the maximum number of concurrent RPCs on a single gRPC connection.
Expand Down Expand Up @@ -85,6 +87,8 @@ func init() {
// themselves with every RPC, otherwise it is easy to accidentally enforce a timeout on
// WaitExecution, for example.
flag.Var((*moreflag.StringMapValue)(&RPCTimeouts), "rpc_timeouts", "Comma-separated key value pairs in the form rpc_name=timeout. The key for default RPC is named default. 0 indicates no timeout. Example: GetActionResult=500ms,Execute=0,default=10s.")

flag.Var((*moreflag.StringListMapValue)(&RemoteHeaders), "remote_headers", "Comma-separated headers to pass with each RPC in the form key=value.")
}

// NewClientFromFlags connects to a remote execution service and returns a client suitable for higher-level
Expand Down Expand Up @@ -152,5 +156,6 @@ func NewClientFromFlags(ctx context.Context, opts ...client.Opt) (*client.Client
TLSClientAuthCert: *TLSClientAuthCert,
TLSClientAuthKey: *TLSClientAuthKey,
MaxConcurrentRequests: uint32(*MaxConcurrentRequests),
RemoteHeaders: RemoteHeaders,
}, opts...)
}
84 changes: 69 additions & 15 deletions go/pkg/moreflag/moreflag.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,16 @@ func (m *StringMapValue) String() string {
// Set updates the map with key and value pair(s) in the format key1=value1,key2=value2.
func (m *StringMapValue) Set(s string) error {
*m = make(map[string]string)
pairs := strings.Split(s, ",")
for _, p := range pairs {
if p == "" {
continue
}
pair := strings.Split(p, "=")
if len(pair) != 2 {
return fmt.Errorf("wrong format for key-value pair: %v", p)
}
if pair[0] == "" {
return fmt.Errorf("key not provided")
}
if _, ok := (*m)[pair[0]]; ok {
return fmt.Errorf("key %v already defined in list of key-value pairs %v", pair[0], s)
pairs, err := parsePairs(s)
if err != nil {
return err
}
for i := 0; i < len(pairs); i += 2 {
k, v := pairs[i], pairs[i+1]
if _, ok := (*m)[k]; ok {
return fmt.Errorf("key %v already defined in list of key-value pairs %v", k, s)
}
(*m)[pair[0]] = pair[1]
(*m)[k] = v
}
return nil
}
Expand Down Expand Up @@ -107,3 +101,63 @@ func (m *StringListValue) Set(s string) error {
func (m *StringListValue) Get() interface{} {
return []string(*m)
}

// StringListMapValue is like StringMapValue, but it allows a key to be used
// with multiple values. The command-line syntax is the same: for example,
// the string key1=a,key1=b,key2=c parses as a map with "key1" having values
// "a" and "b", and "key2" having the value "c".
type StringListMapValue map[string][]string

func (m *StringListMapValue) String() string {
keys := make([]string, 0, len(*m))
for key := range *m {
keys = append(keys, key)
}
sort.Strings(keys)
var b bytes.Buffer
for _, key := range keys {
for _, value := range (*m)[key] {
if b.Len() > 0 {
b.WriteRune(',')
}
b.WriteString(key)
b.WriteRune('=')
b.WriteString(value)
}
}
return b.String()
}

func (m *StringListMapValue) Set(s string) error {
*m = make(map[string][]string)
pairs, err := parsePairs(s)
if err != nil {
return err
}
for i := 0; i < len(pairs); i += 2 {
k, v := pairs[i], pairs[i+1]
(*m)[k] = append((*m)[k], v)
}
return nil
}

// parsePairs parses a string of the form "key1=value1,key2=value2", returning
// a slice with an even number of strings like "key1", "value1", "key2",
// "value2". Pairs are separated by ','; keys and values are separated by '='.
func parsePairs(s string) ([]string, error) {
var pairs []string
for _, p := range strings.Split(s, ",") {
if p == "" {
continue
}
k, v, ok := strings.Cut(p, "=")
if !ok {
return nil, fmt.Errorf("wrong format for key=value pair: %v", p)
}
if k == "" {
return nil, fmt.Errorf("key not provided")
}
pairs = append(pairs, k, v)
}
return pairs, nil
}
Loading
Loading