diff --git a/ably/realtime_client_integration_test.go b/ably/realtime_client_integration_test.go index 3cca4dd8..c8e33b27 100644 --- a/ably/realtime_client_integration_test.go +++ b/ably/realtime_client_integration_test.go @@ -266,6 +266,29 @@ func TestRealtime_RTN17_HostFallback(t *testing.T) { }) } +func TestRealtime_RTN17_Integration_HostFallback(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + serverURL, err := url.Parse(server.URL) + assert.NoError(t, err) + + app, realtime := ablytest.NewRealtime( + ably.WithAutoConnect(false), + ably.WithTLS(false), + ably.WithUseTokenAuth(true), + ably.WithFallbackHosts(ably.GetEnvFallbackHosts(ablytest.Environment)), + ably.WithRealtimeHost(serverURL.Host)) + + defer safeclose(t, ablytest.FullRealtimeCloser(realtime), app) + + err = ablytest.Wait(ablytest.ConnWaiter(realtime, realtime.Connect, ably.ConnectionEventConnected), nil) + if err != nil { + t.Fatalf("Error connecting host with error %v", err) + } +} + func checkUnique(ch chan string, typ string, n int) error { close(ch) uniq := make(map[string]struct{}, n) diff --git a/ably/realtime_conn.go b/ably/realtime_conn.go index 3abcab00..72323f74 100644 --- a/ably/realtime_conn.go +++ b/ably/realtime_conn.go @@ -416,7 +416,7 @@ func (c *Connection) connectWith(arg connArgs) (result, error) { } break } - resp := extractHttpResponseFromConn(c.conn) + resp := extractHttpResponseFromError(err) if hostCounter < len(hosts)-1 && canFallBack(err, resp) && c.opts.hasActiveInternetConnection() { // RTN17d, RTN17c continue } diff --git a/ably/websocket.go b/ably/websocket.go index 84d71ddd..0e8b7756 100644 --- a/ably/websocket.go +++ b/ably/websocket.go @@ -22,7 +22,21 @@ const ( type websocketConn struct { conn *websocket.Conn proto proto - resp *http.Response +} + +type websocketErr struct { + err error + resp *http.Response +} + +// websocketErr implements the builtin error interface. +func (e *websocketErr) Error() string { + return e.err.Error() +} + +// Unwrap implements the implicit interface that errors.Unwrap understands. +func (e *websocketErr) Unwrap() error { + return e.err } func (ws *websocketConn) Send(msg *protocolMessage) error { @@ -90,9 +104,8 @@ func dialWebsocket(proto string, u *url.URL, timeout time.Duration, agents map[s } // Starts a raw websocket connection with server conn, resp, err := dialWebsocketTimeout(u.String(), "https://"+u.Host, timeout, agents) - ws.resp = resp if err != nil { - return nil, err + return nil, &websocketErr{err: err, resp: resp} } ws.conn = conn return ws, nil @@ -126,11 +139,10 @@ func unwrapConn(c conn) conn { return unwrapConn(u.Unwrap()) } -func extractHttpResponseFromConn(c conn) *http.Response { - unwrappedConn := unwrapConn(c) - websocketConn, ok := unwrappedConn.(*websocketConn) +func extractHttpResponseFromError(err error) *http.Response { + wsErr, ok := err.(*websocketErr) if ok { - return websocketConn.resp + return wsErr.resp } return nil }