Skip to content

Commit

Permalink
Upgrade HostProvider (#6)
Browse files Browse the repository at this point in the history
* Upgrade HostProvider

This change fixes the behavior of DNSHostProvider where it does not refresh its cached IP addresses that it resolves once on startup for the configured ZK servers. This new behavior more closely matches the Java client's behavior by randomly selecting an address after resolving the host. It slightly changes the semantics of `HostProvider` with an off-by-one, otherwise the `connect` loop could end up in a situation where it attempts to connect to a stale address. This is fixed by moving the backoff to _before_ getting the address, rather than _after_.

* Bump linter version to support generics

* Fix linter and integration test actions

* Add docs
  • Loading branch information
PapaCharlie authored Feb 2, 2024
1 parent 718d9a2 commit 4dc0808
Show file tree
Hide file tree
Showing 7 changed files with 304 additions and 110 deletions.
39 changes: 19 additions & 20 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,11 @@ type Event struct {
type HostProvider interface {
// Init is called first, with the servers specified in the connection string.
Init(servers []string) error
// Len returns the number of servers.
Len() int
// Next returns the next server to connect to. retryStart will be true if we've looped through
// all known servers without Connected() being called.
// Next returns the next server to connect to. retryStart should be true if this call to Next
// exhausted the list of known servers without Connected being called. If connecting to this final
// host fails, the connect loop will back off before invoking Next again for a fresh server.
Next() (server string, retryStart bool)
// Notify the HostProvider of a successful connection.
// Connected notifies the HostProvider of a successful connection.
Connected()
}

Expand All @@ -203,12 +202,12 @@ func Connect(servers []string, sessionTimeout time.Duration, options ...connOpti
srvs := FormatServers(servers)

// Randomize the order of the servers to avoid creating hotspots
stringShuffle(srvs)
shuffleSlice(srvs)

ec := make(chan Event, eventChanSize)
conn := &Conn{
dialer: net.DialTimeout,
hostProvider: &DNSHostProvider{},
hostProvider: new(StaticHostProvider),
conn: nil,
state: StateDisconnected,
eventChan: ec,
Expand Down Expand Up @@ -387,7 +386,7 @@ func (c *Conn) sendEvent(evt Event) {
}
}

func (c *Conn) connect() error {
func (c *Conn) connect() (err error) {
var retryStart bool
for {
c.serverMu.Lock()
Expand All @@ -396,18 +395,6 @@ func (c *Conn) connect() error {

c.setState(StateConnecting)

if retryStart {
c.flushUnsentRequests(ErrNoServer)
select {
case <-time.After(time.Second):
// pass
case <-c.shouldQuit:
c.setState(StateDisconnected)
c.flushUnsentRequests(ErrClosing)
return ErrClosing
}
}

zkConn, err := c.dialer("tcp", c.Server(), c.connectTimeout)
if err == nil {
c.conn = zkConn
Expand All @@ -419,6 +406,18 @@ func (c *Conn) connect() error {
}

c.logger.Printf("failed to connect to %s: %v", c.Server(), err)

if retryStart {
c.flushUnsentRequests(ErrNoServer)
select {
case <-time.After(time.Second):
// pass
case <-c.shouldQuit:
c.setState(StateDisconnected)
c.flushUnsentRequests(ErrClosing)
return ErrClosing
}
}
}
}

Expand Down
47 changes: 22 additions & 25 deletions dnshostprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import (
"sync"
)

// DNSHostProvider is the default HostProvider. It currently matches
// the Java StaticHostProvider, resolving hosts from DNS once during
// the call to Init. It could be easily extended to re-query DNS
// periodically or if there is trouble connecting.
// DNSHostProvider is a simple implementation of a HostProvider. It resolves the hosts once during
// Init, and iterates through the resolved addresses for every call to Next. Note that if the
// addresses that back the ZK hosts change, those changes will not be reflected.
//
// Deprecated: Because this HostProvider does not attempt to re-read from DNS, it can lead to issues
// if the addresses of the hosts change. It is preserved for backwards compatibility.
type DNSHostProvider struct {
mu sync.Mutex // Protects everything, so we can add asynchronous updates later.
servers []string
Expand All @@ -30,7 +32,7 @@ func (hp *DNSHostProvider) Init(servers []string) error {
lookupHost = net.LookupHost
}

found := []string{}
var found []string
for _, server := range servers {
host, port, err := net.SplitHostPort(server)
if err != nil {
Expand All @@ -46,43 +48,38 @@ func (hp *DNSHostProvider) Init(servers []string) error {
}

if len(found) == 0 {
return fmt.Errorf("No hosts found for addresses %q", servers)
return fmt.Errorf("zk: no hosts found for addresses %q", servers)
}

// Randomize the order of the servers to avoid creating hotspots
stringShuffle(found)
shuffleSlice(found)

hp.servers = found
hp.curr = -1
hp.last = -1
hp.curr = 0
hp.last = len(hp.servers) - 1

return nil
}

// Len returns the number of servers available
func (hp *DNSHostProvider) Len() int {
hp.mu.Lock()
defer hp.mu.Unlock()
return len(hp.servers)
}

// Next returns the next server to connect to. retryStart will be true
// if we've looped through all known servers without Connected() being
// called.
// Next returns the next server to connect to. retryStart should be true if this call to Next
// exhausted the list of known servers without Connected being called. If connecting to this final
// host fails, the connect loop will back off before invoking Next again for a fresh server.
func (hp *DNSHostProvider) Next() (server string, retryStart bool) {
hp.mu.Lock()
defer hp.mu.Unlock()
hp.curr = (hp.curr + 1) % len(hp.servers)
retryStart = hp.curr == hp.last
if hp.last == -1 {
hp.last = 0
}
return hp.servers[hp.curr], retryStart
server = hp.servers[hp.curr]
hp.curr = (hp.curr + 1) % len(hp.servers)
return server, retryStart
}

// Connected notifies the HostProvider of a successful connection.
func (hp *DNSHostProvider) Connected() {
hp.mu.Lock()
defer hp.mu.Unlock()
hp.last = hp.curr
if hp.curr == 0 {
hp.last = len(hp.servers) - 1
} else {
hp.last = hp.curr - 1
}
}
125 changes: 71 additions & 54 deletions dnshostprovider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ func newLocalHostPortsFacade(inner HostProvider, ports []int) *localHostPortsFac
}
}

func (lhpf *localHostPortsFacade) Len() int { return lhpf.inner.Len() }
func (lhpf *localHostPortsFacade) Connected() { lhpf.inner.Connected() }
func (lhpf *localHostPortsFacade) Init(servers []string) error { return lhpf.inner.Init(servers) }
func (lhpf *localHostPortsFacade) Next() (string, bool) {
Expand Down Expand Up @@ -165,60 +164,78 @@ func TestDNSHostProviderReconnect(t *testing.T) {
}
}

// TestDNSHostProviderRetryStart tests the `retryStart` functionality
// of DNSHostProvider.
// It's also probably the clearest visual explanation of exactly how
// it works.
func TestDNSHostProviderRetryStart(t *testing.T) {
// TestHostProvidersRetryStart tests the `retryStart` functionality of DNSHostProvider and
// StaticHostProvider.
// It's also probably the clearest visual explanation of exactly how it works.
func TestHostProvidersRetryStart(t *testing.T) {
t.Parallel()

hp := &DNSHostProvider{lookupHost: func(host string) ([]string, error) {
return []string{"192.0.2.1", "192.0.2.2", "192.0.2.3"}, nil
}}

if err := hp.Init([]string{"foo.example.com:12345"}); err != nil {
t.Fatal(err)
}

testdata := []struct {
retryStartWant bool
callConnected bool
}{
// Repeated failures.
{false, false},
{false, false},
{false, false},
{true, false},
{false, false},
{false, false},
{true, true},

// One success offsets things.
{false, false},
{false, true},
{false, true},

// Repeated successes.
{false, true},
{false, true},
{false, true},
{false, true},
{false, true},

// And some more failures.
{false, false},
{false, false},
{true, false}, // Looped back to last known good server: all alternates failed.
{false, false},
}

for i, td := range testdata {
_, retryStartGot := hp.Next()
if retryStartGot != td.retryStartWant {
t.Errorf("%d: retryStart=%v; want %v", i, retryStartGot, td.retryStartWant)
}
if td.callConnected {
hp.Connected()
}
lookupHost := func(host string) ([]string, error) {
return []string{host}, nil
}

providers := []HostProvider{
&DNSHostProvider{
lookupHost: lookupHost,
},
&StaticHostProvider{
lookupHost: lookupHost,
},
}

for _, hp := range providers {
t.Run(fmt.Sprintf("%T", hp), func(t *testing.T) {
if err := hp.Init([]string{"foo.com:2121", "bar.com:2121", "baz.com:2121"}); err != nil {
t.Fatal(err)
}

testdata := []struct {
retryStartWant bool
callConnected bool
}{
// Repeated failures.
{false, false},
{false, false},
{true, false},
{false, false},
{false, false},
{true, false},
{false, true},

// One success offsets things.
{false, false},
{false, true},
{false, true},

// Repeated successes.
{false, true},
{false, true},
{false, true},
{false, true},
{false, true},

// And some more failures.
{false, false},
{false, false},
{true, false}, // Looped back to last known good server: all alternates failed.
{false, false},
{false, false},
{true, false},
{false, false},
{false, false},
{true, false},
{false, false},
}

for i, td := range testdata {
_, retryStartGot := hp.Next()
if retryStartGot != td.retryStartWant {
t.Errorf("%d: retryStart=%v; want %v", i, retryStartGot, td.retryStartWant)
}
if td.callConnected {
hp.Connected()
}
}
})
}
}
Loading

0 comments on commit 4dc0808

Please sign in to comment.