Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce heap balancer for the gRPC connection pool #589

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions go/pkg/balancer/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ go_library(
name = "balancer",
srcs = [
"roundrobin.go",
"common.go",
"heap.go",
],
importpath = "github.com/bazelbuild/remote-apis-sdks/go/pkg/balancer",
visibility = ["//visibility:public"],
Expand All @@ -12,3 +14,13 @@ go_library(
],
)

go_test(
name = "balancer_test",
srcs = [
"heap_test.go",
],
embed = [":balancer"],
deps = [
"@org_golang_google_grpc//:go_default_library",
],
)
10 changes: 10 additions & 0 deletions go/pkg/balancer/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package balancer

import (
"context"

"google.golang.org/grpc"
)

// DialFunc defines the dial function used in creating the pool.
type DialFunc func(ctx context.Context) (*grpc.ClientConn, error)
118 changes: 118 additions & 0 deletions go/pkg/balancer/heap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package balancer

import (
"container/heap"
"context"
"errors"
"io"
"sync"

"google.golang.org/grpc"
)

// item is a connection in the queue.
type item struct {
conn *grpc.ClientConn
invokeCount int
streamCount int
}

// priorityQueue implements heap.Interface and holds Items.
type priorityQueue []*item

func (pq priorityQueue) Len() int { return len(pq) }

func (pq priorityQueue) Less(i, j int) bool {
// Prioritize spearding streams first.
if pq[i].streamCount < pq[j].streamCount {
return true
}
return pq[i].invokeCount < pq[j].invokeCount
}

func (pq priorityQueue) Swap(i, j int) {
pq[i], pq[j] = pq[j], pq[i]
}

// Not used in this implementation.
func (pq *priorityQueue) Push(x any) {}

// Not used in this implementation.
func (pq *priorityQueue) Pop() any {
return nil
}

func (pq *priorityQueue) peek() *item {
return (*pq)[0]
}

func (pq *priorityQueue) fix() {
heap.Fix(pq, 0)
}

// HeapConnPool is a pool of *grpc.ClientConn that are selected using a min-heap.
// That is, the least used connection is selected.
type HeapConnPool struct {
grpc.ClientConnInterface
io.Closer

pq priorityQueue
mu sync.Mutex
}

// Invoke picks up a connection from the pool and delegates the call to it.
func (p *HeapConnPool) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error {
p.mu.Lock()
defer p.mu.Unlock()
connItem := p.pq.peek()
connItem.invokeCount += 1
p.pq.fix()
return connItem.conn.Invoke(ctx, method, args, reply, opts...)

Check failure on line 70 in go/pkg/balancer/heap.go

View workflow job for this annotation

GitHub Actions / lint

error returned from external package is unwrapped: sig: func (*google.golang.org/grpc.ClientConn).Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...google.golang.org/grpc.CallOption) error (wrapcheck)
}

// NewStream picks up a connection from the pool and delegates the call to it.
func (p *HeapConnPool) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
p.mu.Lock()
defer p.mu.Unlock()
connItem := p.pq.peek()
connItem.streamCount += 1
p.pq.fix()
return connItem.conn.NewStream(ctx, desc, method, opts...)

Check failure on line 80 in go/pkg/balancer/heap.go

View workflow job for this annotation

GitHub Actions / lint

error returned from external package is unwrapped: sig: func (*google.golang.org/grpc.ClientConn).NewStream(ctx context.Context, desc *google.golang.org/grpc.StreamDesc, method string, opts ...google.golang.org/grpc.CallOption) (google.golang.org/grpc.ClientStream, error) (wrapcheck)
}

// Close closes all connections in the bool.
func (p *HeapConnPool) Close() error {
p.mu.Lock()
defer p.mu.Unlock()

var errs error
for i, item := range p.pq {
item.invokeCount = 0
item.streamCount = 0
heap.Fix(&p.pq, i)

if err := item.conn.Close(); err != nil {
errs = errors.Join(errs, err)
}
}
return errs

Check failure on line 98 in go/pkg/balancer/heap.go

View workflow job for this annotation

GitHub Actions / lint

error returned from external package is unwrapped: sig: func errors.Join(errs ...error) error (wrapcheck)
}

// NewHeapConnPool makes a new instance of the min-heap connection pool and dials as many as poolSize connections
// using the provided dialFn.
func NewHeapConnPool(ctx context.Context, poolSize int, dialFn DialFunc) (*HeapConnPool, error) {
pool := &HeapConnPool{}

for i := 0; i < poolSize; i++ {
conn, err := dialFn(ctx)
if err != nil {
defer pool.Close()
return nil, err
}
pool.pq = append(pool.pq, &item{
conn: conn,
})
}
heap.Init(&pool.pq)
return pool, nil
}
48 changes: 48 additions & 0 deletions go/pkg/balancer/heap_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package balancer

