diff --git a/.gitignore b/.gitignore index 4a55899..7524d3f 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ dist/ /wait-for # IDEs +.vscode/ .idea/ *.iml diff --git a/Makefile b/Makefile index 4e261b6..5db065a 100644 --- a/Makefile +++ b/Makefile @@ -49,6 +49,10 @@ deps: ./bin/tparse go get -v ./... go mod tidy +.PHONY: mocks +mocks: ## generate mocks for interfaces + mockgen -source=waitfor.go -package=waitfor > waitfor_mock_test.go + .PHONY: build build: ## build the application go build -o wait-for ./cmd/wait-for diff --git a/README.md b/README.md index 94ebb4c..29b0607 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,12 @@ or environment. Typically, you would use this to wait on another resource (such as an HTTP resource) to become available before continuing - or timeout and exit with an error. +At the moment, you can wait for a few different kinds of thing. They are: + +* HTTP or HTTPS success response +* TCP or GRPC connection +* DNS IP resolve address change + [![GitHub release (latest SemVer)](https://img.shields.io/github/v/release/dnnrly/wait-for)](https://github.com/dnnrly/wait-for/releases/latest) [![GitHub Workflow Status](https://img.shields.io/github/workflow/status/dnnrly/wait-for/Release%20workflow)](https://github.com/dnnrly/wait-for/actions?query=workflow%3A%22Release+workflow%22) [![codecov](https://codecov.io/gh/dnnrly/wait-for/branch/main/graph/badge.svg?token=s0OfKkTFuI)](https://codecov.io/gh/dnnrly/wait-for) @@ -28,14 +34,14 @@ go install github.com/dnnrly/wait-for/cmd/wait-for@latest If you don't have Go installed (in a Docker container, for example) then you can take advantage of the pre-built versions. Check out the [releases](https://github.com/dnnrly/wait-for/releases) and check out the links for direct downloads. You can download and unpack a release like so: ```shell -wget https://github.com/dnnrly/wait-for/releases/download/v0.0.1/wait-for_0.0.1_linux_386.tar.gz -gunzip wait-for_0.0.1_linux_386.tar.gz -tar -xfv wait-for_0.0.1_linux_386.tar +wget https://github.com/dnnrly/wait-for/releases/download/v0.0.5/wait-for_0.0.5_linux_386.tar.gz +gunzip wait-for_0.0.5_linux_386.tar.gz +tar -xfv wait-for_0.0.5_linux_386.tar ``` In your Dockerfile, you can do this: ```docker -ADD https://github.com/dnnrly/wait-for/releases/download/v0.0.1/wait-for_0.0.1_linux_386.tar.gz wait-for.tar.gz +ADD https://github.com/dnnrly/wait-for/releases/download/v0.0.1/wait-for_0.0.5_linux_386.tar.gz wait-for.tar.gz RUN gunzip wait-for.tar.gz && tar -xf wait-for.tar ``` @@ -50,10 +56,21 @@ $ wait-for http://your-service-here:8080/health https://another-service/ ``` ### Waiting for gRPC services + ```shell script $ wait-for grpc-server:8092 other-grpc-server:9091 ``` +### Waiting for DNS changes + +```shell script +$ wait-for dns:google.com +``` + +This will wait for the list of IP addresses bound to that DNS name to be +updated, regardless of order. You can use this to wait for a DNS update +such as failover or other similar operations. + ### Preconfiguring services to connect to ```shell script @@ -81,6 +98,9 @@ wait-for: snmp-service: type: tcp target: snmp-trap-dns:514 + dns-thing: + type: dns + target: your.r53-entry.com ``` ### Using `wait-for` in Docker Compose diff --git a/cmd/wait-for/main.go b/cmd/wait-for/main.go index 4403142..a8b1523 100644 --- a/cmd/wait-for/main.go +++ b/cmd/wait-for/main.go @@ -4,6 +4,7 @@ import ( "flag" "fmt" "log" + "net" "os" waitfor "github.com/dnnrly/wait-for" @@ -40,6 +41,13 @@ func main() { os.Exit(1) } + waitfor.SupportedWaiters = map[string]waitfor.Waiter{ + "http": waitfor.WaiterFunc(waitfor.HTTPWaiter), + "tcp": waitfor.WaiterFunc(waitfor.TCPWaiter), + "grpc": waitfor.WaiterFunc(waitfor.GRPCWaiter), + "dns": waitfor.NewDNSWaiter(net.LookupIP, logger), + } + err = waitfor.WaitOn(config, logger, flag.Args(), waitfor.SupportedWaiters) if err != nil { _, _ = fmt.Fprintf(os.Stderr, "%v", err) diff --git a/config.go b/config.go index ef4bc5e..3f68aa8 100644 --- a/config.go +++ b/config.go @@ -96,6 +96,15 @@ func (c *Config) AddFromString(t string) error { return nil } + if strings.HasPrefix(t, "dns:") { + c.Targets[t] = TargetConfig{ + Target: strings.Replace(t, "dns:", "", 1), + Type: "dns", + Timeout: c.DefaultTimeout, + } + return nil + } + return errors.New("unable to understand target " + t) } diff --git a/config_test.go b/config_test.go index 022065f..c69db64 100644 --- a/config_test.go +++ b/config_test.go @@ -85,9 +85,10 @@ func TestConfig_AddFromString(t *testing.T) { assert.NoError(t, config.AddFromString("https://some-host/endpoint")) assert.NoError(t, config.AddFromString("http://another-host/endpoint")) assert.NoError(t, config.AddFromString("tcp:listener-tcp:9090")) + assert.NoError(t, config.AddFromString("dns:some.dns.com")) assert.Error(t, config.AddFromString("udp:some-listener:9090")) - assert.Equal(t, 4, len(config.Targets)) + assert.Equal(t, 5, len(config.Targets)) assert.Equal(t, "http://some-host/endpoint", config.Targets["http://some-host/endpoint"].Target) assert.Equal(t, "http", config.Targets["http://some-host/endpoint"].Type) @@ -107,6 +108,10 @@ func TestConfig_AddFromString(t *testing.T) { assert.Equal(t, "listener-tcp:9090", config.Targets["tcp:listener-tcp:9090"].Target) assert.Equal(t, "tcp", config.Targets["tcp:listener-tcp:9090"].Type) assert.Equal(t, time.Second*5, config.Targets["tcp:listener-tcp:9090"].Timeout) + + assert.Equal(t, "some.dns.com", config.Targets["dns:some.dns.com"].Target) + assert.Equal(t, "dns", config.Targets["dns:some.dns.com"].Type) + assert.Equal(t, time.Second*5, config.Targets["dns:some.dns.com"].Timeout) } func TestConfig_Filters(t *testing.T) { diff --git a/waitfor.go b/waitfor.go index 974dc87..75ab63e 100644 --- a/waitfor.go +++ b/waitfor.go @@ -3,35 +3,43 @@ package waitfor import ( "context" "fmt" - "google.golang.org/grpc/credentials/insecure" "net" "net/http" + "sort" + "strings" "time" + "google.golang.org/grpc/credentials/insecure" + "golang.org/x/sync/errgroup" "google.golang.org/grpc" "github.com/spf13/afero" ) +type Waiter interface { + Wait(name string, target *TargetConfig) error +} + // WaiterFunc is used to implement waiting for a specific type of target. // The name is used in the error and target is the actual destination being tested. type WaiterFunc func(name string, target *TargetConfig) error + +func (w WaiterFunc) Wait(name string, target *TargetConfig) error { + return w(name, target) +} + type Logger func(string, ...interface{}) // NullLogger can be used in place of a real logging function var NullLogger = func(f string, a ...interface{}) {} // SupportedWaiters is a mapping of known protocol names to waiter implementations -var SupportedWaiters = map[string]WaiterFunc{ - "http": HTTPWaiter, - "tcp": TCPWaiter, - "grpc": GRPCWaiter, -} +var SupportedWaiters map[string]Waiter // WaitOn implements waiting for many targets, using the location of config file provided with named targets to wait until // all of those targets are responding as expected -func WaitOn(config *Config, logger Logger, targets []string, waiters map[string]WaiterFunc) error { +func WaitOn(config *Config, logger Logger, targets []string, waiters map[string]Waiter) error { for _, target := range targets { if !config.GotTarget(target) { @@ -80,7 +88,7 @@ func OpenConfig(configFile, defaultTimeout, defaultHTTPTimeout string, fs afero. return config, nil } -func waitOnTargets(logger Logger, targets map[string]TargetConfig, waiters map[string]WaiterFunc) error { +func waitOnTargets(logger Logger, targets map[string]TargetConfig, waiters map[string]Waiter) error { var eg errgroup.Group for name, target := range targets { @@ -108,14 +116,14 @@ func waitOnTargets(logger Logger, targets map[string]TargetConfig, waiters map[s return nil } -func waitOnSingleTarget(name string, logger Logger, target TargetConfig, waiter WaiterFunc) error { +func waitOnSingleTarget(name string, logger Logger, target TargetConfig, waiter Waiter) error { end := time.Now().Add(target.Timeout) - err := waiter(name, &target) + err := waiter.Wait(name, &target) for err != nil && end.After(time.Now()) { logger("error while waiting for %s: %v", name, err) time.Sleep(time.Second) - err = waiter(name, &target) + err = waiter.Wait(name, &target) } if err != nil { @@ -182,3 +190,59 @@ func isSuccess(code int) bool { return true } + +type DNSLookup func(host string) ([]net.IP, error) + +type DNSWaiter struct { + lookup DNSLookup + logger Logger +} + +func NewDNSWaiter(lookup DNSLookup, logger Logger) *DNSWaiter { + return &DNSWaiter{ + lookup: lookup, + logger: logger, + } +} + +type IPList []net.IP + +func (l IPList) Equals(r IPList) bool { + return l.String() == r.String() +} + +func (l IPList) Len() int { + return len(l) +} +func (l IPList) Swap(i, j int) { l[i], l[j] = l[j], l[i] } +func (l IPList) Less(i, j int) bool { return strings.Compare(l[i].String(), l[j].String()) < 0 } +func (l IPList) String() string { + sort.Sort(l) + var s []string + for _, v := range l { + s = append(s, v.String()) + } + return strings.Join(s, ",") +} + +func (w *DNSWaiter) Wait(host string, target *TargetConfig) error { + in, _ := w.lookup(target.Target) + initial := IPList(in) + last := initial + + start := time.Now() + now := start + + for now.Sub(start) < target.Timeout { + w.logger("got DNS result %s", last) + time.Sleep(time.Second) + l, _ := w.lookup(target.Target) + last = IPList(l) + + if !initial.Equals(last) { + return nil + } + now = time.Now() + } + return fmt.Errorf("timed out waiting for DNS update to %s", host) +} diff --git a/waitfor_test.go b/waitfor_test.go index 1014742..e3a701d 100644 --- a/waitfor_test.go +++ b/waitfor_test.go @@ -3,17 +3,27 @@ package waitfor import ( "errors" "fmt" - "github.com/phayes/freeport" - "google.golang.org/grpc" "net" "testing" "time" + "github.com/phayes/freeport" + "google.golang.org/grpc" + "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +var ( + ip1 = net.IPv4(byte(0x01), byte(0x02), byte(0x03), byte(0x04)) + ip2 = net.IPv4(byte(0x11), byte(0x12), byte(0x13), byte(0x14)) + ip3 = net.IPv4(byte(0x21), byte(0x22), byte(0x23), byte(0x24)) + ip4 = net.IPv4(byte(0x04), byte(0x05), byte(0x06), byte(0x07)) + ip5 = net.IPv4(byte(0x14), byte(0x15), byte(0x16), byte(0x17)) + ip6 = net.IPv4(byte(0x24), byte(0x22), byte(0x23), byte(0x24)) +) + func Test_isSuccess(t *testing.T) { assert.True(t, isSuccess(200)) assert.True(t, isSuccess(214)) @@ -80,12 +90,12 @@ func TestOpenConfig_defaultHTTPTimeoutCanBeSet(t *testing.T) { } func TestWaitOn_errorsInvalidTarget(t *testing.T) { - err := WaitOn(NewConfig(), NullLogger, []string{"localhost"}, map[string]WaiterFunc{}) + err := WaitOn(NewConfig(), NullLogger, []string{"localhost"}, map[string]Waiter{}) assert.Error(t, err) } func TestRun_errorsOnParseFailure(t *testing.T) { - err := WaitOn(NewConfig(), NullLogger, []string{"http://localhost"}, map[string]WaiterFunc{}) + err := WaitOn(NewConfig(), NullLogger, []string{"http://localhost"}, map[string]Waiter{}) assert.Error(t, err) } @@ -97,7 +107,7 @@ func TestWaitOnSingleTarget_succeedsImmediately(t *testing.T) { "name", doLog, TargetConfig{Timeout: time.Second * 2}, - func(name string, target *TargetConfig) error { return nil }, + WaiterFunc(func(name string, target *TargetConfig) error { return nil }), ) assert.NoError(t, err) @@ -115,12 +125,12 @@ func TestWaitOnSingleTarget_succeedsAfterWaiting(t *testing.T) { "name", doLog, TargetConfig{Timeout: time.Second * 2}, - func(name string, target *TargetConfig) error { + WaiterFunc(func(name string, target *TargetConfig) error { if waitUntil.After(time.Now()) { return fmt.Errorf("there was an error") } return nil - }, + }), ) assert.NoError(t, err) @@ -136,9 +146,9 @@ func TestWaitOnSingleTarget_failsIfTimerExpires(t *testing.T) { "name", doLog, TargetConfig{Timeout: time.Second * 2}, - func(name string, target *TargetConfig) error { + WaiterFunc(func(name string, target *TargetConfig) error { return fmt.Errorf("") - }, + }), ) assert.Error(t, err) @@ -149,7 +159,7 @@ func TestWaitOnTargets_failsForUnknownType(t *testing.T) { err := waitOnTargets( NullLogger, map[string]TargetConfig{"unkown": {Type: "unknown type"}}, - map[string]WaiterFunc{"type": func(string, *TargetConfig) error { return errors.New("") }}, + map[string]Waiter{"type": WaiterFunc(func(string, *TargetConfig) error { return errors.New("") })}, ) require.Error(t, err) @@ -162,9 +172,9 @@ func TestWaitOnTargets_selectsCorrectWaiter(t *testing.T) { map[string]TargetConfig{ "type 1": {Type: "t1"}, }, - map[string]WaiterFunc{ - "t1": func(string, *TargetConfig) error { return nil }, - "t2": func(string, *TargetConfig) error { return errors.New("an error") }, + map[string]Waiter{ + "t1": WaiterFunc(func(string, *TargetConfig) error { return nil }), + "t2": WaiterFunc(func(string, *TargetConfig) error { return errors.New("an error") }), }, ) @@ -178,9 +188,9 @@ func TestWaitOnTargets_failsWhenWaiterFails(t *testing.T) { "type 1": {Type: "t1"}, "type 2": {Type: "t2"}, }, - map[string]WaiterFunc{ - "t1": func(string, *TargetConfig) error { return nil }, - "t2": func(string, *TargetConfig) error { return errors.New("an error") }, + map[string]Waiter{ + "t1": WaiterFunc(func(string, *TargetConfig) error { return nil }), + "t2": WaiterFunc(func(string, *TargetConfig) error { return errors.New("an error") }), }, ) @@ -228,11 +238,27 @@ func TestGRPCWaiter_succeedsImmediately(t *testing.T) { Target: lis.Addr().String(), Timeout: DefaultTimeout, Type: "grpc", - }, SupportedWaiters["grpc"]) + }, WaiterFunc(GRPCWaiter)) assert.Nil(t, err, "error waiting for grpc: %v", err) } +func TestIPList_Equality(t *testing.T) { + l1 := IPList([]net.IP{ip1, ip2, ip3}) + l2 := IPList([]net.IP{ip1, ip3, ip2}) + l3 := IPList([]net.IP{ip3, ip3, ip2}) + l4 := IPList([]net.IP{ip1, ip2, ip3, ip3}) + + assert.Truef(t, l1.Equals(l2), "%s != %s", l1, l2) + assert.Truef(t, l2.Equals(l1), "%s != %s", l2, l1) + assert.Falsef(t, l1.Equals(l3), "%s == %s", l1, l3) + assert.Falsef(t, l1.Equals(l4), "%s == %s", l1, l4) +} + +func TestIPList_String(t *testing.T) { + assert.Equal(t, "1.2.3.4,17.18.19.20,33.34.35.36", IPList{ip1, ip2, ip3}.String()) +} + func TestGRPCWaiter_failsToConnect(t *testing.T) { server, lis, err := setupGrpcServer(t) if err != nil { @@ -244,8 +270,110 @@ func TestGRPCWaiter_failsToConnect(t *testing.T) { Target: "localhost:8081", Timeout: DefaultTimeout, Type: "grpc", - }, SupportedWaiters["grpc"]) + }, WaiterFunc(GRPCWaiter)) assert.NotNil(t, err, "expected error but error was nil") fmt.Println(err) } + +func TestDNSWaiter_resolvesCorrectDNSName(t *testing.T) { + name := "" + w := NewDNSWaiter(func(host string) ([]net.IP, error) { + name = host + return []net.IP{ip1, ip2, ip3}, nil + }, NullLogger) + + _ = w.Wait("dns1", &TargetConfig{ + Target: "dns.name", + }) + assert.Equal(t, "dns.name", name) +} + +func TestDNSWaiter_timesOutOnSameDNS(t *testing.T) { + w := NewDNSWaiter(func(host string) ([]net.IP, error) { return []net.IP{ip1, ip2, ip3}, nil }, NullLogger) + + start := time.Now() + err := w.Wait("dns1", &TargetConfig{ + Target: "dns.name", + Timeout: time.Second, + }) + end := time.Now() + require.Error(t, err) + assert.Equal(t, "timed out waiting for DNS update to dns1", err.Error()) + assert.GreaterOrEqual(t, end.Sub(start), time.Second) +} + +func TestDNSWaiter_successAfterDNSChange(t *testing.T) { + ips := [][]net.IP{ + {ip1, ip2, ip3}, + {ip1, ip2, ip3}, + {ip4, ip5, ip6}, + } + w := NewDNSWaiter(func(host string) ([]net.IP, error) { + next := ips[0] + if len(ips) > 0 { + ips = ips[1:] + } + return next, nil + }, NullLogger) + + err := w.Wait("dns1", &TargetConfig{ + Target: "dns.name", + Type: "dns", + Timeout: time.Second * 3, + }) + require.NoError(t, err) +} + +func TestDNSWaiter_allowsAddressrderChange(t *testing.T) { + ips := [][]net.IP{ + {ip1, ip2, ip3}, + {ip2, ip1, ip3}, + {ip1, ip3, ip2}, + } + w := NewDNSWaiter(func(host string) ([]net.IP, error) { + next := ips[0] + if len(ips) > 0 { + ips = ips[1:] + } + return next, nil + }, NullLogger) + + err := w.Wait("dns1", &TargetConfig{ + Target: "dns.name", + Type: "dns", + Timeout: time.Second * 2, + }) + require.Error(t, err) +} + +func TestDNSWaiter_returnsErrorOnStart(t *testing.T) { + w := NewDNSWaiter(func(host string) ([]net.IP, error) { + return nil, fmt.Errorf("some error") + }, NullLogger) + + err := w.Wait("dns1", &TargetConfig{ + Target: "dns.name", + Type: "dns", + Timeout: time.Second * 2, + }) + assert.Error(t, err) +} + +func TestDNSWaiter_returnsErrorWhenWaitingz(t *testing.T) { + errs := []error{nil, nil, fmt.Errorf("some error")} + w := NewDNSWaiter(func(host string) ([]net.IP, error) { + next := errs[0] + if len(errs) > 0 { + errs = errs[1:] + } + return nil, next + }, NullLogger) + + err := w.Wait("dns1", &TargetConfig{ + Target: "dns.name", + Type: "dns", + Timeout: time.Second * 2, + }) + assert.Error(t, err) +}