diff --git a/README.md b/README.md index 64247ebf..b56e1c92 100644 --- a/README.md +++ b/README.md @@ -450,8 +450,6 @@ See [jwt auth issue](https://github.com/ably/ably-go/issues/569) for more detail - Inband reauthentication is not supported; expiring tokens will trigger a disconnection and resume of a realtime connection. See [server initiated auth](https://github.com/ably/ably-go/issues/228) for more details. -- Realtime connection failure handling is partially implemented. See [host fallback](https://github.com/ably/ably-go/issues/225) for more details. - - Channel suspended state is partially implemented. See [suspended channel state](https://github.com/ably/ably-go/issues/568). - Realtime Ping function is not implemented. diff --git a/ably/error.go b/ably/error.go index d88ff902..623422ce 100644 --- a/ably/error.go +++ b/ably/error.go @@ -147,6 +147,17 @@ func errFromUnprocessableBody(resp *http.Response) error { return &ErrorInfo{Code: ErrBadRequest, StatusCode: resp.StatusCode, err: err} } +func isTimeoutOrDnsErr(err error) bool { + var netErr net.Error + if errors.As(err, &netErr) { + if netErr.Timeout() { // RSC15l2 + return true + } + } + var dnsErr *net.DNSError + return errors.As(err, &dnsErr) // RSC15l1 +} + func checkValidHTTPResponse(resp *http.Response) error { type errorBody struct { Error errorInfo `json:"error,omitempty" codec:"error,omitempty"` diff --git a/ably/export_test.go b/ably/export_test.go index d9e73c3b..9c4482bc 100644 --- a/ably/export_test.go +++ b/ably/export_test.go @@ -36,10 +36,6 @@ func (opts *clientOptions) RestURL() string { return opts.restURL() } -func (opts *clientOptions) RealtimeURL() string { - return opts.realtimeURL() -} - func (c *REST) Post(ctx context.Context, path string, in, out interface{}) (*http.Response, error) { return c.post(ctx, path, in, out) } @@ -93,6 +89,10 @@ func (c *REST) GetCachedFallbackHost() string { return c.hostCache.get() } +func (c *REST) ActiveRealtimeHost() string { + return c.activeRealtimeHost +} + func (c *RealtimeChannel) GetChannelSerial() string { c.mtx.Lock() defer c.mtx.Unlock() @@ -121,6 +121,10 @@ func (opts *clientOptions) GetFallbackRetryTimeout() time.Duration { return opts.fallbackRetryTimeout() } +func (opts *clientOptions) HasActiveInternetConnection() bool { + return opts.hasActiveInternetConnection() +} + func NewErrorInfo(code ErrorCode, err error) *ErrorInfo { return newError(code, err) } @@ -222,6 +226,10 @@ func (c *Connection) SetKey(key string) { c.key = key } +func (r *Realtime) Rest() *REST { + return r.rest +} + func (c *RealtimePresence) Members() map[string]*PresenceMessage { c.mtx.Lock() defer c.mtx.Unlock() @@ -272,6 +280,11 @@ type DurationFromMsecs = durationFromMsecs type ProtoErrorInfo = errorInfo type ProtoFlag = protoFlag type ProtocolMessage = protocolMessage +type WebsocketErr = websocketErr + +func (w *WebsocketErr) HttpResp() *http.Response { + return w.resp +} const ( DefaultCipherKeyLength = defaultCipherKeyLength diff --git a/ably/options.go b/ably/options.go index 5e22a377..1bf26783 100644 --- a/ably/options.go +++ b/ably/options.go @@ -1,9 +1,11 @@ package ably import ( + "bytes" "context" "errors" "fmt" + "io" "log" "net" "net/http" @@ -28,6 +30,10 @@ const ( Port = 80 TLSPort = 443 maxMessageSize = 65536 // 64kb, default value TO3l8 + + // RTN17c + internetCheckUrl = "https://internet-up.ably-realtime.com/is-the-internet-up.txt" + internetCheckOk = "yes" ) var defaultOptions = clientOptions{ @@ -482,10 +488,10 @@ func (opts *clientOptions) restURL() (restUrl string) { return "https://" + baseUrl } -func (opts *clientOptions) realtimeURL() (realtimeUrl string) { - baseUrl := opts.getRealtimeHost() +func (opts *clientOptions) realtimeURL(realtimeHost string) (realtimeUrl string) { + baseUrl := realtimeHost _, _, err := net.SplitHostPort(baseUrl) - if err != nil { // set port if not set in baseUrl + if err != nil { // set port if not set in provided realtimeHost port, _ := opts.activePort() baseUrl = net.JoinHostPort(baseUrl, strconv.Itoa(port)) } @@ -595,6 +601,20 @@ func (opts *clientOptions) idempotentRESTPublishing() bool { return opts.IdempotentRESTPublishing } +// RTN17c +func (opts *clientOptions) hasActiveInternetConnection() bool { + res, err := opts.httpclient().Get(internetCheckUrl) + if err != nil || res.StatusCode != 200 { + return false + } + defer res.Body.Close() + data, err := io.ReadAll(res.Body) + if err != nil { + return false + } + return bytes.Contains(data, []byte(internetCheckOk)) +} + type ScopeParams struct { Start time.Time End time.Time diff --git a/ably/options_test.go b/ably/options_test.go index c622bb54..855076bd 100644 --- a/ably/options_test.go +++ b/ably/options_test.go @@ -14,31 +14,32 @@ import ( ) func TestDefaultFallbacks_RSC15h(t *testing.T) { - t.Run("with env should return environment fallback hosts", func(t *testing.T) { - expectedFallBackHosts := []string{ - "a.ably-realtime.com", - "b.ably-realtime.com", - "c.ably-realtime.com", - "d.ably-realtime.com", - "e.ably-realtime.com", - } - hosts := ably.DefaultFallbackHosts() - assert.Equal(t, expectedFallBackHosts, hosts) - }) + expectedFallBackHosts := []string{ + "a.ably-realtime.com", + "b.ably-realtime.com", + "c.ably-realtime.com", + "d.ably-realtime.com", + "e.ably-realtime.com", + } + hosts := ably.DefaultFallbackHosts() + assert.Equal(t, expectedFallBackHosts, hosts) } func TestEnvFallbackHosts_RSC15i(t *testing.T) { - t.Run("with env should return environment fallback hosts", func(t *testing.T) { - expectedFallBackHosts := []string{ - "sandbox-a-fallback.ably-realtime.com", - "sandbox-b-fallback.ably-realtime.com", - "sandbox-c-fallback.ably-realtime.com", - "sandbox-d-fallback.ably-realtime.com", - "sandbox-e-fallback.ably-realtime.com", - } - hosts := ably.GetEnvFallbackHosts("sandbox") - assert.Equal(t, expectedFallBackHosts, hosts) - }) + expectedFallBackHosts := []string{ + "sandbox-a-fallback.ably-realtime.com", + "sandbox-b-fallback.ably-realtime.com", + "sandbox-c-fallback.ably-realtime.com", + "sandbox-d-fallback.ably-realtime.com", + "sandbox-e-fallback.ably-realtime.com", + } + hosts := ably.GetEnvFallbackHosts("sandbox") + assert.Equal(t, expectedFallBackHosts, hosts) +} + +func TestInternetConnectionCheck_RTN17c(t *testing.T) { + clientOptions := ably.NewClientOptions() + assert.True(t, clientOptions.HasActiveInternetConnection()) } func TestFallbackHosts_RSC15b(t *testing.T) { diff --git a/ably/realtime_client_integration_test.go b/ably/realtime_client_integration_test.go index 925918fa..857a022e 100644 --- a/ably/realtime_client_integration_test.go +++ b/ably/realtime_client_integration_test.go @@ -5,6 +5,7 @@ package ably_test import ( "context" + "errors" "fmt" "net" "net/http" @@ -15,9 +16,11 @@ import ( "time" "github.com/ably/ably-go/ably" + "github.com/ably/ably-go/ably/internal/ablyutil" "github.com/ably/ably-go/ablytest" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestRealtime_RealtimeHost(t *testing.T) { @@ -132,6 +135,223 @@ func TestRealtime_RSC7_AblyAgent(t *testing.T) { }) } +func TestRealtime_RTN17_HostFallback(t *testing.T) { + t.Parallel() + + getDNSErr := func() *net.DNSError { + return &net.DNSError{ + IsTimeout: false, + } + } + + getTimeoutErr := func() error { + return &errTimeout{} + } + + initClientWithConnError := func(customErr error, opts ...ably.ClientOption) (visitedHosts []string) { + client, err := ably.NewRealtime(append(opts, ably.WithAutoConnect(false), ably.WithKey("fake:key"), + ably.WithDial(func(protocol string, u *url.URL, timeout time.Duration) (ably.Conn, error) { + visitedHosts = append(visitedHosts, u.Hostname()) + return nil, customErr + }))...) + require.NoError(t, err) + ablytest.Wait(ablytest.ConnWaiter(client, client.Connect, ably.ConnectionEventDisconnected), nil) + return + } + + t.Run("RTN17a: First attempt should be made on default primary host", func(t *testing.T) { + visitedHosts := initClientWithConnError(errors.New("host url is wrong")) + assert.Equal(t, "realtime.ably.io", visitedHosts[0]) + }) + + t.Run("RTN17b: Fallback behaviour", func(t *testing.T) { + t.Parallel() + + t.Run("apply when default realtime endpoint is not overridden, port/tlsport not set", func(t *testing.T) { + visitedHosts := initClientWithConnError(getTimeoutErr()) + expectedPrimaryHost := "realtime.ably.io" + expectedFallbackHosts := ably.DefaultFallbackHosts() + + assert.Equal(t, 6, len(visitedHosts)) + assert.Equal(t, expectedPrimaryHost, visitedHosts[0]) + assert.ElementsMatch(t, expectedFallbackHosts, visitedHosts[1:]) + }) + + t.Run("does not apply when the custom realtime endpoint is used", func(t *testing.T) { + visitedHosts := initClientWithConnError(getTimeoutErr(), ably.WithRealtimeHost("custom-realtime.ably.io")) + expectedHost := "custom-realtime.ably.io" + + require.Equal(t, 1, len(visitedHosts)) + assert.Equal(t, expectedHost, visitedHosts[0]) + }) + + t.Run("apply when fallbacks are provided", func(t *testing.T) { + fallbacks := []string{"fallback0", "fallback1", "fallback2"} + visitedHosts := initClientWithConnError(getTimeoutErr(), ably.WithFallbackHosts(fallbacks)) + expectedPrimaryHost := "realtime.ably.io" + + assert.Equal(t, 4, len(visitedHosts)) + assert.Equal(t, expectedPrimaryHost, visitedHosts[0]) + assert.ElementsMatch(t, fallbacks, visitedHosts[1:]) + }) + + t.Run("apply when fallbackHostUseDefault is true, even if env. or host is set", func(t *testing.T) { + visitedHosts := initClientWithConnError( + getTimeoutErr(), + ably.WithFallbackHostsUseDefault(true), + ably.WithEnvironment("custom"), + ably.WithRealtimeHost("custom-ably.realtime.com")) + + expectedPrimaryHost := "custom-ably.realtime.com" + expectedFallbackHosts := ably.DefaultFallbackHosts() + + assert.Equal(t, 6, len(visitedHosts)) + assert.Equal(t, expectedPrimaryHost, visitedHosts[0]) + assert.ElementsMatch(t, expectedFallbackHosts, visitedHosts[1:]) + }) + }) + + t.Run("RTN17c: Verifies internet connection is active in case of error necessitating use of an alternative host", func(t *testing.T) { + t.Parallel() + const internetCheckUrl = "https://internet-up.ably-realtime.com/is-the-internet-up.txt" + rec, optn := ablytest.NewHttpRecorder() + visitedHosts := initClientWithConnError(getDNSErr(), optn...) + assert.Equal(t, 6, len(visitedHosts)) // including primary host + assert.Equal(t, 5, len(rec.Requests())) + for _, request := range rec.Requests() { + assert.Equal(t, request.URL.String(), internetCheckUrl) + } + }) + + t.Run("RTN17d: Check for compatible errors before attempting to reconnect to a fallback host", func(t *testing.T) { + visitedHosts := initClientWithConnError(fmt.Errorf("host url is wrong")) // non-dns or non-timeout error + assert.Equal(t, 1, len(visitedHosts)) + visitedHosts = initClientWithConnError(getDNSErr()) + assert.Equal(t, 6, len(visitedHosts)) + visitedHosts = initClientWithConnError(getTimeoutErr()) + assert.Equal(t, 6, len(visitedHosts)) + }) + + t.Run("RTN17e: Same fallback host should be used for REST as Realtime Fallback Host for a given active connection", func(t *testing.T) { + errCh := make(chan error, 1) + errCh <- getTimeoutErr() + realtimeMsgRecorder := NewMessageRecorder() // websocket recorder + restMsgRecorder, optn := ablytest.NewHttpRecorder() // http recorder + _, client := ablytest.NewRealtime(ably.WithAutoConnect(false), + ably.WithDial(func(protocol string, u *url.URL, timeout time.Duration) (ably.Conn, error) { + err, ok := <-errCh + if ok { + close(errCh) + return nil, err // return timeout error for primary host + } + return realtimeMsgRecorder.Dial(protocol, u, timeout) // return dial for subsequent dials + }), optn[0]) + defer client.Close() + + err := ablytest.Wait(ablytest.ConnWaiter(client, client.Connect, ably.ConnectionEventConnected), nil) + if err != nil { + t.Fatalf("Error connecting host with error %v", err) + } + realtimeSuccessHost := realtimeMsgRecorder.URLs()[0].Hostname() + fallbackHosts := ably.GetEnvFallbackHosts("sandbox") + if !ablyutil.SliceContains(fallbackHosts, realtimeSuccessHost) { + t.Fatalf("realtime host must be one of fallback hosts, received %v", realtimeSuccessHost) + } + + client.Time(context.Background()) // make a rest request + restSuccessHost := restMsgRecorder.Request(1).URL.Hostname() // second request is to get the time, first for internet connection + assert.Equal(t, realtimeSuccessHost, restSuccessHost) + }) +} + +func TestRealtime_RTN17_Integration_HostFallback_Internal_Server_Error(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + serverURL, err := url.Parse(server.URL) + assert.NoError(t, err) + + fallbackHost := "sandbox-a-fallback.ably-realtime.com" + connAttempts := 0 + + app, realtime := ablytest.NewRealtime( + ably.WithAutoConnect(false), + ably.WithTLS(false), + ably.WithUseTokenAuth(true), + ably.WithFallbackHosts([]string{fallbackHost}), + ably.WithDial(func(protocol string, u *url.URL, timeout time.Duration) (ably.Conn, error) { + connAttempts += 1 + conn, err := ably.DialWebsocket(protocol, u, timeout) + if connAttempts == 1 { + assert.Equal(t, serverURL.Host, u.Host) + var websocketErr *ably.WebsocketErr + assert.ErrorAs(t, err, &websocketErr) + assert.Equal(t, http.StatusInternalServerError, websocketErr.HttpResp().StatusCode) + } else { + assert.NoError(t, err) + assert.Equal(t, fallbackHost, u.Hostname()) + } + return conn, err + }), + ably.WithRealtimeHost(serverURL.Host)) + + defer safeclose(t, ablytest.FullRealtimeCloser(realtime), app) + + err = ablytest.Wait(ablytest.ConnWaiter(realtime, realtime.Connect, ably.ConnectionEventConnected), nil) + assert.NoError(t, err) + + assert.Equal(t, 2, connAttempts) + assert.Equal(t, fallbackHost, realtime.Rest().ActiveRealtimeHost()) +} + +func TestRealtime_RTN17_Integration_HostFallback_Timeout(t *testing.T) { + timedOut := make(chan bool) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-timedOut + w.WriteHeader(http.StatusSwitchingProtocols) + })) + defer server.Close() + serverURL, err := url.Parse(server.URL) + assert.NoError(t, err) + + fallbackHost := "sandbox-a-fallback.ably-realtime.com" + requestTimeout := 2 * time.Second + connAttempts := 0 + + app, realtime := ablytest.NewRealtime( + ably.WithAutoConnect(false), + ably.WithTLS(false), + ably.WithUseTokenAuth(true), + ably.WithFallbackHosts([]string{fallbackHost}), + ably.WithRealtimeRequestTimeout(requestTimeout), + ably.WithDial(func(protocol string, u *url.URL, timeout time.Duration) (ably.Conn, error) { + connAttempts += 1 + assert.Equal(t, requestTimeout, timeout) + conn, err := ably.DialWebsocket(protocol, u, timeout) + if connAttempts == 1 { + assert.Equal(t, serverURL.Host, u.Host) + var timeoutError net.Error + assert.ErrorAs(t, err, &timeoutError) + assert.True(t, timeoutError.Timeout()) + timedOut <- true + } else { + assert.NoError(t, err) + assert.Equal(t, fallbackHost, u.Hostname()) + } + return conn, err + }), + ably.WithRealtimeHost(serverURL.Host)) + + defer safeclose(t, ablytest.FullRealtimeCloser(realtime), app) + + err = ablytest.Wait(ablytest.ConnWaiter(realtime, realtime.Connect, ably.ConnectionEventConnected), nil) + assert.NoError(t, err) + + assert.Equal(t, 2, connAttempts) + assert.Equal(t, fallbackHost, realtime.Rest().ActiveRealtimeHost()) +} + func checkUnique(ch chan string, typ string, n int) error { close(ch) uniq := make(map[string]struct{}, n) diff --git a/ably/realtime_conn.go b/ably/realtime_conn.go index beb2a358..03ba6d00 100644 --- a/ably/realtime_conn.go +++ b/ably/realtime_conn.go @@ -8,6 +8,8 @@ import ( "strconv" "sync" "time" + + "github.com/ably/ably-go/ably/internal/ablyutil" ) var ( @@ -62,7 +64,6 @@ type Connection struct { msgSerial int64 connStateTTL durationFromMsecs - err error conn conn opts *clientOptions pending pendingEmitter @@ -367,16 +368,15 @@ func (c *Connection) connectWithRetryLoop(arg connArgs) (result, error) { } func (c *Connection) connectWith(arg connArgs) (result, error) { + connectMode := c.getMode() + c.mtx.Lock() // set ably connection state to connecting, connecting state exists regardless of whether raw connection is successful or not if !c.isActive() { // check if already in connecting state c.lockSetState(ConnectionStateConnecting, nil, 0) } c.mtx.Unlock() - u, err := url.Parse(c.opts.realtimeURL()) - if err != nil { - return nil, err - } + var res result if arg.result { res = c.internalEmitter.listenResult( @@ -385,22 +385,47 @@ func (c *Connection) connectWith(arg connArgs) (result, error) { ConnectionStateDisconnected, ) } - connectMode := c.getMode() - query, err := c.params(connectMode) - if err != nil { - return nil, err - } - u.RawQuery = query.Encode() - proto := c.opts.protocol() - if c.State() == ConnectionStateClosed { // RTN12d - if connection is closed by client, don't try to reconnect - return nopResult, nil + var conn conn + primaryHost := c.opts.getRealtimeHost() + hosts := []string{primaryHost} + fallbackHosts, err := c.opts.getFallbackHosts() + if err != nil { + c.log().Warn(err) + } else { + hosts = append(hosts, ablyutil.Shuffle(fallbackHosts)...) } + // Always try primary host first and then fallback hosts for realtime conn + for hostCounter, host := range hosts { + u, err := url.Parse(c.opts.realtimeURL(host)) + if err != nil { + return nil, err + } + query, err := c.params(connectMode) + if err != nil { + return nil, err + } + u.RawQuery = query.Encode() + proto := c.opts.protocol() - // if err is nil, raw connection with server is successful - conn, err := c.dial(proto, u) - if err != nil { - return nil, err + if c.State() == ConnectionStateClosed { // RTN12d - if connection is closed by client, don't try to reconnect + return nopResult, nil + } + // if err is nil, raw connection with server is successful + conn, err = c.dial(proto, u) + if err != nil { + resp := extractHttpResponseFromError(err) + if hostCounter < len(hosts)-1 && canFallBack(err, resp) && c.opts.hasActiveInternetConnection() { // RTN17d, RTN17c + continue + } + return nil, err + } + if host != primaryHost { // RTN17e + c.client.rest.setActiveRealtimeHost(host) + } else if !empty(c.client.rest.activeRealtimeHost) { + c.client.rest.setActiveRealtimeHost("") // reset to default + } + break } c.mtx.Lock() diff --git a/ably/rest_client.go b/ably/rest_client.go index ef506d08..67d05cff 100644 --- a/ably/rest_client.go +++ b/ably/rest_client.go @@ -6,13 +6,11 @@ import ( _ "crypto/sha512" "encoding/base64" "encoding/json" - "errors" "fmt" "io" "io/ioutil" "math/rand" "mime" - "net" "net/http" "net/http/httptrace" "net/url" @@ -134,9 +132,10 @@ type REST struct { //Channels is a [ably.RESTChannels] object (RSN1). Channels *RESTChannels - opts *clientOptions - hostCache *hostCache - log logger + opts *clientOptions + hostCache *hostCache + activeRealtimeHost string // RTN17e + log logger } // NewREST construct a RestClient object using an [ably.ClientOption] object to configure @@ -195,6 +194,10 @@ func (c *REST) Stats(o ...StatsOption) StatsRequest { return StatsRequest{r: c.newPaginatedRequest("/stats", "", params)} } +func (c *REST) setActiveRealtimeHost(realtimeHost string) { + c.activeRealtimeHost = realtimeHost +} + // A StatsOption configures a call to REST.Stats or Realtime.Stats. type StatsOption func(*statsOptions) @@ -641,8 +644,12 @@ func (c *REST) doWithHandle(ctx context.Context, r *request, handle func(*http.R } if h := c.hostCache.get(); h != "" { req.URL.Host = h // RSC15f - c.log.Verbosef("RestClient: setting URL.Host=%q", h) + c.log.Verbosef("RestClient: setting cached URL.Host=%q", h) + } else if !empty(c.activeRealtimeHost) { // RTN17e + req.URL.Host = c.activeRealtimeHost + c.log.Verbosef("RestClient: setting activeRealtimeHost URL.Host=%q", c.activeRealtimeHost) } + if c.opts.Trace != nil { req = req.WithContext(httptrace.WithClientTrace(req.Context(), c.opts.Trace)) c.log.Verbose("RestClient: enabling httptrace") @@ -741,17 +748,6 @@ func canFallBack(err error, res *http.Response) bool { isTimeoutOrDnsErr(err) //RSC15l1, RSC15l2 } -func isTimeoutOrDnsErr(err error) bool { - var netErr net.Error - if errors.As(err, &netErr) { - if netErr.Timeout() { // RSC15l2 - return true - } - } - var dnsErr *net.DNSError - return errors.As(err, &dnsErr) // RSC15l1 -} - // RSC15l3 func isStatusCodeBetween500_504(res *http.Response) bool { return res != nil && diff --git a/ably/websocket.go b/ably/websocket.go index febfc6ae..0e8b7756 100644 --- a/ably/websocket.go +++ b/ably/websocket.go @@ -24,6 +24,21 @@ type websocketConn struct { proto proto } +type websocketErr struct { + err error + resp *http.Response +} + +// websocketErr implements the builtin error interface. +func (e *websocketErr) Error() string { + return e.err.Error() +} + +// Unwrap implements the implicit interface that errors.Unwrap understands. +func (e *websocketErr) Unwrap() error { + return e.err +} + func (ws *websocketConn) Send(msg *protocolMessage) error { switch ws.proto { case jsonProto: @@ -88,16 +103,16 @@ func dialWebsocket(proto string, u *url.URL, timeout time.Duration, agents map[s return nil, errors.New(`invalid protocol "` + proto + `"`) } // Starts a raw websocket connection with server - conn, err := dialWebsocketTimeout(u.String(), "https://"+u.Host, timeout, agents) + conn, resp, err := dialWebsocketTimeout(u.String(), "https://"+u.Host, timeout, agents) if err != nil { - return nil, err + return nil, &websocketErr{err: err, resp: resp} } ws.conn = conn return ws, nil } // dialWebsocketTimeout dials the websocket with a timeout. -func dialWebsocketTimeout(uri, origin string, timeout time.Duration, agents map[string]string) (*websocket.Conn, error) { +func dialWebsocketTimeout(uri, origin string, timeout time.Duration, agents map[string]string) (*websocket.Conn, *http.Response, error) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() @@ -105,13 +120,13 @@ func dialWebsocketTimeout(uri, origin string, timeout time.Duration, agents map[ ops.HTTPHeader = make(http.Header) ops.HTTPHeader.Add(ablyAgentHeader, ablyAgentIdentifier(agents)) - c, _, err := websocket.Dial(ctx, uri, &ops) + c, resp, err := websocket.Dial(ctx, uri, &ops) if err != nil { - return nil, err + return nil, resp, err } - return c, nil + return c, resp, nil } func unwrapConn(c conn) conn { @@ -124,6 +139,14 @@ func unwrapConn(c conn) conn { return unwrapConn(u.Unwrap()) } +func extractHttpResponseFromError(err error) *http.Response { + wsErr, ok := err.(*websocketErr) + if ok { + return wsErr.resp + } + return nil +} + func setConnectionReadLimit(c conn, readLimit int64) error { unwrappedConn := unwrapConn(c) websocketConn, ok := unwrappedConn.(*websocketConn) diff --git a/ablytest/recorders.go b/ablytest/recorders.go index e5af3dcc..2b91c76b 100644 --- a/ablytest/recorders.go +++ b/ablytest/recorders.go @@ -23,6 +23,13 @@ type RoundTripRecorder struct { stopped int32 } +func NewHttpRecorder() (*RoundTripRecorder, []ably.ClientOption) { + rec := &RoundTripRecorder{} + httpClient := &http.Client{Transport: &http.Transport{}} + httpClient.Transport = rec.Hijack(httpClient.Transport) + return rec, []ably.ClientOption{ably.WithHTTPClient(httpClient)} +} + var _ http.RoundTripper = (*RoundTripRecorder)(nil) // Len gives number of recorded request/response pairs. diff --git a/ablytest/sandbox.go b/ablytest/sandbox.go index 59a7f36c..b3a6a40b 100644 --- a/ablytest/sandbox.go +++ b/ablytest/sandbox.go @@ -241,14 +241,13 @@ func (app *Sandbox) Options(opts ...ably.ClientOption) []ably.ClientOption { // If opts want to record round trips inject the recording transport // via TransportHijacker interface. - opt := MergeOptions(opts) - if httpClient := ClientOptionsInspector.HTTPClient(opt); httpClient != nil { + if httpClient := ClientOptionsInspector.HTTPClient(opts); httpClient != nil { if hijacker, ok := httpClient.Transport.(transportHijacker); ok { appHTTPClient.Transport = hijacker.Hijack(appHTTPClient.Transport) - opt = append(opt, ably.WithHTTPClient(appHTTPClient)) + opts = append(opts, ably.WithHTTPClient(appHTTPClient)) } } - appOpts = MergeOptions(appOpts, opt) + appOpts = MergeOptions(appOpts, opts) return appOpts }