import (
"context"
"fmt"
"testing"

"google.golang.org/grpc"
)

func TestHeapConnPool(t *testing.T) {
ctx := context.Background()
dial := func(ctx context.Context) (*grpc.ClientConn, error) {
return grpc.DialContext(ctx, "example.com", grpc.WithInsecure())

Check failure on line 14 in go/pkg/balancer/heap_test.go

View workflow job for this annotation

GitHub Actions / lint

SA1019: grpc.WithInsecure is deprecated: use WithTransportCredentials and insecure.NewCredentials() instead. Will be supported throughout 1.x. (staticcheck)
}
heapPool, err := NewHeapConnPool(ctx, 5, dial)
if err != nil {
t.Fatal(err)
}

if len(heapPool.pq) != 5 {
t.Fatal(fmt.Sprintf("pool size is incorrect, want 5, got %d", len(heapPool.pq)))

Check failure on line 22 in go/pkg/balancer/heap_test.go

View workflow job for this annotation

GitHub Actions / lint

S1038: should use t.Fatalf(...) instead of t.Fatal(fmt.Sprintf(...)) (gosimple)
}

// When used for Invoke (or NewStream) only, it should act as round-robin.
for i := 0; i < 5; i++ {
if c := heapPool.pq.peek().invokeCount; c > 0 {
t.Errorf(fmt.Sprintf("invokeCount for item #%d should be 0, not %d", i, c))
}
_ = heapPool.Invoke(ctx, "foo", nil, nil)
}
if c := heapPool.pq.peek().invokeCount; c != 1 {
t.Errorf(fmt.Sprintf("invokeCount should be 1, not %d", c))

}

// With mixed-use, NewStream calls ignore invokeCount to prioritize spreading streams across the pool.
for i := 0; i < 5; i++ {
if c := heapPool.pq.peek().streamCount; c > 0 {
t.Errorf(fmt.Sprintf("streamCount for item #%d should be 0, not %d", i, c))
}
_, _ = heapPool.NewStream(ctx, nil, "foo")
}
if c := heapPool.pq.peek().streamCount; c != 1 {
t.Errorf(fmt.Sprintf("streamCount should be 1, not %d", c))

}
}
11 changes: 4 additions & 7 deletions go/pkg/balancer/roundrobin.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
idx uint32 // access via sync/atomic
}

// Conn picks the next connection from the pool in a round-robin fasion.
func (p *RRConnPool) Conn() *grpc.ClientConn {
// conn picks the next connection from the pool in a round-robin fasion.
func (p *RRConnPool) conn() *grpc.ClientConn {
i := atomic.AddUint32(&p.idx, 1)
return p.conns[i%uint32(len(p.conns))]
}
Expand All @@ -40,17 +40,14 @@

// Invoke picks up a connection from the pool and delegates the call to it.
func (p *RRConnPool) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error {
return p.Conn().Invoke(ctx, method, args, reply, opts...)
return p.conn().Invoke(ctx, method, args, reply, opts...)

Check failure on line 43 in go/pkg/balancer/roundrobin.go

View workflow job for this annotation

GitHub Actions / lint

error returned from external package is unwrapped: sig: func (*google.golang.org/grpc.ClientConn).Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...google.golang.org/grpc.CallOption) error (wrapcheck)
}

// NewStream picks up a connection from the pool and delegates the call to it.
func (p *RRConnPool) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
return p.Conn().NewStream(ctx, desc, method, opts...)
return p.conn().NewStream(ctx, desc, method, opts...)

Check failure on line 48 in go/pkg/balancer/roundrobin.go

View workflow job for this annotation

GitHub Actions / lint

error returned from external package is unwrapped: sig: func (*google.golang.org/grpc.ClientConn).NewStream(ctx context.Context, desc *google.golang.org/grpc.StreamDesc, method string, opts ...google.golang.org/grpc.CallOption) (google.golang.org/grpc.ClientStream, error) (wrapcheck)
}

// DialFunc defines the dial function used in creating the pool.
type DialFunc func(ctx context.Context) (*grpc.ClientConn, error)

