diff --git a/go.mod b/go.mod index 064ae841202..679d8cb1458 100644 --- a/go.mod +++ b/go.mod @@ -109,6 +109,7 @@ require ( go.uber.org/goleak v1.2.1 golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 golang.org/x/sync v0.3.0 + gonum.org/v1/gonum v0.14.0 modernc.org/sqlite v1.20.3 ) diff --git a/go.sum b/go.sum index 8ad13f84e5e..f0b6cc35a35 100644 --- a/go.sum +++ b/go.sum @@ -944,6 +944,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 h1:H2TDz8ibqkAF6YGhCdN3jS9O0/s90v0rJh3X/OLHEUk= golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= +gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0= +gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= diff --git a/go/atomic2/atomic128.go b/go/atomic2/atomic128.go new file mode 100644 index 00000000000..f2a44ad1643 --- /dev/null +++ b/go/atomic2/atomic128.go @@ -0,0 +1,63 @@ +//go:build amd64 || arm64 + +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package atomic2 + +import ( + "unsafe" +) + +//go:linkname writeBarrier runtime.writeBarrier +var writeBarrier struct { + enabled bool // compiler emits a check of this before calling write barrier + pad [3]byte // compiler uses 32-bit load for "enabled" field + needed bool // identical to enabled, for now (TODO: dedup) + alignme uint64 // guarantee alignment so that compiler can use a 32 or 64-bit load +} + +//go:linkname atomicwb runtime.atomicwb +//go:nosplit +func atomicwb(ptr *unsafe.Pointer, new unsafe.Pointer) + +type PointerAndUint64[T any] struct { + p unsafe.Pointer + u uint64 +} + +//go:nosplit +func loadUint128_(addr *unsafe.Pointer) (pp unsafe.Pointer, uu uint64) + +func (x *PointerAndUint64[T]) Load() (*T, uint64) { + p, u := loadUint128_(&x.p) + return (*T)(p), u +} + +//go:nosplit +func compareAndSwapUint128_(addr *unsafe.Pointer, oldp unsafe.Pointer, oldu uint64, newp unsafe.Pointer, newu uint64) (swapped bool) + +//go:nosplit +func compareAndSwapUint128(addr *unsafe.Pointer, oldp unsafe.Pointer, oldu uint64, newp unsafe.Pointer, newu uint64) bool { + if writeBarrier.enabled { + atomicwb(addr, newp) + } + return compareAndSwapUint128_(addr, oldp, oldu, newp, newu) +} + +func (x *PointerAndUint64[T]) CompareAndSwap(oldp *T, oldu uint64, newp *T, newu uint64) bool { + return compareAndSwapUint128(&x.p, unsafe.Pointer(oldp), oldu, unsafe.Pointer(newp), newu) +} diff --git a/go/atomic2/atomic128_amd64.s b/go/atomic2/atomic128_amd64.s new file mode 100644 index 00000000000..99931032dc0 --- /dev/null +++ b/go/atomic2/atomic128_amd64.s @@ -0,0 +1,46 @@ +// Copyright 2023 The Vitess Authors. +// Copyright (c) 2021, Carlo Alberto Ferraris +// Copyright (c) 2017, Tom Thorogood +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Use of this source code is governed by a +// Modified BSD License that can be found in +// the LICENSE file. + +//+build !noasm,!appengine + +#include "textflag.h" + +TEXT ·compareAndSwapUint128_(SB), NOSPLIT, $0-41 + MOVQ addr+0(FP), R8 + MOVQ oldp+8(FP), AX + MOVQ oldu+16(FP), DX + MOVQ newp+24(FP), BX + MOVQ newu+32(FP), CX + LOCK + CMPXCHG16B (R8) + SETEQ swapped+40(FP) + RET + +TEXT ·loadUint128_(SB), NOSPLIT, $0-24 + MOVQ addr+0(FP), R8 + XORQ AX, AX + XORQ DX, DX + XORQ BX, BX + XORQ CX, CX + LOCK + CMPXCHG16B (R8) + MOVQ AX, pp+8(FP) + MOVQ DX, uu+16(FP) + RET diff --git a/go/atomic2/atomic128_arm64.s b/go/atomic2/atomic128_arm64.s new file mode 100644 index 00000000000..96f91010707 --- /dev/null +++ b/go/atomic2/atomic128_arm64.s @@ -0,0 +1,39 @@ +// Copyright 2023 The Vitess Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//+build !noasm,!appengine + +#include "textflag.h" + +TEXT ·compareAndSwapUint128_(SB), NOSPLIT, $0-41 + MOVD addr+0(FP), R5 + MOVD oldp+8(FP), R0 + MOVD oldu+16(FP), R1 + MOVD newp+24(FP), R2 + MOVD newu+32(FP), R3 + MOVD R0, R6 + MOVD R1, R7 + CASPD (R0, R1), (R5), (R2, R3) + CMP R0, R6 + CCMP EQ, R1, R7, $0 + CSET EQ, R0 + MOVB R0, ret+40(FP) + RET + +TEXT ·loadUint128_(SB), NOSPLIT, $0-24 + MOVD addr+0(FP), R3 + LDAXP (R3), (R0, R1) + MOVD R0, val+8(FP) + MOVD R1, val+16(FP) + RET diff --git a/go/atomic2/atomic128_spinlock.go b/go/atomic2/atomic128_spinlock.go new file mode 100644 index 00000000000..deefc118564 --- /dev/null +++ b/go/atomic2/atomic128_spinlock.go @@ -0,0 +1,61 @@ +//go:build !amd64 && !arm64 + +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package atomic2 + +import ( + "runtime" + "sync/atomic" +) + +type PointerAndUint64[T any] struct { + spin atomic.Uint64 + p *T + u uint64 +} + +func (x *PointerAndUint64[T]) Store(p *T, u uint64) { + for !x.spin.CompareAndSwap(0, 1) { + runtime.Gosched() + } + defer x.spin.Store(0) + x.p = p + x.u = u +} + +func (x *PointerAndUint64[T]) Load() (*T, uint64) { + for !x.spin.CompareAndSwap(0, 1) { + runtime.Gosched() + } + defer x.spin.Store(0) + return x.p, x.u +} + +func (x *PointerAndUint64[T]) CompareAndSwap(oldp *T, oldu uint64, newp *T, newu uint64) bool { + for !x.spin.CompareAndSwap(0, 1) { + runtime.Gosched() + } + defer x.spin.Store(0) + + if x.p == oldp && x.u == oldu { + x.p = newp + x.u = newu + return true + } + return false +} diff --git a/go/atomic2/atomic128_test.go b/go/atomic2/atomic128_test.go new file mode 100644 index 00000000000..499514f688a --- /dev/null +++ b/go/atomic2/atomic128_test.go @@ -0,0 +1,44 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package atomic2 + +import ( + "testing" + "unsafe" + + "github.com/stretchr/testify/require" +) + +func TestCompareAndSwap(t *testing.T) { + i1 := new(int) + i2 := new(int) + n := &PointerAndUint64[int]{p: unsafe.Pointer(i1), u: 12345} + + ok := n.CompareAndSwap(i1, 12345, i2, 67890) + require.Truef(t, ok, "unexpected CAS failure") + + pp, uu := n.Load() + require.Equal(t, i2, pp) + require.Equal(t, uint64(67890), uu) + + ok = n.CompareAndSwap(i1, 12345, nil, 0) + require.Falsef(t, ok, "unexpected CAS success") + + pp, uu = n.Load() + require.Equal(t, pp, i2) + require.Equal(t, uu, uint64(67890)) +} diff --git a/go/hack/runtime.go b/go/hack/runtime.go index c80ac1d38e5..724a6c34f8d 100644 --- a/go/hack/runtime.go +++ b/go/hack/runtime.go @@ -54,3 +54,6 @@ func RuntimeAllocSize(size int64) int64 { //go:linkname ParseFloatPrefix strconv.parseFloatPrefix func ParseFloatPrefix(s string, bitSize int) (float64, int, error) + +//go:linkname FastRand runtime.fastrand +func FastRand() uint32 diff --git a/go/list/list.go b/go/list/list.go new file mode 100644 index 00000000000..2ad837b7c64 --- /dev/null +++ b/go/list/list.go @@ -0,0 +1,161 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2009 The Go Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package list is the standard library's 'container/list', but using Generics +// for performance. +package list + +import "sync/atomic" + +// Element is an element of a linked list. +type Element[T any] struct { + // Next and previous pointers in the doubly-linked list of elements. + // To simplify the implementation, internally a list l is implemented + // as a ring, such that &l.root is both the next element of the last + // list element (l.Back()) and the previous element of the first list + // element (l.Front()). + next, prev *Element[T] + + // The list to which this element belongs. + list *List[T] + + // The value stored with this element. + Value T +} + +// Next returns the next list element or nil. +func (e *Element[T]) Next() *Element[T] { + if p := e.next; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// Prev returns the previous list element or nil. +func (e *Element[T]) Prev() *Element[T] { + if p := e.prev; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// List represents a doubly linked list. +// The zero value for List is an empty list ready to use. +type List[T any] struct { + root Element[T] // sentinel list element, only &root, root.prev, and root.next are used + len atomic.Int64 +} + +// Init initializes or clears list l. +func (l *List[T]) Init() *List[T] { + l.root.next = &l.root + l.root.prev = &l.root + return l +} + +func (l *List[T]) Len() int { + return int(l.len.Load()) +} + +// New returns an initialized list. +func New[T any]() *List[T] { return new(List[T]).Init() } + +// Front returns the first element of list l or nil if the list is empty. +func (l *List[T]) Front() *Element[T] { + if l.len.Load() == 0 { + return nil + } + return l.root.next +} + +// Back returns the last element of list l or nil if the list is empty. +func (l *List[T]) Back() *Element[T] { + if l.len.Load() == 0 { + return nil + } + return l.root.prev +} + +// insert inserts e after at, increments l.len, and returns e. +func (l *List[T]) insert(e, at *Element[T]) *Element[T] { + e.prev = at + e.next = at.next + e.prev.next = e + e.next.prev = e + e.list = l + l.len.Add(1) + return e +} + +// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). +func (l *List[T]) insertValue(v T, at *Element[T]) *Element[T] { + return l.insert(&Element[T]{Value: v}, at) +} + +// remove removes e from its list, decrements l.len +func (l *List[T]) remove(e *Element[T]) { + e.prev.next = e.next + e.next.prev = e.prev + e.next = nil // avoid memory leaks + e.prev = nil // avoid memory leaks + e.list = nil + l.len.Add(-1) +} + +// move moves e to next to at. +func (l *List[T]) move(e, at *Element[T]) { + if e == at { + return + } + e.prev.next = e.next + e.next.prev = e.prev + + e.prev = at + e.next = at.next + e.prev.next = e + e.next.prev = e +} + +// Remove removes e from l if e is an element of list l. +// It returns the element value e.Value. +// The element must not be nil. +func (l *List[T]) Remove(e *Element[T]) { + if e.list != l { + panic("removing from wrong List") + } + // if e.list == l, l must have been initialized when e was inserted + // in l or l == nil (e is a zero Element) and l.remove will crash + l.remove(e) +} + +// PushFront inserts a new element e with value v at the front of list l and returns e. +func (l *List[T]) PushFront(v T) *Element[T] { + return l.insertValue(v, &l.root) +} + +func (l *List[T]) PushFrontValue(v *Element[T]) { + l.insert(v, &l.root) +} + +// PushBack inserts a new element e with value v at the back of list l and returns e. +func (l *List[T]) PushBack(v T) *Element[T] { + return l.insertValue(v, l.root.prev) +} + +func (l *List[T]) PushBackValue(v *Element[T]) { + l.insert(v, l.root.prev) +} diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 6f3643ebc7f..b1159f11c31 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -28,6 +28,7 @@ import ( "strings" "sync" "sync/atomic" + "syscall" "time" "vitess.io/vitess/go/bucketpool" @@ -1707,6 +1708,44 @@ func (c *Conn) IsMarkedForClose() bool { return c.closing } +// ConnCheck ensures that this connection to the MySQL server hasn't been broken. +// This is a fast, non-blocking check. For details on its implementation, please read +// "Three Bugs in the Go MySQL Driver" (Vicent Marti, GitHub, 2020) +// https://github.blog/2020-05-20-three-bugs-in-the-go-mysql-driver/ +func (c *Conn) ConnCheck() error { + conn := c.conn + if tlsconn, ok := conn.(*tls.Conn); ok { + conn = tlsconn.NetConn() + } + if conn, ok := conn.(syscall.Conn); ok { + rc, err := conn.SyscallConn() + if err != nil { + return err + } + + var n int + var buff [1]byte + rerr := rc.Read(func(fd uintptr) bool { + n, err = syscall.Read(int(fd), buff[:]) + return true + }) + + switch { + case rerr != nil: + return rerr + case n == 0 && err == nil: + return io.EOF + case n > 0: + return sqlerror.NewSQLError(sqlerror.CRUnknownError, sqlerror.SSUnknownSQLState, "unexpected read from conn") + case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK: + return nil + default: + return err + } + } + return nil +} + // GetTestConn returns a conn for testing purpose only. func GetTestConn() *Conn { return newConn(testConn{}) diff --git a/go/pools/resource_pool.go b/go/pools/resource_pool.go index f049b667fe6..939b73fa66c 100644 --- a/go/pools/resource_pool.go +++ b/go/pools/resource_pool.go @@ -35,41 +35,12 @@ import ( ) type ( - IResourcePool interface { - Close() - Name() string - Get(ctx context.Context, setting *Setting) (resource Resource, err error) - Put(resource Resource) - SetCapacity(capacity int) error - SetIdleTimeout(idleTimeout time.Duration) - StatsJSON() string - Capacity() int64 - Available() int64 - Active() int64 - InUse() int64 - MaxCap() int64 - WaitCount() int64 - WaitTime() time.Duration - IdleTimeout() time.Duration - IdleClosed() int64 - MaxLifetimeClosed() int64 - Exhausted() int64 - GetCount() int64 - GetSettingCount() int64 - DiffSettingCount() int64 - ResetSettingCount() int64 - } - // Resource defines the interface that every resource must provide. // Thread synchronization between Close() and IsClosed() // is the responsibility of the caller. Resource interface { - Close() Expired(time.Duration) bool - ApplySetting(ctx context.Context, setting *Setting) error - IsSettingApplied() bool - IsSameSetting(setting string) bool - ResetSetting(ctx context.Context) error + Close() } // Factory is a function that can be used to create a resource. @@ -80,12 +51,6 @@ type ( timeUsed time.Time } - // Setting represents a set query and reset query for system settings. - Setting struct { - query string - resetQuery string - } - // ResourcePool allows you to use a pool of resources. ResourcePool struct { available atomic.Int64 @@ -106,11 +71,7 @@ type ( idleTimer *timer.Timer logWait func(time.Time) - settingResources chan resourceWrapper - getCount atomic.Int64 - getSettingCount atomic.Int64 - diffSettingCount atomic.Int64 - resetSettingCount atomic.Int64 + getCount atomic.Int64 reopenMutex sync.Mutex refresh *poolRefresh @@ -128,21 +89,6 @@ var ( ErrCtxTimeout = vterrors.New(vtrpcpb.Code_DEADLINE_EXCEEDED, "resource pool context already expired") ) -func NewSetting(query, resetQuery string) *Setting { - return &Setting{ - query: query, - resetQuery: resetQuery, - } -} - -func (s *Setting) GetQuery() string { - return s.query -} - -func (s *Setting) GetResetQuery() string { - return s.resetQuery -} - // NewResourcePool creates a new ResourcePool pool. // capacity is the number of possible resources in the pool: // there can be up to 'capacity' of these at a given time. @@ -162,10 +108,9 @@ func NewResourcePool(factory Factory, capacity, maxCap int, idleTimeout time.Dur panic(errors.New("invalid/out of range capacity")) } rp := &ResourcePool{ - resources: make(chan resourceWrapper, maxCap), - settingResources: make(chan resourceWrapper, maxCap), - factory: factory, - logWait: logWait, + resources: make(chan resourceWrapper, maxCap), + factory: factory, + logWait: logWait, } rp.available.Store(int64(capacity)) rp.capacity.Store(int64(capacity)) @@ -210,33 +155,19 @@ func (rp *ResourcePool) closeIdleResources() { for i := 0; i < available; i++ { var wrapper resourceWrapper - var origPool bool select { case wrapper = <-rp.resources: - origPool = true - case wrapper = <-rp.settingResources: - origPool = false default: // stop early if we don't get anything new from the pool return } - var reopened bool if wrapper.resource != nil && idleTimeout > 0 && time.Until(wrapper.timeUsed.Add(idleTimeout)) < 0 { wrapper.resource.Close() rp.idleClosed.Add(1) rp.reopenResource(&wrapper) - reopened = true } - rp.returnResource(&wrapper, origPool, reopened) - } -} - -func (rp *ResourcePool) returnResource(wrapper *resourceWrapper, origPool bool, reopened bool) { - if origPool || reopened { - rp.resources <- *wrapper - } else { - rp.settingResources <- *wrapper + rp.resources <- wrapper } } @@ -258,15 +189,12 @@ func (rp *ResourcePool) reopen() { // has not been reached, it will create a new one using the factory. Otherwise, // it will wait till the next resource becomes available or a timeout. // A timeout of 0 is an indefinite wait. -func (rp *ResourcePool) Get(ctx context.Context, setting *Setting) (resource Resource, err error) { +func (rp *ResourcePool) Get(ctx context.Context) (resource Resource, err error) { // If ctx has already expired, avoid racing with rp's resource channel. if ctx.Err() != nil { return nil, ErrCtxTimeout } - if setting == nil { - return rp.get(ctx) - } - return rp.getWithSettings(ctx, setting) + return rp.get(ctx) } func (rp *ResourcePool) get(ctx context.Context) (resource Resource, err error) { @@ -281,95 +209,19 @@ func (rp *ResourcePool) get(ctx context.Context) (resource Resource, err error) // check normal resources first case wrapper, ok = <-rp.resources: default: + // now waiting + startTime := time.Now() select { - // then checking setting resources - case wrapper, ok = <-rp.settingResources: - default: - // now waiting - startTime := time.Now() - select { - case wrapper, ok = <-rp.resources: - case wrapper, ok = <-rp.settingResources: - case <-ctx.Done(): - return nil, ErrTimeout - } - rp.recordWait(startTime) - } - } - if !ok { - return nil, ErrClosed - } - - // if the resource has setting applied, we will close it and return a new one - if wrapper.resource != nil && wrapper.resource.IsSettingApplied() { - rp.resetSettingCount.Add(1) - err = wrapper.resource.ResetSetting(ctx) - if err != nil { - // as reset is unsuccessful, we will close this resource - wrapper.resource.Close() - wrapper.resource = nil - rp.active.Add(-1) - } - } - - // Unwrap - if wrapper.resource == nil { - wrapper.resource, err = rp.factory(ctx) - if err != nil { - rp.resources <- resourceWrapper{} - return nil, err - } - rp.active.Add(1) - } - if rp.available.Add(-1) <= 0 { - rp.exhausted.Add(1) - } - rp.inUse.Add(1) - return wrapper.resource, err -} - -func (rp *ResourcePool) getWithSettings(ctx context.Context, setting *Setting) (Resource, error) { - rp.getSettingCount.Add(1) - var wrapper resourceWrapper - var ok bool - var err error - - // Fetch - select { - // check setting resources first - case wrapper, ok = <-rp.settingResources: - default: - select { - // then, check normal resources case wrapper, ok = <-rp.resources: - default: - // now waiting - startTime := time.Now() - select { - case wrapper, ok = <-rp.settingResources: - case wrapper, ok = <-rp.resources: - case <-ctx.Done(): - return nil, ErrTimeout - } - rp.recordWait(startTime) + case <-ctx.Done(): + return nil, ErrTimeout } + rp.recordWait(startTime) } if !ok { return nil, ErrClosed } - // Checking setting hash id, if it is different, we will close the resource and return a new one later in unwrap - if wrapper.resource != nil && wrapper.resource.IsSettingApplied() && !wrapper.resource.IsSameSetting(setting.query) { - rp.diffSettingCount.Add(1) - err = wrapper.resource.ResetSetting(ctx) - if err != nil { - // as reset is unsuccessful, we will close this resource - wrapper.resource.Close() - wrapper.resource = nil - rp.active.Add(-1) - } - } - // Unwrap if wrapper.resource == nil { wrapper.resource, err = rp.factory(ctx) @@ -379,16 +231,6 @@ func (rp *ResourcePool) getWithSettings(ctx context.Context, setting *Setting) ( } rp.active.Add(1) } - - if !wrapper.resource.IsSettingApplied() { - if err = wrapper.resource.ApplySetting(ctx, setting); err != nil { - // as we are not able to apply setting, we can return this connection to non-setting channel. - // TODO: may check the error code to see if it is recoverable or not. - rp.resources <- wrapper - return nil, err - } - } - if rp.available.Add(-1) <= 0 { rp.exhausted.Add(1) } @@ -402,14 +244,11 @@ func (rp *ResourcePool) getWithSettings(ctx context.Context, setting *Setting) ( // This will cause a new resource to be created in its place. func (rp *ResourcePool) Put(resource Resource) { var wrapper resourceWrapper - var recreated bool - var hasSettings bool if resource != nil { wrapper = resourceWrapper{ resource: resource, timeUsed: time.Now(), } - hasSettings = resource.IsSettingApplied() if resource.Expired(rp.extendedMaxLifetime()) { rp.maxLifetimeClosed.Add(1) resource.Close() @@ -419,20 +258,11 @@ func (rp *ResourcePool) Put(resource Resource) { if resource == nil { // Create new resource rp.reopenResource(&wrapper) - recreated = true } - if !hasSettings || recreated { - select { - case rp.resources <- wrapper: - default: - panic(errors.New("attempt to Put into a full ResourcePool")) - } - } else { - select { - case rp.settingResources <- wrapper: - default: - panic(errors.New("attempt to Put into a full ResourcePool")) - } + select { + case rp.resources <- wrapper: + default: + panic(errors.New("attempt to Put into a full ResourcePool")) } rp.inUse.Add(-1) rp.available.Add(1) @@ -466,7 +296,6 @@ func (rp *ResourcePool) SetCapacity(capacity int) error { if oldcap == 0 && capacity > 0 { // Closed this before, re-open the channel rp.resources = make(chan resourceWrapper, cap(rp.resources)) - rp.settingResources = make(chan resourceWrapper, cap(rp.settingResources)) } if oldcap == capacity { return nil @@ -483,11 +312,7 @@ func (rp *ResourcePool) SetCapacity(capacity int) error { // then we just add empty resource to the channel. if capacity < oldcap { for i := 0; i < oldcap-capacity; i++ { - var wrapper resourceWrapper - select { - case wrapper = <-rp.resources: - case wrapper = <-rp.settingResources: - } + wrapper := <-rp.resources if wrapper.resource != nil { wrapper.resource.Close() rp.active.Add(-1) @@ -502,7 +327,6 @@ func (rp *ResourcePool) SetCapacity(capacity int) error { } if capacity == 0 { close(rp.resources) - close(rp.settingResources) } return nil } @@ -612,18 +436,3 @@ func (rp *ResourcePool) Exhausted() int64 { func (rp *ResourcePool) GetCount() int64 { return rp.getCount.Load() } - -// GetSettingCount returns the number of times getWithSettings was called -func (rp *ResourcePool) GetSettingCount() int64 { - return rp.getSettingCount.Load() -} - -// DiffSettingCount returns the number of times different setting were applied on the resource. -func (rp *ResourcePool) DiffSettingCount() int64 { - return rp.diffSettingCount.Load() -} - -// ResetSettingCount returns the number of times setting were reset on the resource. -func (rp *ResourcePool) ResetSettingCount() int64 { - return rp.resetSettingCount.Load() -} diff --git a/go/pools/resource_pool_test.go b/go/pools/resource_pool_test.go index 886fab34751..933476361d8 100644 --- a/go/pools/resource_pool_test.go +++ b/go/pools/resource_pool_test.go @@ -19,7 +19,6 @@ package pools import ( "context" "errors" - "fmt" "sync/atomic" "testing" "time" @@ -30,42 +29,13 @@ import ( var ( lastID, count, closeCount, resetCount atomic.Int64 - - waitStarts []time.Time - - sFoo = &Setting{query: "set foo=1"} - sBar = &Setting{query: "set bar=1"} - sFooBar = &Setting{query: "set foo=1, bar=2"} + waitStarts []time.Time ) type TestResource struct { num int64 timeCreated time.Time closed bool - setting string - failApply bool -} - -func (tr *TestResource) ResetSetting(ctx context.Context) error { - resetCount.Add(1) - tr.setting = "" - return nil -} - -func (tr *TestResource) ApplySetting(ctx context.Context, setting *Setting) error { - if tr.failApply { - return fmt.Errorf("ApplySetting failed") - } - tr.setting = setting.query - return nil -} - -func (tr *TestResource) IsSettingApplied() bool { - return len(tr.setting) > 0 -} - -func (tr *TestResource) IsSameSetting(setting string) bool { - return tr.setting == setting } func (tr *TestResource) Close() { @@ -100,11 +70,6 @@ func SlowFailFactory(context.Context) (Resource, error) { return nil, errors.New("Failed") } -func DisallowSettingsFactory(context.Context) (Resource, error) { - count.Add(1) - return &TestResource{num: lastID.Add(1), failApply: true}, nil -} - func TestOpen(t *testing.T) { ctx := context.Background() lastID.Store(0) @@ -119,11 +84,7 @@ func TestOpen(t *testing.T) { // Test Get for i := 0; i < 5; i++ { - if i%2 == 0 { - r, err = p.Get(ctx, nil) - } else { - r, err = p.Get(ctx, sFoo) - } + r, err = p.Get(ctx) require.NoError(t, err) resources[i] = r assert.EqualValues(t, 5-i-1, p.Available()) @@ -138,11 +99,7 @@ func TestOpen(t *testing.T) { ch := make(chan bool) go func() { for i := 0; i < 5; i++ { - if i%2 == 0 { - r, err = p.Get(ctx, nil) - } else { - r, err = p.Get(ctx, sFoo) - } + r, err = p.Get(ctx) require.NoError(t, err) resources[i] = r } @@ -168,7 +125,7 @@ func TestOpen(t *testing.T) { assert.NotZero(t, p.WaitTime()) assert.EqualValues(t, 5, lastID.Load()) // Test Close resource - r, err = p.Get(ctx, nil) + r, err = p.Get(ctx) require.NoError(t, err) r.Close() // A nil Put should cause the resource to be reopened. @@ -177,11 +134,7 @@ func TestOpen(t *testing.T) { assert.EqualValues(t, 6, lastID.Load()) for i := 0; i < 5; i++ { - if i%2 == 0 { - r, err = p.Get(ctx, nil) - } else { - r, err = p.Get(ctx, sFoo) - } + r, err = p.Get(ctx) require.NoError(t, err) resources[i] = r } @@ -203,11 +156,7 @@ func TestOpen(t *testing.T) { assert.EqualValues(t, 6, p.Available()) for i := 0; i < 6; i++ { - if i%2 == 0 { - r, err = p.Get(ctx, nil) - } else { - r, err = p.Get(ctx, sFoo) - } + r, err = p.Get(ctx) require.NoError(t, err) resources[i] = r } @@ -236,11 +185,7 @@ func TestShrinking(t *testing.T) { for i := 0; i < 4; i++ { var r Resource var err error - if i%2 == 0 { - r, err = p.Get(ctx, nil) - } else { - r, err = p.Get(ctx, sFoo) - } + r, err = p.Get(ctx) require.NoError(t, err) resources[i] = r } @@ -277,17 +222,13 @@ func TestShrinking(t *testing.T) { var err error for i := 0; i < 3; i++ { var r Resource - if i%2 == 0 { - r, err = p.Get(ctx, nil) - } else { - r, err = p.Get(ctx, sFoo) - } + r, err = p.Get(ctx) require.NoError(t, err) resources[i] = r } // This will wait because pool is empty go func() { - r, err := p.Get(ctx, nil) + r, err := p.Get(ctx) require.NoError(t, err) p.Put(r) done <- true @@ -316,18 +257,13 @@ func TestShrinking(t *testing.T) { p.SetCapacity(3) for i := 0; i < 3; i++ { var r Resource - var err error - if i%2 == 0 { - r, err = p.Get(ctx, nil) - } else { - r, err = p.Get(ctx, sFoo) - } + r, err = p.Get(ctx) require.NoError(t, err) resources[i] = r } // This will wait because pool is empty go func() { - r, err := p.Get(ctx, nil) + r, err := p.Get(ctx) require.NoError(t, err) p.Put(r) done <- true @@ -368,11 +304,7 @@ func TestClosing(t *testing.T) { for i := 0; i < 5; i++ { var r Resource var err error - if i%2 == 0 { - r, err = p.Get(ctx, nil) - } else { - r, err = p.Get(ctx, sFoo) - } + r, err = p.Get(ctx) require.NoError(t, err) resources[i] = r } @@ -415,11 +347,7 @@ func TestReopen(t *testing.T) { for i := 0; i < 5; i++ { var r Resource var err error - if i%2 == 0 { - r, err = p.Get(ctx, nil) - } else { - r, err = p.Get(ctx, sFoo) - } + r, err = p.Get(ctx) require.NoError(t, err) resources[i] = r } @@ -448,7 +376,7 @@ func TestIdleTimeout(t *testing.T) { p := NewResourcePool(PoolFactory, 1, 1, 10*time.Millisecond, 0, logWait, nil, 0) defer p.Close() - r, err := p.Get(ctx, nil) + r, err := p.Get(ctx) require.NoError(t, err) assert.EqualValues(t, 1, count.Load()) assert.EqualValues(t, 0, p.IdleClosed()) @@ -462,7 +390,7 @@ func TestIdleTimeout(t *testing.T) { assert.EqualValues(t, 1, count.Load()) assert.EqualValues(t, 1, p.IdleClosed()) - r, err = p.Get(ctx, nil) + r, err = p.Get(ctx) require.NoError(t, err) assert.EqualValues(t, 2, lastID.Load()) assert.EqualValues(t, 1, count.Load()) @@ -476,7 +404,7 @@ func TestIdleTimeout(t *testing.T) { assert.EqualValues(t, 1, p.IdleClosed()) p.Put(r) - r, err = p.Get(ctx, nil) + r, err = p.Get(ctx) require.NoError(t, err) assert.EqualValues(t, 2, lastID.Load()) assert.EqualValues(t, 1, count.Load()) @@ -493,7 +421,7 @@ func TestIdleTimeout(t *testing.T) { assert.EqualValues(t, 1, p.IdleClosed()) // Get and Put to refresh timeUsed - r, err = p.Get(ctx, nil) + r, err = p.Get(ctx) require.NoError(t, err) p.Put(r) p.SetIdleTimeout(10 * time.Millisecond) @@ -503,93 +431,30 @@ func TestIdleTimeout(t *testing.T) { assert.EqualValues(t, 2, p.IdleClosed()) } -func TestIdleTimeoutWithSettings(t *testing.T) { +func TestIdleTimeoutCreateFail(t *testing.T) { ctx := context.Background() lastID.Store(0) count.Store(0) p := NewResourcePool(PoolFactory, 1, 1, 10*time.Millisecond, 0, logWait, nil, 0) defer p.Close() - r, err := p.Get(ctx, sFooBar) + r, err := p.Get(ctx) require.NoError(t, err) - assert.EqualValues(t, 1, count.Load()) - assert.EqualValues(t, 0, p.IdleClosed()) - + // Change the factory before putting back + // to prevent race with the idle closer, who will + // try to use it. + p.factory = FailFactory p.Put(r) - assert.EqualValues(t, 1, lastID.Load()) - assert.EqualValues(t, 1, count.Load()) - assert.EqualValues(t, 0, p.IdleClosed()) - - time.Sleep(15 * time.Millisecond) - assert.EqualValues(t, 1, count.Load()) - assert.EqualValues(t, 1, p.IdleClosed()) - - r, err = p.Get(ctx, sFooBar) - require.NoError(t, err) - assert.EqualValues(t, 2, lastID.Load()) - assert.EqualValues(t, 1, count.Load()) - assert.EqualValues(t, 1, p.IdleClosed()) - - // sleep to let the idle closer run while all resources are in use - // then make sure things are still as we expect - time.Sleep(15 * time.Millisecond) - assert.EqualValues(t, 2, lastID.Load()) - assert.EqualValues(t, 1, count.Load()) - assert.EqualValues(t, 1, p.IdleClosed()) - - p.Put(r) - r, err = p.Get(ctx, sFooBar) - require.NoError(t, err) - assert.EqualValues(t, 2, lastID.Load()) - assert.EqualValues(t, 1, count.Load()) - assert.EqualValues(t, 1, p.IdleClosed()) - - // the idle close thread wakes up every 1/100 of the idle time, so ensure - // the timeout change applies to newly added resources - p.SetIdleTimeout(1000 * time.Millisecond) - p.Put(r) - - time.Sleep(15 * time.Millisecond) - assert.EqualValues(t, 2, lastID.Load()) - assert.EqualValues(t, 1, count.Load()) - assert.EqualValues(t, 1, p.IdleClosed()) - - // Get and Put to refresh timeUsed - r, err = p.Get(ctx, sFooBar) - require.NoError(t, err) - p.Put(r) - p.SetIdleTimeout(10 * time.Millisecond) - time.Sleep(15 * time.Millisecond) - assert.EqualValues(t, 3, lastID.Load()) - assert.EqualValues(t, 1, count.Load()) - assert.EqualValues(t, 2, p.IdleClosed()) -} - -func TestIdleTimeoutCreateFail(t *testing.T) { - ctx := context.Background() - lastID.Store(0) - count.Store(0) - p := NewResourcePool(PoolFactory, 1, 1, 10*time.Millisecond, 0, logWait, nil, 0) - defer p.Close() - for _, setting := range []*Setting{nil, sFoo} { - r, err := p.Get(ctx, setting) - require.NoError(t, err) - // Change the factory before putting back - // to prevent race with the idle closer, who will - // try to use it. - p.factory = FailFactory - p.Put(r) - timeout := time.After(1 * time.Second) - for p.Active() != 0 { - select { - case <-timeout: - t.Errorf("Timed out waiting for resource to be closed by idle timeout") - default: - } + timeout := time.After(1 * time.Second) + for p.Active() != 0 { + select { + case <-timeout: + t.Errorf("Timed out waiting for resource to be closed by idle timeout") + default: } - // reset factory for next run. - p.factory = PoolFactory } + // reset factory for next run. + p.factory = PoolFactory } func TestMaxLifetime(t *testing.T) { @@ -601,7 +466,7 @@ func TestMaxLifetime(t *testing.T) { p := NewResourcePool(PoolFactory, 1, 1, 10*time.Second, 0, logWait, nil, 0) defer p.Close() - r, err := p.Get(ctx, nil) + r, err := p.Get(ctx) require.NoError(t, err) assert.EqualValues(t, 1, count.Load()) assert.EqualValues(t, 0, p.MaxLifetimeClosed()) @@ -621,7 +486,7 @@ func TestMaxLifetime(t *testing.T) { p = NewResourcePool(PoolFactory, 1, 1, 10*time.Second, 10*time.Millisecond, logWait, nil, 0) defer p.Close() - r, err = p.Get(ctx, nil) + r, err = p.Get(ctx) require.NoError(t, err) assert.EqualValues(t, 1, count.Load()) assert.EqualValues(t, 0, p.MaxLifetimeClosed()) @@ -633,7 +498,7 @@ func TestMaxLifetime(t *testing.T) { assert.EqualValues(t, 1, count.Load()) assert.EqualValues(t, 0, p.MaxLifetimeClosed()) - r, err = p.Get(ctx, nil) + r, err = p.Get(ctx) require.NoError(t, err) assert.EqualValues(t, 1, count.Load()) assert.EqualValues(t, 0, p.MaxLifetimeClosed()) @@ -669,14 +534,12 @@ func TestCreateFail(t *testing.T) { p := NewResourcePool(FailFactory, 5, 5, time.Second, 0, logWait, nil, 0) defer p.Close() - for _, setting := range []*Setting{nil, sFoo} { - if _, err := p.Get(ctx, setting); err.Error() != "Failed" { - t.Errorf("Expecting Failed, received %v", err) - } - stats := p.StatsJSON() - expected := `{"Capacity": 5, "Available": 5, "Active": 0, "InUse": 0, "MaxCapacity": 5, "WaitCount": 0, "WaitTime": 0, "IdleTimeout": 1000000000, "IdleClosed": 0, "MaxLifetimeClosed": 0, "Exhausted": 0}` - assert.Equal(t, expected, stats) + if _, err := p.Get(ctx); err.Error() != "Failed" { + t.Errorf("Expecting Failed, received %v", err) } + stats := p.StatsJSON() + expected := `{"Capacity": 5, "Available": 5, "Active": 0, "InUse": 0, "MaxCapacity": 5, "WaitCount": 0, "WaitTime": 0, "IdleTimeout": 1000000000, "IdleClosed": 0, "MaxLifetimeClosed": 0, "Exhausted": 0}` + assert.Equal(t, expected, stats) } func TestCreateFailOnPut(t *testing.T) { @@ -686,18 +549,13 @@ func TestCreateFailOnPut(t *testing.T) { p := NewResourcePool(PoolFactory, 5, 5, time.Second, 0, logWait, nil, 0) defer p.Close() - for _, setting := range []*Setting{nil, sFoo} { - _, err := p.Get(ctx, setting) - require.NoError(t, err) - - // change factory to fail the put. - p.factory = FailFactory - p.Put(nil) - assert.Zero(t, p.Active()) + _, err := p.Get(ctx) + require.NoError(t, err) - // change back for next iteration. - p.factory = PoolFactory - } + // change factory to fail the put. + p.factory = FailFactory + p.Put(nil) + assert.Zero(t, p.Active()) } func TestSlowCreateFail(t *testing.T) { @@ -707,19 +565,18 @@ func TestSlowCreateFail(t *testing.T) { p := NewResourcePool(SlowFailFactory, 2, 2, time.Second, 0, logWait, nil, 0) defer p.Close() ch := make(chan bool) - for _, setting := range []*Setting{nil, sFoo} { - // The third Get should not wait indefinitely - for i := 0; i < 3; i++ { - go func() { - p.Get(ctx, setting) - ch <- true - }() - } - for i := 0; i < 3; i++ { - <-ch - } - assert.EqualValues(t, 2, p.Available()) + + // The third Get should not wait indefinitely + for i := 0; i < 3; i++ { + go func() { + p.Get(ctx) + ch <- true + }() } + for i := 0; i < 3; i++ { + <-ch + } + assert.EqualValues(t, 2, p.Available()) } func TestTimeout(t *testing.T) { @@ -730,17 +587,14 @@ func TestTimeout(t *testing.T) { defer p.Close() // take the only connection available - r, err := p.Get(ctx, nil) + r, err := p.Get(ctx) require.NoError(t, err) - for _, setting := range []*Setting{nil, sFoo} { - // trying to get the connection without a timeout. - newctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) - _, err = p.Get(newctx, setting) - cancel() - assert.EqualError(t, err, "resource pool timed out") - - } + // trying to get the connection without a timeout. + newctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) + _, err = p.Get(newctx) + cancel() + assert.EqualError(t, err, "resource pool timed out") // put the connection take was taken initially. p.Put(r) @@ -752,178 +606,9 @@ func TestExpired(t *testing.T) { p := NewResourcePool(PoolFactory, 1, 1, time.Second, 0, logWait, nil, 0) defer p.Close() - for _, setting := range []*Setting{nil, sFoo} { - // expired context - ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Second)) - _, err := p.Get(ctx, setting) - cancel() - require.EqualError(t, err, "resource pool context already expired") - } -} - -func TestMultiSettings(t *testing.T) { - ctx := context.Background() - lastID.Store(0) - count.Store(0) - waitStarts = waitStarts[:0] - - p := NewResourcePool(PoolFactory, 5, 5, time.Second, 0, logWait, nil, 0) - var resources [10]Resource - var r Resource - var err error - - settings := []*Setting{nil, sFoo, sBar, sBar, sFoo} - - // Test Get - for i := 0; i < 5; i++ { - r, err = p.Get(ctx, settings[i]) - require.NoError(t, err) - resources[i] = r - assert.EqualValues(t, 5-i-1, p.Available()) - assert.Zero(t, p.WaitCount()) - assert.Zero(t, len(waitStarts)) - assert.Zero(t, p.WaitTime()) - assert.EqualValues(t, i+1, lastID.Load()) - assert.EqualValues(t, i+1, count.Load()) - } - - // Test that Get waits - ch := make(chan bool) - go func() { - for i := 0; i < 5; i++ { - r, err = p.Get(ctx, settings[i]) - require.NoError(t, err) - resources[i] = r - } - for i := 0; i < 5; i++ { - p.Put(resources[i]) - } - ch <- true - }() - for i := 0; i < 5; i++ { - // Sleep to ensure the goroutine waits - time.Sleep(10 * time.Millisecond) - p.Put(resources[i]) - } - <-ch - assert.EqualValues(t, 5, p.WaitCount()) - assert.Equal(t, 5, len(waitStarts)) - // verify start times are monotonic increasing - for i := 1; i < len(waitStarts); i++ { - if waitStarts[i].Before(waitStarts[i-1]) { - t.Errorf("Expecting monotonic increasing start times") - } - } - assert.NotZero(t, p.WaitTime()) - assert.EqualValues(t, 5, lastID.Load()) - - // Close - p.Close() - assert.EqualValues(t, 0, p.Capacity()) - assert.EqualValues(t, 0, p.Available()) - assert.EqualValues(t, 0, count.Load()) -} - -func TestMultiSettingsWithReset(t *testing.T) { - ctx := context.Background() - lastID.Store(0) - count.Store(0) - resetCount.Store(0) - - p := NewResourcePool(PoolFactory, 5, 5, time.Second, 0, logWait, nil, 0) - var resources [10]Resource - var r Resource - var err error - - settings := []*Setting{nil, sFoo, sBar, sBar, sFoo} - - // Test Get - for i := 0; i < 5; i++ { - r, err = p.Get(ctx, settings[i]) - require.NoError(t, err) - resources[i] = r - assert.EqualValues(t, 5-i-1, p.Available()) - assert.EqualValues(t, i+1, lastID.Load()) - assert.EqualValues(t, i+1, count.Load()) - } - - // Put all of them back - for i := 0; i < 5; i++ { - p.Put(resources[i]) - } - - // Getting all with same setting. - for i := 0; i < 5; i++ { - r, err = p.Get(ctx, settings[1]) // {foo} - require.NoError(t, err) - p.Put(r) - } - assert.EqualValues(t, 2, resetCount.Load()) // when setting was {bar} and getting for {foo} - assert.EqualValues(t, 5, p.Available()) - assert.EqualValues(t, 5, lastID.Load()) - assert.EqualValues(t, 5, count.Load()) - - // Close - p.Close() - assert.EqualValues(t, 0, p.Capacity()) - assert.EqualValues(t, 0, p.Available()) - assert.EqualValues(t, 0, count.Load()) -} - -func TestApplySettingsFailure(t *testing.T) { - ctx := context.Background() - var resources []Resource - var r Resource - var err error - - p := NewResourcePool(PoolFactory, 5, 5, time.Second, 0, logWait, nil, 0) - defer p.Close() - - settings := []*Setting{nil, sFoo, sBar, sBar, sFoo} - // get the resource and mark for failure - for i := 0; i < 5; i++ { - r, err = p.Get(ctx, settings[i]) - require.NoError(t, err) - r.(*TestResource).failApply = true - resources = append(resources, r) - } - // put them back - for _, r = range resources { - p.Put(r) - } - - // any new connection created will fail to apply setting - p.factory = DisallowSettingsFactory - - // Get the resource with "foo" setting - // For an applied connection if the setting are same it will be returned as-is. - // Otherwise, will fail to get the resource. - var failCount int - resources = nil - for i := 0; i < 5; i++ { - r, err = p.Get(ctx, settings[1]) - if err != nil { - failCount++ - assert.EqualError(t, err, "ApplySetting failed") - continue - } - resources = append(resources, r) - } - // put them back - for _, r = range resources { - p.Put(r) - } - require.Equal(t, 3, failCount) - - // should be able to get all the resource with no setting - resources = nil - for i := 0; i < 5; i++ { - r, err = p.Get(ctx, nil) - require.NoError(t, err) - resources = append(resources, r) - } - // put them back - for _, r = range resources { - p.Put(r) - } + // expired context + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Second)) + _, err := p.Get(ctx) + cancel() + require.EqualError(t, err, "resource pool context already expired") } diff --git a/go/pools/rp_bench_test.go b/go/pools/rp_bench_test.go index a045c31d52f..ddd36159a4b 100644 --- a/go/pools/rp_bench_test.go +++ b/go/pools/rp_bench_test.go @@ -35,7 +35,7 @@ func BenchmarkGetPut(b *testing.B) { b.RunParallel(func(pb *testing.PB) { var ctx = context.Background() for pb.Next() { - if conn, err := pool.Get(ctx, nil); err != nil { + if conn, err := pool.Get(ctx); err != nil { b.Error(err) } else { pool.Put(conn) @@ -46,94 +46,3 @@ func BenchmarkGetPut(b *testing.B) { } } } - -func BenchmarkGetPutWithSettings(b *testing.B) { - testResourceFactory := func(context.Context) (Resource, error) { - return &TestResource{}, nil - } - setting := &Setting{query: "set a=1, b=2, c=3"} - for _, size := range []int{64, 128, 512} { - for _, parallelism := range []int{1, 8, 32, 128} { - rName := fmt.Sprintf("x%d-cap%d", parallelism, size) - b.Run(rName, func(b *testing.B) { - pool := NewResourcePool(testResourceFactory, size, size, 0, 0, nil, nil, 0) - defer pool.Close() - - b.ReportAllocs() - b.SetParallelism(parallelism) - b.RunParallel(func(pb *testing.PB) { - var ctx = context.Background() - for pb.Next() { - if conn, err := pool.Get(ctx, setting); err != nil { - b.Error(err) - } else { - pool.Put(conn) - } - } - }) - }) - } - } -} - -func BenchmarkGetPutMixed(b *testing.B) { - testResourceFactory := func(context.Context) (Resource, error) { - return &TestResource{}, nil - } - settings := []*Setting{nil, {query: "set a=1, b=2, c=3"}} - for _, size := range []int{64, 128, 512} { - for _, parallelism := range []int{1, 8, 32, 128} { - rName := fmt.Sprintf("x%d-cap%d", parallelism, size) - b.Run(rName, func(b *testing.B) { - pool := NewResourcePool(testResourceFactory, size, size, 0, 0, nil, nil, 0) - defer pool.Close() - - b.ReportAllocs() - b.SetParallelism(parallelism) - b.RunParallel(func(pb *testing.PB) { - var ctx = context.Background() - i := 0 - for pb.Next() { - if conn, err := pool.Get(ctx, settings[i]); err != nil { - b.Error(err) - } else { - pool.Put(conn) - } - i = (i + 1) % 2 - } - }) - }) - } - } -} - -func BenchmarkGetPutMixedMulti(b *testing.B) { - testResourceFactory := func(context.Context) (Resource, error) { - return &TestResource{}, nil - } - settings := []*Setting{nil, {query: "set a=1"}, {query: "set a=1, b=2"}, {query: "set c=1, d=2, e=3"}, {query: "set x=1, y=2, z=3"}} - for _, size := range []int{64, 128, 512} { - for _, parallelism := range []int{1, 8, 32, 128} { - rName := fmt.Sprintf("x%d-cap%d", parallelism, size) - b.Run(rName, func(b *testing.B) { - pool := NewResourcePool(testResourceFactory, size, size, 0, 0, nil, nil, 0) - defer pool.Close() - - b.ReportAllocs() - b.SetParallelism(parallelism) - b.RunParallel(func(pb *testing.PB) { - var ctx = context.Background() - i := 0 - for pb.Next() { - if conn, err := pool.Get(ctx, settings[i]); err != nil { - b.Error(err) - } else { - pool.Put(conn) - } - i = (i + 1) % 5 - } - }) - }) - } - } -} diff --git a/go/pools/rpc_pool.go b/go/pools/rpc_pool.go index 7ed1349e89e..8f20641dc5d 100644 --- a/go/pools/rpc_pool.go +++ b/go/pools/rpc_pool.go @@ -19,10 +19,6 @@ package pools import ( "context" "time" - - "vitess.io/vitess/go/vt/vterrors" - - vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" ) // RPCPool is a specialized version of the ResourcePool, for bounding concurrent @@ -36,7 +32,7 @@ import ( // one method of acquisition, Acquire(context.Context), which always uses the // lower of the pool-global timeout or the context deadline. type RPCPool struct { - rp IResourcePool + rp *ResourcePool waitTimeout time.Duration } @@ -71,7 +67,7 @@ func (pool *RPCPool) Acquire(ctx context.Context) error { defer cancel() } - _, err := pool.rp.Get(ctx, nil) + _, err := pool.rp.Get(ctx) return err } @@ -92,25 +88,6 @@ var rpc = &_rpc{} // Close implements Resource for _rpc. func (*_rpc) Close() {} -// ApplySetting implements Resource for _rpc. -func (r *_rpc) ApplySetting(context.Context, *Setting) error { - // should be unreachable - return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG]: _rpc does not support ApplySetting") -} - -func (r *_rpc) IsSettingApplied() bool { - return false -} - -func (r *_rpc) IsSameSetting(string) bool { - return true -} - -func (r *_rpc) ResetSetting(context.Context) error { - // should be unreachable - return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG]: _rpc does not support ResetSetting") -} - func (r *_rpc) Expired(time.Duration) bool { return false } diff --git a/go/pools/smartconnpool/benchmarking/legacy/refresh_pool.go b/go/pools/smartconnpool/benchmarking/legacy/refresh_pool.go new file mode 100644 index 00000000000..006497e8168 --- /dev/null +++ b/go/pools/smartconnpool/benchmarking/legacy/refresh_pool.go @@ -0,0 +1,97 @@ +/* +Copyright 2022 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package legacy + +import ( + "sync" + "time" + + "vitess.io/vitess/go/vt/log" +) + +type ( + // RefreshCheck is a function used to determine if a resource pool should be + // refreshed (i.e. closed and reopened) + RefreshCheck func() (bool, error) + + // poolRefresh refreshes the pool by calling the RefreshCheck function. + // If the RefreshCheck returns true, the pool is closed and reopened. + poolRefresh struct { + refreshCheck RefreshCheck + refreshInterval time.Duration + refreshTicker *time.Ticker + refreshStop chan struct{} + refreshWg sync.WaitGroup + + pool refreshPool + } +) + +type refreshPool interface { + // reopen drains and reopens the connection pool + reopen() + + // closeIdleResources scans the pool for idle resources and closes them. + closeIdleResources() +} + +func newPoolRefresh(pool refreshPool, refreshCheck RefreshCheck, refreshInterval time.Duration) *poolRefresh { + if refreshCheck == nil || refreshInterval <= 0 { + return nil + } + return &poolRefresh{ + pool: pool, + refreshInterval: refreshInterval, + refreshCheck: refreshCheck, + } +} + +func (pr *poolRefresh) startRefreshTicker() { + if pr == nil { + return + } + pr.refreshTicker = time.NewTicker(pr.refreshInterval) + pr.refreshStop = make(chan struct{}) + pr.refreshWg.Add(1) + go func() { + defer pr.refreshWg.Done() + for { + select { + case <-pr.refreshTicker.C: + val, err := pr.refreshCheck() + if err != nil { + log.Info(err) + } + if val { + go pr.pool.reopen() + return + } + case <-pr.refreshStop: + return + } + } + }() +} + +func (pr *poolRefresh) stop() { + if pr == nil || pr.refreshTicker == nil { + return + } + pr.refreshTicker.Stop() + close(pr.refreshStop) + pr.refreshWg.Wait() +} diff --git a/go/pools/smartconnpool/benchmarking/legacy/resource_pool.go b/go/pools/smartconnpool/benchmarking/legacy/resource_pool.go new file mode 100644 index 00000000000..df8c44e1530 --- /dev/null +++ b/go/pools/smartconnpool/benchmarking/legacy/resource_pool.go @@ -0,0 +1,610 @@ +/* +Copyright 2019 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package legacy + +import ( + "context" + "errors" + "fmt" + "math/rand" + "sync" + "sync/atomic" + "time" + + "vitess.io/vitess/go/pools/smartconnpool" + "vitess.io/vitess/go/timer" + "vitess.io/vitess/go/vt/log" + "vitess.io/vitess/go/vt/vterrors" + + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" +) + +type ( + IResourcePool interface { + Close() + Name() string + Get(ctx context.Context, setting *Setting) (resource Resource, err error) + Put(resource Resource) + SetCapacity(capacity int) error + SetIdleTimeout(idleTimeout time.Duration) + StatsJSON() string + Capacity() int64 + Available() int64 + Active() int64 + InUse() int64 + MaxCap() int64 + WaitCount() int64 + WaitTime() time.Duration + IdleTimeout() time.Duration + IdleClosed() int64 + MaxLifetimeClosed() int64 + Exhausted() int64 + GetCount() int64 + GetSettingCount() int64 + DiffSettingCount() int64 + ResetSettingCount() int64 + } + + // Resource defines the interface that every resource must provide. + // Thread synchronization between Close() and IsClosed() + // is the responsibility of the caller. + Resource interface { + Close() + Expired(time.Duration) bool + ApplySetting(ctx context.Context, setting *Setting) error + IsSettingApplied() bool + IsSameSetting(setting string) bool + ResetSetting(ctx context.Context) error + } + + // Factory is a function that can be used to create a resource. + Factory func(context.Context) (Resource, error) + + resourceWrapper struct { + resource Resource + timeUsed time.Time + } + + // Setting represents a set query and reset query for system settings. + Setting = smartconnpool.Setting + + // ResourcePool allows you to use a pool of resources. + ResourcePool struct { + available atomic.Int64 + active atomic.Int64 + inUse atomic.Int64 + waitCount atomic.Int64 + waitTime atomic.Int64 + idleClosed atomic.Int64 + maxLifetimeClosed atomic.Int64 + exhausted atomic.Int64 + + capacity atomic.Int64 + idleTimeout atomic.Int64 + maxLifetime atomic.Int64 + + resources chan resourceWrapper + factory Factory + idleTimer *timer.Timer + logWait func(time.Time) + + settingResources chan resourceWrapper + getCount atomic.Int64 + getSettingCount atomic.Int64 + diffSettingCount atomic.Int64 + resetSettingCount atomic.Int64 + + reopenMutex sync.Mutex + refresh *poolRefresh + } +) + +var ( + // ErrClosed is returned if ResourcePool is used when it's closed. + ErrClosed = errors.New("resource pool is closed") + + // ErrTimeout is returned if a resource get times out. + ErrTimeout = vterrors.New(vtrpcpb.Code_RESOURCE_EXHAUSTED, "resource pool timed out") + + // ErrCtxTimeout is returned if a ctx is already expired by the time the resource pool is used + ErrCtxTimeout = vterrors.New(vtrpcpb.Code_DEADLINE_EXCEEDED, "resource pool context already expired") +) + +// NewResourcePool creates a new ResourcePool pool. +// capacity is the number of possible resources in the pool: +// there can be up to 'capacity' of these at a given time. +// maxCap specifies the extent to which the pool can be resized +// in the future through the SetCapacity function. +// You cannot resize the pool beyond maxCap. +// If a resource is unused beyond idleTimeout, it's replaced +// with a new one. +// An idleTimeout of 0 means that there is no timeout. +// An maxLifetime of 0 means that there is no timeout. +// A non-zero value of prefillParallelism causes the pool to be pre-filled. +// The value specifies how many resources can be opened in parallel. +// refreshCheck is a function we consult at refreshInterval +// intervals to determine if the pool should be drained and reopened +func NewResourcePool(factory Factory, capacity, maxCap int, idleTimeout time.Duration, maxLifetime time.Duration, logWait func(time.Time), refreshCheck RefreshCheck, refreshInterval time.Duration) *ResourcePool { + if capacity <= 0 || maxCap <= 0 || capacity > maxCap { + panic(errors.New("invalid/out of range capacity")) + } + rp := &ResourcePool{ + resources: make(chan resourceWrapper, maxCap), + settingResources: make(chan resourceWrapper, maxCap), + factory: factory, + logWait: logWait, + } + rp.available.Store(int64(capacity)) + rp.capacity.Store(int64(capacity)) + rp.idleTimeout.Store(idleTimeout.Nanoseconds()) + rp.maxLifetime.Store(maxLifetime.Nanoseconds()) + + for i := 0; i < capacity; i++ { + rp.resources <- resourceWrapper{} + } + + if idleTimeout != 0 { + rp.idleTimer = timer.NewTimer(idleTimeout / 10) + rp.idleTimer.Start(rp.closeIdleResources) + } + + rp.refresh = newPoolRefresh(rp, refreshCheck, refreshInterval) + rp.refresh.startRefreshTicker() + + return rp +} + +func (rp *ResourcePool) Name() string { + return "ResourcePool" +} + +// Close empties the pool calling Close on all its resources. +// You can call Close while there are outstanding resources. +// It waits for all resources to be returned (Put). +// After a Close, Get is not allowed. +func (rp *ResourcePool) Close() { + if rp.idleTimer != nil { + rp.idleTimer.Stop() + } + rp.refresh.stop() + _ = rp.SetCapacity(0) +} + +// closeIdleResources scans the pool for idle resources +func (rp *ResourcePool) closeIdleResources() { + available := int(rp.Available()) + idleTimeout := rp.IdleTimeout() + + for i := 0; i < available; i++ { + var wrapper resourceWrapper + var origPool bool + select { + case wrapper = <-rp.resources: + origPool = true + case wrapper = <-rp.settingResources: + origPool = false + default: + // stop early if we don't get anything new from the pool + return + } + + var reopened bool + if wrapper.resource != nil && idleTimeout > 0 && time.Until(wrapper.timeUsed.Add(idleTimeout)) < 0 { + wrapper.resource.Close() + rp.idleClosed.Add(1) + rp.reopenResource(&wrapper) + reopened = true + } + rp.returnResource(&wrapper, origPool, reopened) + } +} + +func (rp *ResourcePool) returnResource(wrapper *resourceWrapper, origPool bool, reopened bool) { + if origPool || reopened { + rp.resources <- *wrapper + } else { + rp.settingResources <- *wrapper + } +} + +// reopen drains and reopens the connection pool +func (rp *ResourcePool) reopen() { + rp.reopenMutex.Lock() // Avoid race, since we can refresh asynchronously + defer rp.reopenMutex.Unlock() + capacity := int(rp.capacity.Load()) + log.Infof("Draining and reopening resource pool with capacity %d by request", capacity) + rp.Close() + _ = rp.SetCapacity(capacity) + if rp.idleTimer != nil { + rp.idleTimer.Start(rp.closeIdleResources) + } + rp.refresh.startRefreshTicker() +} + +// Get will return the next available resource. If capacity +// has not been reached, it will create a new one using the factory. Otherwise, +// it will wait till the next resource becomes available or a timeout. +// A timeout of 0 is an indefinite wait. +func (rp *ResourcePool) Get(ctx context.Context, setting *Setting) (resource Resource, err error) { + // If ctx has already expired, avoid racing with rp's resource channel. + if ctx.Err() != nil { + return nil, ErrCtxTimeout + } + if setting == nil { + return rp.get(ctx) + } + return rp.getWithSettings(ctx, setting) +} + +func (rp *ResourcePool) get(ctx context.Context) (resource Resource, err error) { + rp.getCount.Add(1) + // Fetch + var wrapper resourceWrapper + var ok bool + // If we put both the channel together, then, go select can read from any channel + // this way we guarantee it will try to read from the channel we intended to read it from first + // and then try to read from next best available resource. + select { + // check normal resources first + case wrapper, ok = <-rp.resources: + default: + select { + // then checking setting resources + case wrapper, ok = <-rp.settingResources: + default: + // now waiting + startTime := time.Now() + select { + case wrapper, ok = <-rp.resources: + case wrapper, ok = <-rp.settingResources: + case <-ctx.Done(): + return nil, ErrTimeout + } + rp.recordWait(startTime) + } + } + if !ok { + return nil, ErrClosed + } + + // if the resource has setting applied, we will close it and return a new one + if wrapper.resource != nil && wrapper.resource.IsSettingApplied() { + rp.resetSettingCount.Add(1) + err = wrapper.resource.ResetSetting(ctx) + if err != nil { + // as reset is unsuccessful, we will close this resource + wrapper.resource.Close() + wrapper.resource = nil + rp.active.Add(-1) + } + } + + // Unwrap + if wrapper.resource == nil { + wrapper.resource, err = rp.factory(ctx) + if err != nil { + rp.resources <- resourceWrapper{} + return nil, err + } + rp.active.Add(1) + } + if rp.available.Add(-1) <= 0 { + rp.exhausted.Add(1) + } + rp.inUse.Add(1) + return wrapper.resource, err +} + +func (rp *ResourcePool) getWithSettings(ctx context.Context, setting *Setting) (Resource, error) { + rp.getSettingCount.Add(1) + var wrapper resourceWrapper + var ok bool + var err error + + // Fetch + select { + // check setting resources first + case wrapper, ok = <-rp.settingResources: + default: + select { + // then, check normal resources + case wrapper, ok = <-rp.resources: + default: + // now waiting + startTime := time.Now() + select { + case wrapper, ok = <-rp.settingResources: + case wrapper, ok = <-rp.resources: + case <-ctx.Done(): + return nil, ErrTimeout + } + rp.recordWait(startTime) + } + } + if !ok { + return nil, ErrClosed + } + + // Checking setting hash id, if it is different, we will close the resource and return a new one later in unwrap + if wrapper.resource != nil && wrapper.resource.IsSettingApplied() && !wrapper.resource.IsSameSetting(setting.ApplyQuery()) { + rp.diffSettingCount.Add(1) + err = wrapper.resource.ResetSetting(ctx) + if err != nil { + // as reset is unsuccessful, we will close this resource + wrapper.resource.Close() + wrapper.resource = nil + rp.active.Add(-1) + } + } + + // Unwrap + if wrapper.resource == nil { + wrapper.resource, err = rp.factory(ctx) + if err != nil { + rp.resources <- resourceWrapper{} + return nil, err + } + rp.active.Add(1) + } + + if !wrapper.resource.IsSettingApplied() { + if err = wrapper.resource.ApplySetting(ctx, setting); err != nil { + // as we are not able to apply setting, we can return this connection to non-setting channel. + // TODO: may check the error code to see if it is recoverable or not. + rp.resources <- wrapper + return nil, err + } + } + + if rp.available.Add(-1) <= 0 { + rp.exhausted.Add(1) + } + rp.inUse.Add(1) + return wrapper.resource, err +} + +// Put will return a resource to the pool. For every successful Get, +// a corresponding Put is required. If you no longer need a resource, +// you will need to call Put(nil) instead of returning the closed resource. +// This will cause a new resource to be created in its place. +func (rp *ResourcePool) Put(resource Resource) { + var wrapper resourceWrapper + var recreated bool + var hasSettings bool + if resource != nil { + wrapper = resourceWrapper{ + resource: resource, + timeUsed: time.Now(), + } + hasSettings = resource.IsSettingApplied() + if resource.Expired(rp.extendedMaxLifetime()) { + rp.maxLifetimeClosed.Add(1) + resource.Close() + resource = nil + } + } + if resource == nil { + // Create new resource + rp.reopenResource(&wrapper) + recreated = true + } + if !hasSettings || recreated { + select { + case rp.resources <- wrapper: + default: + panic(errors.New("attempt to Put into a full ResourcePool")) + } + } else { + select { + case rp.settingResources <- wrapper: + default: + panic(errors.New("attempt to Put into a full ResourcePool")) + } + } + rp.inUse.Add(-1) + rp.available.Add(1) +} + +func (rp *ResourcePool) reopenResource(wrapper *resourceWrapper) { + if r, err := rp.factory(context.TODO()); err == nil { + wrapper.resource = r + wrapper.timeUsed = time.Now() + } else { + wrapper.resource = nil + rp.active.Add(-1) + } +} + +// SetCapacity changes the capacity of the pool. +// You can use it to shrink or expand, but not beyond +// the max capacity. If the change requires the pool +// to be shrunk, SetCapacity waits till the necessary +// number of resources are returned to the pool. +// A SetCapacity of 0 is equivalent to closing the ResourcePool. +func (rp *ResourcePool) SetCapacity(capacity int) error { + if capacity < 0 || capacity > cap(rp.resources) { + return fmt.Errorf("capacity %d is out of range", capacity) + } + + // Atomically swap new capacity with old + var oldcap int + for { + oldcap = int(rp.capacity.Load()) + if oldcap == 0 && capacity > 0 { + // Closed this before, re-open the channel + rp.resources = make(chan resourceWrapper, cap(rp.resources)) + rp.settingResources = make(chan resourceWrapper, cap(rp.settingResources)) + } + if oldcap == capacity { + return nil + } + if rp.capacity.CompareAndSwap(int64(oldcap), int64(capacity)) { + break + } + } + + // If the required capacity is less than the current capacity, + // then we need to wait till the current resources are returned + // to the pool and close them from any of the channel. + // Otherwise, if the required capacity is more than the current capacity, + // then we just add empty resource to the channel. + if capacity < oldcap { + for i := 0; i < oldcap-capacity; i++ { + var wrapper resourceWrapper + select { + case wrapper = <-rp.resources: + case wrapper = <-rp.settingResources: + } + if wrapper.resource != nil { + wrapper.resource.Close() + rp.active.Add(-1) + } + rp.available.Add(-1) + } + } else { + for i := 0; i < capacity-oldcap; i++ { + rp.resources <- resourceWrapper{} + rp.available.Add(1) + } + } + if capacity == 0 { + close(rp.resources) + close(rp.settingResources) + } + return nil +} + +func (rp *ResourcePool) recordWait(start time.Time) { + rp.waitCount.Add(1) + rp.waitTime.Add(time.Since(start).Nanoseconds()) + if rp.logWait != nil { + rp.logWait(start) + } +} + +// SetIdleTimeout sets the idle timeout. It can only be used if there was an +// idle timeout set when the pool was created. +func (rp *ResourcePool) SetIdleTimeout(idleTimeout time.Duration) { + if rp.idleTimer == nil { + panic("SetIdleTimeout called when timer not initialized") + } + + rp.idleTimeout.Store(idleTimeout.Nanoseconds()) + rp.idleTimer.SetInterval(idleTimeout / 10) +} + +// StatsJSON returns the stats in JSON format. +func (rp *ResourcePool) StatsJSON() string { + return fmt.Sprintf(`{"Capacity": %v, "Available": %v, "Active": %v, "InUse": %v, "MaxCapacity": %v, "WaitCount": %v, "WaitTime": %v, "IdleTimeout": %v, "IdleClosed": %v, "MaxLifetimeClosed": %v, "Exhausted": %v}`, + rp.Capacity(), + rp.Available(), + rp.Active(), + rp.InUse(), + rp.MaxCap(), + rp.WaitCount(), + rp.WaitTime().Nanoseconds(), + rp.IdleTimeout().Nanoseconds(), + rp.IdleClosed(), + rp.MaxLifetimeClosed(), + rp.Exhausted(), + ) +} + +// Capacity returns the capacity. +func (rp *ResourcePool) Capacity() int64 { + return rp.capacity.Load() +} + +// Available returns the number of currently unused and available resources. +func (rp *ResourcePool) Available() int64 { + return rp.available.Load() +} + +// Active returns the number of active (i.e. non-nil) resources either in the +// pool or claimed for use +func (rp *ResourcePool) Active() int64 { + return rp.active.Load() +} + +// InUse returns the number of claimed resources from the pool +func (rp *ResourcePool) InUse() int64 { + return rp.inUse.Load() +} + +// MaxCap returns the max capacity. +func (rp *ResourcePool) MaxCap() int64 { + return int64(cap(rp.resources)) +} + +// WaitCount returns the total number of waits. +func (rp *ResourcePool) WaitCount() int64 { + return rp.waitCount.Load() +} + +// WaitTime returns the total wait time. +func (rp *ResourcePool) WaitTime() time.Duration { + return time.Duration(rp.waitTime.Load()) +} + +// IdleTimeout returns the resource idle timeout. +func (rp *ResourcePool) IdleTimeout() time.Duration { + return time.Duration(rp.idleTimeout.Load()) +} + +// IdleClosed returns the count of resources closed due to idle timeout. +func (rp *ResourcePool) IdleClosed() int64 { + return rp.idleClosed.Load() +} + +// extendedMaxLifetime returns random duration within range [maxLifetime, 2*maxLifetime) +func (rp *ResourcePool) extendedMaxLifetime() time.Duration { + maxLifetime := rp.maxLifetime.Load() + if maxLifetime == 0 { + return 0 + } + return time.Duration(maxLifetime + rand.Int63n(maxLifetime)) +} + +// MaxLifetimeClosed returns the count of resources closed due to refresh timeout. +func (rp *ResourcePool) MaxLifetimeClosed() int64 { + return rp.maxLifetimeClosed.Load() +} + +// Exhausted returns the number of times Available dropped below 1 +func (rp *ResourcePool) Exhausted() int64 { + return rp.exhausted.Load() +} + +// GetCount returns the number of times get was called +func (rp *ResourcePool) GetCount() int64 { + return rp.getCount.Load() +} + +// GetSettingCount returns the number of times getWithSettings was called +func (rp *ResourcePool) GetSettingCount() int64 { + return rp.getSettingCount.Load() +} + +// DiffSettingCount returns the number of times different setting were applied on the resource. +func (rp *ResourcePool) DiffSettingCount() int64 { + return rp.diffSettingCount.Load() +} + +// ResetSettingCount returns the number of times setting were reset on the resource. +func (rp *ResourcePool) ResetSettingCount() int64 { + return rp.resetSettingCount.Load() +} diff --git a/go/pools/smartconnpool/benchmarking/load_test.go b/go/pools/smartconnpool/benchmarking/load_test.go new file mode 100644 index 00000000000..537daf2c357 --- /dev/null +++ b/go/pools/smartconnpool/benchmarking/load_test.go @@ -0,0 +1,494 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package benchmarking_test + +import ( + "context" + "encoding/json" + "fmt" + "math" + "math/rand" + "os" + "sort" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + "gonum.org/v1/gonum/floats" + "gonum.org/v1/gonum/stat/distuv" + + "vitess.io/vitess/go/pools/smartconnpool" + pools "vitess.io/vitess/go/pools/smartconnpool/benchmarking/legacy" +) + +type Request struct { + Delay time.Duration + Duration time.Duration + Setting int +} + +type ConnStats struct { + Requests int + Reset int + Apply int +} + +type BenchConn struct { + Stats ConnStats + + setting *smartconnpool.Setting + latency time.Duration + closed bool +} + +func (b *BenchConn) Expired(_ time.Duration) bool { + return false +} + +func (b *BenchConn) IsSettingApplied() bool { + return b.setting != nil +} + +func (b *BenchConn) IsSameSetting(setting string) bool { + return b.setting != nil && b.setting.ApplyQuery() == setting +} + +var _ smartconnpool.Connection = (*BenchConn)(nil) +var _ pools.Resource = (*BenchConn)(nil) + +func (b *BenchConn) ApplySetting(ctx context.Context, setting *smartconnpool.Setting) error { + time.Sleep(b.latency) + b.setting = setting + b.Stats.Apply++ + return nil +} + +func (b *BenchConn) ResetSetting(ctx context.Context) error { + time.Sleep(b.latency) + b.setting = nil + b.Stats.Reset++ + return nil +} + +func (b *BenchConn) Setting() *smartconnpool.Setting { + return b.setting +} + +func (b *BenchConn) IsClosed() bool { + return b.closed +} + +func (b *BenchConn) Close() { + b.closed = true +} + +type Trace []Request +type Perform func(ctx context.Context, setting *smartconnpool.Setting, delay time.Duration) + +type Benchmark struct { + t testing.TB + name string + trace Trace + settings []*smartconnpool.Setting + latency time.Duration + + wg sync.WaitGroup + progress atomic.Int64 + concurrent atomic.Int64 + concurrency []int64 + + mu sync.Mutex + waits []time.Duration + connstats []*ConnStats +} + +func NewBenchmark(t testing.TB, name string, opts *TraceOptions) *Benchmark { + bench := &Benchmark{ + t: t, + name: name, + trace: opts.Generate(), + latency: opts.Latency, + } + + bench.settings = append(bench.settings, nil) + for i := 1; i < len(opts.Settings); i++ { + bench.settings = append(bench.settings, smartconnpool.NewSetting(fmt.Sprintf("set setting%d=1", i), "")) + } + + return bench +} + +func (b *Benchmark) displayProgress(done <-chan struct{}, total int) { + tick1 := time.NewTicker(time.Second) + defer tick1.Stop() + + tick2 := time.NewTicker(100 * time.Millisecond) + defer tick2.Stop() + + for { + select { + case <-done: + return + case <-tick1.C: + count := b.progress.Load() + b.t.Logf("benchmark: %d/%d (%.02f%%), concurrency = %v", count, total, 100*float64(count)/float64(total), b.concurrency[len(b.concurrency)-1]) + + case <-tick2.C: + b.concurrency = append(b.concurrency, b.concurrent.Load()) + } + } +} + +func (b *Benchmark) run(perform Perform) { + trace := b.trace + + b.progress.Store(0) + b.concurrent.Store(0) + b.waits = make([]time.Duration, 0, len(trace)) + b.connstats = make([]*ConnStats, 0, 64) + b.concurrency = nil + + done := make(chan struct{}) + go b.displayProgress(done, len(trace)) + + b.wg.Add(len(trace)) + + for _, req := range trace { + b.progress.Add(1) + time.Sleep(req.Delay) + + go func(req Request) { + b.concurrent.Add(1) + defer func() { + b.concurrent.Add(-1) + b.wg.Done() + }() + + start := time.Now() + perform(context.Background(), b.settings[req.Setting], req.Duration) + wait := time.Since(start) - req.Duration + + b.mu.Lock() + b.waits = append(b.waits, wait) + b.mu.Unlock() + }(req) + } + + b.wg.Wait() + close(done) +} + +func (b *Benchmark) waitTotal() (t time.Duration) { + for _, w := range b.waits { + t += w + } + return +} + +type InternalStatistics struct { + Capacity int + WaitCount int64 + WaitTime time.Duration + DiffCount int64 + ResetCount int64 +} + +type Statistics struct { + Connections []*ConnStats + Waits []time.Duration + Trace []Request + + Settings int + Internal InternalStatistics +} + +func (b *Benchmark) serialize(suffix string, internal *InternalStatistics) { + stats := &Statistics{ + Connections: b.connstats, + Waits: b.waits, + Trace: b.trace, + Settings: len(b.settings), + Internal: *internal, + } + + f, err := os.Create(b.name + "_pool_" + suffix + ".json") + require.NoError(b.t, err) + defer f.Close() + + enc := json.NewEncoder(f) + enc.SetEscapeHTML(false) + enc.Encode(stats) + + b.t.Logf("written %s", f.Name()) +} + +func (b *Benchmark) ResourcePool(capacity int) { + factory := func(ctx context.Context) (pools.Resource, error) { + conn := &BenchConn{latency: b.latency} + + b.mu.Lock() + b.connstats = append(b.connstats, &conn.Stats) + b.mu.Unlock() + + return conn, nil + } + pool := pools.NewResourcePool(factory, capacity, capacity, 0, 0, nil, nil, 0) + + perform := func(ctx context.Context, setting *smartconnpool.Setting, delay time.Duration) { + conn, err := pool.Get(context.Background(), setting) + if err != nil { + panic(err) + } + + conn.(*BenchConn).Stats.Requests++ + time.Sleep(delay) + pool.Put(conn) + } + + b.run(perform) + b.serialize("before", &InternalStatistics{ + Capacity: capacity, + WaitCount: pool.WaitCount(), + WaitTime: pool.WaitTime(), + DiffCount: pool.DiffSettingCount(), + ResetCount: pool.ResetSettingCount(), + }) +} + +func (b *Benchmark) SmartConnPool(capacity int) { + connect := func(ctx context.Context) (*BenchConn, error) { + conn := &BenchConn{latency: b.latency} + + b.mu.Lock() + b.connstats = append(b.connstats, &conn.Stats) + b.mu.Unlock() + + return conn, nil + } + + pool := smartconnpool.NewPool(&smartconnpool.Config[*BenchConn]{ + Capacity: int64(capacity), + }).Open(connect, nil) + + perform := func(ctx context.Context, setting *smartconnpool.Setting, delay time.Duration) { + conn, err := pool.Get(context.Background(), setting) + if err != nil { + panic(err) + } + + conn.Conn.Stats.Requests++ + time.Sleep(delay) + conn.Recycle() + } + + b.run(perform) + b.serialize("after", &InternalStatistics{ + Capacity: capacity, + WaitCount: pool.Metrics.WaitCount(), + WaitTime: pool.Metrics.WaitTime(), + DiffCount: pool.Metrics.DiffSettingCount(), + ResetCount: pool.Metrics.ResetSettingCount(), + }) +} + +type TraceOptions struct { + RequestsPerSecond int + DecayRate float64 + Duration time.Duration + Latency time.Duration + Settings []float64 +} + +func (opt *TraceOptions) arrivalTimes() (out []time.Duration) { + var t time.Duration + for t < opt.Duration { + currentRate := float64(opt.RequestsPerSecond) * math.Exp(-opt.DecayRate*t.Seconds()) + interArrivalTime := time.Duration((rand.ExpFloat64() / currentRate) * float64(time.Second)) + if interArrivalTime >= opt.Duration { + continue + } + + out = append(out, interArrivalTime) + t += interArrivalTime + } + return +} + +func weightedDraw(p []float64, n int) []int { + // Initialization: create the discrete CDF + // We know that cdf is sorted in ascending order + cdf := make([]float64, len(p)) + floats.CumSum(cdf, p) + // Generation: + // 1. Generate a uniformly-random value x in the range [0,1) + // 2. Using a binary search, find the index of the smallest element in cdf larger than x + var val float64 + indices := make([]int, n) + for i := range indices { + // multiply the sample with the largest CDF value; easier than normalizing to [0,1) + val = distuv.UnitUniform.Rand() * cdf[len(cdf)-1] + // Search returns the smallest index i such that cdf[i] > val + indices[i] = sort.Search(len(cdf), func(i int) bool { return cdf[i] > val }) + } + return indices +} + +func (opt *TraceOptions) Generate() Trace { + times := opt.arrivalTimes() + + var settings []int + if len(opt.Settings) > 1 { + settings = weightedDraw(opt.Settings, len(times)) + } + + durations := distuv.Pareto{ + Xm: float64(opt.Latency), + Alpha: 1, + } + + var trace []Request + for i := range times { + req := Request{} + req.Delay = times[i] + req.Duration = time.Duration(durations.Rand()) + for req.Duration > opt.Duration/4 { + req.Duration = time.Duration(durations.Rand()) + } + if settings != nil { + req.Setting = settings[i] + } + + trace = append(trace, req) + } + return trace +} + +func TestPoolPerformance(t *testing.T) { + t.Skipf("skipping load tests...") + + t.Run("Contended", func(t *testing.T) { + opt := TraceOptions{ + RequestsPerSecond: 100, + DecayRate: 0.01, + Duration: 30 * time.Second, + Latency: 15 * time.Millisecond, + Settings: []float64{5, 1, 1, 1}, + } + + bench := NewBenchmark(t, "contended", &opt) + bench.ResourcePool(8) + bench.SmartConnPool(8) + }) + + t.Run("Uncontended", func(t *testing.T) { + opt := TraceOptions{ + RequestsPerSecond: 20, + DecayRate: 0.01, + Duration: 30 * time.Second, + Latency: 15 * time.Millisecond, + Settings: []float64{5, 1, 1, 1, 1, 1}, + } + + bench := NewBenchmark(t, "uncontended", &opt) + bench.ResourcePool(16) + bench.SmartConnPool(16) + }) + + t.Run("Uncontended Without Settings", func(t *testing.T) { + opt := TraceOptions{ + RequestsPerSecond: 20, + DecayRate: 0.01, + Duration: 30 * time.Second, + Latency: 15 * time.Millisecond, + Settings: []float64{5, 1}, + } + + bench := NewBenchmark(t, "uncontended_no_settings", &opt) + bench.ResourcePool(16) + bench.SmartConnPool(16) + }) + + t.Run("Points", func(t *testing.T) { + opt := TraceOptions{ + RequestsPerSecond: 2000, + DecayRate: 0.01, + Duration: 30 * time.Second, + Latency: 2 * time.Millisecond, + Settings: []float64{5, 2, 1, 1}, + } + + bench := NewBenchmark(t, "points", &opt) + bench.ResourcePool(16) + bench.SmartConnPool(16) + }) +} + +func BenchmarkGetPut(b *testing.B) { + connLegacy := func(context.Context) (pools.Resource, error) { + return &BenchConn{}, nil + } + connSmart := func(ctx context.Context) (*BenchConn, error) { + return &BenchConn{}, nil + } + + for _, size := range []int{64, 128, 512} { + for _, parallelism := range []int{1, 8, 32, 128} { + rName := fmt.Sprintf("x%d-cap%d", parallelism, size) + + b.Run("Legacy/"+rName, func(b *testing.B) { + pool := pools.NewResourcePool(connLegacy, size, size, 0, 0, nil, nil, 0) + defer pool.Close() + + b.ReportAllocs() + b.SetParallelism(parallelism) + b.RunParallel(func(pb *testing.PB) { + var ctx = context.Background() + for pb.Next() { + if conn, err := pool.Get(ctx, nil); err != nil { + b.Error(err) + } else { + pool.Put(conn) + } + } + }) + }) + + b.Run("Smart/"+rName, func(b *testing.B) { + pool := smartconnpool.NewPool[*BenchConn](&smartconnpool.Config[*BenchConn]{ + Capacity: int64(size), + }).Open(connSmart, nil) + + defer pool.Close() + + b.ReportAllocs() + b.SetParallelism(parallelism) + b.RunParallel(func(pb *testing.PB) { + var ctx = context.Background() + for pb.Next() { + if conn, err := pool.Get(ctx, nil); err != nil { + b.Error(err) + } else { + conn.Recycle() + } + } + }) + }) + } + } +} diff --git a/go/pools/cached_size.go b/go/pools/smartconnpool/cached_size.go similarity index 78% rename from go/pools/cached_size.go rename to go/pools/smartconnpool/cached_size.go index 79bc4e6e28a..8c985349db3 100644 --- a/go/pools/cached_size.go +++ b/go/pools/smartconnpool/cached_size.go @@ -15,7 +15,7 @@ limitations under the License. */ // Code generated by Sizegen. DO NOT EDIT. -package pools +package smartconnpool import hack "vitess.io/vitess/go/hack" @@ -25,11 +25,11 @@ func (cached *Setting) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(32) + size += int64(48) } - // field query string - size += hack.RuntimeAllocSize(int64(len(cached.query))) - // field resetQuery string - size += hack.RuntimeAllocSize(int64(len(cached.resetQuery))) + // field queryApply string + size += hack.RuntimeAllocSize(int64(len(cached.queryApply))) + // field queryReset string + size += hack.RuntimeAllocSize(int64(len(cached.queryReset))) return size } diff --git a/go/pools/smartconnpool/connection.go b/go/pools/smartconnpool/connection.go new file mode 100644 index 00000000000..cdb5720596e --- /dev/null +++ b/go/pools/smartconnpool/connection.go @@ -0,0 +1,64 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package smartconnpool + +import ( + "context" + "sync/atomic" + "time" +) + +type Connection interface { + ApplySetting(ctx context.Context, setting *Setting) error + ResetSetting(ctx context.Context) error + Setting() *Setting + + IsClosed() bool + Close() +} + +type Pooled[C Connection] struct { + next atomic.Pointer[Pooled[C]] + timeCreated time.Time + timeUsed time.Time + pool *ConnPool[C] + + Conn C +} + +func (dbc *Pooled[C]) Close() { + dbc.Conn.Close() +} + +func (dbc *Pooled[C]) Recycle() { + switch { + case dbc.pool == nil: + dbc.Conn.Close() + case dbc.Conn.IsClosed(): + dbc.pool.put(nil) + default: + dbc.pool.put(dbc) + } +} + +func (dbc *Pooled[C]) Taint() { + if dbc.pool == nil { + return + } + dbc.pool.put(nil) + dbc.pool = nil +} diff --git a/go/pools/smartconnpool/pool.go b/go/pools/smartconnpool/pool.go new file mode 100644 index 00000000000..7c10d6ba4b0 --- /dev/null +++ b/go/pools/smartconnpool/pool.go @@ -0,0 +1,706 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package smartconnpool + +import ( + "context" + "slices" + "sync" + "sync/atomic" + "time" + + "vitess.io/vitess/go/hack" + "vitess.io/vitess/go/vt/log" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/servenv" + "vitess.io/vitess/go/vt/vterrors" +) + +var ( + // ErrTimeout is returned if a connection get times out. + ErrTimeout = vterrors.New(vtrpcpb.Code_RESOURCE_EXHAUSTED, "resource pool timed out") + + // ErrCtxTimeout is returned if a ctx is already expired by the time the connection pool is used + ErrCtxTimeout = vterrors.New(vtrpcpb.Code_DEADLINE_EXCEEDED, "resource pool context already expired") +) + +type Metrics struct { + maxLifetimeClosed atomic.Int64 + getCount atomic.Int64 + getWithSettingsCount atomic.Int64 + waitCount atomic.Int64 + waitTime atomic.Int64 + idleClosed atomic.Int64 + diffSetting atomic.Int64 + resetSetting atomic.Int64 +} + +func (m *Metrics) MaxLifetimeClosed() int64 { + return m.maxLifetimeClosed.Load() +} + +func (m *Metrics) GetCount() int64 { + return m.getCount.Load() +} + +func (m *Metrics) GetSettingCount() int64 { + return m.getWithSettingsCount.Load() +} + +func (m *Metrics) WaitCount() int64 { + return m.waitCount.Load() +} + +func (m *Metrics) WaitTime() time.Duration { + return time.Duration(m.waitTime.Load()) +} + +func (m *Metrics) IdleClosed() int64 { + return m.idleClosed.Load() +} + +func (m *Metrics) DiffSettingCount() int64 { + return m.diffSetting.Load() +} + +func (m *Metrics) ResetSettingCount() int64 { + return m.resetSetting.Load() +} + +type Connector[C Connection] func(ctx context.Context) (C, error) +type RefreshCheck func() (bool, error) + +type Config[C Connection] struct { + Capacity int64 + IdleTimeout time.Duration + MaxLifetime time.Duration + RefreshInterval time.Duration + LogWait func(time.Time) +} + +// stackMask is the number of connection setting stacks minus one; +// the number of stacks must always be a power of two +const stackMask = 7 + +// ConnPool is a connection pool for generic connections +type ConnPool[C Connection] struct { + // clean is a connections stack for connections with no Setting applied + clean connStack[C] + // settings are N connection stacks for connections with a Setting applied + // connections are distributed between stacks based on their Setting.bucket + settings [stackMask + 1]connStack[C] + // freshSettingStack is the index in settings to the last stack when a connection + // was pushed, or -1 if no connection with a Setting has been opened in this pool + freshSettingsStack atomic.Int64 + // wait is the list of clients waiting for a connection to be returned to the pool + wait waitlist[C] + + // borrowed is the number of connections that the pool has given out to clients + // and that haven't been returned yet + borrowed atomic.Int64 + // active is the number of connections that the pool has opened; this includes connections + // in the pool and borrowed by clients + active atomic.Int64 + // capacity is the maximum number of connections that this pool can open + capacity atomic.Int64 + + // workers is a waitgroup for all the currently running worker goroutines + workers sync.WaitGroup + close chan struct{} + + config struct { + // connect is the callback to create a new connection for the pool + connect Connector[C] + // refresh is the callback to check whether the pool needs to be refreshed + refresh RefreshCheck + + // maxCapacity is the maximum value to which capacity can be set; when the pool + // is re-opened, it defaults to this capacity + maxCapacity int64 + // maxLifetime is the maximum time a connection can be open + maxLifetime atomic.Int64 + // idleTimeout is the maximum time a connection can remain idle + idleTimeout atomic.Int64 + // refreshInterval is how often to call the refresh check + refreshInterval atomic.Int64 + // logWait is called every time a client must block waiting for a connection + logWait func(time.Time) + } + + Metrics Metrics +} + +// NewPool creates a new connection pool with the given Config. +// The pool must be ConnPool.Open before it can start giving out connections +func NewPool[C Connection](config *Config[C]) *ConnPool[C] { + pool := &ConnPool[C]{} + pool.freshSettingsStack.Store(-1) + pool.config.maxCapacity = config.Capacity + pool.config.maxLifetime.Store(config.MaxLifetime.Nanoseconds()) + pool.config.idleTimeout.Store(config.IdleTimeout.Nanoseconds()) + pool.config.refreshInterval.Store(config.RefreshInterval.Nanoseconds()) + pool.config.logWait = config.LogWait + pool.wait.init() + + return pool +} + +func (pool *ConnPool[C]) runWorker(close <-chan struct{}, interval time.Duration, worker func(now time.Time) bool) { + pool.workers.Add(1) + + go func() { + tick := time.NewTicker(interval) + + defer tick.Stop() + defer pool.workers.Done() + + for { + select { + case now := <-tick.C: + if !worker(now) { + return + } + case <-close: + return + } + } + }() +} + +func (pool *ConnPool[C]) open() { + pool.close = make(chan struct{}) + pool.capacity.Store(pool.config.maxCapacity) + + // The expire worker takes care of removing from the waiter list any clients whose + // context has been cancelled. + pool.runWorker(pool.close, 1*time.Second, func(_ time.Time) bool { + pool.wait.expire(false) + return true + }) + + idleTimeout := pool.IdleTimeout() + if idleTimeout != 0 { + // The idle worker takes care of closing connections that have been idle too long + pool.runWorker(pool.close, idleTimeout/10, func(now time.Time) bool { + pool.closeIdleResources(now) + return true + }) + } + + refreshInterval := pool.RefreshInterval() + if refreshInterval != 0 && pool.config.refresh != nil { + // The refresh worker periodically checks the refresh callback in this pool + // to decide whether all the connections in the pool need to be cycled + // (this usually only happens when there's a global DNS change). + pool.runWorker(pool.close, refreshInterval, func(_ time.Time) bool { + refresh, err := pool.config.refresh() + if err != nil { + log.Error(err) + } + if refresh { + go pool.reopen() + return false + } + return true + }) + } +} + +// Open starts the background workers that manage the pool and gets it ready +// to start serving out connections. +func (pool *ConnPool[C]) Open(connect Connector[C], refresh RefreshCheck) *ConnPool[C] { + if pool.close != nil { + // already open + return pool + } + + pool.config.connect = connect + pool.config.refresh = refresh + pool.open() + return pool +} + +// Close shuts down the pool. No connections will be returned from ConnPool.Get after calling this, +// but calling ConnPool.Put is still allowed. This function will not return until all of the pool's +// connections have been returned. +func (pool *ConnPool[C]) Close() { + if pool.close == nil { + // already closed + return + } + + pool.SetCapacity(0) + + close(pool.close) + pool.workers.Wait() + pool.close = nil +} + +func (pool *ConnPool[C]) reopen() { + capacity := pool.capacity.Load() + if capacity == 0 { + return + } + + pool.Close() + pool.open() + pool.SetCapacity(capacity) +} + +// IsOpen returns whether the pool is open +func (pool *ConnPool[C]) IsOpen() bool { + return pool.close != nil +} + +// Capacity returns the maximum amount of connections that this pool can maintain open +func (pool *ConnPool[C]) Capacity() int64 { + return pool.capacity.Load() +} + +// MaxCapacity returns the maximum value to which Capacity can be set via ConnPool.SetCapacity +func (pool *ConnPool[C]) MaxCapacity() int64 { + return pool.config.maxCapacity +} + +// InUse returns the number of connections that the pool has lent out to clients and that +// haven't been returned yet. +func (pool *ConnPool[C]) InUse() int64 { + return pool.borrowed.Load() +} + +// Available returns the number of connections that the pool can immediately lend out to +// clients without blocking. +func (pool *ConnPool[C]) Available() int64 { + return pool.capacity.Load() - pool.borrowed.Load() +} + +// Active returns the numer of connections that the pool has currently open. +func (pool *ConnPool[C]) Active() int64 { + return pool.active.Load() +} + +func (pool *ConnPool[D]) IdleTimeout() time.Duration { + return time.Duration(pool.config.idleTimeout.Load()) +} + +func (pool *ConnPool[C]) SetIdleTimeout(duration time.Duration) { + pool.config.idleTimeout.Store(duration.Nanoseconds()) +} + +func (pool *ConnPool[D]) RefreshInterval() time.Duration { + return time.Duration(pool.config.refreshInterval.Load()) +} + +func (pool *ConnPool[C]) recordWait(start time.Time) { + pool.Metrics.waitCount.Add(1) + pool.Metrics.waitTime.Add(time.Since(start).Nanoseconds()) + if pool.config.logWait != nil { + pool.config.logWait(start) + } +} + +// Get returns a connection from the pool with the given Setting applied. +// If there are no connections in the pool to be returned, Get blocks until one +// is returned, or until the given ctx is cancelled. +// The connection must be returned to the pool once it's not needed by calling Pooled.Recycle +func (pool *ConnPool[C]) Get(ctx context.Context, setting *Setting) (*Pooled[C], error) { + if ctx.Err() != nil { + return nil, ErrCtxTimeout + } + if pool.capacity.Load() == 0 { + return nil, ErrTimeout + } + if setting == nil { + return pool.get(ctx) + } + return pool.getWithSetting(ctx, setting) +} + +// put returns a connection to the pool. This is a private API. +// Return connections to the pool by calling Pooled.Recycle +func (pool *ConnPool[C]) put(conn *Pooled[C]) { + pool.borrowed.Add(-1) + + if conn == nil { + var err error + conn, err = pool.connNew(context.Background()) + if err != nil { + pool.closedConn() + return + } + } else { + conn.timeUsed = time.Now() + + lifetime := pool.extendedMaxLifetime() + if lifetime > 0 && time.Until(conn.timeCreated.Add(lifetime)) < 0 { + pool.Metrics.maxLifetimeClosed.Add(1) + conn.Close() + if err := pool.connReopen(context.Background(), conn, conn.timeUsed); err != nil { + pool.closedConn() + return + } + } + } + + if !pool.wait.tryReturnConn(conn) { + connSetting := conn.Conn.Setting() + if connSetting == nil { + pool.clean.Push(conn) + } else { + stack := connSetting.bucket & stackMask + pool.settings[stack].Push(conn) + pool.freshSettingsStack.Store(int64(stack)) + } + } +} + +func (pool *ConnPool[D]) extendedMaxLifetime() time.Duration { + maxLifetime := pool.config.maxLifetime.Load() + if maxLifetime == 0 { + return 0 + } + extended := hack.FastRand() % uint32(maxLifetime) + return time.Duration(maxLifetime) + time.Duration(extended) +} + +func (pool *ConnPool[C]) connReopen(ctx context.Context, dbconn *Pooled[C], now time.Time) error { + var err error + dbconn.Conn, err = pool.config.connect(ctx) + if err != nil { + return err + } + + dbconn.timeUsed = now + dbconn.timeCreated = now + return nil +} + +func (pool *ConnPool[C]) connNew(ctx context.Context) (*Pooled[C], error) { + conn, err := pool.config.connect(ctx) + if err != nil { + return nil, err + } + now := time.Now() + return &Pooled[C]{ + timeCreated: now, + timeUsed: now, + pool: pool, + Conn: conn, + }, nil +} + +func (pool *ConnPool[C]) getFromSettingsStack(setting *Setting) *Pooled[C] { + fresh := pool.freshSettingsStack.Load() + if fresh < 0 { + return nil + } + + var start uint32 + if setting == nil { + start = uint32(fresh) + } else { + start = setting.bucket + } + + for i := uint32(0); i <= stackMask; i++ { + pos := (i + start) & stackMask + if conn, ok := pool.settings[pos].Pop(); ok { + return conn + } + } + return nil +} + +func (pool *ConnPool[C]) closedConn() { + _ = pool.active.Add(-1) +} + +func (pool *ConnPool[C]) getNew(ctx context.Context) (*Pooled[C], error) { + for { + open := pool.active.Load() + if open >= pool.capacity.Load() { + return nil, nil + } + + if pool.active.CompareAndSwap(open, open+1) { + conn, err := pool.connNew(ctx) + if err != nil { + pool.closedConn() + return nil, err + } + return conn, nil + } + } +} + +// get returns a pooled connection with no Setting applied +func (pool *ConnPool[C]) get(ctx context.Context) (*Pooled[C], error) { + pool.Metrics.getCount.Add(1) + + // best case: if there's a connection in the clean stack, return it right away + if conn, ok := pool.clean.Pop(); ok { + pool.borrowed.Add(1) + return conn, nil + } + + // check if we have enough capacity to open a brand-new connection to return + conn, err := pool.getNew(ctx) + if err != nil { + return nil, err + } + // if we don't have capacity, try popping a connection from any of the setting stacks + if conn == nil { + conn = pool.getFromSettingsStack(nil) + } + // if there are no connections in the setting stacks and we've lent out connections + // to other clients, wait until one of the connections is returned + if conn == nil { + start := time.Now() + conn, err = pool.wait.waitForConn(ctx, nil) + if err != nil { + return nil, ErrTimeout + } + pool.recordWait(start) + } + // no connections available and no connections to wait for (pool is closed) + if conn == nil { + return nil, ErrTimeout + } + + // if the connection we've acquired has a Setting applied, we must reset it before returning + if conn.Conn.Setting() != nil { + pool.Metrics.resetSetting.Add(1) + + err = conn.Conn.ResetSetting(ctx) + if err != nil { + conn.Close() + err = pool.connReopen(ctx, conn, time.Now()) + if err != nil { + pool.closedConn() + return nil, err + } + } + } + + pool.borrowed.Add(1) + return conn, nil +} + +// getWithSetting returns a connection from the pool with the given Setting applied +func (pool *ConnPool[C]) getWithSetting(ctx context.Context, setting *Setting) (*Pooled[C], error) { + pool.Metrics.getWithSettingsCount.Add(1) + + var err error + // best case: check if there's a connection in the setting stack where our Setting belongs + conn, _ := pool.settings[setting.bucket&stackMask].Pop() + // if there's connection with our setting, try popping a clean connection + if conn == nil { + conn, _ = pool.clean.Pop() + } + // otherwise try opening a brand new connection and we'll apply the setting to it + if conn == nil { + conn, err = pool.getNew(ctx) + if err != nil { + return nil, err + } + } + // try on the _other_ setting stacks, even if we have to reset the Setting for the returned + // connection + if conn == nil { + conn = pool.getFromSettingsStack(setting) + } + // no connections anywhere in the pool; if we've lent out connections to other clients + // wait for one of them + if conn == nil { + start := time.Now() + conn, err = pool.wait.waitForConn(ctx, setting) + if err != nil { + return nil, ErrTimeout + } + pool.recordWait(start) + } + // no connections available and no connections to wait for (pool is closed) + if conn == nil { + return nil, ErrTimeout + } + + // ensure that the setting applied to the connection matches the one we want + connSetting := conn.Conn.Setting() + if connSetting != setting { + // if there's another setting applied, reset it before applying our setting + if connSetting != nil { + pool.Metrics.diffSetting.Add(1) + + err = conn.Conn.ResetSetting(ctx) + if err != nil { + conn.Close() + err = pool.connReopen(ctx, conn, time.Now()) + if err != nil { + pool.closedConn() + return nil, err + } + } + } + // apply our setting now; if we can't we assume that the conn is broken + // and close it without returning to the pool + if err := conn.Conn.ApplySetting(ctx, setting); err != nil { + conn.Close() + pool.closedConn() + return nil, err + } + } + + pool.borrowed.Add(1) + return conn, nil +} + +// SetCapacity changes the capacity (number of open connections) on the pool. +// If the capacity is smaller than the number of connections that there are +// currently open, we'll close enough connections before returning, even if +// that means waiting for clients to return connections to the pool. +func (pool *ConnPool[C]) SetCapacity(newcap int64) { + if newcap < 0 { + panic("negative capacity") + } + + oldcap := pool.capacity.Swap(newcap) + if oldcap == newcap { + return + } + + backoff := 1 * time.Millisecond + + // close connections until we're under capacity + for pool.active.Load() > newcap { + // try closing from connections which are currently idle in the stacks + conn := pool.getFromSettingsStack(nil) + if conn == nil { + conn, _ = pool.clean.Pop() + } + if conn == nil { + time.Sleep(backoff) + backoff += 1 * time.Millisecond + continue + } + conn.Close() + pool.closedConn() + } + + // if we're closing down the pool, wake up any blocked waiters because no connections + // are going to be returned in the future + if newcap == 0 { + pool.wait.expire(true) + } +} + +func (pool *ConnPool[C]) closeIdleResources(now time.Time) { + timeout := pool.IdleTimeout() + if timeout == 0 { + return + } + if pool.Capacity() == 0 { + return + } + + var conns []*Pooled[C] + + closeInStack := func(s *connStack[C]) { + conns = s.PopAll(conns[:0]) + slices.Reverse(conns) + + for _, conn := range conns { + if conn.timeUsed.Add(timeout).Sub(now) < 0 { + pool.Metrics.idleClosed.Add(1) + conn.Close() + pool.closedConn() + continue + } + + s.Push(conn) + } + } + + for i := 0; i <= stackMask; i++ { + closeInStack(&pool.settings[i]) + } + closeInStack(&pool.clean) +} + +func (pool *ConnPool[C]) StatsJSON() map[string]any { + return map[string]any{ + "Capacity": int(pool.Capacity()), + "Available": int(pool.Available()), + "Active": int(pool.active.Load()), + "InUse": int(pool.InUse()), + "WaitCount": int(pool.Metrics.WaitCount()), + "WaitTime": pool.Metrics.WaitTime(), + "IdleTimeout": pool.IdleTimeout(), + "IdleClosed": int(pool.Metrics.IdleClosed()), + "MaxLifetimeClosed": int(pool.Metrics.MaxLifetimeClosed()), + } +} + +// RegisterStats registers this pool's metrics into a stats Exporter +func (pool *ConnPool[C]) RegisterStats(stats *servenv.Exporter, name string) { + if stats == nil || name == "" { + return + } + + stats.NewGaugeFunc(name+"Capacity", "Tablet server conn pool capacity", func() int64 { + return pool.Capacity() + }) + stats.NewGaugeFunc(name+"Available", "Tablet server conn pool available", func() int64 { + return pool.Available() + }) + stats.NewGaugeFunc(name+"Active", "Tablet server conn pool active", func() int64 { + return pool.Active() + }) + stats.NewGaugeFunc(name+"InUse", "Tablet server conn pool in use", func() int64 { + return pool.InUse() + }) + stats.NewGaugeFunc(name+"MaxCap", "Tablet server conn pool max cap", func() int64 { + // the smartconnpool doesn't have a maximum capacity + return pool.Capacity() + }) + stats.NewCounterFunc(name+"WaitCount", "Tablet server conn pool wait count", func() int64 { + return pool.Metrics.WaitCount() + }) + stats.NewCounterDurationFunc(name+"WaitTime", "Tablet server wait time", func() time.Duration { + return pool.Metrics.WaitTime() + }) + stats.NewGaugeDurationFunc(name+"IdleTimeout", "Tablet server idle timeout", func() time.Duration { + return pool.IdleTimeout() + }) + stats.NewCounterFunc(name+"IdleClosed", "Tablet server conn pool idle closed", func() int64 { + return pool.Metrics.IdleClosed() + }) + stats.NewCounterFunc(name+"MaxLifetimeClosed", "Tablet server conn pool refresh closed", func() int64 { + return pool.Metrics.MaxLifetimeClosed() + }) + stats.NewCounterFunc(name+"Get", "Tablet server conn pool get count", func() int64 { + return pool.Metrics.GetCount() + }) + stats.NewCounterFunc(name+"GetSetting", "Tablet server conn pool get with setting count", func() int64 { + return pool.Metrics.GetSettingCount() + }) + stats.NewCounterFunc(name+"DiffSetting", "Number of times pool applied different setting", func() int64 { + return pool.Metrics.DiffSettingCount() + }) + stats.NewCounterFunc(name+"ResetSetting", "Number of times pool reset the setting", func() int64 { + return pool.Metrics.ResetSettingCount() + }) +} diff --git a/go/pools/smartconnpool/pool_test.go b/go/pools/smartconnpool/pool_test.go new file mode 100644 index 00000000000..c9c2235d90f --- /dev/null +++ b/go/pools/smartconnpool/pool_test.go @@ -0,0 +1,1028 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package smartconnpool + +import ( + "context" + "fmt" + "reflect" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + sFoo = &Setting{queryApply: "set foo=1"} + sBar = &Setting{queryApply: "set bar=1"} +) + +type TestState struct { + lastID, open, close, reset atomic.Int64 + waits []time.Time + + chaos struct { + delayConnect time.Duration + failConnect bool + failApply bool + } +} + +func (ts *TestState) LogWait(start time.Time) { + ts.waits = append(ts.waits, start) +} + +type TestConn struct { + counts *TestState + onClose chan struct{} + + setting *Setting + num int64 + timeCreated time.Time + closed bool + failApply bool +} + +func (tr *TestConn) waitForClose() chan struct{} { + tr.onClose = make(chan struct{}) + return tr.onClose +} + +func (tr *TestConn) IsClosed() bool { + return tr.closed +} + +func (tr *TestConn) Setting() *Setting { + return tr.setting +} + +func (tr *TestConn) ResetSetting(ctx context.Context) error { + tr.counts.reset.Add(1) + tr.setting = nil + return nil +} + +func (tr *TestConn) ApplySetting(ctx context.Context, setting *Setting) error { + if tr.failApply { + return fmt.Errorf("ApplySetting failed") + } + tr.setting = setting + return nil +} + +func (tr *TestConn) Close() { + if !tr.closed { + if tr.onClose != nil { + close(tr.onClose) + } + tr.counts.open.Add(-1) + tr.counts.close.Add(1) + tr.closed = true + } +} + +var _ Connection = (*TestConn)(nil) + +func newConnector(state *TestState) Connector[*TestConn] { + return func(ctx context.Context) (*TestConn, error) { + state.open.Add(1) + if state.chaos.delayConnect != 0 { + time.Sleep(state.chaos.delayConnect) + } + if state.chaos.failConnect { + return nil, fmt.Errorf("failed to connect: forced failure") + } + return &TestConn{ + num: state.lastID.Add(1), + timeCreated: time.Now(), + counts: state, + failApply: state.chaos.failApply, + }, nil + } +} + +func TestOpen(t *testing.T) { + var state TestState + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 5, + IdleTimeout: time.Second, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + var resources [10]*Pooled[*TestConn] + var r *Pooled[*TestConn] + var err error + + // Test Get + for i := 0; i < 5; i++ { + if i%2 == 0 { + r, err = p.Get(ctx, nil) + } else { + r, err = p.Get(ctx, sFoo) + } + require.NoError(t, err) + resources[i] = r + assert.EqualValues(t, 5-i-1, p.Available()) + assert.Zero(t, p.Metrics.WaitCount()) + assert.Zero(t, len(state.waits)) + assert.Zero(t, p.Metrics.WaitTime()) + assert.EqualValues(t, i+1, state.lastID.Load()) + assert.EqualValues(t, i+1, state.open.Load()) + } + + // Test that Get waits + ch := make(chan bool) + go func() { + for i := 0; i < 5; i++ { + if i%2 == 0 { + r, err = p.Get(ctx, nil) + } else { + r, err = p.Get(ctx, sFoo) + } + require.NoError(t, err) + resources[i] = r + } + for i := 0; i < 5; i++ { + p.put(resources[i]) + } + ch <- true + }() + for i := 0; i < 5; i++ { + // Sleep to ensure the goroutine waits + time.Sleep(10 * time.Millisecond) + p.put(resources[i]) + } + <-ch + assert.EqualValues(t, 5, p.Metrics.WaitCount()) + assert.Equal(t, 5, len(state.waits)) + // verify start times are monotonic increasing + for i := 1; i < len(state.waits); i++ { + if state.waits[i].Before(state.waits[i-1]) { + t.Errorf("Expecting monotonic increasing start times") + } + } + assert.NotZero(t, p.Metrics.WaitTime()) + assert.EqualValues(t, 5, state.lastID.Load()) + // Test Close resource + r, err = p.Get(ctx, nil) + require.NoError(t, err) + r.Close() + // A nil Put should cause the resource to be reopened. + p.put(nil) + assert.EqualValues(t, 5, state.open.Load()) + assert.EqualValues(t, 6, state.lastID.Load()) + + for i := 0; i < 5; i++ { + if i%2 == 0 { + r, err = p.Get(ctx, nil) + } else { + r, err = p.Get(ctx, sFoo) + } + require.NoError(t, err) + resources[i] = r + } + for i := 0; i < 5; i++ { + p.put(resources[i]) + } + assert.EqualValues(t, 5, state.open.Load()) + assert.EqualValues(t, 6, state.lastID.Load()) + + // SetCapacity + p.SetCapacity(3) + assert.EqualValues(t, 3, state.open.Load()) + assert.EqualValues(t, 6, state.lastID.Load()) + assert.EqualValues(t, 3, p.Capacity()) + assert.EqualValues(t, 3, p.Available()) + + p.SetCapacity(6) + assert.EqualValues(t, 6, p.Capacity()) + assert.EqualValues(t, 6, p.Available()) + + for i := 0; i < 6; i++ { + if i%2 == 0 { + r, err = p.Get(ctx, nil) + } else { + r, err = p.Get(ctx, sFoo) + } + require.NoError(t, err) + resources[i] = r + } + for i := 0; i < 6; i++ { + p.put(resources[i]) + } + assert.EqualValues(t, 6, state.open.Load()) + assert.EqualValues(t, 9, state.lastID.Load()) + + // Close + p.Close() + assert.EqualValues(t, 0, p.Capacity()) + assert.EqualValues(t, 0, p.Available()) + assert.EqualValues(t, 0, state.open.Load()) +} + +func TestShrinking(t *testing.T) { + var state TestState + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 5, + IdleTimeout: time.Second, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + var resources [10]*Pooled[*TestConn] + // Leave one empty slot in the pool + for i := 0; i < 4; i++ { + var r *Pooled[*TestConn] + var err error + if i%2 == 0 { + r, err = p.Get(ctx, nil) + } else { + r, err = p.Get(ctx, sFoo) + } + require.NoError(t, err) + resources[i] = r + } + done := make(chan bool) + go func() { + p.SetCapacity(3) + done <- true + }() + expected := map[string]any{ + "Capacity": 3, + "Available": -1, // negative because we've borrowed past our capacity + "Active": 4, + "InUse": 4, + "WaitCount": 0, + "WaitTime": time.Duration(0), + "IdleTimeout": 1 * time.Second, + "IdleClosed": 0, + "MaxLifetimeClosed": 0, + } + for i := 0; i < 10; i++ { + time.Sleep(10 * time.Millisecond) + stats := p.StatsJSON() + if reflect.DeepEqual(expected, stats) { + break + } + if i == 9 { + assert.Equal(t, expected, stats) + } + } + // There are already 2 resources available in the pool. + // So, returning one should be enough for SetCapacity to complete. + p.put(resources[3]) + <-done + // Return the rest of the resources + for i := 0; i < 3; i++ { + p.put(resources[i]) + } + stats := p.StatsJSON() + expected = map[string]any{ + "Capacity": 3, + "Available": 3, + "Active": 3, + "InUse": 0, + "WaitCount": 0, + "WaitTime": time.Duration(0), + "IdleTimeout": 1 * time.Second, + "IdleClosed": 0, + "MaxLifetimeClosed": 0, + } + assert.Equal(t, expected, stats) + assert.EqualValues(t, 3, state.open.Load()) + + // Ensure no deadlock if SetCapacity is called after we start + // waiting for a resource + var err error + for i := 0; i < 3; i++ { + var r *Pooled[*TestConn] + if i%2 == 0 { + r, err = p.Get(ctx, nil) + } else { + r, err = p.Get(ctx, sFoo) + } + require.NoError(t, err) + resources[i] = r + } + // This will wait because pool is empty + go func() { + r, err := p.Get(ctx, nil) + require.NoError(t, err) + p.put(r) + done <- true + }() + + // This will also wait + go func() { + p.SetCapacity(2) + done <- true + }() + time.Sleep(10 * time.Millisecond) + + // This should not hang + for i := 0; i < 3; i++ { + p.put(resources[i]) + } + <-done + <-done + assert.EqualValues(t, 2, p.Capacity()) + assert.EqualValues(t, 2, p.Available()) + assert.EqualValues(t, 1, p.Metrics.WaitCount()) + assert.EqualValues(t, p.Metrics.WaitCount(), len(state.waits)) + assert.EqualValues(t, 2, state.open.Load()) + + // Test race condition of SetCapacity with itself + p.SetCapacity(3) + for i := 0; i < 3; i++ { + var r *Pooled[*TestConn] + var err error + if i%2 == 0 { + r, err = p.Get(ctx, nil) + } else { + r, err = p.Get(ctx, sFoo) + } + require.NoError(t, err) + resources[i] = r + } + // This will wait because pool is empty + go func() { + r, err := p.Get(ctx, nil) + require.NoError(t, err) + p.put(r) + done <- true + }() + time.Sleep(10 * time.Millisecond) + + // This will wait till we Put + go p.SetCapacity(2) + time.Sleep(10 * time.Millisecond) + go p.SetCapacity(4) + time.Sleep(10 * time.Millisecond) + + // This should not hang + for i := 0; i < 3; i++ { + p.put(resources[i]) + } + <-done + + assert.Panics(t, func() { + p.SetCapacity(-1) + }) + + assert.EqualValues(t, 4, p.Capacity()) + assert.EqualValues(t, 4, p.Available()) +} + +func TestClosing(t *testing.T) { + var state TestState + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 5, + IdleTimeout: time.Second, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + var resources [10]*Pooled[*TestConn] + for i := 0; i < 5; i++ { + var r *Pooled[*TestConn] + var err error + if i%2 == 0 { + r, err = p.Get(ctx, nil) + } else { + r, err = p.Get(ctx, sFoo) + } + require.NoError(t, err) + resources[i] = r + } + ch := make(chan bool) + go func() { + p.Close() + ch <- true + }() + + // Wait for goroutine to call Close + time.Sleep(10 * time.Millisecond) + stats := p.StatsJSON() + expected := map[string]any{ + "Capacity": 0, + "Available": -5, + "Active": 5, + "InUse": 5, + "WaitCount": 0, + "WaitTime": time.Duration(0), + "IdleTimeout": 1 * time.Second, + "IdleClosed": 0, + "MaxLifetimeClosed": 0, + } + assert.Equal(t, expected, stats) + + // Put is allowed when closing + for i := 0; i < 5; i++ { + p.put(resources[i]) + } + + // Wait for Close to return + <-ch + + stats = p.StatsJSON() + expected = map[string]any{ + "Capacity": 0, + "Available": 0, + "Active": 0, + "InUse": 0, + "WaitCount": 0, + "WaitTime": time.Duration(0), + "IdleTimeout": 1 * time.Second, + "IdleClosed": 0, + "MaxLifetimeClosed": 0, + } + assert.Equal(t, expected, stats) + assert.EqualValues(t, 5, state.lastID.Load()) + assert.EqualValues(t, 0, state.open.Load()) +} + +func TestReopen(t *testing.T) { + var state TestState + var refreshed atomic.Bool + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 5, + IdleTimeout: time.Second, + LogWait: state.LogWait, + RefreshInterval: 500 * time.Millisecond, + }).Open(newConnector(&state), func() (bool, error) { + refreshed.Store(true) + return true, nil + }) + + var resources [10]*Pooled[*TestConn] + for i := 0; i < 5; i++ { + var r *Pooled[*TestConn] + var err error + if i%2 == 0 { + r, err = p.Get(ctx, nil) + } else { + r, err = p.Get(ctx, sFoo) + } + require.NoError(t, err) + resources[i] = r + } + + time.Sleep(10 * time.Millisecond) + stats := p.StatsJSON() + expected := map[string]any{ + "Capacity": 5, + "Available": 0, + "Active": 5, + "InUse": 5, + "WaitCount": 0, + "WaitTime": time.Duration(0), + "IdleTimeout": 1 * time.Second, + "IdleClosed": 0, + "MaxLifetimeClosed": 0, + } + assert.Equal(t, expected, stats) + + time.Sleep(1 * time.Second) + assert.Truef(t, refreshed.Load(), "did not refresh") + + for i := 0; i < 5; i++ { + p.put(resources[i]) + } + time.Sleep(50 * time.Millisecond) + stats = p.StatsJSON() + expected = map[string]any{ + "Capacity": 5, + "Available": 5, + "Active": 0, + "InUse": 0, + "WaitCount": 0, + "WaitTime": time.Duration(0), + "IdleTimeout": 1 * time.Second, + "IdleClosed": 0, + "MaxLifetimeClosed": 0, + } + assert.Equal(t, expected, stats) + assert.EqualValues(t, 5, state.lastID.Load()) + assert.EqualValues(t, 0, state.open.Load()) +} + +func TestIdleTimeout(t *testing.T) { + testTimeout := func(t *testing.T, setting *Setting) { + var state TestState + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 5, + IdleTimeout: 10 * time.Millisecond, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + defer p.Close() + + var conns []*Pooled[*TestConn] + for i := 0; i < 5; i++ { + r, err := p.Get(ctx, setting) + require.NoError(t, err) + assert.EqualValues(t, i+1, state.open.Load()) + assert.EqualValues(t, 0, p.Metrics.IdleClosed()) + + conns = append(conns, r) + } + + // wait a long while; ensure that none of the conns have been closed + time.Sleep(1 * time.Second) + + var closers []chan struct{} + for _, conn := range conns { + assert.Falsef(t, conn.Conn.IsClosed(), "connection was idle-closed while outside the pool") + closers = append(closers, conn.Conn.waitForClose()) + p.put(conn) + } + + for _, closed := range closers { + <-closed + } + + // no need to assert anything: all the connections in the pool should are idle-closed + // now and if they're not the test will timeout and fail + } + + t.Run("WithoutSettings", func(t *testing.T) { testTimeout(t, nil) }) + t.Run("WithSettings", func(t *testing.T) { testTimeout(t, sFoo) }) +} + +func TestIdleTimeoutCreateFail(t *testing.T) { + var state TestState + var connector = newConnector(&state) + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 1, + IdleTimeout: 10 * time.Millisecond, + LogWait: state.LogWait, + }).Open(connector, nil) + + defer p.Close() + + for _, setting := range []*Setting{nil, sFoo} { + r, err := p.Get(ctx, setting) + require.NoError(t, err) + // Change the factory before putting back + // to prevent race with the idle closer, who will + // try to use it. + state.chaos.failConnect = true + p.put(r) + timeout := time.After(1 * time.Second) + for p.Active() != 0 { + select { + case <-timeout: + t.Errorf("Timed out waiting for resource to be closed by idle timeout") + default: + } + } + // reset factory for next run. + state.chaos.failConnect = false + } +} + +func TestMaxLifetime(t *testing.T) { + var state TestState + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 1, + IdleTimeout: 10 * time.Second, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + r, err := p.Get(ctx, nil) + require.NoError(t, err) + assert.EqualValues(t, 1, state.open.Load()) + assert.EqualValues(t, 0, p.Metrics.MaxLifetimeClosed()) + + time.Sleep(10 * time.Millisecond) + + p.put(r) + assert.EqualValues(t, 1, state.lastID.Load()) + assert.EqualValues(t, 1, state.open.Load()) + assert.EqualValues(t, 0, p.Metrics.MaxLifetimeClosed()) + + p.Close() + + // maxLifetime > 0 + state.lastID.Store(0) + state.open.Store(0) + + p = NewPool(&Config[*TestConn]{ + Capacity: 1, + IdleTimeout: 10 * time.Second, + MaxLifetime: 10 * time.Millisecond, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + r, err = p.Get(ctx, nil) + require.NoError(t, err) + assert.EqualValues(t, 1, state.open.Load()) + assert.EqualValues(t, 0, p.Metrics.MaxLifetimeClosed()) + + time.Sleep(5 * time.Millisecond) + + p.put(r) + assert.EqualValues(t, 1, state.lastID.Load()) + assert.EqualValues(t, 1, state.open.Load()) + assert.EqualValues(t, 0, p.Metrics.MaxLifetimeClosed()) + + r, err = p.Get(ctx, nil) + require.NoError(t, err) + assert.EqualValues(t, 1, state.open.Load()) + assert.EqualValues(t, 0, p.Metrics.MaxLifetimeClosed()) + + time.Sleep(10 * time.Millisecond * 2) + + p.put(r) + assert.EqualValues(t, 2, state.lastID.Load()) + assert.EqualValues(t, 1, state.open.Load()) + assert.EqualValues(t, 1, p.Metrics.MaxLifetimeClosed()) +} + +func TestExtendedLifetimeTimeout(t *testing.T) { + var state TestState + var connector = newConnector(&state) + var config = &Config[*TestConn]{ + Capacity: 1, + IdleTimeout: time.Second, + MaxLifetime: 0, + LogWait: state.LogWait, + } + + // maxLifetime 0 + p := NewPool(config).Open(connector, nil) + assert.Zero(t, p.extendedMaxLifetime()) + p.Close() + + // maxLifetime > 0 + config.MaxLifetime = 10 * time.Millisecond + for i := 0; i < 10; i++ { + p = NewPool(config).Open(connector, nil) + assert.LessOrEqual(t, config.MaxLifetime, p.extendedMaxLifetime()) + assert.Greater(t, 2*config.MaxLifetime, p.extendedMaxLifetime()) + p.Close() + } +} + +func TestCreateFail(t *testing.T) { + var state TestState + state.chaos.failConnect = true + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 5, + IdleTimeout: time.Second, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + for _, setting := range []*Setting{nil, sFoo} { + if _, err := p.Get(ctx, setting); err.Error() != "failed to connect: forced failure" { + t.Errorf("Expecting Failed, received %v", err) + } + stats := p.StatsJSON() + expected := map[string]any{ + "Capacity": 5, + "Available": 5, + "Active": 0, + "InUse": 0, + "WaitCount": 0, + "WaitTime": time.Duration(0), + "IdleTimeout": 1 * time.Second, + "IdleClosed": 0, + "MaxLifetimeClosed": 0, + } + assert.Equal(t, expected, stats) + } +} + +func TestCreateFailOnPut(t *testing.T) { + var state TestState + var connector = newConnector(&state) + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 5, + IdleTimeout: time.Second, + LogWait: state.LogWait, + }).Open(connector, nil) + + defer p.Close() + + for _, setting := range []*Setting{nil, sFoo} { + _, err := p.Get(ctx, setting) + require.NoError(t, err) + + // change factory to fail the put. + state.chaos.failConnect = true + p.put(nil) + assert.Zero(t, p.Active()) + + // change back for next iteration. + state.chaos.failConnect = false + } +} + +func TestSlowCreateFail(t *testing.T) { + var state TestState + state.chaos.delayConnect = 10 * time.Millisecond + + ctx := context.Background() + ch := make(chan *Pooled[*TestConn]) + + for _, setting := range []*Setting{nil, sFoo} { + p := NewPool(&Config[*TestConn]{ + Capacity: 2, + IdleTimeout: time.Second, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + state.chaos.failConnect = true + + for i := 0; i < 3; i++ { + go func() { + conn, _ := p.Get(ctx, setting) + ch <- conn + }() + } + assert.Nil(t, <-ch) + assert.Nil(t, <-ch) + assert.Equalf(t, p.Capacity(), int64(2), "pool should not be out of capacity") + assert.Equalf(t, p.Available(), int64(2), "pool should not be out of availability") + + select { + case <-ch: + assert.Fail(t, "there should be no capacity for a third connection") + default: + } + + state.chaos.failConnect = false + conn, err := p.Get(ctx, setting) + require.NoError(t, err) + + p.put(conn) + conn = <-ch + assert.NotNil(t, conn) + p.put(conn) + p.Close() + } +} + +func TestTimeout(t *testing.T) { + var state TestState + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 1, + IdleTimeout: time.Second, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + defer p.Close() + + // take the only connection available + r, err := p.Get(ctx, nil) + require.NoError(t, err) + + for _, setting := range []*Setting{nil, sFoo} { + // trying to get the connection without a timeout. + newctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) + _, err = p.Get(newctx, setting) + cancel() + assert.EqualError(t, err, "resource pool timed out") + + } + + // put the connection take was taken initially. + p.put(r) +} + +func TestExpired(t *testing.T) { + var state TestState + + p := NewPool(&Config[*TestConn]{ + Capacity: 1, + IdleTimeout: time.Second, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + defer p.Close() + + for _, setting := range []*Setting{nil, sFoo} { + // expired context + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Second)) + _, err := p.Get(ctx, setting) + cancel() + require.EqualError(t, err, "resource pool context already expired") + } +} + +func TestMultiSettings(t *testing.T) { + var state TestState + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 5, + IdleTimeout: time.Second, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + var resources [10]*Pooled[*TestConn] + var r *Pooled[*TestConn] + var err error + + settings := []*Setting{nil, sFoo, sBar, sBar, sFoo} + + // Test Get + for i := 0; i < 5; i++ { + r, err = p.Get(ctx, settings[i]) + require.NoError(t, err) + resources[i] = r + assert.EqualValues(t, 5-i-1, p.Available()) + assert.Zero(t, p.Metrics.WaitCount()) + assert.Zero(t, len(state.waits)) + assert.Zero(t, p.Metrics.WaitTime()) + assert.EqualValues(t, i+1, state.lastID.Load()) + assert.EqualValues(t, i+1, state.open.Load()) + } + + // Test that Get waits + ch := make(chan bool) + go func() { + for i := 0; i < 5; i++ { + r, err = p.Get(ctx, settings[i]) + require.NoError(t, err) + resources[i] = r + } + for i := 0; i < 5; i++ { + p.put(resources[i]) + } + ch <- true + }() + for i := 0; i < 5; i++ { + // Sleep to ensure the goroutine waits + time.Sleep(10 * time.Millisecond) + p.put(resources[i]) + } + <-ch + assert.EqualValues(t, 5, p.Metrics.WaitCount()) + assert.Equal(t, 5, len(state.waits)) + // verify start times are monotonic increasing + for i := 1; i < len(state.waits); i++ { + if state.waits[i].Before(state.waits[i-1]) { + t.Errorf("Expecting monotonic increasing start times") + } + } + assert.NotZero(t, p.Metrics.WaitTime()) + assert.EqualValues(t, 5, state.lastID.Load()) + + // Close + p.Close() + assert.EqualValues(t, 0, p.Capacity()) + assert.EqualValues(t, 0, p.Available()) + assert.EqualValues(t, 0, state.open.Load()) +} + +func TestMultiSettingsWithReset(t *testing.T) { + var state TestState + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 5, + IdleTimeout: time.Second, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + var resources [10]*Pooled[*TestConn] + var r *Pooled[*TestConn] + var err error + + settings := []*Setting{nil, sFoo, sBar, sBar, sFoo} + + // Test Get + for i := 0; i < 5; i++ { + r, err = p.Get(ctx, settings[i]) + require.NoError(t, err) + resources[i] = r + assert.EqualValues(t, 5-i-1, p.Available()) + assert.EqualValues(t, i+1, state.lastID.Load()) + assert.EqualValues(t, i+1, state.open.Load()) + } + + // Put all of them back + for i := 0; i < 5; i++ { + p.put(resources[i]) + } + + // Getting all with same setting. + for i := 0; i < 5; i++ { + r, err = p.Get(ctx, settings[1]) // {foo} + require.NoError(t, err) + assert.Truef(t, r.Conn.setting == settings[1], "setting was not properly applied") + resources[i] = r + } + assert.EqualValues(t, 2, state.reset.Load()) // when setting was {bar} and getting for {foo} + assert.EqualValues(t, 0, p.Available()) + assert.EqualValues(t, 5, state.lastID.Load()) + assert.EqualValues(t, 5, state.open.Load()) + + for i := 0; i < 5; i++ { + p.put(resources[i]) + } + + // Close + p.Close() + assert.EqualValues(t, 0, p.Capacity()) + assert.EqualValues(t, 0, p.Available()) + assert.EqualValues(t, 0, state.open.Load()) +} + +func TestApplySettingsFailure(t *testing.T) { + var state TestState + + ctx := context.Background() + p := NewPool(&Config[*TestConn]{ + Capacity: 5, + IdleTimeout: time.Second, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + var resources []*Pooled[*TestConn] + var r *Pooled[*TestConn] + var err error + + settings := []*Setting{nil, sFoo, sBar, sBar, sFoo} + // get the resource and mark for failure + for i := 0; i < 5; i++ { + r, err = p.Get(ctx, settings[i]) + require.NoError(t, err) + r.Conn.failApply = true + resources = append(resources, r) + } + // put them back + for _, r = range resources { + p.put(r) + } + + // any new connection created will fail to apply setting + state.chaos.failApply = true + + // Get the resource with "foo" setting + // For an applied connection if the setting are same it will be returned as-is. + // Otherwise, will fail to get the resource. + var failCount int + resources = nil + for i := 0; i < 5; i++ { + r, err = p.Get(ctx, settings[1]) + if err != nil { + failCount++ + assert.EqualError(t, err, "ApplySetting failed") + continue + } + resources = append(resources, r) + } + // put them back + for _, r = range resources { + p.put(r) + } + require.Equal(t, 3, failCount) + + // should be able to get all the resource with no setting + resources = nil + for i := 0; i < 5; i++ { + r, err = p.Get(ctx, nil) + require.NoError(t, err) + resources = append(resources, r) + } + // put them back + for _, r = range resources { + p.put(r) + } +} diff --git a/go/pools/smartconnpool/sema.s b/go/pools/smartconnpool/sema.s new file mode 100644 index 00000000000..e69de29bb2d diff --git a/go/pools/smartconnpool/sema_norace.go b/go/pools/smartconnpool/sema_norace.go new file mode 100644 index 00000000000..63afe8082c1 --- /dev/null +++ b/go/pools/smartconnpool/sema_norace.go @@ -0,0 +1,40 @@ +//go:build !race + +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package smartconnpool + +import _ "unsafe" + +//go:linkname sync_runtime_Semacquire sync.runtime_Semacquire +func sync_runtime_Semacquire(addr *uint32) + +//go:linkname sync_runtime_Semrelease sync.runtime_Semrelease +func sync_runtime_Semrelease(addr *uint32, handoff bool, skipframes int) + +// semaphore is a single-use synchronization primitive that allows a Goroutine +// to wait until signaled. We use the Go runtime's internal implementation. +type semaphore struct { + f uint32 +} + +func (s *semaphore) wait() { + sync_runtime_Semacquire(&s.f) +} +func (s *semaphore) notify(handoff bool) { + sync_runtime_Semrelease(&s.f, handoff, 0) +} diff --git a/go/pools/smartconnpool/sema_race.go b/go/pools/smartconnpool/sema_race.go new file mode 100644 index 00000000000..a31cfaa85c5 --- /dev/null +++ b/go/pools/smartconnpool/sema_race.go @@ -0,0 +1,42 @@ +//go:build race + +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package smartconnpool + +import ( + "sync/atomic" + "time" +) + +// semaphore is a slow implementation of a single-use synchronization primitive. +// We use this inefficient implementation when running under the race detector +// because the detector doesn't understand the synchronization performed by the +// runtime's semaphore. +type semaphore struct { + b atomic.Bool +} + +func (s *semaphore) wait() { + for !s.b.CompareAndSwap(true, false) { + time.Sleep(time.Millisecond) + } +} + +func (s *semaphore) notify(_ bool) { + s.b.Store(true) +} diff --git a/go/pools/smartconnpool/settings.go b/go/pools/smartconnpool/settings.go new file mode 100644 index 00000000000..3ab2350aad6 --- /dev/null +++ b/go/pools/smartconnpool/settings.go @@ -0,0 +1,45 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package smartconnpool + +import ( + "sync/atomic" +) + +// Setting is a setting applied to a connection in this pool. +// Setting values must be interned for optimal usage (i.e. a Setting +// that represents a specific set of SQL connection settings should +// always have the same pointer value). +type Setting struct { + queryApply string + queryReset string + bucket uint32 +} + +func (s *Setting) ApplyQuery() string { + return s.queryApply +} + +func (s *Setting) ResetQuery() string { + return s.queryReset +} + +var globalSettingsCounter atomic.Uint32 + +func NewSetting(apply, reset string) *Setting { + return &Setting{apply, reset, globalSettingsCounter.Add(1)} +} diff --git a/go/pools/smartconnpool/stack.go b/go/pools/smartconnpool/stack.go new file mode 100644 index 00000000000..ea7ae50201e --- /dev/null +++ b/go/pools/smartconnpool/stack.go @@ -0,0 +1,77 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package smartconnpool + +import ( + "runtime" + + "vitess.io/vitess/go/atomic2" +) + +// connStack is a lock-free stack for Connection objects. It is safe to +// use from several goroutines. +type connStack[C Connection] struct { + top atomic2.PointerAndUint64[Pooled[C]] +} + +func (s *connStack[C]) Push(item *Pooled[C]) { + for { + oldHead, popCount := s.top.Load() + item.next.Store(oldHead) + if s.top.CompareAndSwap(oldHead, popCount, item, popCount) { + return + } + runtime.Gosched() + } +} + +func (s *connStack[C]) Pop() (*Pooled[C], bool) { + for { + oldHead, popCount := s.top.Load() + if oldHead == nil { + return nil, false + } + + newHead := oldHead.next.Load() + if s.top.CompareAndSwap(oldHead, popCount, newHead, popCount+1) { + return oldHead, true + } + runtime.Gosched() + } +} + +func (s *connStack[C]) PopAll(out []*Pooled[C]) []*Pooled[C] { + var oldHead *Pooled[C] + + for { + var popCount uint64 + oldHead, popCount = s.top.Load() + if oldHead == nil { + return out + } + if s.top.CompareAndSwap(oldHead, popCount, nil, popCount+1) { + break + } + runtime.Gosched() + } + + for oldHead != nil { + out = append(out, oldHead) + oldHead = oldHead.next.Load() + } + return out +} diff --git a/go/pools/smartconnpool/stress_test.go b/go/pools/smartconnpool/stress_test.go new file mode 100644 index 00000000000..c1a1c3cfd58 --- /dev/null +++ b/go/pools/smartconnpool/stress_test.go @@ -0,0 +1,163 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package smartconnpool + +import ( + "context" + "fmt" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +type StressConn struct { + setting *Setting + owner atomic.Int32 + closed atomic.Bool +} + +func (b *StressConn) Expired(_ time.Duration) bool { + return false +} + +func (b *StressConn) IsSettingApplied() bool { + return b.setting != nil +} + +func (b *StressConn) IsSameSetting(setting string) bool { + return b.setting != nil && b.setting.ApplyQuery() == setting +} + +var _ Connection = (*StressConn)(nil) + +func (b *StressConn) ApplySetting(ctx context.Context, setting *Setting) error { + b.setting = setting + return nil +} + +func (b *StressConn) ResetSetting(ctx context.Context) error { + b.setting = nil + return nil +} + +func (b *StressConn) Setting() *Setting { + return b.setting +} + +func (b *StressConn) IsClosed() bool { + return b.closed.Load() +} + +func (b *StressConn) Close() { + b.closed.Store(true) +} + +func TestStackRace(t *testing.T) { + const Count = 64 + const Procs = 32 + + var wg sync.WaitGroup + var stack connStack[*StressConn] + var done atomic.Bool + + for c := 0; c < Count; c++ { + stack.Push(&Pooled[*StressConn]{Conn: &StressConn{}}) + } + + for i := 0; i < Procs; i++ { + wg.Add(1) + go func(tid int32) { + defer wg.Done() + for !done.Load() { + if conn, ok := stack.Pop(); ok { + previousOwner := conn.Conn.owner.Swap(tid) + if previousOwner != 0 { + panic(fmt.Errorf("owner race: %d with %d", tid, previousOwner)) + } + runtime.Gosched() + previousOwner = conn.Conn.owner.Swap(0) + if previousOwner != tid { + panic(fmt.Errorf("owner race: %d with %d", previousOwner, tid)) + } + stack.Push(conn) + } + } + }(int32(i + 1)) + } + + time.Sleep(5 * time.Second) + done.Store(true) + wg.Wait() + + for c := 0; c < Count; c++ { + conn, ok := stack.Pop() + require.NotNil(t, conn) + require.True(t, ok) + } +} + +func TestStress(t *testing.T) { + const Capacity = 64 + const P = 8 + + connect := func(ctx context.Context) (*StressConn, error) { + return &StressConn{}, nil + } + + pool := NewPool[*StressConn](&Config[*StressConn]{ + Capacity: Capacity, + }).Open(connect, nil) + + var wg errgroup.Group + var stop atomic.Bool + + for p := 0; p < P; p++ { + tid := int32(p + 1) + wg.Go(func() error { + ctx := context.Background() + for !stop.Load() { + conn, err := pool.get(ctx) + if err != nil { + return err + } + + previousOwner := conn.Conn.owner.Swap(tid) + if previousOwner != 0 { + return fmt.Errorf("owner race: %d with %d", tid, previousOwner) + } + runtime.Gosched() + previousOwner = conn.Conn.owner.Swap(0) + if previousOwner != tid { + return fmt.Errorf("owner race: %d with %d", previousOwner, tid) + } + conn.Recycle() + } + return nil + }) + } + + time.Sleep(5 * time.Second) + stop.Store(true) + if err := wg.Wait(); err != nil { + t.Fatal(err) + } +} diff --git a/go/pools/smartconnpool/waitlist.go b/go/pools/smartconnpool/waitlist.go new file mode 100644 index 00000000000..d4abeade0ac --- /dev/null +++ b/go/pools/smartconnpool/waitlist.go @@ -0,0 +1,166 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package smartconnpool + +import ( + "context" + "sync" + + "vitess.io/vitess/go/list" +) + +// waiter represents a client waiting for a connection in the waitlist +type waiter[C Connection] struct { + // setting is the connection Setting that we'd like, or nil if we'd like a + // a connection with no Setting applied + setting *Setting + // conn will be set by another client to hand over the connection to use + conn *Pooled[C] + // ctx is the context of the waiting client to check for expiration + ctx context.Context + // sema is a synchronization primitive that allows us to block until our request + // has been fulfilled + sema semaphore + // age is the amount of cycles this client has been on the waitlist + age uint32 +} + +type waitlist[C Connection] struct { + nodes sync.Pool + mu sync.Mutex + list list.List[waiter[C]] +} + +// waitForConn blocks until a connection with the given Setting is returned by another client, +// or until the given context expires. +// The returned connection may _not_ have the requested Setting. This function can +// also return a `nil` connection even if our context has expired, if the pool has +// forced an expiration of all waiters in the waitlist. +func (wl *waitlist[C]) waitForConn(ctx context.Context, setting *Setting) (*Pooled[C], error) { + elem := wl.nodes.Get().(*list.Element[waiter[C]]) + elem.Value = waiter[C]{setting: setting, conn: nil, ctx: ctx} + + wl.mu.Lock() + // add ourselves as a waiter at the end of the waitlist + wl.list.PushBackValue(elem) + wl.mu.Unlock() + + // block on our waiter's semaphore until somebody can hand over a connection to us + elem.Value.sema.wait() + + // we're awake -- the conn in our waiter contains the connection that was handed + // over to us, or nothing if we've been waken up forcefully. save the conn before + // we return our waiter to the pool of waiters for reuse. + conn := elem.Value.conn + wl.nodes.Put(elem) + + if conn != nil { + return conn, nil + } + return nil, ctx.Err() +} + +// expire removes and wakes any expired waiter in the waitlist. +// if force is true, it'll wake and remove all the waiters. +func (wl *waitlist[C]) expire(force bool) { + if wl.list.Len() == 0 { + return + } + + var expired []*list.Element[waiter[C]] + + wl.mu.Lock() + // iterate the waitlist looking for waiters with an expired Context, + // or remove everything if force is true + for e := wl.list.Front(); e != nil; e = e.Next() { + if force || e.Value.ctx.Err() != nil { + wl.list.Remove(e) + expired = append(expired, e) + continue + } + } + wl.mu.Unlock() + + // once all the expired waiters have been removed from the waitlist, wake them up one by one + for _, e := range expired { + e.Value.sema.notify(false) + } +} + +// tryReturnConn tries handing over a connection to one of the waiters in the pool. +func (wl *waitlist[D]) tryReturnConn(conn *Pooled[D]) bool { + // fast path: if there's nobody waiting there's nothing to do + if wl.list.Len() == 0 { + return false + } + // split the slow path into a separate function to enable inlining + return wl.tryReturnConnSlow(conn) +} + +func (wl *waitlist[D]) tryReturnConnSlow(conn *Pooled[D]) bool { + const maxAge = 8 + var ( + target *list.Element[waiter[D]] + connSetting = conn.Conn.Setting() + ) + + wl.mu.Lock() + target = wl.list.Front() + // iterate through the waitlist looking for either waiters that have been + // here too long, or a waiter that is looking exactly for the same Setting + // as the one we have in our connection. + for e := target; e != nil; e = e.Next() { + if e.Value.age > maxAge || e.Value.setting == connSetting { + target = e + break + } + // this only ages the waiters that are being skipped over: we'll start + // aging the waiters in the back once they get to the front of the pool. + // the maxAge of 8 has been set empirically: smaller values cause clients + // with a specific setting to slightly starve, and aging all the clients + // in the list every time leads to unfairness when the system is at capacity + e.Value.age++ + } + if target != nil { + wl.list.Remove(target) + } + wl.mu.Unlock() + + // maybe there isn't anybody to hand over the connection to, because we've + // raced with another client returning another connection + if target == nil { + return false + } + + // if we have a target to return the connection to, simply write the connection + // into the waiter and signal their semaphore. they'll wake up to pick up the + // connection. + target.Value.conn = conn + target.Value.sema.notify(true) + return true +} + +func (wl *waitlist[C]) init() { + wl.nodes.New = func() any { + return &list.Element[waiter[C]]{} + } + wl.list.Init() +} + +func (wl *waitlist[C]) waiting() int { + return wl.list.Len() +} diff --git a/go/vt/dbconnpool/connection.go b/go/vt/dbconnpool/connection.go index 8e9a0f4a5c0..f6bce0a7d5c 100644 --- a/go/vt/dbconnpool/connection.go +++ b/go/vt/dbconnpool/connection.go @@ -18,19 +18,38 @@ package dbconnpool import ( "context" + "errors" "fmt" "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/mysql/sqlerror" + "vitess.io/vitess/go/pools/smartconnpool" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/dbconfigs" ) +type PooledDBConnection = smartconnpool.Pooled[*DBConnection] + // DBConnection re-exposes mysql.Conn with some wrapping to implement // most of PoolConnection interface, except Recycle. That way it can be used // by itself. (Recycle needs to know about the Pool). type DBConnection struct { *mysql.Conn + info dbconfigs.Connector +} + +var errSettingNotSupported = errors.New("DBConnection does not support connection settings") + +func (dbc *DBConnection) ApplySetting(ctx context.Context, setting *smartconnpool.Setting) error { + return errSettingNotSupported +} + +func (dbc *DBConnection) ResetSetting(ctx context.Context) error { + return errSettingNotSupported +} + +func (dbc *DBConnection) Setting() *smartconnpool.Setting { + return nil } // NewDBConnection returns a new DBConnection based on the ConnParams @@ -40,7 +59,19 @@ func NewDBConnection(ctx context.Context, info dbconfigs.Connector) (*DBConnecti if err != nil { return nil, err } - return &DBConnection{Conn: c}, nil + return &DBConnection{Conn: c, info: info}, nil +} + +// Reconnect replaces the existing underlying connection with a new one, +// if possible. Recycle should still be called afterwards. +func (dbc *DBConnection) Reconnect(ctx context.Context) error { + dbc.Conn.Close() + newConn, err := dbc.info.Connect(ctx) + if err != nil { + return err + } + dbc.Conn = newConn + return nil } // ExecuteFetch overwrites mysql.Conn.ExecuteFetch. diff --git a/go/vt/dbconnpool/connection_pool.go b/go/vt/dbconnpool/connection_pool.go index e8e4acce017..9865efdada3 100644 --- a/go/vt/dbconnpool/connection_pool.go +++ b/go/vt/dbconnpool/connection_pool.go @@ -23,71 +23,31 @@ package dbconnpool import ( "context" - "errors" "net" - "sync" "time" "vitess.io/vitess/go/netutil" - - "vitess.io/vitess/go/pools" - "vitess.io/vitess/go/stats" + "vitess.io/vitess/go/pools/smartconnpool" "vitess.io/vitess/go/vt/dbconfigs" -) - -var ( - // ErrConnPoolClosed is returned if the connection pool is closed. - ErrConnPoolClosed = errors.New("connection pool is closed") - // usedNames is for preventing expvar from panicking. Tests - // create pool objects multiple time. If a name was previously - // used, expvar initialization is skipped. - // TODO(sougou): Find a way to still crash if this happened - // through non-test code. - usedNames = make(map[string]bool) + "vitess.io/vitess/go/vt/servenv" ) // ConnectionPool re-exposes ResourcePool as a pool of // PooledDBConnection objects. type ConnectionPool struct { - mu sync.Mutex - connections pools.IResourcePool - capacity int - idleTimeout time.Duration - maxLifetime time.Duration - resolutionFrequency time.Duration - - // info is set at Open() time - info dbconfigs.Connector - name string + *smartconnpool.ConnPool[*DBConnection] } // NewConnectionPool creates a new ConnectionPool. The name is used // to publish stats only. -func NewConnectionPool(name string, capacity int, idleTimeout time.Duration, maxLifetime time.Duration, dnsResolutionFrequency time.Duration) *ConnectionPool { - cp := &ConnectionPool{name: name, capacity: capacity, idleTimeout: idleTimeout, maxLifetime: maxLifetime, resolutionFrequency: dnsResolutionFrequency} - if name == "" || usedNames[name] { - return cp +func NewConnectionPool(name string, stats *servenv.Exporter, capacity int, idleTimeout time.Duration, maxLifetime time.Duration, dnsResolutionFrequency time.Duration) *ConnectionPool { + config := smartconnpool.Config[*DBConnection]{ + Capacity: int64(capacity), + IdleTimeout: idleTimeout, + MaxLifetime: maxLifetime, + RefreshInterval: dnsResolutionFrequency, } - usedNames[name] = true - stats.NewGaugeFunc(name+"Capacity", "Connection pool capacity", cp.Capacity) - stats.NewGaugeFunc(name+"Available", "Connection pool available", cp.Available) - stats.NewGaugeFunc(name+"Active", "Connection pool active", cp.Active) - stats.NewGaugeFunc(name+"InUse", "Connection pool in-use", cp.InUse) - stats.NewGaugeFunc(name+"MaxCap", "Connection pool max cap", cp.MaxCap) - stats.NewCounterFunc(name+"WaitCount", "Connection pool wait count", cp.WaitCount) - stats.NewCounterDurationFunc(name+"WaitTime", "Connection pool wait time", cp.WaitTime) - stats.NewGaugeDurationFunc(name+"IdleTimeout", "Connection pool idle timeout", cp.IdleTimeout) - stats.NewGaugeFunc(name+"IdleClosed", "Connection pool idle closed", cp.IdleClosed) - stats.NewGaugeFunc(name+"MaxLifetimeClosed", "Connection pool refresh closed", cp.MaxLifetimeClosed) - stats.NewCounterFunc(name+"Exhausted", "Number of times pool had zero available slots", cp.Exhausted) - return cp -} - -func (cp *ConnectionPool) pool() (p pools.IResourcePool) { - cp.mu.Lock() - p = cp.connections - cp.mu.Unlock() - return p + return &ConnectionPool{ConnPool: smartconnpool.NewPool(&config)} } // Open must be called before starting to use the pool. @@ -99,205 +59,18 @@ func (cp *ConnectionPool) pool() (p pools.IResourcePool) { // conn, err := pool.Get() // ... func (cp *ConnectionPool) Open(info dbconfigs.Connector) { - var refreshCheck pools.RefreshCheck + var refresh smartconnpool.RefreshCheck if net.ParseIP(info.Host()) == nil { - refreshCheck = netutil.DNSTracker(info.Host()) - } else { - refreshCheck = nil + refresh = netutil.DNSTracker(info.Host()) } - cp.mu.Lock() - defer cp.mu.Unlock() - cp.info = info - cp.connections = pools.NewResourcePool(cp.connect, cp.capacity, cp.capacity, cp.idleTimeout, cp.maxLifetime, nil, refreshCheck, cp.resolutionFrequency) -} -// connect is used by the resource pool to create a new Resource. -func (cp *ConnectionPool) connect(ctx context.Context) (pools.Resource, error) { - c, err := NewDBConnection(ctx, cp.info) - if err != nil { - return nil, err + connect := func(ctx context.Context) (*DBConnection, error) { + return NewDBConnection(ctx, info) } - return &PooledDBConnection{ - DBConnection: c, - timeCreated: time.Now(), - pool: cp, - }, nil -} -// Close will close the pool and wait for connections to be returned before -// exiting. -func (cp *ConnectionPool) Close() { - p := cp.pool() - if p == nil { - return - } - // We should not hold the lock while calling Close - // because it waits for connections to be returned. - p.Close() - cp.mu.Lock() - cp.connections = nil - cp.mu.Unlock() + cp.ConnPool.Open(connect, refresh) } -// Get returns a connection. -// You must call Recycle on the PooledDBConnection once done. func (cp *ConnectionPool) Get(ctx context.Context) (*PooledDBConnection, error) { - p := cp.pool() - if p == nil { - return nil, ErrConnPoolClosed - } - r, err := p.Get(ctx, nil) - if err != nil { - return nil, err - } - - return r.(*PooledDBConnection), nil -} - -// Put puts a connection into the pool. -func (cp *ConnectionPool) Put(conn *PooledDBConnection) { - p := cp.pool() - if p == nil { - panic(ErrConnPoolClosed) - } - if conn == nil { - // conn has a type, if we just Put(conn), we end up - // putting an interface with a nil value, that is not - // equal to a nil value. So just put a plain nil. - p.Put(nil) - return - } - p.Put(conn) -} - -// SetCapacity alters the size of the pool at runtime. -func (cp *ConnectionPool) SetCapacity(capacity int) (err error) { - cp.mu.Lock() - defer cp.mu.Unlock() - if cp.connections != nil { - err = cp.connections.SetCapacity(capacity) - if err != nil { - return err - } - } - cp.capacity = capacity - return nil -} - -// SetIdleTimeout sets the idleTimeout on the pool. -func (cp *ConnectionPool) SetIdleTimeout(idleTimeout time.Duration) { - cp.mu.Lock() - defer cp.mu.Unlock() - if cp.connections != nil { - cp.connections.SetIdleTimeout(idleTimeout) - } - cp.idleTimeout = idleTimeout -} - -// StatsJSON returns the pool stats as a JSOn object. -func (cp *ConnectionPool) StatsJSON() string { - p := cp.pool() - if p == nil { - return "{}" - } - return p.StatsJSON() -} - -// Capacity returns the pool capacity. -func (cp *ConnectionPool) Capacity() int64 { - p := cp.pool() - if p == nil { - return 0 - } - return p.Capacity() -} - -// Available returns the number of available connections in the pool -func (cp *ConnectionPool) Available() int64 { - p := cp.pool() - if p == nil { - return 0 - } - return p.Available() -} - -// Active returns the number of active connections in the pool -func (cp *ConnectionPool) Active() int64 { - p := cp.pool() - if p == nil { - return 0 - } - return p.Active() -} - -// InUse returns the number of in-use connections in the pool -func (cp *ConnectionPool) InUse() int64 { - p := cp.pool() - if p == nil { - return 0 - } - return p.InUse() -} - -// MaxCap returns the maximum size of the pool -func (cp *ConnectionPool) MaxCap() int64 { - p := cp.pool() - if p == nil { - return 0 - } - return p.MaxCap() -} - -// WaitCount returns how many clients are waiting for a connection -func (cp *ConnectionPool) WaitCount() int64 { - p := cp.pool() - if p == nil { - return 0 - } - return p.WaitCount() -} - -// WaitTime return the pool WaitTime. -func (cp *ConnectionPool) WaitTime() time.Duration { - p := cp.pool() - if p == nil { - return 0 - } - return p.WaitTime() -} - -// IdleTimeout returns the idle timeout for the pool. -func (cp *ConnectionPool) IdleTimeout() time.Duration { - p := cp.pool() - if p == nil { - return 0 - } - return p.IdleTimeout() -} - -// IdleClosed returns the number of connections closed due to idle timeout for the pool. -func (cp *ConnectionPool) IdleClosed() int64 { - p := cp.pool() - if p == nil { - return 0 - } - return p.IdleClosed() -} - -// MaxLifetimeClosed returns the number of connections closed due to refresh timeout for the pool. -func (cp *ConnectionPool) MaxLifetimeClosed() int64 { - p := cp.pool() - if p == nil { - return 0 - } - return p.MaxLifetimeClosed() -} - -// Exhausted returns the number of times available went to zero for the pool. -func (cp *ConnectionPool) Exhausted() int64 { - p := cp.pool() - if p == nil { - return 0 - } - return p.Exhausted() + return cp.ConnPool.Get(ctx, nil) } diff --git a/go/vt/dbconnpool/pooled_connection.go b/go/vt/dbconnpool/pooled_connection.go deleted file mode 100644 index b4ca428e973..00000000000 --- a/go/vt/dbconnpool/pooled_connection.go +++ /dev/null @@ -1,74 +0,0 @@ -/* -Copyright 2019 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package dbconnpool - -import ( - "context" - "time" - - "vitess.io/vitess/go/pools" -) - -// PooledDBConnection re-exposes DBConnection to be used by ConnectionPool. -type PooledDBConnection struct { - *DBConnection - timeCreated time.Time - pool *ConnectionPool -} - -func (pc *PooledDBConnection) Expired(lifetimeTimeout time.Duration) bool { - return lifetimeTimeout > 0 && time.Until(pc.timeCreated.Add(lifetimeTimeout)) < 0 -} - -func (pc *PooledDBConnection) ApplySetting(context.Context, *pools.Setting) error { - //TODO implement me - panic("implement me") -} - -func (pc *PooledDBConnection) IsSettingApplied() bool { - return false -} - -func (pc *PooledDBConnection) IsSameSetting(string) bool { - return true -} - -func (pc *PooledDBConnection) ResetSetting(context.Context) error { - //TODO implement me - panic("implement me") -} - -// Recycle should be called to return the PooledDBConnection to the pool. -func (pc *PooledDBConnection) Recycle() { - if pc.IsClosed() { - pc.pool.Put(nil) - } else { - pc.pool.Put(pc) - } -} - -// Reconnect replaces the existing underlying connection with a new one, -// if possible. Recycle should still be called afterwards. -func (pc *PooledDBConnection) Reconnect(ctx context.Context) error { - pc.DBConnection.Close() - newConn, err := NewDBConnection(ctx, pc.pool.info) - if err != nil { - return err - } - pc.DBConnection = newConn - return nil -} diff --git a/go/vt/mysqlctl/fakemysqldaemon.go b/go/vt/mysqlctl/fakemysqldaemon.go index de16e100fa8..791b43da583 100644 --- a/go/vt/mysqlctl/fakemysqldaemon.go +++ b/go/vt/mysqlctl/fakemysqldaemon.go @@ -195,7 +195,7 @@ func NewFakeMysqlDaemon(db *fakesqldb.DB) *FakeMysqlDaemon { Version: "8.0.32", } if db != nil { - result.appPool = dbconnpool.NewConnectionPool("AppConnPool", 5, time.Minute, 0, 0) + result.appPool = dbconnpool.NewConnectionPool("AppConnPool", nil, 5, time.Minute, 0, 0) result.appPool.Open(db.ConnParams()) } return result diff --git a/go/vt/mysqlctl/mysqld.go b/go/vt/mysqlctl/mysqld.go index 989963479f4..ba4ccf755b3 100644 --- a/go/vt/mysqlctl/mysqld.go +++ b/go/vt/mysqlctl/mysqld.go @@ -147,11 +147,11 @@ func NewMysqld(dbcfgs *dbconfigs.DBConfigs) *Mysqld { } // Create and open the connection pool for dba access. - result.dbaPool = dbconnpool.NewConnectionPool("DbaConnPool", dbaPoolSize, DbaIdleTimeout, 0, PoolDynamicHostnameResolution) + result.dbaPool = dbconnpool.NewConnectionPool("DbaConnPool", nil, dbaPoolSize, DbaIdleTimeout, 0, PoolDynamicHostnameResolution) result.dbaPool.Open(dbcfgs.DbaWithDB()) // Create and open the connection pool for app access. - result.appPool = dbconnpool.NewConnectionPool("AppConnPool", appPoolSize, appIdleTimeout, 0, PoolDynamicHostnameResolution) + result.appPool = dbconnpool.NewConnectionPool("AppConnPool", nil, appPoolSize, appIdleTimeout, 0, PoolDynamicHostnameResolution) result.appPool.Open(dbcfgs.AppWithDB()) /* diff --git a/go/vt/mysqlctl/query.go b/go/vt/mysqlctl/query.go index ceed3f58e03..5e21913c617 100644 --- a/go/vt/mysqlctl/query.go +++ b/go/vt/mysqlctl/query.go @@ -36,10 +36,10 @@ func getPoolReconnect(ctx context.Context, pool *dbconnpool.ConnectionPool) (*db return conn, err } // Run a test query to see if this connection is still good. - if _, err := conn.ExecuteFetch("SELECT 1", 1, false); err != nil { + if _, err := conn.Conn.ExecuteFetch("SELECT 1", 1, false); err != nil { // If we get a connection error, try to reconnect. if sqlErr, ok := err.(*sqlerror.SQLError); ok && (sqlErr.Number() == sqlerror.CRServerGone || sqlErr.Number() == sqlerror.CRServerLost) { - if err := conn.Reconnect(ctx); err != nil { + if err := conn.Conn.Reconnect(ctx); err != nil { conn.Recycle() return nil, err } @@ -117,7 +117,7 @@ func (mysqld *Mysqld) executeFetchContext(ctx context.Context, conn *dbconnpool. go func() { defer close(done) - qr, executeErr = conn.ExecuteFetch(query, maxrows, wantfields) + qr, executeErr = conn.Conn.ExecuteFetch(query, maxrows, wantfields) }() // Wait for either the query or the context to be done. @@ -136,7 +136,7 @@ func (mysqld *Mysqld) executeFetchContext(ctx context.Context, conn *dbconnpool. // The context expired or was canceled. // Try to kill the connection to effectively cancel the ExecuteFetch(). - connID := conn.ID() + connID := conn.Conn.ID() log.Infof("Mysqld.executeFetchContext(): killing connID %v due to timeout of query: %v", connID, query) if killErr := mysqld.killConnection(connID); killErr != nil { // Log it, but go ahead and wait for the query anyway. @@ -172,7 +172,7 @@ func (mysqld *Mysqld) killConnection(connID int64) error { if poolConn, connErr := getPoolReconnect(ctx, mysqld.dbaPool); connErr == nil { // We got a pool connection. defer poolConn.Recycle() - killConn = poolConn + killConn = poolConn.Conn } else { // We couldn't get a connection from the pool. // It might be because the connection pool is exhausted, diff --git a/go/vt/mysqlctl/reparent.go b/go/vt/mysqlctl/reparent.go index b76e342d0cd..0cd89c59ab3 100644 --- a/go/vt/mysqlctl/reparent.go +++ b/go/vt/mysqlctl/reparent.go @@ -95,7 +95,7 @@ func (mysqld *Mysqld) Promote(hookExtraEnv map[string]string) (replication.Posit // Since we handle replication, just stop it. cmds := []string{ - conn.StopReplicationCommand(), + conn.Conn.StopReplicationCommand(), "RESET SLAVE ALL", // "ALL" makes it forget primary host:port. // When using semi-sync and GTID, a replica first connects to the new primary with a given GTID set, // it can take a long time to scan the current binlog file to find the corresponding position. @@ -108,5 +108,5 @@ func (mysqld *Mysqld) Promote(hookExtraEnv map[string]string) (replication.Posit if err := mysqld.executeSuperQueryListConn(ctx, conn, cmds); err != nil { return replication.Position{}, err } - return conn.PrimaryPosition() + return conn.Conn.PrimaryPosition() } diff --git a/go/vt/mysqlctl/replication.go b/go/vt/mysqlctl/replication.go index 2b92f5d961d..23b19669f16 100644 --- a/go/vt/mysqlctl/replication.go +++ b/go/vt/mysqlctl/replication.go @@ -75,7 +75,7 @@ func (mysqld *Mysqld) StartReplication(hookExtraEnv map[string]string) error { } defer conn.Recycle() - if err := mysqld.executeSuperQueryListConn(ctx, conn, []string{conn.StartReplicationCommand()}); err != nil { + if err := mysqld.executeSuperQueryListConn(ctx, conn, []string{conn.Conn.StartReplicationCommand()}); err != nil { return err } @@ -92,7 +92,7 @@ func (mysqld *Mysqld) StartReplicationUntilAfter(ctx context.Context, targetPos } defer conn.Recycle() - queries := []string{conn.StartReplicationUntilAfterCommand(targetPos)} + queries := []string{conn.Conn.StartReplicationUntilAfterCommand(targetPos)} return mysqld.executeSuperQueryListConn(ctx, conn, queries) } @@ -105,7 +105,7 @@ func (mysqld *Mysqld) StartSQLThreadUntilAfter(ctx context.Context, targetPos re } defer conn.Recycle() - queries := []string{conn.StartSQLThreadUntilAfterCommand(targetPos)} + queries := []string{conn.Conn.StartSQLThreadUntilAfterCommand(targetPos)} return mysqld.executeSuperQueryListConn(ctx, conn, queries) } @@ -124,7 +124,7 @@ func (mysqld *Mysqld) StopReplication(hookExtraEnv map[string]string) error { } defer conn.Recycle() - return mysqld.executeSuperQueryListConn(ctx, conn, []string{conn.StopReplicationCommand()}) + return mysqld.executeSuperQueryListConn(ctx, conn, []string{conn.Conn.StopReplicationCommand()}) } // StopIOThread stops a replica's IO thread only. @@ -135,7 +135,7 @@ func (mysqld *Mysqld) StopIOThread(ctx context.Context) error { } defer conn.Recycle() - return mysqld.executeSuperQueryListConn(ctx, conn, []string{conn.StopIOThreadCommand()}) + return mysqld.executeSuperQueryListConn(ctx, conn, []string{conn.Conn.StopIOThreadCommand()}) } // StopSQLThread stops a replica's SQL thread(s) only. @@ -146,7 +146,7 @@ func (mysqld *Mysqld) StopSQLThread(ctx context.Context) error { } defer conn.Recycle() - return mysqld.executeSuperQueryListConn(ctx, conn, []string{conn.StopSQLThreadCommand()}) + return mysqld.executeSuperQueryListConn(ctx, conn, []string{conn.Conn.StopSQLThreadCommand()}) } // RestartReplication stops, resets and starts replication. @@ -163,7 +163,7 @@ func (mysqld *Mysqld) RestartReplication(hookExtraEnv map[string]string) error { } defer conn.Recycle() - if err := mysqld.executeSuperQueryListConn(ctx, conn, conn.RestartReplicationCommands()); err != nil { + if err := mysqld.executeSuperQueryListConn(ctx, conn, conn.Conn.RestartReplicationCommands()); err != nil { return err } @@ -212,7 +212,7 @@ func (mysqld *Mysqld) GetServerUUID(ctx context.Context) (string, error) { } defer conn.Recycle() - return conn.GetServerUUID() + return conn.Conn.GetServerUUID() } // IsReadOnly return true if the instance is read only @@ -332,7 +332,7 @@ func (mysqld *Mysqld) WaitSourcePos(ctx context.Context, targetPos replication.P // If we are the primary, WaitUntilFilePositionCommand will fail. // But position is most likely reached. So, check the position // first. - mpos, err := conn.PrimaryFilePosition() + mpos, err := conn.Conn.PrimaryFilePosition() if err != nil { return fmt.Errorf("WaitSourcePos: PrimaryFilePosition failed: %v", err) } @@ -341,7 +341,7 @@ func (mysqld *Mysqld) WaitSourcePos(ctx context.Context, targetPos replication.P } // Find the query to run, run it. - query, err = conn.WaitUntilFilePositionCommand(ctx, targetPos) + query, err = conn.Conn.WaitUntilFilePositionCommand(ctx, targetPos) if err != nil { return err } @@ -350,7 +350,7 @@ func (mysqld *Mysqld) WaitSourcePos(ctx context.Context, targetPos replication.P // If we are the primary, WaitUntilPositionCommand will fail. // But position is most likely reached. So, check the position // first. - mpos, err := conn.PrimaryPosition() + mpos, err := conn.Conn.PrimaryPosition() if err != nil { return fmt.Errorf("WaitSourcePos: PrimaryPosition failed: %v", err) } @@ -359,7 +359,7 @@ func (mysqld *Mysqld) WaitSourcePos(ctx context.Context, targetPos replication.P } // Find the query to run, run it. - query, err = conn.WaitUntilPositionCommand(ctx, targetPos) + query, err = conn.Conn.WaitUntilPositionCommand(ctx, targetPos) if err != nil { return err } @@ -391,7 +391,7 @@ func (mysqld *Mysqld) ReplicationStatus() (replication.ReplicationStatus, error) } defer conn.Recycle() - return conn.ShowReplicationStatus() + return conn.Conn.ShowReplicationStatus() } // PrimaryStatus returns the primary replication statuses @@ -402,7 +402,7 @@ func (mysqld *Mysqld) PrimaryStatus(ctx context.Context) (replication.PrimarySta } defer conn.Recycle() - return conn.ShowPrimaryStatus() + return conn.Conn.ShowPrimaryStatus() } // GetGTIDPurged returns the gtid purged statuses @@ -413,7 +413,7 @@ func (mysqld *Mysqld) GetGTIDPurged(ctx context.Context) (replication.Position, } defer conn.Recycle() - return conn.GetGTIDPurged() + return conn.Conn.GetGTIDPurged() } // PrimaryPosition returns the primary replication position. @@ -424,7 +424,7 @@ func (mysqld *Mysqld) PrimaryPosition() (replication.Position, error) { } defer conn.Recycle() - return conn.PrimaryPosition() + return conn.Conn.PrimaryPosition() } // SetReplicationPosition sets the replication position at which the replica will resume @@ -436,7 +436,7 @@ func (mysqld *Mysqld) SetReplicationPosition(ctx context.Context, pos replicatio } defer conn.Recycle() - cmds := conn.SetReplicationPositionCommands(pos) + cmds := conn.Conn.SetReplicationPositionCommands(pos) log.Infof("Executing commands to set replication position: %v", cmds) return mysqld.executeSuperQueryListConn(ctx, conn, cmds) } @@ -456,12 +456,12 @@ func (mysqld *Mysqld) SetReplicationSource(ctx context.Context, host string, por var cmds []string if stopReplicationBefore { - cmds = append(cmds, conn.StopReplicationCommand()) + cmds = append(cmds, conn.Conn.StopReplicationCommand()) } - smc := conn.SetReplicationSourceCommand(params, host, port, int(replicationConnectRetry.Seconds())) + smc := conn.Conn.SetReplicationSourceCommand(params, host, port, int(replicationConnectRetry.Seconds())) cmds = append(cmds, smc) if startReplicationAfter { - cmds = append(cmds, conn.StartReplicationCommand()) + cmds = append(cmds, conn.Conn.StartReplicationCommand()) } return mysqld.executeSuperQueryListConn(ctx, conn, cmds) } @@ -474,7 +474,7 @@ func (mysqld *Mysqld) ResetReplication(ctx context.Context) error { } defer conn.Recycle() - cmds := conn.ResetReplicationCommands() + cmds := conn.Conn.ResetReplicationCommands() return mysqld.executeSuperQueryListConn(ctx, conn, cmds) } @@ -486,7 +486,7 @@ func (mysqld *Mysqld) ResetReplicationParameters(ctx context.Context) error { } defer conn.Recycle() - cmds := conn.ResetReplicationParametersCommands() + cmds := conn.Conn.ResetReplicationParametersCommands() return mysqld.executeSuperQueryListConn(ctx, conn, cmds) } @@ -582,7 +582,7 @@ func (mysqld *Mysqld) GetGTIDMode(ctx context.Context) (string, error) { } defer conn.Recycle() - return conn.GetGTIDMode() + return conn.Conn.GetGTIDMode() } // FlushBinaryLogs is part of the MysqlDaemon interface. diff --git a/go/vt/mysqlctl/schema.go b/go/vt/mysqlctl/schema.go index 8a14d26acce..6f1c7c19570 100644 --- a/go/vt/mysqlctl/schema.go +++ b/go/vt/mysqlctl/schema.go @@ -366,7 +366,7 @@ func (mysqld *Mysqld) GetColumns(ctx context.Context, dbName, table string) ([]* return nil, nil, err } defer conn.Recycle() - return GetColumns(dbName, table, conn.ExecuteFetch) + return GetColumns(dbName, table, conn.Conn.ExecuteFetch) } // GetPrimaryKeyColumns returns the primary key columns of table. @@ -397,7 +397,7 @@ func (mysqld *Mysqld) getPrimaryKeyColumns(ctx context.Context, dbName string, t WHERE TABLE_SCHEMA = %s AND TABLE_NAME IN %s AND LOWER(INDEX_NAME) = 'primary' ORDER BY table_name, SEQ_IN_INDEX` sql = fmt.Sprintf(sql, encodeEntityName(dbName), tableList) - qr, err := conn.ExecuteFetch(sql, len(tables)*100, true) + qr, err := conn.Conn.ExecuteFetch(sql, len(tables)*100, true) if err != nil { return nil, err } @@ -629,7 +629,7 @@ func (mysqld *Mysqld) GetPrimaryKeyEquivalentColumns(ctx context.Context, dbName encodedDbName := encodeEntityName(dbName) encodedTable := encodeEntityName(table) sql = fmt.Sprintf(sql, encodedDbName, encodedTable, encodedDbName, encodedTable, encodedDbName, encodedTable) - qr, err := conn.ExecuteFetch(sql, 1000, true) + qr, err := conn.Conn.ExecuteFetch(sql, 1000, true) if err != nil { return nil, "", err } diff --git a/go/vt/vttablet/endtoend/settings_test.go b/go/vt/vttablet/endtoend/settings_test.go index 322819ade8e..d0a3b4987dd 100644 --- a/go/vt/vttablet/endtoend/settings_test.go +++ b/go/vt/vttablet/endtoend/settings_test.go @@ -103,28 +103,15 @@ func TestSetttingsReuseConnWithSettings(t *testing.T) { require.NoError(t, err) // We iterate in a loop and try to get a connection with the same settings as before - // but only 1 at a time. So we expect the two connections to be reused, and we should be seeing both of them. - reusedConnection1 := false - reusedConnection2 := false - for i := 0; i < 100; i++ { + // but only 1 at a time. We're only going to see connection 2 here because the pool is LIFO + for i := 0; i < 8; i++ { res, err = client.ReserveBeginExecute(connectionIDQuery, []string{setting}, nil, nil) require.NoError(t, err) - if connectionIDRes.Equal(res) { - reusedConnection1 = true - } else if connectionIDRes2.Equal(res) { - reusedConnection2 = true - } else { - t.Fatalf("The connection should be either of the already created connections") - } + require.Truef(t, connectionIDRes2.Equal(res), "connection pool was not LIFO") err = client.Rollback() require.NoError(t, err) - if reusedConnection2 && reusedConnection1 { - break - } } - require.True(t, reusedConnection1) - require.True(t, reusedConnection2) } // resetTxConnPool resets the settings pool by fetching all the connections from the pool with no settings. diff --git a/go/vt/vttablet/onlineddl/executor.go b/go/vt/vttablet/onlineddl/executor.go index 7b216f0b1ad..746edbd6948 100644 --- a/go/vt/vttablet/onlineddl/executor.go +++ b/go/vt/vttablet/onlineddl/executor.go @@ -286,7 +286,7 @@ func (e *Executor) executeQuery(ctx context.Context, query string) (result *sqlt } defer conn.Recycle() - return conn.Exec(ctx, query, math.MaxInt32, true) + return conn.Conn.Exec(ctx, query, math.MaxInt32, true) } func (e *Executor) executeQueryWithSidecarDBReplacement(ctx context.Context, query string) (result *sqltypes.Result, err error) { @@ -303,7 +303,7 @@ func (e *Executor) executeQueryWithSidecarDBReplacement(ctx context.Context, que if err != nil { return nil, err } - return conn.Exec(ctx, uq, math.MaxInt32, true) + return conn.Conn.Exec(ctx, uq, math.MaxInt32, true) } // TabletAliasString returns tablet alias as string (duh) @@ -506,7 +506,7 @@ func (e *Executor) readMySQLVariables(ctx context.Context) (variables *mysqlVari } defer conn.Recycle() - tm, err := conn.Exec(ctx, `select + tm, err := conn.Conn.Exec(ctx, `select @@global.hostname as hostname, @@global.port as port, @@global.read_only as read_only, @@ -861,7 +861,7 @@ func (e *Executor) cutOverVReplMigration(ctx context.Context, s *VReplStream) er return err } defer lockConn.Recycle() - defer lockConn.Exec(ctx, sqlUnlockTables, 1, false) + defer lockConn.Conn.Exec(ctx, sqlUnlockTables, 1, false) renameCompleteChan := make(chan error) renameWasSuccessful := false @@ -872,7 +872,7 @@ func (e *Executor) cutOverVReplMigration(ctx context.Context, s *VReplStream) er defer renameConn.Recycle() defer func() { if !renameWasSuccessful { - renameConn.Kill("premature exit while renaming tables", 0) + renameConn.Conn.Kill("premature exit while renaming tables", 0) } }() renameQuery := sqlparser.BuildParsedQuery(sqlSwapTables, onlineDDL.Table, sentryTableName, vreplTable, onlineDDL.Table, sentryTableName, vreplTable) @@ -885,7 +885,7 @@ func (e *Executor) cutOverVReplMigration(ctx context.Context, s *VReplStream) er defer cancel() for { - renameProcessFound, err := e.doesConnectionInfoMatch(renameWaitCtx, renameConn.ID(), "rename") + renameProcessFound, err := e.doesConnectionInfoMatch(renameWaitCtx, renameConn.Conn.ID(), "rename") if err != nil { return err } @@ -968,14 +968,14 @@ func (e *Executor) cutOverVReplMigration(ctx context.Context, s *VReplStream) er lockCtx, cancel := context.WithTimeout(ctx, migrationCutOverThreshold) defer cancel() lockTableQuery := sqlparser.BuildParsedQuery(sqlLockTwoTablesWrite, sentryTableName, onlineDDL.Table) - if _, err := lockConn.Exec(lockCtx, lockTableQuery.Query, 1, false); err != nil { + if _, err := lockConn.Conn.Exec(lockCtx, lockTableQuery.Query, 1, false); err != nil { return err } e.updateMigrationStage(ctx, onlineDDL.UUID, "renaming tables") go func() { defer close(renameCompleteChan) - _, err := renameConn.Exec(ctx, renameQuery.Query, 1, false) + _, err := renameConn.Conn.Exec(ctx, renameQuery.Query, 1, false) renameCompleteChan <- err }() // the rename should block, because of the LOCK. Wait for it to show up. @@ -1040,7 +1040,7 @@ func (e *Executor) cutOverVReplMigration(ctx context.Context, s *VReplStream) er dropTableQuery := sqlparser.BuildParsedQuery(sqlDropTable, sentryTableName) lockCtx, cancel := context.WithTimeout(ctx, migrationCutOverThreshold) defer cancel() - if _, err := lockConn.Exec(lockCtx, dropTableQuery.Query, 1, false); err != nil { + if _, err := lockConn.Conn.Exec(lockCtx, dropTableQuery.Query, 1, false); err != nil { return err } } @@ -1048,7 +1048,7 @@ func (e *Executor) cutOverVReplMigration(ctx context.Context, s *VReplStream) er lockCtx, cancel := context.WithTimeout(ctx, migrationCutOverThreshold) defer cancel() e.updateMigrationStage(ctx, onlineDDL.UUID, "unlocking tables") - if _, err := lockConn.Exec(lockCtx, sqlUnlockTables, 1, false); err != nil { + if _, err := lockConn.Conn.Exec(lockCtx, sqlUnlockTables, 1, false); err != nil { return err } } diff --git a/go/vt/vttablet/tabletmanager/rpc_query.go b/go/vt/vttablet/tabletmanager/rpc_query.go index 0d21cee7677..8b8ac605893 100644 --- a/go/vt/vttablet/tabletmanager/rpc_query.go +++ b/go/vt/vttablet/tabletmanager/rpc_query.go @@ -135,7 +135,7 @@ func (tm *TabletManager) ExecuteFetchAsApp(ctx context.Context, req *tabletmanag if err != nil { return nil, err } - result, err := conn.ExecuteFetch(uq, int(req.MaxRows), true /*wantFields*/) + result, err := conn.Conn.ExecuteFetch(uq, int(req.MaxRows), true /*wantFields*/) return sqltypes.ResultToProto3(result), err } diff --git a/go/vt/vttablet/tabletmanager/vreplication/vcopier.go b/go/vt/vttablet/tabletmanager/vreplication/vcopier.go index 2df808c3a77..cbf524c54c3 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/vcopier.go +++ b/go/vt/vttablet/tabletmanager/vreplication/vcopier.go @@ -758,7 +758,7 @@ func (vcq *vcopierCopyWorkQueue) enqueue(ctx context.Context, currT *vcopierCopy } // Get a handle on an unused worker. - poolH, err := vcq.workerPool.Get(ctx, nil) + poolH, err := vcq.workerPool.Get(ctx) if err != nil { return fmt.Errorf("failed to get a worker from pool: %s", err.Error()) } @@ -1018,11 +1018,6 @@ func (vts vcopierCopyTaskState) String() string { return fmt.Sprintf("undefined(%d)", int(vts)) } -// ApplySetting implements pools.Resource. -func (vbc *vcopierCopyWorker) ApplySetting(context.Context, *pools.Setting) error { - return vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "[BUG] vcopierCopyWorker does not implement ApplySetting") -} - // Close implements pool.Resource. func (vbc *vcopierCopyWorker) Close() { if !vbc.isOpen { @@ -1040,21 +1035,6 @@ func (vbc *vcopierCopyWorker) Expired(time.Duration) bool { return false } -// IsSameSetting implements pools.Resource. -func (vbc *vcopierCopyWorker) IsSameSetting(string) bool { - return true -} - -// IsSettingApplied implements pools.Resource. -func (vbc *vcopierCopyWorker) IsSettingApplied() bool { - return false -} - -// ResetSetting implements pools.Resource. -func (vbc *vcopierCopyWorker) ResetSetting(context.Context) error { - return vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "[BUG] vcopierCopyWorker does not implement ResetSetting") -} - // execute advances a task through each state until it is done (= canceled, // completed, failed). func (vbc *vcopierCopyWorker) execute(ctx context.Context, task *vcopierCopyTask) *vcopierCopyTaskResult { diff --git a/go/vt/vttablet/tabletserver/connpool/dbconn.go b/go/vt/vttablet/tabletserver/connpool/dbconn.go index cc81bf39910..63f4c73520e 100644 --- a/go/vt/vttablet/tabletserver/connpool/dbconn.go +++ b/go/vt/vttablet/tabletserver/connpool/dbconn.go @@ -25,79 +25,71 @@ import ( "time" "vitess.io/vitess/go/mysql/sqlerror" - "vitess.io/vitess/go/pools" - "vitess.io/vitess/go/vt/dbconfigs" - "vitess.io/vitess/go/vt/servenv" - "vitess.io/vitess/go/vt/vterrors" - + "vitess.io/vitess/go/pools/smartconnpool" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/trace" + "vitess.io/vitess/go/vt/dbconfigs" "vitess.io/vitess/go/vt/dbconnpool" "vitess.io/vitess/go/vt/log" + "vitess.io/vitess/go/vt/servenv" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vttablet/tabletserver/tabletenv" querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" ) -// DBConn is a db connection for tabletserver. +// Conn is a db connection for tabletserver. // It performs automatic reconnects as needed. // Its Execute function has a timeout that can kill // its own queries and the underlying connection. // It will also trigger a CheckMySQL whenever applicable. -type DBConn struct { - conn *dbconnpool.DBConnection - info dbconfigs.Connector - pool *Pool - dbaPool *dbconnpool.ConnectionPool - stats *tabletenv.Stats - current atomic.Value - timeCreated time.Time - setting string - resetSetting string +type Conn struct { + conn *dbconnpool.DBConnection + setting *smartconnpool.Setting + + env tabletenv.Env + dbaPool *dbconnpool.ConnectionPool + stats *tabletenv.Stats + current atomic.Value // err will be set if a query is killed through a Kill. errmu sync.Mutex err error } -// NewDBConn creates a new DBConn. It triggers a CheckMySQL if creation fails. -func NewDBConn(ctx context.Context, cp *Pool, appParams dbconfigs.Connector) (*DBConn, error) { +// NewConnection creates a new DBConn. It triggers a CheckMySQL if creation fails. +func newPooledConn(ctx context.Context, pool *Pool, appParams dbconfigs.Connector) (*Conn, error) { start := time.Now() - defer cp.env.Stats().MySQLTimings.Record("Connect", start) + defer pool.env.Stats().MySQLTimings.Record("Connect", start) c, err := dbconnpool.NewDBConnection(ctx, appParams) if err != nil { - cp.env.Stats().MySQLTimings.Record("ConnectError", start) - cp.env.CheckMySQL() + pool.env.Stats().MySQLTimings.Record("ConnectError", start) + pool.env.CheckMySQL() return nil, err } - db := &DBConn{ - conn: c, - info: appParams, - pool: cp, - dbaPool: cp.dbaPool, - timeCreated: time.Now(), - stats: cp.env.Stats(), + db := &Conn{ + conn: c, + env: pool.env, + stats: pool.env.Stats(), + dbaPool: pool.dbaPool, } db.current.Store("") return db, nil } -// NewDBConnNoPool creates a new DBConn without a pool. -func NewDBConnNoPool(ctx context.Context, params dbconfigs.Connector, dbaPool *dbconnpool.ConnectionPool, setting *pools.Setting) (*DBConn, error) { +// NewConn creates a new Conn without a pool. +func NewConn(ctx context.Context, params dbconfigs.Connector, dbaPool *dbconnpool.ConnectionPool, setting *smartconnpool.Setting) (*Conn, error) { c, err := dbconnpool.NewDBConnection(ctx, params) if err != nil { return nil, err } - dbconn := &DBConn{ - conn: c, - info: params, - dbaPool: dbaPool, - pool: nil, - timeCreated: time.Now(), - stats: tabletenv.NewStats(servenv.NewExporter("Temp", "Tablet")), + dbconn := &Conn{ + conn: c, + dbaPool: dbaPool, + stats: tabletenv.NewStats(servenv.NewExporter("Temp", "Tablet")), } dbconn.current.Store("") if setting == nil { @@ -112,7 +104,7 @@ func NewDBConnNoPool(ctx context.Context, params dbconfigs.Connector, dbaPool *d // Err returns an error if there was a client initiated error // like a query kill. -func (dbc *DBConn) Err() error { +func (dbc *Conn) Err() error { dbc.errmu.Lock() defer dbc.errmu.Unlock() return dbc.err @@ -120,7 +112,7 @@ func (dbc *DBConn) Err() error { // Exec executes the specified query. If there is a connection error, it will reconnect // and retry. A failed reconnect will trigger a CheckMySQL. -func (dbc *DBConn) Exec(ctx context.Context, query string, maxrows int, wantfields bool) (*sqltypes.Result, error) { +func (dbc *Conn) Exec(ctx context.Context, query string, maxrows int, wantfields bool) (*sqltypes.Result, error) { span, ctx := trace.NewSpan(ctx, "DBConn.Exec") defer span.Finish() @@ -141,15 +133,15 @@ func (dbc *DBConn) Exec(ctx context.Context, query string, maxrows int, wantfiel return nil, err } - // Connection error. Retry if context has not expired. + // Conn error. Retry if context has not expired. select { case <-ctx.Done(): return nil, err default: } - if reconnectErr := dbc.reconnect(ctx); reconnectErr != nil { - dbc.pool.env.CheckMySQL() + if reconnectErr := dbc.Reconnect(ctx); reconnectErr != nil { + dbc.env.CheckMySQL() // Return the error of the reconnect and not the original connection error. return nil, reconnectErr } @@ -159,16 +151,14 @@ func (dbc *DBConn) Exec(ctx context.Context, query string, maxrows int, wantfiel panic("unreachable") } -func (dbc *DBConn) execOnce(ctx context.Context, query string, maxrows int, wantfields bool) (*sqltypes.Result, error) { +func (dbc *Conn) execOnce(ctx context.Context, query string, maxrows int, wantfields bool) (*sqltypes.Result, error) { dbc.current.Store(query) defer dbc.current.Store("") // Check if the context is already past its deadline before // trying to execute the query. - select { - case <-ctx.Done(): - return nil, fmt.Errorf("%v before execution started", ctx.Err()) - default: + if err := ctx.Err(); err != nil { + return nil, fmt.Errorf("%v before execution started", err) } defer dbc.stats.MySQLTimings.Record("Exec", time.Now()) @@ -187,18 +177,16 @@ func (dbc *DBConn) execOnce(ctx context.Context, query string, maxrows int, want } // ExecOnce executes the specified query, but does not retry on connection errors. -func (dbc *DBConn) ExecOnce(ctx context.Context, query string, maxrows int, wantfields bool) (*sqltypes.Result, error) { +func (dbc *Conn) ExecOnce(ctx context.Context, query string, maxrows int, wantfields bool) (*sqltypes.Result, error) { return dbc.execOnce(ctx, query, maxrows, wantfields) } // FetchNext returns the next result set. -func (dbc *DBConn) FetchNext(ctx context.Context, maxrows int, wantfields bool) (*sqltypes.Result, error) { +func (dbc *Conn) FetchNext(ctx context.Context, maxrows int, wantfields bool) (*sqltypes.Result, error) { // Check if the context is already past its deadline before // trying to fetch the next result. - select { - case <-ctx.Done(): - return nil, fmt.Errorf("%v before reading next result set", ctx.Err()) - default: + if err := ctx.Err(); err != nil { + return nil, fmt.Errorf("%v before reading next result set", err) } res, _, _, err := dbc.conn.ReadQueryResult(maxrows, wantfields) if err != nil { @@ -209,7 +197,7 @@ func (dbc *DBConn) FetchNext(ctx context.Context, maxrows int, wantfields bool) } // Stream executes the query and streams the results. -func (dbc *DBConn) Stream(ctx context.Context, query string, callback func(*sqltypes.Result) error, alloc func() *sqltypes.Result, streamBufferSize int, includedFields querypb.ExecuteOptions_IncludedFields) error { +func (dbc *Conn) Stream(ctx context.Context, query string, callback func(*sqltypes.Result) error, alloc func() *sqltypes.Result, streamBufferSize int, includedFields querypb.ExecuteOptions_IncludedFields) error { span, ctx := trace.NewSpan(ctx, "DBConn.Stream") trace.AnnotateSQL(span, sqlparser.Preview(query)) defer span.Finish() @@ -247,14 +235,12 @@ func (dbc *DBConn) Stream(ctx context.Context, query string, callback func(*sqlt return err } - // Connection error. Retry if context has not expired. - select { - case <-ctx.Done(): + // Conn error. Retry if context has not expired. + if ctx.Err() != nil { return err - default: } - if reconnectErr := dbc.reconnect(ctx); reconnectErr != nil { - dbc.pool.env.CheckMySQL() + if reconnectErr := dbc.Reconnect(ctx); reconnectErr != nil { + dbc.env.CheckMySQL() // Return the error of the reconnect and not the original connection error. return reconnectErr } @@ -262,7 +248,7 @@ func (dbc *DBConn) Stream(ctx context.Context, query string, callback func(*sqlt panic("unreachable") } -func (dbc *DBConn) streamOnce(ctx context.Context, query string, callback func(*sqltypes.Result) error, alloc func() *sqltypes.Result, streamBufferSize int) error { +func (dbc *Conn) streamOnce(ctx context.Context, query string, callback func(*sqltypes.Result) error, alloc func() *sqltypes.Result, streamBufferSize int) error { defer dbc.stats.MySQLTimings.Record("ExecStream", time.Now()) dbc.current.Store(query) @@ -282,7 +268,7 @@ func (dbc *DBConn) streamOnce(ctx context.Context, query string, callback func(* } // StreamOnce executes the query and streams the results. But, does not retry on connection errors. -func (dbc *DBConn) StreamOnce(ctx context.Context, query string, callback func(*sqltypes.Result) error, alloc func() *sqltypes.Result, streamBufferSize int, includedFields querypb.ExecuteOptions_IncludedFields) error { +func (dbc *Conn) StreamOnce(ctx context.Context, query string, callback func(*sqltypes.Result) error, alloc func() *sqltypes.Result, streamBufferSize int, includedFields querypb.ExecuteOptions_IncludedFields) error { resultSent := false return dbc.streamOnce( ctx, @@ -307,7 +293,7 @@ var ( // VerifyMode is a helper method to verify mysql is running with // sql_mode = STRICT_TRANS_TABLES or STRICT_ALL_TABLES and autocommit=ON. -func (dbc *DBConn) VerifyMode(strictTransTables bool) error { +func (dbc *Conn) VerifyMode(strictTransTables bool) error { if strictTransTables { qr, err := dbc.conn.ExecuteFetch(getModeSQL, 2, false) if err != nil { @@ -345,78 +331,41 @@ func (dbc *DBConn) VerifyMode(strictTransTables bool) error { } // Close closes the DBConn. -func (dbc *DBConn) Close() { +func (dbc *Conn) Close() { dbc.conn.Close() } // ApplySetting implements the pools.Resource interface. -func (dbc *DBConn) ApplySetting(ctx context.Context, setting *pools.Setting) error { - query := setting.GetQuery() - if _, err := dbc.execOnce(ctx, query, 1, false); err != nil { +func (dbc *Conn) ApplySetting(ctx context.Context, setting *smartconnpool.Setting) error { + if _, err := dbc.execOnce(ctx, setting.ApplyQuery(), 1, false); err != nil { return err } - dbc.setting = query - dbc.resetSetting = setting.GetResetQuery() + dbc.setting = setting return nil } -// IsSettingApplied implements the pools.Resource interface. -func (dbc *DBConn) IsSettingApplied() bool { - return dbc.setting != "" -} - -// IsSameSetting implements the pools.Resource interface. -func (dbc *DBConn) IsSameSetting(setting string) bool { - return strings.EqualFold(setting, dbc.setting) -} - // ResetSetting implements the pools.Resource interface. -func (dbc *DBConn) ResetSetting(ctx context.Context) error { - if _, err := dbc.execOnce(ctx, dbc.resetSetting, 1, false); err != nil { +func (dbc *Conn) ResetSetting(ctx context.Context) error { + if _, err := dbc.execOnce(ctx, dbc.setting.ResetQuery(), 1, false); err != nil { return err } - dbc.setting = "" - dbc.resetSetting = "" + dbc.setting = nil return nil } -var _ pools.Resource = (*DBConn)(nil) +func (dbc *Conn) Setting() *smartconnpool.Setting { + return dbc.setting +} // IsClosed returns true if DBConn is closed. -func (dbc *DBConn) IsClosed() bool { +func (dbc *Conn) IsClosed() bool { return dbc.conn.IsClosed() } -// Expired returns whether a connection has passed its lifetime -func (dbc *DBConn) Expired(lifetimeTimeout time.Duration) bool { - return lifetimeTimeout > 0 && time.Until(dbc.timeCreated.Add(lifetimeTimeout)) < 0 -} - -// Recycle returns the DBConn to the pool. -func (dbc *DBConn) Recycle() { - switch { - case dbc.pool == nil: - dbc.Close() - case dbc.conn.IsClosed(): - dbc.pool.Put(nil) - default: - dbc.pool.Put(dbc) - } -} - -// Taint unregister connection from original pool and taints the connection. -func (dbc *DBConn) Taint() { - if dbc.pool == nil { - return - } - dbc.pool.Put(nil) - dbc.pool = nil -} - // Kill kills the currently executing query both on MySQL side // and on the connection side. If no query is executing, it's a no-op. // Kill will also not kill a query more than once. -func (dbc *DBConn) Kill(reason string, elapsed time.Duration) error { +func (dbc *Conn) Kill(reason string, elapsed time.Duration) error { dbc.stats.KillCounters.Add("Queries", 1) log.Infof("Due to %s, elapsed time: %v, killing query ID %v %s", reason, elapsed, dbc.conn.ID(), dbc.CurrentForLogging()) @@ -434,7 +383,7 @@ func (dbc *DBConn) Kill(reason string, elapsed time.Duration) error { } defer killConn.Recycle() sql := fmt.Sprintf("kill %d", dbc.conn.ID()) - _, err = killConn.ExecuteFetch(sql, 10000, false) + _, err = killConn.Conn.ExecuteFetch(sql, 10000, false) if err != nil { log.Errorf("Could not kill query ID %v %s: %v", dbc.conn.ID(), dbc.CurrentForLogging(), err) @@ -444,34 +393,38 @@ func (dbc *DBConn) Kill(reason string, elapsed time.Duration) error { } // Current returns the currently executing query. -func (dbc *DBConn) Current() string { +func (dbc *Conn) Current() string { return dbc.current.Load().(string) } // ID returns the connection id. -func (dbc *DBConn) ID() int64 { +func (dbc *Conn) ID() int64 { return dbc.conn.ID() } // BaseShowTables returns a query that shows tables -func (dbc *DBConn) BaseShowTables() string { +func (dbc *Conn) BaseShowTables() string { return dbc.conn.BaseShowTables() } // BaseShowTablesWithSizes returns a query that shows tables and their sizes -func (dbc *DBConn) BaseShowTablesWithSizes() string { +func (dbc *Conn) BaseShowTablesWithSizes() string { return dbc.conn.BaseShowTablesWithSizes() } -func (dbc *DBConn) reconnect(ctx context.Context) error { - dbc.conn.Close() - // Reuse MySQLTimings from dbc.conn. - newConn, err := dbconnpool.NewDBConnection(ctx, dbc.info) +func (dbc *Conn) ConnCheck(ctx context.Context) error { + if err := dbc.conn.ConnCheck(); err != nil { + return dbc.Reconnect(ctx) + } + return nil +} + +func (dbc *Conn) Reconnect(ctx context.Context) error { + err := dbc.conn.Reconnect(ctx) if err != nil { return err } - dbc.conn = newConn - if dbc.IsSettingApplied() { + if dbc.setting != nil { err = dbc.applySameSetting(ctx) if err != nil { return err @@ -487,7 +440,7 @@ func (dbc *DBConn) reconnect(ctx context.Context) error { // if the deadline is exceeded. It returns a channel and a waitgroup. After the // query is done executing, the caller is required to close the done channel // and wait for the waitgroup to make sure that the necessary cleanup is done. -func (dbc *DBConn) setDeadline(ctx context.Context) (chan bool, *sync.WaitGroup) { +func (dbc *Conn) setDeadline(ctx context.Context) (chan bool, *sync.WaitGroup) { if ctx.Done() == nil { return nil, nil } @@ -525,9 +478,9 @@ func (dbc *DBConn) setDeadline(ctx context.Context) (chan bool, *sync.WaitGroup) // CurrentForLogging applies transformations to the query making it suitable to log. // It applies sanitization rules based on tablet settings and limits the max length of // queries. -func (dbc *DBConn) CurrentForLogging() string { +func (dbc *Conn) CurrentForLogging() string { var queryToLog string - if dbc.pool != nil && dbc.pool.env != nil && dbc.pool.env.Config() != nil && !dbc.pool.env.Config().SanitizeLogMessages { + if dbc.env != nil && dbc.env.Config() != nil && !dbc.env.Config().SanitizeLogMessages { queryToLog = dbc.Current() } else { queryToLog, _ = sqlparser.RedactSQLQuery(dbc.Current()) @@ -535,7 +488,7 @@ func (dbc *DBConn) CurrentForLogging() string { return sqlparser.TruncateForLog(queryToLog) } -func (dbc *DBConn) applySameSetting(ctx context.Context) (err error) { - _, err = dbc.execOnce(ctx, dbc.setting, 1, false) +func (dbc *Conn) applySameSetting(ctx context.Context) (err error) { + _, err = dbc.execOnce(ctx, dbc.setting.ApplyQuery(), 1, false) return } diff --git a/go/vt/vttablet/tabletserver/connpool/dbconn_test.go b/go/vt/vttablet/tabletserver/connpool/dbconn_test.go index 54792e17fa5..9717c95d9f7 100644 --- a/go/vt/vttablet/tabletserver/connpool/dbconn_test.go +++ b/go/vt/vttablet/tabletserver/connpool/dbconn_test.go @@ -28,9 +28,9 @@ import ( "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql/sqlerror" + "vitess.io/vitess/go/pools/smartconnpool" "vitess.io/vitess/go/mysql/fakesqldb" - "vitess.io/vitess/go/pools" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" ) @@ -66,7 +66,7 @@ func TestDBConnExec(t *testing.T) { defer connPool.Close() ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(10*time.Second)) defer cancel() - dbConn, err := NewDBConn(context.Background(), connPool, db.ConnParams()) + dbConn, err := newPooledConn(context.Background(), connPool, db.ConnParams()) if dbConn != nil { defer dbConn.Close() } @@ -139,7 +139,7 @@ func TestDBConnExecLost(t *testing.T) { defer connPool.Close() ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(10*time.Second)) defer cancel() - dbConn, err := NewDBConn(context.Background(), connPool, db.ConnParams()) + dbConn, err := newPooledConn(context.Background(), connPool, db.ConnParams()) if dbConn != nil { defer dbConn.Close() } @@ -200,7 +200,7 @@ func TestDBConnDeadline(t *testing.T) { ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(50*time.Millisecond)) defer cancel() - dbConn, err := NewDBConn(context.Background(), connPool, db.ConnParams()) + dbConn, err := newPooledConn(context.Background(), connPool, db.ConnParams()) if dbConn != nil { defer dbConn.Close() } @@ -253,7 +253,7 @@ func TestDBConnKill(t *testing.T) { connPool := newPool() connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams()) defer connPool.Close() - dbConn, err := NewDBConn(context.Background(), connPool, db.ConnParams()) + dbConn, err := newPooledConn(context.Background(), connPool, db.ConnParams()) if dbConn != nil { defer dbConn.Close() } @@ -277,7 +277,7 @@ func TestDBConnKill(t *testing.T) { t.Fatalf("kill should succeed, but got error: %v", err) } - err = dbConn.reconnect(context.Background()) + err = dbConn.Reconnect(context.Background()) if err != nil { t.Fatalf("reconnect should succeed, but got error: %v", err) } @@ -299,7 +299,7 @@ func TestDBConnClose(t *testing.T) { connPool := newPool() connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams()) defer connPool.Close() - dbConn, err := NewDBConn(context.Background(), connPool, db.ConnParams()) + dbConn, err := newPooledConn(context.Background(), connPool, db.ConnParams()) require.NoError(t, err) defer dbConn.Close() @@ -324,7 +324,7 @@ func TestDBNoPoolConnKill(t *testing.T) { connPool := newPool() connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams()) defer connPool.Close() - dbConn, err := NewDBConnNoPool(context.Background(), db.ConnParams(), connPool.dbaPool, nil) + dbConn, err := NewConn(context.Background(), db.ConnParams(), connPool.dbaPool, nil) if dbConn != nil { defer dbConn.Close() } @@ -348,7 +348,7 @@ func TestDBNoPoolConnKill(t *testing.T) { t.Fatalf("kill should succeed, but got error: %v", err) } - err = dbConn.reconnect(context.Background()) + err = dbConn.Reconnect(context.Background()) if err != nil { t.Fatalf("reconnect should succeed, but got error: %v", err) } @@ -380,7 +380,7 @@ func TestDBConnStream(t *testing.T) { defer connPool.Close() ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(10*time.Second)) defer cancel() - dbConn, err := NewDBConn(context.Background(), connPool, db.ConnParams()) + dbConn, err := newPooledConn(context.Background(), connPool, db.ConnParams()) if dbConn != nil { defer dbConn.Close() } @@ -438,7 +438,7 @@ func TestDBConnStreamKill(t *testing.T) { connPool := newPool() connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams()) defer connPool.Close() - dbConn, err := NewDBConn(context.Background(), connPool, db.ConnParams()) + dbConn, err := newPooledConn(context.Background(), connPool, db.ConnParams()) require.NoError(t, err) defer dbConn.Close() @@ -468,7 +468,7 @@ func TestDBConnReconnect(t *testing.T) { connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams()) defer connPool.Close() - dbConn, err := NewDBConn(context.Background(), connPool, db.ConnParams()) + dbConn, err := newPooledConn(context.Background(), connPool, db.ConnParams()) require.NoError(t, err) defer dbConn.Close() @@ -494,14 +494,14 @@ func TestDBConnReApplySetting(t *testing.T) { defer connPool.Close() ctx := context.Background() - dbConn, err := NewDBConn(ctx, connPool, db.ConnParams()) + dbConn, err := newPooledConn(ctx, connPool, db.ConnParams()) require.NoError(t, err) defer dbConn.Close() // apply system settings. setQ := "set @@sql_mode='ANSI_QUOTES'" db.AddExpectedQuery(setQ, nil) - err = dbConn.ApplySetting(ctx, pools.NewSetting(setQ, "set @@sql_mode = default")) + err = dbConn.ApplySetting(ctx, smartconnpool.NewSetting(setQ, "set @@sql_mode = default")) require.NoError(t, err) // close the connection and let the dbconn reconnect to start a new connection when required. diff --git a/go/vt/vttablet/tabletserver/connpool/pool.go b/go/vt/vttablet/tabletserver/connpool/pool.go index d2f8efb7af0..6f8b72870e0 100644 --- a/go/vt/vttablet/tabletserver/connpool/pool.go +++ b/go/vt/vttablet/tabletserver/connpool/pool.go @@ -18,20 +18,17 @@ package connpool import ( "context" - "fmt" + "encoding/json" "net" "strings" - "sync" - "sync/atomic" "time" "vitess.io/vitess/go/netutil" - "vitess.io/vitess/go/pools" + "vitess.io/vitess/go/pools/smartconnpool" "vitess.io/vitess/go/trace" "vitess.io/vitess/go/vt/callerid" "vitess.io/vitess/go/vt/dbconfigs" "vitess.io/vitess/go/vt/dbconnpool" - "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/mysqlctl" "vitess.io/vitess/go/vt/servenv" "vitess.io/vitess/go/vt/vterrors" @@ -48,6 +45,8 @@ const ( getWithS = "GetWithSettings" ) +type PooledConn = smartconnpool.Pooled[*Conn] + // Pool implements a custom connection pool for tabletserver. // It's similar to dbconnpool.ConnPool, but the connections it creates // come with built-in ability to kill in-flight queries. These connections @@ -55,148 +54,88 @@ const ( // Other than the connection type, ConnPool maintains an additional // pool of dba connections that are used to kill connections. type Pool struct { - env tabletenv.Env - name string - mu sync.Mutex - connections pools.IResourcePool - capacity int - timeout time.Duration - idleTimeout time.Duration - maxLifetime time.Duration - waiterCap int64 - waiterCount atomic.Int64 - waiterQueueFull atomic.Int64 - dbaPool *dbconnpool.ConnectionPool - appDebugParams dbconfigs.Connector - getConnTime *servenv.TimingsWrapper + *smartconnpool.ConnPool[*Conn] + dbaPool *dbconnpool.ConnectionPool + + timeout time.Duration + env tabletenv.Env + + appDebugParams dbconfigs.Connector + getConnTime *servenv.TimingsWrapper } // NewPool creates a new Pool. The name is used // to publish stats only. func NewPool(env tabletenv.Env, name string, cfg tabletenv.ConnPoolConfig) *Pool { - idleTimeout := cfg.IdleTimeoutSeconds.Get() - maxLifetime := cfg.MaxLifetimeSeconds.Get() cp := &Pool{ - env: env, - name: name, - capacity: cfg.Size, - timeout: cfg.TimeoutSeconds.Get(), - idleTimeout: idleTimeout, - maxLifetime: maxLifetime, - waiterCap: int64(cfg.MaxWaiters), - dbaPool: dbconnpool.NewConnectionPool("", 1, idleTimeout, maxLifetime, 0), + timeout: cfg.TimeoutSeconds.Get(), + env: env, } - if name == "" { - return cp + + config := smartconnpool.Config[*Conn]{ + Capacity: int64(cfg.Size), + IdleTimeout: cfg.IdleTimeoutSeconds.Get(), + MaxLifetime: cfg.MaxLifetimeSeconds.Get(), + RefreshInterval: mysqlctl.PoolDynamicHostnameResolution, } - env.Exporter().NewGaugeFunc(name+"Capacity", "Tablet server conn pool capacity", cp.Capacity) - env.Exporter().NewGaugeFunc(name+"Available", "Tablet server conn pool available", cp.Available) - env.Exporter().NewGaugeFunc(name+"Active", "Tablet server conn pool active", cp.Active) - env.Exporter().NewGaugeFunc(name+"InUse", "Tablet server conn pool in use", cp.InUse) - env.Exporter().NewGaugeFunc(name+"MaxCap", "Tablet server conn pool max cap", cp.MaxCap) - env.Exporter().NewCounterFunc(name+"WaitCount", "Tablet server conn pool wait count", cp.WaitCount) - env.Exporter().NewCounterDurationFunc(name+"WaitTime", "Tablet server wait time", cp.WaitTime) - env.Exporter().NewGaugeDurationFunc(name+"IdleTimeout", "Tablet server idle timeout", cp.IdleTimeout) - env.Exporter().NewCounterFunc(name+"IdleClosed", "Tablet server conn pool idle closed", cp.IdleClosed) - env.Exporter().NewCounterFunc(name+"MaxLifetimeClosed", "Tablet server conn pool refresh closed", cp.MaxLifetimeClosed) - env.Exporter().NewCounterFunc(name+"Exhausted", "Number of times pool had zero available slots", cp.Exhausted) - env.Exporter().NewCounterFunc(name+"WaiterQueueFull", "Number of times the waiter queue was full", cp.waiterQueueFull.Load) - env.Exporter().NewCounterFunc(name+"Get", "Tablet server conn pool get count", cp.GetCount) - env.Exporter().NewCounterFunc(name+"GetSetting", "Tablet server conn pool get with setting count", cp.GetSettingCount) - env.Exporter().NewCounterFunc(name+"DiffSetting", "Number of times pool applied different setting", cp.DiffSettingCount) - env.Exporter().NewCounterFunc(name+"ResetSetting", "Number of times pool reset the setting", cp.ResetSettingCount) - cp.getConnTime = env.Exporter().NewTimings(name+"GetConnTime", "Tracks the amount of time it takes to get a connection", "Settings") - return cp -} + if name != "" { + config.LogWait = func(start time.Time) { + env.Stats().WaitTimings.Record(name+"ResourceWaitTime", start) + } + + cp.getConnTime = env.Exporter().NewTimings(name+"GetConnTime", "Tracks the amount of time it takes to get a connection", "Settings") + } + + cp.ConnPool = smartconnpool.NewPool(&config) + cp.ConnPool.RegisterStats(env.Exporter(), name) + + cp.dbaPool = dbconnpool.NewConnectionPool("", env.Exporter(), 1, config.IdleTimeout, config.MaxLifetime, 0) -func (cp *Pool) pool() (p pools.IResourcePool) { - cp.mu.Lock() - p = cp.connections - cp.mu.Unlock() - return p + return cp } // Open must be called before starting to use the pool. func (cp *Pool) Open(appParams, dbaParams, appDebugParams dbconfigs.Connector) { - cp.mu.Lock() - defer cp.mu.Unlock() - - f := func(ctx context.Context) (pools.Resource, error) { - return NewDBConn(ctx, cp, appParams) - } + cp.appDebugParams = appDebugParams - var refreshCheck pools.RefreshCheck + var refresh smartconnpool.RefreshCheck if net.ParseIP(appParams.Host()) == nil { - refreshCheck = netutil.DNSTracker(appParams.Host()) + refresh = netutil.DNSTracker(appParams.Host()) } - cp.connections = pools.NewResourcePool(f, cp.capacity, cp.capacity, cp.idleTimeout, cp.maxLifetime, cp.getLogWaitCallback(), refreshCheck, mysqlctl.PoolDynamicHostnameResolution) - cp.appDebugParams = appDebugParams + connect := func(ctx context.Context) (*Conn, error) { + return newPooledConn(ctx, cp, appParams) + } + cp.ConnPool.Open(connect, refresh) cp.dbaPool.Open(dbaParams) } -func (cp *Pool) getLogWaitCallback() func(time.Time) { - if cp.name == "" { - return func(start time.Time) {} // no op - } - return func(start time.Time) { - cp.env.Stats().WaitTimings.Record(cp.name+"ResourceWaitTime", start) - } -} - // Close will close the pool and wait for connections to be returned before // exiting. func (cp *Pool) Close() { - log.Infof("connpool - started execution of Close") - p := cp.pool() - log.Infof("connpool - found the pool") - if p == nil { - log.Infof("connpool - pool is empty") - return - } - // We should not hold the lock while calling Close - // because it waits for connections to be returned. - log.Infof("connpool - calling close on the pool") - p.Close() - log.Infof("connpool - acquiring lock") - cp.mu.Lock() - log.Infof("connpool - acquired lock") - cp.connections.Close() - cp.connections = nil - cp.mu.Unlock() - log.Infof("connpool - closing dbaPool") + cp.ConnPool.Close() cp.dbaPool.Close() - log.Infof("connpool - finished execution of Close") } // Get returns a connection. // You must call Recycle on DBConn once done. -func (cp *Pool) Get(ctx context.Context, setting *pools.Setting) (*DBConn, error) { +func (cp *Pool) Get(ctx context.Context, setting *smartconnpool.Setting) (*PooledConn, error) { span, ctx := trace.NewSpan(ctx, "Pool.Get") defer span.Finish() - if cp.waiterCap > 0 { - waiterCount := cp.waiterCount.Add(1) - defer cp.waiterCount.Add(-1) - if waiterCount > cp.waiterCap { - cp.waiterQueueFull.Add(1) - return nil, vterrors.Errorf(vtrpcpb.Code_RESOURCE_EXHAUSTED, "pool %s waiter count exceeded", cp.name) - } - } - if cp.isCallerIDAppDebug(ctx) { - return NewDBConnNoPool(ctx, cp.appDebugParams, cp.dbaPool, setting) - } - p := cp.pool() - if p == nil { - return nil, ErrConnPoolClosed + conn, err := NewConn(ctx, cp.appDebugParams, cp.dbaPool, setting) + if err != nil { + return nil, err + } + return &smartconnpool.Pooled[*Conn]{Conn: conn}, nil } - span.Annotate("capacity", p.Capacity()) - span.Annotate("in_use", p.InUse()) - span.Annotate("available", p.Available()) - span.Annotate("active", p.Active()) + span.Annotate("capacity", cp.Capacity()) + span.Annotate("in_use", cp.InUse()) + span.Annotate("available", cp.Available()) + span.Annotate("active", cp.Active()) if cp.timeout != 0 { var cancel context.CancelFunc @@ -205,7 +144,7 @@ func (cp *Pool) Get(ctx context.Context, setting *pools.Setting) (*DBConn, error } start := time.Now() - r, err := p.Get(ctx, setting) + conn, err := cp.ConnPool.Get(ctx, setting) if err != nil { return nil, err } @@ -216,194 +155,25 @@ func (cp *Pool) Get(ctx context.Context, setting *pools.Setting) (*DBConn, error cp.getConnTime.Record(getWithS, start) } } - return r.(*DBConn), nil -} - -// Put puts a connection into the pool. -func (cp *Pool) Put(conn *DBConn) { - p := cp.pool() - if p == nil { - panic(ErrConnPoolClosed) - } - if conn == nil { - p.Put(nil) - } else { - p.Put(conn) - } -} - -// SetCapacity alters the size of the pool at runtime. -func (cp *Pool) SetCapacity(capacity int) (err error) { - cp.mu.Lock() - defer cp.mu.Unlock() - if cp.connections != nil { - err = cp.connections.SetCapacity(capacity) - if err != nil { - return err - } - } - cp.capacity = capacity - return nil + return conn, nil } // SetIdleTimeout sets the idleTimeout on the pool. func (cp *Pool) SetIdleTimeout(idleTimeout time.Duration) { - cp.mu.Lock() - defer cp.mu.Unlock() - if cp.connections != nil { - cp.connections.SetIdleTimeout(idleTimeout) - } + cp.ConnPool.SetIdleTimeout(idleTimeout) cp.dbaPool.SetIdleTimeout(idleTimeout) - cp.idleTimeout = idleTimeout } // StatsJSON returns the pool stats as a JSON object. func (cp *Pool) StatsJSON() string { - p := cp.pool() - if p == nil { + if !cp.ConnPool.IsOpen() { return "{}" } - res := p.StatsJSON() - closingBraceIndex := strings.LastIndex(res, "}") - if closingBraceIndex == -1 { // unexpected... - return res - } - return fmt.Sprintf(`%s, "WaiterQueueFull": %v}`, res[:closingBraceIndex], cp.waiterQueueFull.Load()) -} -// Capacity returns the pool capacity. -func (cp *Pool) Capacity() int64 { - p := cp.pool() - if p == nil { - return 0 - } - return p.Capacity() -} - -// Available returns the number of available connections in the pool -func (cp *Pool) Available() int64 { - p := cp.pool() - if p == nil { - return 0 - } - return p.Available() -} - -// Active returns the number of active connections in the pool -func (cp *Pool) Active() int64 { - p := cp.pool() - if p == nil { - return 0 - } - return p.Active() -} - -// InUse returns the number of in-use connections in the pool -func (cp *Pool) InUse() int64 { - p := cp.pool() - if p == nil { - return 0 - } - return p.InUse() -} - -// MaxCap returns the maximum size of the pool -func (cp *Pool) MaxCap() int64 { - p := cp.pool() - if p == nil { - return 0 - } - return p.MaxCap() -} - -// WaitCount returns how many clients are waiting for a connection -func (cp *Pool) WaitCount() int64 { - p := cp.pool() - if p == nil { - return 0 - } - return p.WaitCount() -} - -// WaitTime return the pool WaitTime. -func (cp *Pool) WaitTime() time.Duration { - p := cp.pool() - if p == nil { - return 0 - } - return p.WaitTime() -} - -// IdleTimeout returns the idle timeout for the pool. -func (cp *Pool) IdleTimeout() time.Duration { - p := cp.pool() - if p == nil { - return 0 - } - return p.IdleTimeout() -} - -// IdleClosed returns the number of closed connections for the pool. -func (cp *Pool) IdleClosed() int64 { - p := cp.pool() - if p == nil { - return 0 - } - return p.IdleClosed() -} - -// MaxLifetimeClosed returns the number of connections closed to refresh timeout for the pool. -func (cp *Pool) MaxLifetimeClosed() int64 { - p := cp.pool() - if p == nil { - return 0 - } - return p.MaxLifetimeClosed() -} - -// Exhausted returns the number of times available went to zero for the pool. -func (cp *Pool) Exhausted() int64 { - p := cp.pool() - if p == nil { - return 0 - } - return p.Exhausted() -} - -// GetCount returns the number of times get was called -func (cp *Pool) GetCount() int64 { - p := cp.pool() - if p == nil { - return 0 - } - return p.GetCount() -} - -// GetSettingCount returns the number of times getWithSettings was called -func (cp *Pool) GetSettingCount() int64 { - p := cp.pool() - if p == nil { - return 0 - } - return p.GetSettingCount() -} - -// DiffSettingCount returns the number of times different settings were applied on the resource. -func (cp *Pool) DiffSettingCount() int64 { - p := cp.pool() - if p == nil { - return 0 - } - return p.DiffSettingCount() -} - -// ResetSettingCount returns the number of times settings were reset on the resource. -func (cp *Pool) ResetSettingCount() int64 { - p := cp.pool() - if p == nil { - return 0 - } - return p.ResetSettingCount() + var buf strings.Builder + enc := json.NewEncoder(&buf) + _ = enc.Encode(cp.ConnPool.StatsJSON()) + return buf.String() } func (cp *Pool) isCallerIDAppDebug(ctx context.Context) bool { diff --git a/go/vt/vttablet/tabletserver/connpool/pool_test.go b/go/vt/vttablet/tabletserver/connpool/pool_test.go index 43c27fa817a..ecdd2df4465 100644 --- a/go/vt/vttablet/tabletserver/connpool/pool_test.go +++ b/go/vt/vttablet/tabletserver/connpool/pool_test.go @@ -18,8 +18,6 @@ package connpool import ( "context" - "runtime" - "sync" "testing" "time" @@ -27,7 +25,7 @@ import ( "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql/fakesqldb" - "vitess.io/vitess/go/pools" + "vitess.io/vitess/go/pools/smartconnpool" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/callerid" "vitess.io/vitess/go/vt/vttablet/tabletserver/tabletenv" @@ -46,10 +44,6 @@ func TestConnPoolGet(t *testing.T) { if dbConn == nil { t.Fatalf("db conn should not be nil") } - // There is no context, it should not use appdebug connection - if dbConn.pool == nil { - t.Fatalf("db conn pool should not be nil") - } dbConn.Recycle() } @@ -72,47 +66,6 @@ func TestConnPoolTimeout(t *testing.T) { assert.EqualError(t, err, "resource pool timed out") } -func TestConnPoolMaxWaiters(t *testing.T) { - db := fakesqldb.New(t) - defer db.Close() - connPool := NewPool(tabletenv.NewEnv(nil, "PoolTest"), "TestPool", tabletenv.ConnPoolConfig{ - Size: 1, - MaxWaiters: 1, - }) - connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams()) - defer connPool.Close() - dbConn, err := connPool.Get(context.Background(), nil) - require.NoError(t, err) - - // waiter 1 - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - c1, err := connPool.Get(context.Background(), nil) - if err != nil { - t.Errorf("unexpected error: %v", err) - return - } - c1.Recycle() - }() - // Wait for the first waiter to increment count. - for { - runtime.Gosched() - if connPool.waiterCount.Load() == 1 { - break - } - } - - // waiter 2 - _, err = connPool.Get(context.Background(), nil) - assert.EqualError(t, err, "pool TestPool waiter count exceeded") - - // This recycle will make waiter1 succeed. - dbConn.Recycle() - wg.Wait() -} - func TestConnPoolGetEmptyDebugConfig(t *testing.T) { db := fakesqldb.New(t) debugConn := db.ConnParamsWithUname("") @@ -131,10 +84,6 @@ func TestConnPoolGetEmptyDebugConfig(t *testing.T) { if dbConn == nil { t.Fatalf("db conn should not be nil") } - // Context is empty, it should not use appdebug connection - if dbConn.pool == nil { - t.Fatalf("db conn pool should not be nil") - } dbConn.Recycle() } @@ -156,39 +105,23 @@ func TestConnPoolGetAppDebug(t *testing.T) { if dbConn == nil { t.Fatalf("db conn should not be nil") } - if dbConn.pool != nil { - t.Fatalf("db conn pool should be nil for appDebug") - } dbConn.Recycle() - if !dbConn.IsClosed() { + if !dbConn.Conn.IsClosed() { t.Fatalf("db conn should be closed after recycle") } } -func TestConnPoolPutWhilePoolIsClosed(t *testing.T) { - connPool := newPool() - defer func() { - if recover() == nil { - t.Fatalf("pool is closed, should get an error") - } - }() - connPool.Put(nil) -} - func TestConnPoolSetCapacity(t *testing.T) { db := fakesqldb.New(t) defer db.Close() connPool := newPool() connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams()) defer connPool.Close() - err := connPool.SetCapacity(-10) - if err == nil { - t.Fatalf("set capacity should return error for negative capacity") - } - err = connPool.SetCapacity(10) - if err != nil { - t.Fatalf("set capacity should succeed") - } + + assert.Panics(t, func() { + connPool.SetCapacity(-10) + }) + connPool.SetCapacity(10) if connPool.Capacity() != 10 { t.Fatalf("capacity should be 10") } @@ -199,7 +132,7 @@ func TestConnPoolStatJSON(t *testing.T) { defer db.Close() connPool := newPool() if connPool.StatsJSON() != "{}" { - t.Fatalf("pool is closed, stats json should be empty: {}") + t.Fatalf("pool is closed, stats json should be empty; was: %q", connPool.StatsJSON()) } connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams()) defer connPool.Close() @@ -213,10 +146,6 @@ func TestConnPoolStateWhilePoolIsClosed(t *testing.T) { connPool := newPool() assert.EqualValues(t, 0, connPool.Capacity(), "pool capacity should be 0 because it is still closed") assert.EqualValues(t, 0, connPool.Available(), "pool available connections should be 0 because it is still closed") - assert.EqualValues(t, 0, connPool.MaxCap(), "pool max capacity should be 0 because it is still closed") - assert.EqualValues(t, 0, connPool.WaitCount(), "pool wait count should be 0 because it is still closed") - assert.EqualValues(t, 0, connPool.WaitTime(), "pool wait time should be 0 because it is still closed") - assert.EqualValues(t, 0, connPool.IdleTimeout(), "pool idle timeout should be 0 because it is still closed") } func TestConnPoolStateWhilePoolIsOpen(t *testing.T) { @@ -227,9 +156,8 @@ func TestConnPoolStateWhilePoolIsOpen(t *testing.T) { connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams()) defer connPool.Close() assert.EqualValues(t, 100, connPool.Capacity(), "pool capacity should be 100") - assert.EqualValues(t, 100, connPool.MaxCap(), "pool max capacity should be 100") - assert.EqualValues(t, 0, connPool.WaitTime(), "pool wait time should be 0") - assert.EqualValues(t, 0, connPool.WaitCount(), "pool wait count should be 0") + assert.EqualValues(t, 0, connPool.Metrics.WaitTime(), "pool wait time should be 0") + assert.EqualValues(t, 0, connPool.Metrics.WaitCount(), "pool wait count should be 0") assert.EqualValues(t, idleTimeout, connPool.IdleTimeout(), "pool idle timeout should be 0") assert.EqualValues(t, 100, connPool.Available(), "pool available connections should be 100") assert.EqualValues(t, 0, connPool.Active(), "pool active connections should be 0") @@ -256,50 +184,50 @@ func TestConnPoolStateWithSettings(t *testing.T) { assert.EqualValues(t, 5, connPool.Available(), "pool available connections should be 5") assert.EqualValues(t, 0, connPool.Active(), "pool active connections should be 0") assert.EqualValues(t, 0, connPool.InUse(), "pool inUse connections should be 0") - assert.EqualValues(t, 0, connPool.GetCount(), "pool get count should be 0") - assert.EqualValues(t, 0, connPool.GetSettingCount(), "pool get with settings should be 0") - assert.EqualValues(t, 0, connPool.DiffSettingCount(), "pool different settings count should be 0") - assert.EqualValues(t, 0, connPool.ResetSettingCount(), "pool reset settings count should be 0") + assert.EqualValues(t, 0, connPool.Metrics.GetCount(), "pool get count should be 0") + assert.EqualValues(t, 0, connPool.Metrics.GetSettingCount(), "pool get with settings should be 0") + assert.EqualValues(t, 0, connPool.Metrics.DiffSettingCount(), "pool different settings count should be 0") + assert.EqualValues(t, 0, connPool.Metrics.ResetSettingCount(), "pool reset settings count should be 0") dbConn, err := connPool.Get(context.Background(), nil) require.NoError(t, err) assert.EqualValues(t, 4, connPool.Available(), "pool available connections should be 4") assert.EqualValues(t, 1, connPool.Active(), "pool active connections should be 1") assert.EqualValues(t, 1, connPool.InUse(), "pool inUse connections should be 1") - assert.EqualValues(t, 1, connPool.GetCount(), "pool get count should be 1") - assert.EqualValues(t, 0, connPool.GetSettingCount(), "pool get with settings should be 0") - assert.EqualValues(t, 0, connPool.DiffSettingCount(), "pool different settings count should be 0") - assert.EqualValues(t, 0, connPool.ResetSettingCount(), "pool reset settings count should be 0") + assert.EqualValues(t, 1, connPool.Metrics.GetCount(), "pool get count should be 1") + assert.EqualValues(t, 0, connPool.Metrics.GetSettingCount(), "pool get with settings should be 0") + assert.EqualValues(t, 0, connPool.Metrics.DiffSettingCount(), "pool different settings count should be 0") + assert.EqualValues(t, 0, connPool.Metrics.ResetSettingCount(), "pool reset settings count should be 0") dbConn.Recycle() assert.EqualValues(t, 5, connPool.Available(), "pool available connections should be 5") assert.EqualValues(t, 1, connPool.Active(), "pool active connections should be 1") assert.EqualValues(t, 0, connPool.InUse(), "pool inUse connections should be 0") - assert.EqualValues(t, 1, connPool.GetCount(), "pool get count should be 0") - assert.EqualValues(t, 0, connPool.GetSettingCount(), "pool get with settings should be 0") - assert.EqualValues(t, 0, connPool.DiffSettingCount(), "pool different settings count should be 0") - assert.EqualValues(t, 0, connPool.ResetSettingCount(), "pool reset settings count should be 0") + assert.EqualValues(t, 1, connPool.Metrics.GetCount(), "pool get count should be 0") + assert.EqualValues(t, 0, connPool.Metrics.GetSettingCount(), "pool get with settings should be 0") + assert.EqualValues(t, 0, connPool.Metrics.DiffSettingCount(), "pool different settings count should be 0") + assert.EqualValues(t, 0, connPool.Metrics.ResetSettingCount(), "pool reset settings count should be 0") db.AddQuery("a", &sqltypes.Result{}) - sa := pools.NewSetting("a", "") + sa := smartconnpool.NewSetting("a", "") dbConn, err = connPool.Get(context.Background(), sa) require.NoError(t, err) assert.EqualValues(t, 4, connPool.Available(), "pool available connections should be 4") - assert.EqualValues(t, 2, connPool.Active(), "pool active connections should be 2") + assert.EqualValues(t, 1, connPool.Active(), "pool active connections should be 1") assert.EqualValues(t, 1, connPool.InUse(), "pool inUse connections should be 1") - assert.EqualValues(t, 1, connPool.GetCount(), "pool get count should be 1") - assert.EqualValues(t, 1, connPool.GetSettingCount(), "pool get with settings should be 1") - assert.EqualValues(t, 0, connPool.DiffSettingCount(), "pool different settings count should be 0") - assert.EqualValues(t, 0, connPool.ResetSettingCount(), "pool reset settings count should be 0") + assert.EqualValues(t, 1, connPool.Metrics.GetCount(), "pool get count should be 1") + assert.EqualValues(t, 1, connPool.Metrics.GetSettingCount(), "pool get with settings should be 1") + assert.EqualValues(t, 0, connPool.Metrics.DiffSettingCount(), "pool different settings count should be 0") + assert.EqualValues(t, 0, connPool.Metrics.ResetSettingCount(), "pool reset settings count should be 0") dbConn.Recycle() assert.EqualValues(t, 5, connPool.Available(), "pool available connections should be 5") - assert.EqualValues(t, 2, connPool.Active(), "pool active connections should be 2") + assert.EqualValues(t, 1, connPool.Active(), "pool active connections should be 1") assert.EqualValues(t, 0, connPool.InUse(), "pool inUse connections should be 0") - assert.EqualValues(t, 1, connPool.GetCount(), "pool get count should be 1") - assert.EqualValues(t, 1, connPool.GetSettingCount(), "pool get with settings should be 1") - assert.EqualValues(t, 0, connPool.DiffSettingCount(), "pool different settings count should be 0") - assert.EqualValues(t, 0, connPool.ResetSettingCount(), "pool reset settings count should be 0") + assert.EqualValues(t, 1, connPool.Metrics.GetCount(), "pool get count should be 1") + assert.EqualValues(t, 1, connPool.Metrics.GetSettingCount(), "pool get with settings should be 1") + assert.EqualValues(t, 0, connPool.Metrics.DiffSettingCount(), "pool different settings count should be 0") + assert.EqualValues(t, 0, connPool.Metrics.ResetSettingCount(), "pool reset settings count should be 0") // now showcasing diff and reset setting. // Steps 1: acquire all connection with same setting @@ -308,7 +236,7 @@ func TestConnPoolStateWithSettings(t *testing.T) { // Steps 4: acquire a connection with different setting - this will show diff setting count // Step 1 - var conns []*DBConn + var conns []*PooledConn for i := 0; i < capacity; i++ { dbConn, err = connPool.Get(context.Background(), sa) require.NoError(t, err) @@ -317,10 +245,10 @@ func TestConnPoolStateWithSettings(t *testing.T) { assert.EqualValues(t, 0, connPool.Available(), "pool available connections should be 0") assert.EqualValues(t, 5, connPool.Active(), "pool active connections should be 5") assert.EqualValues(t, 5, connPool.InUse(), "pool inUse connections should be 5") - assert.EqualValues(t, 1, connPool.GetCount(), "pool get count should be 1") - assert.EqualValues(t, 6, connPool.GetSettingCount(), "pool get with settings should be 6") - assert.EqualValues(t, 0, connPool.DiffSettingCount(), "pool different settings count should be 0") - assert.EqualValues(t, 0, connPool.ResetSettingCount(), "pool reset settings count should be 0") + assert.EqualValues(t, 1, connPool.Metrics.GetCount(), "pool get count should be 1") + assert.EqualValues(t, 6, connPool.Metrics.GetSettingCount(), "pool get with settings should be 6") + assert.EqualValues(t, 0, connPool.Metrics.DiffSettingCount(), "pool different settings count should be 0") + assert.EqualValues(t, 0, connPool.Metrics.ResetSettingCount(), "pool reset settings count should be 0") // Step 2 for _, conn := range conns { @@ -329,10 +257,10 @@ func TestConnPoolStateWithSettings(t *testing.T) { assert.EqualValues(t, 5, connPool.Available(), "pool available connections should be 5") assert.EqualValues(t, 5, connPool.Active(), "pool active connections should be 5") assert.EqualValues(t, 0, connPool.InUse(), "pool inUse connections should be 0") - assert.EqualValues(t, 1, connPool.GetCount(), "pool get count should be 1") - assert.EqualValues(t, 6, connPool.GetSettingCount(), "pool get with settings should be 6") - assert.EqualValues(t, 0, connPool.DiffSettingCount(), "pool different settings count should be 0") - assert.EqualValues(t, 0, connPool.ResetSettingCount(), "pool reset settings count should be 0") + assert.EqualValues(t, 1, connPool.Metrics.GetCount(), "pool get count should be 1") + assert.EqualValues(t, 6, connPool.Metrics.GetSettingCount(), "pool get with settings should be 6") + assert.EqualValues(t, 0, connPool.Metrics.DiffSettingCount(), "pool different settings count should be 0") + assert.EqualValues(t, 0, connPool.Metrics.ResetSettingCount(), "pool reset settings count should be 0") // Step 3 dbConn, err = connPool.Get(context.Background(), nil) @@ -340,24 +268,24 @@ func TestConnPoolStateWithSettings(t *testing.T) { assert.EqualValues(t, 4, connPool.Available(), "pool available connections should be 4") assert.EqualValues(t, 5, connPool.Active(), "pool active connections should be 5") assert.EqualValues(t, 1, connPool.InUse(), "pool inUse connections should be 1") - assert.EqualValues(t, 2, connPool.GetCount(), "pool get count should be 2") - assert.EqualValues(t, 6, connPool.GetSettingCount(), "pool get with settings should be 6") - assert.EqualValues(t, 0, connPool.DiffSettingCount(), "pool different settings count should be 0") - assert.EqualValues(t, 1, connPool.ResetSettingCount(), "pool reset settings count should be 1") + assert.EqualValues(t, 2, connPool.Metrics.GetCount(), "pool get count should be 2") + assert.EqualValues(t, 6, connPool.Metrics.GetSettingCount(), "pool get with settings should be 6") + assert.EqualValues(t, 0, connPool.Metrics.DiffSettingCount(), "pool different settings count should be 0") + assert.EqualValues(t, 1, connPool.Metrics.ResetSettingCount(), "pool reset settings count should be 1") dbConn.Recycle() // Step 4 db.AddQuery("b", &sqltypes.Result{}) - sb := pools.NewSetting("b", "") + sb := smartconnpool.NewSetting("b", "") dbConn, err = connPool.Get(context.Background(), sb) require.NoError(t, err) assert.EqualValues(t, 4, connPool.Available(), "pool available connections should be 4") assert.EqualValues(t, 5, connPool.Active(), "pool active connections should be 5") assert.EqualValues(t, 1, connPool.InUse(), "pool inUse connections should be 1") - assert.EqualValues(t, 2, connPool.GetCount(), "pool get count should be 2") - assert.EqualValues(t, 7, connPool.GetSettingCount(), "pool get with settings should be 7") - assert.EqualValues(t, 1, connPool.DiffSettingCount(), "pool different settings count should be 1") - assert.EqualValues(t, 1, connPool.ResetSettingCount(), "pool reset settings count should be 1") + assert.EqualValues(t, 2, connPool.Metrics.GetCount(), "pool get count should be 2") + assert.EqualValues(t, 7, connPool.Metrics.GetSettingCount(), "pool get with settings should be 7") + assert.EqualValues(t, 0, connPool.Metrics.DiffSettingCount(), "pool different settings count should be 0") + assert.EqualValues(t, 1, connPool.Metrics.ResetSettingCount(), "pool reset settings count should be 1") dbConn.Recycle() } @@ -383,7 +311,7 @@ func TestPoolGetConnTime(t *testing.T) { assert.Zero(t, getTimeMap["PoolTest.GetWithSettings"]) db.AddQuery("b", &sqltypes.Result{}) - sb := pools.NewSetting("b", "") + sb := smartconnpool.NewSetting("b", "") dbConn, err = connPool.Get(context.Background(), sb) require.NoError(t, err) defer dbConn.Recycle() diff --git a/go/vt/vttablet/tabletserver/gc/tablegc.go b/go/vt/vttablet/tabletserver/gc/tablegc.go index cf3595c973c..4947fd9c97a 100644 --- a/go/vt/vttablet/tabletserver/gc/tablegc.go +++ b/go/vt/vttablet/tabletserver/gc/tablegc.go @@ -388,7 +388,7 @@ func (collector *TableGC) readTables(ctx context.Context) (gcTables []*gcTable, } defer conn.Recycle() - res, err := conn.Exec(ctx, sqlShowVtTables, math.MaxInt32, true) + res, err := conn.Conn.Exec(ctx, sqlShowVtTables, math.MaxInt32, true) if err != nil { return nil, err } @@ -556,7 +556,7 @@ func (collector *TableGC) dropTable(ctx context.Context, tableName string, isBas parsed := sqlparser.BuildParsedQuery(sqlDrop, tableName) log.Infof("TableGC: dropping table: %s", tableName) - _, err = conn.ExecuteFetch(parsed.Query, 1, false) + _, err = conn.Conn.ExecuteFetch(parsed.Query, 1, false) if err != nil { return err } @@ -593,7 +593,7 @@ func (collector *TableGC) transitionTable(ctx context.Context, transition *trans } log.Infof("TableGC: renaming table: %s to %s", transition.fromTableName, toTableName) - _, err = conn.Exec(ctx, renameStatement, 1, true) + _, err = conn.Conn.Exec(ctx, renameStatement, 1, true) if err != nil { return err } diff --git a/go/vt/vttablet/tabletserver/health_streamer.go b/go/vt/vttablet/tabletserver/health_streamer.go index 3ecdc180600..87c70a7133d 100644 --- a/go/vt/vttablet/tabletserver/health_streamer.go +++ b/go/vt/vttablet/tabletserver/health_streamer.go @@ -374,7 +374,7 @@ func (hs *healthStreamer) reload(full map[string]*schema.Table, created, altered // Reload the tables and views. // This stores the data that is used by VTGates upto v17. So, we can remove this reload of // tables and views in v19. - err = hs.reloadTables(ctx, conn, tables) + err = hs.reloadTables(ctx, conn.Conn, tables) if err != nil { return err } @@ -394,7 +394,7 @@ func (hs *healthStreamer) reload(full map[string]*schema.Table, created, altered return nil } -func (hs *healthStreamer) reloadTables(ctx context.Context, conn *connpool.DBConn, tableNames []string) error { +func (hs *healthStreamer) reloadTables(ctx context.Context, conn *connpool.Conn, tableNames []string) error { if len(tableNames) == 0 { return nil } diff --git a/go/vt/vttablet/tabletserver/query_engine.go b/go/vt/vttablet/tabletserver/query_engine.go index 1881dd4091a..7f83a29fc51 100644 --- a/go/vt/vttablet/tabletserver/query_engine.go +++ b/go/vt/vttablet/tabletserver/query_engine.go @@ -23,6 +23,7 @@ import ( "errors" "fmt" "net/http" + "strings" "sync" "sync/atomic" "time" @@ -30,7 +31,7 @@ import ( "vitess.io/vitess/go/acl" "vitess.io/vitess/go/cache/theine" "vitess.io/vitess/go/mysql/sqlerror" - "vitess.io/vitess/go/pools" + "vitess.io/vitess/go/pools/smartconnpool" "vitess.io/vitess/go/stats" "vitess.io/vitess/go/streamlog" "vitess.io/vitess/go/sync2" @@ -44,7 +45,6 @@ import ( "vitess.io/vitess/go/vt/tableacl" tacl "vitess.io/vitess/go/vt/tableacl/acl" "vitess.io/vitess/go/vt/vterrors" - "vitess.io/vitess/go/vt/vthash" "vitess.io/vitess/go/vt/vttablet/tabletserver/connpool" "vitess.io/vitess/go/vt/vttablet/tabletserver/planbuilder" "vitess.io/vitess/go/vt/vttablet/tabletserver/rules" @@ -126,8 +126,8 @@ func isValid(planType planbuilder.PlanType, hasReservedCon bool, hasSysSettings type PlanCacheKey = theine.StringKey type PlanCache = theine.Store[PlanCacheKey, *TabletPlan] -type SettingsCacheKey = theine.HashKey256 -type SettingsCache = theine.Store[SettingsCacheKey, *pools.Setting] +type SettingsCacheKey = theine.StringKey +type SettingsCache = theine.Store[SettingsCacheKey, *smartconnpool.Setting] type currentSchema struct { tables map[string]*schema.Table @@ -217,7 +217,7 @@ func NewQueryEngine(env tabletenv.Env, se *schema.Engine) *QueryEngine { // not use a doorkeeper because custom connection settings are rarely one-off and we always // want to cache them var settingsCacheMemory = config.QueryCacheMemory / 4 - qe.settings = theine.NewStore[SettingsCacheKey, *pools.Setting](settingsCacheMemory, false) + qe.settings = theine.NewStore[SettingsCacheKey, *smartconnpool.Setting](settingsCacheMemory, false) qe.schema.Store(¤tSchema{ tables: make(map[string]*schema.Table), @@ -320,7 +320,7 @@ func (qe *QueryEngine) Open() error { qe.conns.Close() return err } - err = conn.VerifyMode(qe.strictTransTables) + err = conn.Conn.VerifyMode(qe.strictTransTables) // Recycle needs to happen before error check. // Otherwise, qe.conns.Close will hang. conn.Recycle() @@ -466,25 +466,24 @@ func (qe *QueryEngine) GetMessageStreamPlan(name string) (*TabletPlan, error) { } // GetConnSetting returns system settings for the connection. -func (qe *QueryEngine) GetConnSetting(ctx context.Context, settings []string) (*pools.Setting, error) { +func (qe *QueryEngine) GetConnSetting(ctx context.Context, settings []string) (*smartconnpool.Setting, error) { span, _ := trace.NewSpan(ctx, "QueryEngine.GetConnSetting") defer span.Finish() - hasher := vthash.New256() + var buf strings.Builder for _, q := range settings { - _, _ = hasher.WriteString(q) + _, _ = buf.WriteString(q) + _ = buf.WriteByte(';') } - var cacheKey SettingsCacheKey - hasher.Sum(cacheKey[:0]) - - connSetting, _, err := qe.settings.GetOrLoad(cacheKey, 0, func() (*pools.Setting, error) { + cacheKey := SettingsCacheKey(buf.String()) + connSetting, _, err := qe.settings.GetOrLoad(cacheKey, 0, func() (*smartconnpool.Setting, error) { // build the setting queries query, resetQuery, err := planbuilder.BuildSettingQuery(settings) if err != nil { return nil, err } - return pools.NewSetting(query, resetQuery), nil + return smartconnpool.NewSetting(query, resetQuery), nil }) return connSetting, err } diff --git a/go/vt/vttablet/tabletserver/query_executor.go b/go/vt/vttablet/tabletserver/query_executor.go index 3d1d9b4b87c..63dcd42d0a8 100644 --- a/go/vt/vttablet/tabletserver/query_executor.go +++ b/go/vt/vttablet/tabletserver/query_executor.go @@ -26,10 +26,10 @@ import ( "vitess.io/vitess/go/mysql/replication" "vitess.io/vitess/go/mysql/sqlerror" + "vitess.io/vitess/go/pools/smartconnpool" "vitess.io/vitess/go/mysql/collations" - "vitess.io/vitess/go/pools" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/trace" "vitess.io/vitess/go/vt/callerid" @@ -63,7 +63,7 @@ type QueryExecutor struct { logStats *tabletenv.LogStats tsv *TabletServer tabletType topodatapb.TabletType - setting *pools.Setting + setting *smartconnpool.Setting } const ( @@ -362,7 +362,7 @@ func (qre *QueryExecutor) Stream(callback StreamCallback) error { } // if we have a transaction id, let's use the txPool for this query - var conn *connpool.DBConn + var conn *connpool.PooledConn if qre.connID != 0 { txConn, err := qre.tsv.te.txPool.GetAndLock(qre.connID, "for streaming query") if err != nil { @@ -703,7 +703,7 @@ func (qre *QueryExecutor) execSelect() (*sqltypes.Result, error) { q.SetErr(err) } else { defer conn.Recycle() - res, err := qre.execDBConn(conn, sql, true) + res, err := qre.execDBConn(conn.Conn, sql, true) q.SetResult(res) q.SetErr(err) } @@ -723,7 +723,7 @@ func (qre *QueryExecutor) execSelect() (*sqltypes.Result, error) { return nil, err } defer conn.Recycle() - res, err := qre.execDBConn(conn, sql, true) + res, err := qre.execDBConn(conn.Conn, sql, true) if err != nil { return nil, err } @@ -765,10 +765,10 @@ func (qre *QueryExecutor) execOther() (*sqltypes.Result, error) { return nil, err } defer conn.Recycle() - return qre.execDBConn(conn, qre.query, true) + return qre.execDBConn(conn.Conn, qre.query, true) } -func (qre *QueryExecutor) getConn() (*connpool.DBConn, error) { +func (qre *QueryExecutor) getConn() (*connpool.PooledConn, error) { span, ctx := trace.NewSpan(qre.ctx, "QueryExecutor.getConn") defer span.Finish() @@ -785,7 +785,7 @@ func (qre *QueryExecutor) getConn() (*connpool.DBConn, error) { return nil, err } -func (qre *QueryExecutor) getStreamConn() (*connpool.DBConn, error) { +func (qre *QueryExecutor) getStreamConn() (*connpool.PooledConn, error) { span, ctx := trace.NewSpan(qre.ctx, "QueryExecutor.getStreamConn") defer span.Finish() @@ -874,7 +874,7 @@ func (qre *QueryExecutor) execCallProc() (*sqltypes.Result, error) { return nil, err } - qr, err := qre.execDBConn(conn, sql, true) + qr, err := qre.execDBConn(conn.Conn, sql, true) if err != nil { return nil, rewriteOUTParamError(err) } @@ -885,7 +885,7 @@ func (qre *QueryExecutor) execCallProc() (*sqltypes.Result, error) { } return qr, nil } - err = qre.drainResultSetOnConn(conn) + err = qre.drainResultSetOnConn(conn.Conn) if err != nil { return nil, err } @@ -910,7 +910,7 @@ func (qre *QueryExecutor) execProc(conn *StatefulConnection) (*sqltypes.Result, } return qr, nil } - err = qre.drainResultSetOnConn(conn.UnderlyingDBConn()) + err = qre.drainResultSetOnConn(conn.UnderlyingDBConn().Conn) if err != nil { return nil, err } @@ -1047,7 +1047,7 @@ func (qre *QueryExecutor) execShowThrottlerStatus() (*sqltypes.Result, error) { return result, nil } -func (qre *QueryExecutor) drainResultSetOnConn(conn *connpool.DBConn) error { +func (qre *QueryExecutor) drainResultSetOnConn(conn *connpool.Conn) error { more := true for more { qr, err := conn.FetchNext(qre.ctx, int(qre.getSelectLimit()), true) @@ -1063,7 +1063,7 @@ func (qre *QueryExecutor) getSelectLimit() int64 { return qre.tsv.qe.maxResultSize.Load() } -func (qre *QueryExecutor) execDBConn(conn *connpool.DBConn, sql string, wantfields bool) (*sqltypes.Result, error) { +func (qre *QueryExecutor) execDBConn(conn *connpool.Conn, sql string, wantfields bool) (*sqltypes.Result, error) { span, ctx := trace.NewSpan(qre.ctx, "QueryExecutor.execDBConn") defer span.Finish() @@ -1089,7 +1089,7 @@ func (qre *QueryExecutor) execStatefulConn(conn *StatefulConnection, sql string, return conn.Exec(ctx, sql, int(qre.tsv.qe.maxResultSize.Load()), wantfields) } -func (qre *QueryExecutor) execStreamSQL(conn *connpool.DBConn, isTransaction bool, sql string, callback func(*sqltypes.Result) error) error { +func (qre *QueryExecutor) execStreamSQL(conn *connpool.PooledConn, isTransaction bool, sql string, callback func(*sqltypes.Result) error) error { span, ctx := trace.NewSpan(qre.ctx, "QueryExecutor.execStreamSQL") trace.AnnotateSQL(span, sqlparser.Preview(sql)) callBackClosingSpan := func(result *sqltypes.Result) error { @@ -1105,15 +1105,15 @@ func (qre *QueryExecutor) execStreamSQL(conn *connpool.DBConn, isTransaction boo // weren't getting cleaned up during unserveCommon>handleShutdownGracePeriod in state_manager.go. // This change will ensure that long-running streaming stateful queries get gracefully shutdown during ServingTypeChange // once their grace period is over. - qd := NewQueryDetail(qre.logStats.Ctx, conn) + qd := NewQueryDetail(qre.logStats.Ctx, conn.Conn) if isTransaction { qre.tsv.statefulql.Add(qd) defer qre.tsv.statefulql.Remove(qd) - return conn.StreamOnce(ctx, sql, callBackClosingSpan, allocStreamResult, int(qre.tsv.qe.streamBufferSize.Load()), sqltypes.IncludeFieldsOrDefault(qre.options)) + return conn.Conn.StreamOnce(ctx, sql, callBackClosingSpan, allocStreamResult, int(qre.tsv.qe.streamBufferSize.Load()), sqltypes.IncludeFieldsOrDefault(qre.options)) } qre.tsv.olapql.Add(qd) defer qre.tsv.olapql.Remove(qd) - return conn.Stream(ctx, sql, callBackClosingSpan, allocStreamResult, int(qre.tsv.qe.streamBufferSize.Load()), sqltypes.IncludeFieldsOrDefault(qre.options)) + return conn.Conn.Stream(ctx, sql, callBackClosingSpan, allocStreamResult, int(qre.tsv.qe.streamBufferSize.Load()), sqltypes.IncludeFieldsOrDefault(qre.options)) } func (qre *QueryExecutor) recordUserQuery(queryType string, duration int64) { diff --git a/go/vt/vttablet/tabletserver/repltracker/reader.go b/go/vt/vttablet/tabletserver/repltracker/reader.go index fc42a367989..fe469bb2e31 100644 --- a/go/vt/vttablet/tabletserver/repltracker/reader.go +++ b/go/vt/vttablet/tabletserver/repltracker/reader.go @@ -180,7 +180,7 @@ func (r *heartbeatReader) fetchMostRecentHeartbeat(ctx context.Context) (*sqltyp if err != nil { return nil, err } - return conn.Exec(ctx, sel, 1, false) + return conn.Conn.Exec(ctx, sel, 1, false) } // bindHeartbeatFetch takes a heartbeat read and adds the necessary diff --git a/go/vt/vttablet/tabletserver/repltracker/writer.go b/go/vt/vttablet/tabletserver/repltracker/writer.go index 2b7dcd1ff2e..b13b78b59b7 100644 --- a/go/vt/vttablet/tabletserver/repltracker/writer.go +++ b/go/vt/vttablet/tabletserver/repltracker/writer.go @@ -87,8 +87,8 @@ func newHeartbeatWriter(env tabletenv.Env, alias *topodatapb.TabletAlias) *heart errorLog: logutil.NewThrottledLogger("HeartbeatWriter", 60*time.Second), // We make this pool size 2; to prevent pool exhausted // stats from incrementing continually, and causing concern - appPool: dbconnpool.NewConnectionPool("HeartbeatWriteAppPool", 2, mysqlctl.DbaIdleTimeout, 0, mysqlctl.PoolDynamicHostnameResolution), - allPrivsPool: dbconnpool.NewConnectionPool("HeartbeatWriteAllPrivsPool", 2, mysqlctl.DbaIdleTimeout, 0, mysqlctl.PoolDynamicHostnameResolution), + appPool: dbconnpool.NewConnectionPool("HeartbeatWriteAppPool", env.Exporter(), 2, mysqlctl.DbaIdleTimeout, 0, mysqlctl.PoolDynamicHostnameResolution), + allPrivsPool: dbconnpool.NewConnectionPool("HeartbeatWriteAllPrivsPool", env.Exporter(), 2, mysqlctl.DbaIdleTimeout, 0, mysqlctl.PoolDynamicHostnameResolution), } if w.onDemandDuration > 0 { // see RequestHeartbeats() for use of onDemandRequestTicks @@ -207,7 +207,7 @@ func (w *heartbeatWriter) write() error { return err } defer appConn.Recycle() - _, err = appConn.ExecuteFetch(upsert, 1, false) + _, err = appConn.Conn.ExecuteFetch(upsert, 1, false) if err != nil { return err } diff --git a/go/vt/vttablet/tabletserver/schema/db.go b/go/vt/vttablet/tabletserver/schema/db.go index 85ebf3b1457..5699ffc1bde 100644 --- a/go/vt/vttablet/tabletserver/schema/db.go +++ b/go/vt/vttablet/tabletserver/schema/db.go @@ -89,7 +89,7 @@ where table_schema = database() and table_name in ::viewNames` ) // reloadTablesDataInDB reloads teh tables information we have stored in our database we use for schema-tracking. -func reloadTablesDataInDB(ctx context.Context, conn *connpool.DBConn, tables []*Table, droppedTables []string) error { +func reloadTablesDataInDB(ctx context.Context, conn *connpool.Conn, tables []*Table, droppedTables []string) error { // No need to do anything if we have no tables to refresh or drop. if len(tables) == 0 && len(droppedTables) == 0 { return nil @@ -174,7 +174,7 @@ func generateFullQuery(query string) (*sqlparser.ParsedQuery, error) { } // reloadViewsDataInDB reloads teh views information we have stored in our database we use for schema-tracking. -func reloadViewsDataInDB(ctx context.Context, conn *connpool.DBConn, views []*Table, droppedViews []string) error { +func reloadViewsDataInDB(ctx context.Context, conn *connpool.Conn, views []*Table, droppedViews []string) error { // No need to do anything if we have no views to refresh or drop. if len(views) == 0 && len(droppedViews) == 0 { return nil @@ -266,7 +266,7 @@ func reloadViewsDataInDB(ctx context.Context, conn *connpool.DBConn, views []*Ta } // getViewDefinition gets the viewDefinition for the given views. -func getViewDefinition(ctx context.Context, conn *connpool.DBConn, bv map[string]*querypb.BindVariable, callback func(qr *sqltypes.Result) error, alloc func() *sqltypes.Result, bufferSize int) error { +func getViewDefinition(ctx context.Context, conn *connpool.Conn, bv map[string]*querypb.BindVariable, callback func(qr *sqltypes.Result) error, alloc func() *sqltypes.Result, bufferSize int) error { viewsDefParsedQuery, err := generateFullQuery(fetchViewDefinitions) if err != nil { return err @@ -279,7 +279,7 @@ func getViewDefinition(ctx context.Context, conn *connpool.DBConn, bv map[string } // getCreateStatement gets the create-statement for the given view/table. -func getCreateStatement(ctx context.Context, conn *connpool.DBConn, tableName string) (string, error) { +func getCreateStatement(ctx context.Context, conn *connpool.Conn, tableName string) (string, error) { res, err := conn.Exec(ctx, sqlparser.BuildParsedQuery(fetchCreateStatement, tableName).Query, 1, false) if err != nil { return "", err @@ -288,7 +288,7 @@ func getCreateStatement(ctx context.Context, conn *connpool.DBConn, tableName st } // getChangedViewNames gets the list of views that have their definitions changed. -func getChangedViewNames(ctx context.Context, conn *connpool.DBConn, isServingPrimary bool) (map[string]any, error) { +func getChangedViewNames(ctx context.Context, conn *connpool.Conn, isServingPrimary bool) (map[string]any, error) { /* Retrieve changed views */ views := make(map[string]any) if !isServingPrimary { @@ -314,7 +314,7 @@ func getChangedViewNames(ctx context.Context, conn *connpool.DBConn, isServingPr } // getMismatchedTableNames gets the tables that do not align with the tables information we have in the cache. -func (se *Engine) getMismatchedTableNames(ctx context.Context, conn *connpool.DBConn, isServingPrimary bool) (map[string]any, error) { +func (se *Engine) getMismatchedTableNames(ctx context.Context, conn *connpool.Conn, isServingPrimary bool) (map[string]any, error) { tablesMismatched := make(map[string]any) if !isServingPrimary { return tablesMismatched, nil @@ -358,7 +358,7 @@ func (se *Engine) getMismatchedTableNames(ctx context.Context, conn *connpool.DB } // reloadDataInDB reloads the schema tracking data in the database -func reloadDataInDB(ctx context.Context, conn *connpool.DBConn, altered []*Table, created []*Table, dropped []*Table) error { +func reloadDataInDB(ctx context.Context, conn *connpool.Conn, altered []*Table, created []*Table, dropped []*Table) error { // tablesToReload and viewsToReload stores the tables and views that need reloading and storing in our MySQL database. var tablesToReload, viewsToReload []*Table // droppedTables, droppedViews stores the list of tables and views we need to delete, respectively. diff --git a/go/vt/vttablet/tabletserver/schema/db_test.go b/go/vt/vttablet/tabletserver/schema/db_test.go index 44a3fd0c687..ac6999d309a 100644 --- a/go/vt/vttablet/tabletserver/schema/db_test.go +++ b/go/vt/vttablet/tabletserver/schema/db_test.go @@ -96,7 +96,7 @@ func TestGenerateFullQuery(t *testing.T) { func TestGetCreateStatement(t *testing.T) { db := fakesqldb.New(t) - conn, err := connpool.NewDBConnNoPool(context.Background(), db.ConnParams(), nil, nil) + conn, err := connpool.NewConn(context.Background(), db.ConnParams(), nil, nil) require.NoError(t, err) // Success view @@ -131,7 +131,7 @@ func TestGetCreateStatement(t *testing.T) { func TestGetChangedViewNames(t *testing.T) { db := fakesqldb.New(t) - conn, err := connpool.NewDBConnNoPool(context.Background(), db.ConnParams(), nil, nil) + conn, err := connpool.NewConn(context.Background(), db.ConnParams(), nil, nil) require.NoError(t, err) // Success @@ -164,7 +164,7 @@ func TestGetChangedViewNames(t *testing.T) { func TestGetViewDefinition(t *testing.T) { db := fakesqldb.New(t) - conn, err := connpool.NewDBConnNoPool(context.Background(), db.ConnParams(), nil, nil) + conn, err := connpool.NewConn(context.Background(), db.ConnParams(), nil, nil) require.NoError(t, err) viewsBV, err := sqltypes.BuildBindVariable([]string{"v1", "lead"}) @@ -200,7 +200,7 @@ func TestGetViewDefinition(t *testing.T) { require.Len(t, got, 0) } -func collectGetViewDefinitions(conn *connpool.DBConn, bv map[string]*querypb.BindVariable) (map[string]string, error) { +func collectGetViewDefinitions(conn *connpool.Conn, bv map[string]*querypb.BindVariable) (map[string]string, error) { viewDefinitions := make(map[string]string) err := getViewDefinition(context.Background(), conn, bv, func(qr *sqltypes.Result) error { for _, row := range qr.Rows { @@ -336,7 +336,7 @@ func TestGetMismatchedTableNames(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { db := fakesqldb.New(t) - conn, err := connpool.NewDBConnNoPool(context.Background(), db.ConnParams(), nil, nil) + conn, err := connpool.NewConn(context.Background(), db.ConnParams(), nil, nil) require.NoError(t, err) if tc.dbError != "" { @@ -456,7 +456,7 @@ func TestReloadTablesInDB(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { db := fakesqldb.New(t) - conn, err := connpool.NewDBConnNoPool(context.Background(), db.ConnParams(), nil, nil) + conn, err := connpool.NewConn(context.Background(), db.ConnParams(), nil, nil) require.NoError(t, err) // Add queries with the expected results and errors. @@ -588,7 +588,7 @@ func TestReloadViewsInDB(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { db := fakesqldb.New(t) - conn, err := connpool.NewDBConnNoPool(context.Background(), db.ConnParams(), nil, nil) + conn, err := connpool.NewConn(context.Background(), db.ConnParams(), nil, nil) require.NoError(t, err) // Add queries with the expected results and errors. @@ -878,7 +878,7 @@ func TestReloadDataInDB(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { db := fakesqldb.New(t) - conn, err := connpool.NewDBConnNoPool(context.Background(), db.ConnParams(), nil, nil) + conn, err := connpool.NewConn(context.Background(), db.ConnParams(), nil, nil) require.NoError(t, err) // Add queries with the expected results and errors. diff --git a/go/vt/vttablet/tabletserver/schema/engine.go b/go/vt/vttablet/tabletserver/schema/engine.go index 1ef6d071b7c..ae50b460a96 100644 --- a/go/vt/vttablet/tabletserver/schema/engine.go +++ b/go/vt/vttablet/tabletserver/schema/engine.go @@ -414,12 +414,12 @@ func (se *Engine) reload(ctx context.Context, includeStats bool) error { defer conn.Recycle() // curTime will be saved into lastChange after schema is loaded. - curTime, err := se.mysqlTime(ctx, conn) + curTime, err := se.mysqlTime(ctx, conn.Conn) if err != nil { return err } - tableData, err := getTableData(ctx, conn, includeStats) + tableData, err := getTableData(ctx, conn.Conn, includeStats) if err != nil { return vterrors.Wrapf(err, "in Engine.reload(), reading tables") } @@ -428,19 +428,19 @@ func (se *Engine) reload(ctx context.Context, includeStats bool) error { // changedViews are the views that have changed. We can't use the same createTime logic for views because, MySQL // doesn't update the create_time field for views when they are altered. This is annoying, but something we have to work around. - changedViews, err := getChangedViewNames(ctx, conn, shouldUseDatabase) + changedViews, err := getChangedViewNames(ctx, conn.Conn, shouldUseDatabase) if err != nil { return err } // mismatchTables stores the tables whose createTime in our cache doesn't match the createTime stored in the database. // This can happen if a primary crashed right after a DML succeeded, before it could reload its state. If all the replicas // are able to reload their cache before one of them is promoted, then the database information would be out of sync. - mismatchTables, err := se.getMismatchedTableNames(ctx, conn, shouldUseDatabase) + mismatchTables, err := se.getMismatchedTableNames(ctx, conn.Conn, shouldUseDatabase) if err != nil { return err } - err = se.updateInnoDBRowsRead(ctx, conn) + err = se.updateInnoDBRowsRead(ctx, conn.Conn) if err != nil { return err } @@ -526,7 +526,7 @@ func (se *Engine) reload(ctx context.Context, includeStats bool) error { dropped := se.getDroppedTables(curTables, changedViews, mismatchTables) // Populate PKColumns for changed tables. - if err := se.populatePrimaryKeys(ctx, conn, changedTables); err != nil { + if err := se.populatePrimaryKeys(ctx, conn.Conn, changedTables); err != nil { return err } @@ -534,7 +534,7 @@ func (se *Engine) reload(ctx context.Context, includeStats bool) error { if shouldUseDatabase { // If reloadDataInDB succeeds, then we don't want to prevent sending the broadcast notification. // So, we do this step in the end when we can receive no more errors that fail the reload operation. - err = reloadDataInDB(ctx, conn, altered, created, dropped) + err = reloadDataInDB(ctx, conn.Conn, altered, created, dropped) if err != nil { log.Errorf("error in updating schema information in Engine.reload() - %v", err) } @@ -589,7 +589,7 @@ func (se *Engine) getDroppedTables(curTables map[string]bool, changedViews map[s return maps2.Values(dropped) } -func getTableData(ctx context.Context, conn *connpool.DBConn, includeStats bool) (*sqltypes.Result, error) { +func getTableData(ctx context.Context, conn *connpool.Conn, includeStats bool) (*sqltypes.Result, error) { var showTablesQuery string if includeStats { showTablesQuery = conn.BaseShowTablesWithSizes() @@ -599,7 +599,7 @@ func getTableData(ctx context.Context, conn *connpool.DBConn, includeStats bool) return conn.Exec(ctx, showTablesQuery, maxTableCount, false) } -func (se *Engine) updateInnoDBRowsRead(ctx context.Context, conn *connpool.DBConn) error { +func (se *Engine) updateInnoDBRowsRead(ctx context.Context, conn *connpool.Conn) error { readRowsData, err := conn.Exec(ctx, mysql.ShowRowsRead, 10, false) if err != nil { return err @@ -618,7 +618,7 @@ func (se *Engine) updateInnoDBRowsRead(ctx context.Context, conn *connpool.DBCon return nil } -func (se *Engine) mysqlTime(ctx context.Context, conn *connpool.DBConn) (int64, error) { +func (se *Engine) mysqlTime(ctx context.Context, conn *connpool.Conn) (int64, error) { // Keep `SELECT UNIX_TIMESTAMP` is in uppercase because binlog server queries are case sensitive and expect it to be so. tm, err := conn.Exec(ctx, "SELECT UNIX_TIMESTAMP()", 1, false) if err != nil { @@ -635,7 +635,7 @@ func (se *Engine) mysqlTime(ctx context.Context, conn *connpool.DBConn) (int64, } // populatePrimaryKeys populates the PKColumns for the specified tables. -func (se *Engine) populatePrimaryKeys(ctx context.Context, conn *connpool.DBConn, tables map[string]*Table) error { +func (se *Engine) populatePrimaryKeys(ctx context.Context, conn *connpool.Conn, tables map[string]*Table) error { pkData, err := conn.Exec(ctx, mysql.BaseShowPrimary, maxTableCount, false) if err != nil { return vterrors.Errorf(vtrpcpb.Code_UNKNOWN, "could not get table primary key info: %v", err) @@ -789,7 +789,7 @@ func newMinimalTable(st *Table) *binlogdatapb.MinimalTable { } // GetConnection returns a connection from the pool -func (se *Engine) GetConnection(ctx context.Context) (*connpool.DBConn, error) { +func (se *Engine) GetConnection(ctx context.Context) (*connpool.PooledConn, error) { return se.conns.Get(ctx, nil) } diff --git a/go/vt/vttablet/tabletserver/schema/engine_test.go b/go/vt/vttablet/tabletserver/schema/engine_test.go index 4000795d9d0..0a98a6ee676 100644 --- a/go/vt/vttablet/tabletserver/schema/engine_test.go +++ b/go/vt/vttablet/tabletserver/schema/engine_test.go @@ -742,7 +742,7 @@ func TestEngineMysqlTime(t *testing.T) { t.Run(tt.name, func(t *testing.T) { se := &Engine{} db := fakesqldb.New(t) - conn, err := connpool.NewDBConnNoPool(context.Background(), db.ConnParams(), nil, nil) + conn, err := connpool.NewConn(context.Background(), db.ConnParams(), nil, nil) require.NoError(t, err) if tt.timeStampErr != nil { @@ -848,7 +848,7 @@ func TestEnginePopulatePrimaryKeys(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db := fakesqldb.New(t) - conn, err := connpool.NewDBConnNoPool(context.Background(), db.ConnParams(), nil, nil) + conn, err := connpool.NewConn(context.Background(), db.ConnParams(), nil, nil) require.NoError(t, err) se := &Engine{} @@ -909,7 +909,7 @@ func TestEngineUpdateInnoDBRowsRead(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db := fakesqldb.New(t) - conn, err := connpool.NewDBConnNoPool(context.Background(), db.ConnParams(), nil, nil) + conn, err := connpool.NewConn(context.Background(), db.ConnParams(), nil, nil) require.NoError(t, err) se := &Engine{} se.innoDbReadRowsCounter = stats.NewCounter("TestEngineUpdateInnoDBRowsRead-"+tt.name, "") @@ -936,7 +936,7 @@ func TestEngineUpdateInnoDBRowsRead(t *testing.T) { // TestEngineGetTableData tests the functionality of getTableData function func TestEngineGetTableData(t *testing.T) { db := fakesqldb.New(t) - conn, err := connpool.NewDBConnNoPool(context.Background(), db.ConnParams(), nil, nil) + conn, err := connpool.NewConn(context.Background(), db.ConnParams(), nil, nil) require.NoError(t, err) tests := []struct { @@ -1110,7 +1110,7 @@ func TestEngineReload(t *testing.T) { cfg := tabletenv.NewDefaultConfig() cfg.DB = newDBConfigs(db) cfg.SignalWhenSchemaChange = true - conn, err := connpool.NewDBConnNoPool(context.Background(), db.ConnParams(), nil, nil) + conn, err := connpool.NewConn(context.Background(), db.ConnParams(), nil, nil) require.NoError(t, err) se := newEngine(10*time.Second, 10*time.Second, 0, db) diff --git a/go/vt/vttablet/tabletserver/schema/historian.go b/go/vt/vttablet/tabletserver/schema/historian.go index e40777c6fe5..b65ab514585 100644 --- a/go/vt/vttablet/tabletserver/schema/historian.go +++ b/go/vt/vttablet/tabletserver/schema/historian.go @@ -171,10 +171,10 @@ func (h *historian) loadFromDB(ctx context.Context) error { var tableData *sqltypes.Result if h.lastID == 0 && h.schemaMaxAgeSeconds > 0 { // only at vttablet start schemaMaxAge := time.Now().UTC().Add(time.Duration(-h.schemaMaxAgeSeconds) * time.Second) - tableData, err = conn.Exec(ctx, sqlparser.BuildParsedQuery(getInitialSchemaVersions, sidecar.GetIdentifier(), + tableData, err = conn.Conn.Exec(ctx, sqlparser.BuildParsedQuery(getInitialSchemaVersions, sidecar.GetIdentifier(), schemaMaxAge.Unix()).Query, 10000, true) } else { - tableData, err = conn.Exec(ctx, sqlparser.BuildParsedQuery(getNextSchemaVersions, sidecar.GetIdentifier(), + tableData, err = conn.Conn.Exec(ctx, sqlparser.BuildParsedQuery(getNextSchemaVersions, sidecar.GetIdentifier(), h.lastID).Query, 10000, true) } diff --git a/go/vt/vttablet/tabletserver/schema/load_table.go b/go/vt/vttablet/tabletserver/schema/load_table.go index 08e70fc321d..687672a4a02 100644 --- a/go/vt/vttablet/tabletserver/schema/load_table.go +++ b/go/vt/vttablet/tabletserver/schema/load_table.go @@ -34,7 +34,7 @@ import ( ) // LoadTable creates a Table from the schema info in the database. -func LoadTable(conn *connpool.DBConn, databaseName, tableName, tableType string, comment string) (*Table, error) { +func LoadTable(conn *connpool.PooledConn, databaseName, tableName, tableType string, comment string) (*Table, error) { ta := NewTable(tableName, NoType) sqlTableName := sqlparser.String(ta.Name) if err := fetchColumns(ta, conn, databaseName, sqlTableName); err != nil { @@ -55,10 +55,10 @@ func LoadTable(conn *connpool.DBConn, databaseName, tableName, tableType string, return ta, nil } -func fetchColumns(ta *Table, conn *connpool.DBConn, databaseName, sqlTableName string) error { +func fetchColumns(ta *Table, conn *connpool.PooledConn, databaseName, sqlTableName string) error { ctx := context.Background() exec := func(query string, maxRows int, wantFields bool) (*sqltypes.Result, error) { - return conn.Exec(ctx, query, maxRows, wantFields) + return conn.Conn.Exec(ctx, query, maxRows, wantFields) } fields, _, err := mysqlctl.GetColumns(databaseName, sqlTableName, exec) if err != nil { diff --git a/go/vt/vttablet/tabletserver/schema/tracker.go b/go/vt/vttablet/tabletserver/schema/tracker.go index 9e036bb5139..9b4deaff6c4 100644 --- a/go/vt/vttablet/tabletserver/schema/tracker.go +++ b/go/vt/vttablet/tabletserver/schema/tracker.go @@ -170,7 +170,7 @@ func (tr *Tracker) isSchemaVersionTableEmpty(ctx context.Context) (bool, error) return false, err } defer conn.Recycle() - result, err := conn.Exec(ctx, sqlparser.BuildParsedQuery("select id from %s.schema_version limit 1", + result, err := conn.Conn.Exec(ctx, sqlparser.BuildParsedQuery("select id from %s.schema_version limit 1", sidecar.GetIdentifier()).Query, 1, false) if err != nil { return false, err @@ -234,7 +234,7 @@ func (tr *Tracker) saveCurrentSchemaToDb(ctx context.Context, gtid, ddl string, "(pos, ddl, schemax, time_updated) "+ "values (%s, %s, %s, %d)", sidecar.GetIdentifier(), encodeString(gtid), encodeString(ddl), encodeString(string(blob)), timestamp).Query - _, err = conn.Exec(ctx, query, 1, false) + _, err = conn.Conn.Exec(ctx, query, 1, false) if err != nil { return err } diff --git a/go/vt/vttablet/tabletserver/stateful_connection.go b/go/vt/vttablet/tabletserver/stateful_connection.go index 97d20f594c9..739ed5c4295 100644 --- a/go/vt/vttablet/tabletserver/stateful_connection.go +++ b/go/vt/vttablet/tabletserver/stateful_connection.go @@ -22,7 +22,7 @@ import ( "time" "vitess.io/vitess/go/mysql/sqlerror" - "vitess.io/vitess/go/pools" + "vitess.io/vitess/go/pools/smartconnpool" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/callerid" "vitess.io/vitess/go/vt/servenv" @@ -40,7 +40,7 @@ import ( // NOTE: After use, if must be returned either by doing a Unlock() or a Release(). type StatefulConnection struct { pool *StatefulConnectionPool - dbConn *connpool.DBConn + dbConn *connpool.PooledConn ConnID tx.ConnID env tabletenv.Env txProps *tx.Properties @@ -68,7 +68,7 @@ func (sc *StatefulConnection) Close() { // IsClosed returns true when the connection is still operational func (sc *StatefulConnection) IsClosed() bool { - return sc.dbConn == nil || sc.dbConn.IsClosed() + return sc.dbConn == nil || sc.dbConn.Conn.IsClosed() } // IsInTransaction returns true when the connection has tx state @@ -94,7 +94,7 @@ func (sc *StatefulConnection) Exec(ctx context.Context, query string, maxrows in } return nil, vterrors.New(vtrpcpb.Code_ABORTED, "connection was aborted") } - r, err := sc.dbConn.ExecOnce(ctx, query, maxrows, wantfields) + r, err := sc.dbConn.Conn.ExecOnce(ctx, query, maxrows, wantfields) if err != nil { if sqlerror.IsConnErr(err) { select { @@ -115,7 +115,7 @@ func (sc *StatefulConnection) execWithRetry(ctx context.Context, query string, m if sc.IsClosed() { return "", vterrors.New(vtrpcpb.Code_CANCELED, "connection is closed") } - res, err := sc.dbConn.Exec(ctx, query, maxrows, wantfields) + res, err := sc.dbConn.Conn.Exec(ctx, query, maxrows, wantfields) if err != nil { return "", err } @@ -127,7 +127,7 @@ func (sc *StatefulConnection) FetchNext(ctx context.Context, maxrows int, wantfi if sc.IsClosed() { return nil, vterrors.New(vtrpcpb.Code_CANCELED, "connection is closed") } - return sc.dbConn.FetchNext(ctx, maxrows, wantfields) + return sc.dbConn.Conn.FetchNext(ctx, maxrows, wantfields) } // Unlock returns the connection to the pool. The connection remains active. @@ -148,7 +148,7 @@ func (sc *StatefulConnection) unlock(updateTime bool) { if sc.dbConn == nil { return } - if sc.dbConn.IsClosed() { + if sc.dbConn.Conn.IsClosed() { sc.Releasef("unlocked closed connection") } else { sc.pool.markAsNotInUse(sc, updateTime) @@ -194,17 +194,17 @@ func (sc *StatefulConnection) String(sanitize bool) string { // Current returns the currently executing query func (sc *StatefulConnection) Current() string { - return sc.dbConn.Current() + return sc.dbConn.Conn.Current() } // ID returns the mysql connection ID func (sc *StatefulConnection) ID() int64 { - return sc.dbConn.ID() + return sc.dbConn.Conn.ID() } // Kill kills the currently executing query and connection func (sc *StatefulConnection) Kill(reason string, elapsed time.Duration) error { - return sc.dbConn.Kill(reason, elapsed) + return sc.dbConn.Conn.Kill(reason, elapsed) } // TxProperties returns the transactional properties of the connection @@ -218,7 +218,7 @@ func (sc *StatefulConnection) ReservedID() tx.ConnID { } // UnderlyingDBConn returns the underlying database connection -func (sc *StatefulConnection) UnderlyingDBConn() *connpool.DBConn { +func (sc *StatefulConnection) UnderlyingDBConn() *connpool.PooledConn { return sc.dbConn } @@ -304,11 +304,11 @@ func (sc *StatefulConnection) getUsername() string { return callerid.GetUsername(sc.reservedProps.ImmediateCaller) } -func (sc *StatefulConnection) ApplySetting(ctx context.Context, setting *pools.Setting) error { - if sc.dbConn.IsSameSetting(setting.GetQuery()) { +func (sc *StatefulConnection) ApplySetting(ctx context.Context, setting *smartconnpool.Setting) error { + if sc.dbConn.Conn.Setting() == setting { return nil } - return sc.dbConn.ApplySetting(ctx, setting) + return sc.dbConn.Conn.ApplySetting(ctx, setting) } func (sc *StatefulConnection) resetExpiryTime() { diff --git a/go/vt/vttablet/tabletserver/stateful_connection_pool.go b/go/vt/vttablet/tabletserver/stateful_connection_pool.go index 398ad31dfe0..ce6f917610e 100644 --- a/go/vt/vttablet/tabletserver/stateful_connection_pool.go +++ b/go/vt/vttablet/tabletserver/stateful_connection_pool.go @@ -22,6 +22,7 @@ import ( "time" "vitess.io/vitess/go/pools" + "vitess.io/vitess/go/pools/smartconnpool" "vitess.io/vitess/go/vt/dbconfigs" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/vttablet/tabletserver/connpool" @@ -170,8 +171,8 @@ func (sf *StatefulConnectionPool) GetAndLock(id int64, reason string) (*Stateful // NewConn creates a new StatefulConnection. It will be created from either the normal pool or // the found_rows pool, depending on the options provided -func (sf *StatefulConnectionPool) NewConn(ctx context.Context, options *querypb.ExecuteOptions, setting *pools.Setting) (*StatefulConnection, error) { - var conn *connpool.DBConn +func (sf *StatefulConnectionPool) NewConn(ctx context.Context, options *querypb.ExecuteOptions, setting *smartconnpool.Setting) (*StatefulConnection, error) { + var conn *connpool.PooledConn var err error if options.GetClientFoundRows() { @@ -183,6 +184,13 @@ func (sf *StatefulConnectionPool) NewConn(ctx context.Context, options *querypb. return nil, err } + // A StatefulConnection is usually part of a transaction, so it does not support retries. + // Ensure that it's actually a valid connection before we return it or the transaction will fail. + if err = conn.Conn.ConnCheck(ctx); err != nil { + conn.Recycle() + return nil, err + } + connID := sf.lastID.Add(1) sfConn := &StatefulConnection{ dbConn: conn, diff --git a/go/vt/vttablet/tabletserver/tabletserver.go b/go/vt/vttablet/tabletserver/tabletserver.go index 229ac53e3bc..25eb4da7168 100644 --- a/go/vt/vttablet/tabletserver/tabletserver.go +++ b/go/vt/vttablet/tabletserver/tabletserver.go @@ -34,9 +34,9 @@ import ( "time" "vitess.io/vitess/go/mysql/sqlerror" + "vitess.io/vitess/go/pools/smartconnpool" "vitess.io/vitess/go/acl" - "vitess.io/vitess/go/pools" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/stats" "vitess.io/vitess/go/tb" @@ -496,7 +496,7 @@ func (tsv *TabletServer) begin(ctx context.Context, target *querypb.Target, save if tsv.txThrottler.Throttle(tsv.getPriorityFromOptions(options), options.GetWorkloadName()) { return errTxThrottled } - var connSetting *pools.Setting + var connSetting *smartconnpool.Setting if len(settings) > 0 { connSetting, err = tsv.qe.GetConnSetting(ctx, settings) if err != nil { @@ -794,7 +794,7 @@ func (tsv *TabletServer) execute(ctx context.Context, target *querypb.Target, sq logStats.ReservedID = reservedID logStats.TransactionID = transactionID - var connSetting *pools.Setting + var connSetting *smartconnpool.Setting if len(settings) > 0 { connSetting, err = tsv.qe.GetConnSetting(ctx, settings) if err != nil { @@ -896,7 +896,7 @@ func (tsv *TabletServer) streamExecute(ctx context.Context, target *querypb.Targ logStats.ReservedID = reservedID logStats.TransactionID = transactionID - var connSetting *pools.Setting + var connSetting *smartconnpool.Setting if len(settings) > 0 { connSetting, err = tsv.qe.GetConnSetting(ctx, settings) if err != nil { @@ -1927,7 +1927,7 @@ func (tsv *TabletServer) SetPoolSize(val int) { if val <= 0 { return } - tsv.qe.conns.SetCapacity(val) + tsv.qe.conns.SetCapacity(int64(val)) } // PoolSize returns the pool size. @@ -1937,7 +1937,7 @@ func (tsv *TabletServer) PoolSize() int { // SetStreamPoolSize changes the pool size to the specified value. func (tsv *TabletServer) SetStreamPoolSize(val int) { - tsv.qe.streamConns.SetCapacity(val) + tsv.qe.streamConns.SetCapacity(int64(val)) } // SetStreamConsolidationBlocking sets whether the stream consolidator should wait for slow clients @@ -1952,7 +1952,7 @@ func (tsv *TabletServer) StreamPoolSize() int { // SetTxPoolSize changes the tx pool size to the specified value. func (tsv *TabletServer) SetTxPoolSize(val int) { - tsv.te.txPool.scp.conns.SetCapacity(val) + tsv.te.txPool.scp.conns.SetCapacity(int64(val)) } // TxPoolSize returns the tx pool size. diff --git a/go/vt/vttablet/tabletserver/throttle/throttler.go b/go/vt/vttablet/tabletserver/throttle/throttler.go index 6558b052c9a..b8d84b1ed5e 100644 --- a/go/vt/vttablet/tabletserver/throttle/throttler.go +++ b/go/vt/vttablet/tabletserver/throttle/throttler.go @@ -544,7 +544,7 @@ func (throttler *Throttler) readSelfMySQLThrottleMetric(ctx context.Context, pro } defer conn.Recycle() - tm, err := conn.Exec(ctx, probe.MetricQuery, 1, true) + tm, err := conn.Conn.Exec(ctx, probe.MetricQuery, 1, true) if err != nil { metric.Err = err return metric diff --git a/go/vt/vttablet/tabletserver/twopc.go b/go/vt/vttablet/tabletserver/twopc.go index 7784f7f1702..bbc54b8ea57 100644 --- a/go/vt/vttablet/tabletserver/twopc.go +++ b/go/vt/vttablet/tabletserver/twopc.go @@ -214,7 +214,7 @@ func (tpc *TwoPC) ReadAllRedo(ctx context.Context) (prepared, failed []*tx.Prepa } defer conn.Recycle() - qr, err := conn.Exec(ctx, tpc.readAllRedo, 10000, false) + qr, err := conn.Conn.Exec(ctx, tpc.readAllRedo, 10000, false) if err != nil { return nil, nil, err } @@ -261,7 +261,7 @@ func (tpc *TwoPC) CountUnresolvedRedo(ctx context.Context, unresolvedTime time.T bindVars := map[string]*querypb.BindVariable{ "time_created": sqltypes.Int64BindVariable(unresolvedTime.UnixNano()), } - qr, err := tpc.read(ctx, conn, tpc.countUnresolvedRedo, bindVars) + qr, err := tpc.read(ctx, conn.Conn, tpc.countUnresolvedRedo, bindVars) if err != nil { return 0, err } @@ -347,7 +347,7 @@ func (tpc *TwoPC) ReadTransaction(ctx context.Context, dtid string) (*querypb.Tr bindVars := map[string]*querypb.BindVariable{ "dtid": sqltypes.StringBindVariable(dtid), } - qr, err := tpc.read(ctx, conn, tpc.readTransaction, bindVars) + qr, err := tpc.read(ctx, conn.Conn, tpc.readTransaction, bindVars) if err != nil { return nil, err } @@ -368,7 +368,7 @@ func (tpc *TwoPC) ReadTransaction(ctx context.Context, dtid string) (*querypb.Tr tm, _ := qr.Rows[0][2].ToCastInt64() result.TimeCreated = tm - qr, err = tpc.read(ctx, conn, tpc.readParticipants, bindVars) + qr, err = tpc.read(ctx, conn.Conn, tpc.readParticipants, bindVars) if err != nil { return nil, err } @@ -396,7 +396,7 @@ func (tpc *TwoPC) ReadAbandoned(ctx context.Context, abandonTime time.Time) (map bindVars := map[string]*querypb.BindVariable{ "time_created": sqltypes.Int64BindVariable(abandonTime.UnixNano()), } - qr, err := tpc.read(ctx, conn, tpc.readAbandoned, bindVars) + qr, err := tpc.read(ctx, conn.Conn, tpc.readAbandoned, bindVars) if err != nil { return nil, err } @@ -419,7 +419,7 @@ func (tpc *TwoPC) ReadAllTransactions(ctx context.Context) ([]*tx.DistributedTx, } defer conn.Recycle() - qr, err := conn.Exec(ctx, tpc.readAllTransactions, 10000, false) + qr, err := conn.Conn.Exec(ctx, tpc.readAllTransactions, 10000, false) if err != nil { return nil, err } @@ -466,7 +466,7 @@ func (tpc *TwoPC) exec(ctx context.Context, conn *StatefulConnection, pq *sqlpar return conn.Exec(ctx, q, 1, false) } -func (tpc *TwoPC) read(ctx context.Context, conn *connpool.DBConn, pq *sqlparser.ParsedQuery, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { +func (tpc *TwoPC) read(ctx context.Context, conn *connpool.Conn, pq *sqlparser.ParsedQuery, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { q, err := pq.GenerateQuery(bindVars, nil) if err != nil { return nil, err diff --git a/go/vt/vttablet/tabletserver/tx_engine.go b/go/vt/vttablet/tabletserver/tx_engine.go index 2369282a7c2..fe8f1aa0b6e 100644 --- a/go/vt/vttablet/tabletserver/tx_engine.go +++ b/go/vt/vttablet/tabletserver/tx_engine.go @@ -22,7 +22,7 @@ import ( "sync" "time" - "vitess.io/vitess/go/pools" + "vitess.io/vitess/go/pools/smartconnpool" "vitess.io/vitess/go/timer" "vitess.io/vitess/go/trace" "vitess.io/vitess/go/vt/concurrency" @@ -222,7 +222,7 @@ func (te *TxEngine) isTxPoolAvailable(addToWaitGroup func(int)) error { // statement(s) used to execute the begin (if any). // // Subsequent statements can access the connection through the transaction id. -func (te *TxEngine) Begin(ctx context.Context, savepointQueries []string, reservedID int64, setting *pools.Setting, options *querypb.ExecuteOptions) (int64, string, string, error) { +func (te *TxEngine) Begin(ctx context.Context, savepointQueries []string, reservedID int64, setting *smartconnpool.Setting, options *querypb.ExecuteOptions) (int64, string, string, error) { span, ctx := trace.NewSpan(ctx, "TxEngine.Begin") defer span.Finish() diff --git a/go/vt/vttablet/tabletserver/tx_engine_test.go b/go/vt/vttablet/tabletserver/tx_engine_test.go index 6a86d044bcd..6ddf2f5a9d3 100644 --- a/go/vt/vttablet/tabletserver/tx_engine_test.go +++ b/go/vt/vttablet/tabletserver/tx_engine_test.go @@ -218,7 +218,7 @@ func TestTxEngineRenewFails(t *testing.T) { _, _, err = te.Commit(ctx, connID) require.Error(t, err) assert.True(t, conn.IsClosed(), "connection was not closed") - assert.True(t, dbConn.IsClosed(), "underlying connection was not closed") + assert.True(t, dbConn.Conn.IsClosed(), "underlying connection was not closed") } type TxType int diff --git a/go/vt/vttablet/tabletserver/tx_pool.go b/go/vt/vttablet/tabletserver/tx_pool.go index 8af66d4d32d..f42e3c95408 100644 --- a/go/vt/vttablet/tabletserver/tx_pool.go +++ b/go/vt/vttablet/tabletserver/tx_pool.go @@ -22,7 +22,7 @@ import ( "sync" "time" - "vitess.io/vitess/go/pools" + "vitess.io/vitess/go/pools/smartconnpool" "vitess.io/vitess/go/timer" "vitess.io/vitess/go/trace" "vitess.io/vitess/go/vt/callerid" @@ -230,7 +230,7 @@ func (tp *TxPool) Rollback(ctx context.Context, txConn *StatefulConnection) erro // the statements (if any) executed to initiate the transaction. In autocommit // mode the statement will be "". // The connection returned is locked for the callee and its responsibility is to unlock the connection. -func (tp *TxPool) Begin(ctx context.Context, options *querypb.ExecuteOptions, readOnly bool, reservedID int64, savepointQueries []string, setting *pools.Setting) (*StatefulConnection, string, string, error) { +func (tp *TxPool) Begin(ctx context.Context, options *querypb.ExecuteOptions, readOnly bool, reservedID int64, savepointQueries []string, setting *smartconnpool.Setting) (*StatefulConnection, string, string, error) { span, ctx := trace.NewSpan(ctx, "TxPool.Begin") defer span.Finish() @@ -284,15 +284,15 @@ func (tp *TxPool) begin(ctx context.Context, options *querypb.ExecuteOptions, re return beginQueries, sessionStateChanges, nil } -func (tp *TxPool) createConn(ctx context.Context, options *querypb.ExecuteOptions, setting *pools.Setting) (*StatefulConnection, error) { +func (tp *TxPool) createConn(ctx context.Context, options *querypb.ExecuteOptions, setting *smartconnpool.Setting) (*StatefulConnection, error) { conn, err := tp.scp.NewConn(ctx, options, setting) if err != nil { errCode := vterrors.Code(err) switch err { - case pools.ErrCtxTimeout: + case smartconnpool.ErrCtxTimeout: tp.LogActive() err = vterrors.Errorf(errCode, "transaction pool aborting request due to already expired context") - case pools.ErrTimeout: + case smartconnpool.ErrTimeout: tp.LogActive() err = vterrors.Errorf(errCode, "transaction pool connection limit exceeded") }