// NewRRConnPool makes a new instance of the round-robin connection pool and dials as many as poolSize connections
// using the provided dialFn.
func NewRRConnPool(ctx context.Context, poolSize int, dialFn DialFunc) (*RRConnPool, error) {
Expand Down
37 changes: 32 additions & 5 deletions go/pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,9 @@ const DefaultCASConcurrency = 500
// that the GRPC balancer can perform.
const DefaultMaxConcurrentRequests = 25

// DefaultGrpcClientBalancer specifies the default balancer implementation.
const DefaultGRPCClientBalancer = "roundrobin"

// Apply sets the CASConcurrency flag on a client.
func (cy CASConcurrency) Apply(c *Client) {
c.casConcurrency = int64(cy)
Expand Down Expand Up @@ -542,6 +545,9 @@ type DialParams struct {
// MaxConcurrentRequests specifies the maximum number of concurrent RPCs on a single connection.
MaxConcurrentRequests uint32

// GRPCClientBalancer specifies the balancer implementation to use.
GRPCClientBalancer string

// TLSClientAuthCert specifies the public key in PEM format for using mTLS auth to connect to the RBE service.
//
// If this is specified, TLSClientAuthKey must also be specified.
Expand Down Expand Up @@ -666,29 +672,50 @@ func NewClient(ctx context.Context, instanceName string, params DialParams, opts
if params.MaxConcurrentRequests == 0 {
params.MaxConcurrentRequests = DefaultMaxConcurrentRequests
}
if params.GRPCClientBalancer == "" {
params.GRPCClientBalancer = DefaultGRPCClientBalancer
}
log.Infof("Connecting to remote execution instance %s", instanceName)
log.Infof("Connecting to remote execution service %s", params.Service)
dialOpts, authUsed, err := OptsFromParams(ctx, params)
if err != nil {
return nil, fmt.Errorf("failed to prepare gRPC dial options: %v", err)
}

var conn, casConn GrpcClientConn
newConn := func(dial balancer.DialFunc) (GrpcClientConn, error) {
switch params.GRPCClientBalancer {
case "roundrobin":
conn, err := balancer.NewRRConnPool(ctx, int(params.MaxConcurrentRequests), dial)
if err != nil {
return nil, fmt.Errorf("couldn't dial gRPC %q: %v", params.Service, err)
}
return conn, nil
case "heap":
conn, err := balancer.NewHeapConnPool(ctx, int(params.MaxConcurrentRequests), dial)
if err != nil {
return nil, fmt.Errorf("couldn't dial gRPC %q: %v", params.Service, err)
}
return conn, nil
default:
return nil, fmt.Errorf("unknown gRPC client balancer implementation: %q", params.GRPCClientBalancer)
}
}

dial := func(ctx context.Context) (*grpc.ClientConn, error) {
return grpc.DialContext(ctx, params.Service, dialOpts...)
}
conn, err = balancer.NewRRConnPool(ctx, int(params.MaxConcurrentRequests), dial)
conn, err := newConn(dial)
if err != nil {
return nil, fmt.Errorf("couldn't dial gRPC %q: %v", params.Service, err)
return nil, &InitError{Err: statusWrap(err), AuthUsed: authUsed}
}

casConn = conn
casConn := conn
if params.CASService != "" && params.CASService != params.Service {
log.Infof("Connecting to CAS service %s", params.CASService)
dial := func(ctx context.Context) (*grpc.ClientConn, error) {
return grpc.DialContext(ctx, params.CASService, dialOpts...)
}
casConn, err = balancer.NewRRConnPool(ctx, int(params.MaxConcurrentRequests), dial)
casConn, err = newConn(dial)
}
if err != nil {
return nil, &InitError{Err: statusWrap(err), AuthUsed: authUsed}
Expand Down
5 changes: 4 additions & 1 deletion go/pkg/flags/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ var (
Instance = flag.String("instance", "", "The instance ID to target when calling remote execution via gRPC (e.g., projects/$PROJECT/instances/default_instance for Google RBE).")
// CASConcurrency specifies the maximum number of concurrent upload & download RPCs that can be in flight.
CASConcurrency = flag.Int("cas_concurrency", client.DefaultCASConcurrency, "Num concurrent upload / download RPCs that the SDK is allowed to do.")
// MaxConcurrentRequests denotes the maximum number of concurrent RPCs on a single gRPC connection.
// MaxConcurrentRequests specifies the maximum number of concurrent RPCs on a single gRPC connection.
MaxConcurrentRequests = flag.Uint("max_concurrent_requests_per_conn", client.DefaultMaxConcurrentRequests, "Maximum number of concurrent RPCs on a single gRPC connection.")
// GRPCClientBalancer denotes the balancer implementation to use on the client side..
GRPCClientBalancer = flag.String("grpc_client_balancer", client.DefaultGRPCClientBalancer, "gRPC client balancer implementation to use. Possible options are [roundrobin, heap]. Defaults to roundrobin.")
// TLSServerName overrides the server name sent in the TLS session.
TLSServerName = flag.String("tls_server_name", "", "Override the TLS server name")
// TLSCACert loads CA certificates from a file
Expand Down Expand Up @@ -154,5 +156,6 @@ func NewClientFromFlags(ctx context.Context, opts ...client.Opt) (*client.Client
TLSClientAuthCert: *TLSClientAuthCert,
TLSClientAuthKey: *TLSClientAuthKey,
MaxConcurrentRequests: uint32(*MaxConcurrentRequests),
GRPCClientBalancer: *GRPCClientBalancer,
}, opts...)
}
Loading