diff --git a/ably/export_test.go b/ably/export_test.go index ee1e6a1d..9c4482bc 100644 --- a/ably/export_test.go +++ b/ably/export_test.go @@ -280,6 +280,11 @@ type DurationFromMsecs = durationFromMsecs type ProtoErrorInfo = errorInfo type ProtoFlag = protoFlag type ProtocolMessage = protocolMessage +type WebsocketErr = websocketErr + +func (w *WebsocketErr) HttpResp() *http.Response { + return w.resp +} const ( DefaultCipherKeyLength = defaultCipherKeyLength diff --git a/ably/realtime_client_integration_test.go b/ably/realtime_client_integration_test.go index 872e749d..857a022e 100644 --- a/ably/realtime_client_integration_test.go +++ b/ably/realtime_client_integration_test.go @@ -272,11 +272,28 @@ func TestRealtime_RTN17_Integration_HostFallback_Internal_Server_Error(t *testin serverURL, err := url.Parse(server.URL) assert.NoError(t, err) + fallbackHost := "sandbox-a-fallback.ably-realtime.com" + connAttempts := 0 + app, realtime := ablytest.NewRealtime( ably.WithAutoConnect(false), ably.WithTLS(false), ably.WithUseTokenAuth(true), - ably.WithFallbackHosts([]string{"sandbox-a-fallback.ably-realtime.com"}), + ably.WithFallbackHosts([]string{fallbackHost}), + ably.WithDial(func(protocol string, u *url.URL, timeout time.Duration) (ably.Conn, error) { + connAttempts += 1 + conn, err := ably.DialWebsocket(protocol, u, timeout) + if connAttempts == 1 { + assert.Equal(t, serverURL.Host, u.Host) + var websocketErr *ably.WebsocketErr + assert.ErrorAs(t, err, &websocketErr) + assert.Equal(t, http.StatusInternalServerError, websocketErr.HttpResp().StatusCode) + } else { + assert.NoError(t, err) + assert.Equal(t, fallbackHost, u.Hostname()) + } + return conn, err + }), ably.WithRealtimeHost(serverURL.Host)) defer safeclose(t, ablytest.FullRealtimeCloser(realtime), app) @@ -284,24 +301,46 @@ func TestRealtime_RTN17_Integration_HostFallback_Internal_Server_Error(t *testin err = ablytest.Wait(ablytest.ConnWaiter(realtime, realtime.Connect, ably.ConnectionEventConnected), nil) assert.NoError(t, err) - assert.Equal(t, "sandbox-a-fallback.ably-realtime.com", realtime.Rest().ActiveRealtimeHost()) + assert.Equal(t, 2, connAttempts) + assert.Equal(t, fallbackHost, realtime.Rest().ActiveRealtimeHost()) } func TestRealtime_RTN17_Integration_HostFallback_Timeout(t *testing.T) { + timedOut := make(chan bool) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - time.Sleep(3 * time.Second) + <-timedOut w.WriteHeader(http.StatusSwitchingProtocols) })) defer server.Close() serverURL, err := url.Parse(server.URL) assert.NoError(t, err) + fallbackHost := "sandbox-a-fallback.ably-realtime.com" + requestTimeout := 2 * time.Second + connAttempts := 0 + app, realtime := ablytest.NewRealtime( ably.WithAutoConnect(false), ably.WithTLS(false), ably.WithUseTokenAuth(true), - ably.WithFallbackHosts([]string{"sandbox-a-fallback.ably-realtime.com"}), - ably.WithRealtimeRequestTimeout(2*time.Second), + ably.WithFallbackHosts([]string{fallbackHost}), + ably.WithRealtimeRequestTimeout(requestTimeout), + ably.WithDial(func(protocol string, u *url.URL, timeout time.Duration) (ably.Conn, error) { + connAttempts += 1 + assert.Equal(t, requestTimeout, timeout) + conn, err := ably.DialWebsocket(protocol, u, timeout) + if connAttempts == 1 { + assert.Equal(t, serverURL.Host, u.Host) + var timeoutError net.Error + assert.ErrorAs(t, err, &timeoutError) + assert.True(t, timeoutError.Timeout()) + timedOut <- true + } else { + assert.NoError(t, err) + assert.Equal(t, fallbackHost, u.Hostname()) + } + return conn, err + }), ably.WithRealtimeHost(serverURL.Host)) defer safeclose(t, ablytest.FullRealtimeCloser(realtime), app) @@ -309,7 +348,8 @@ func TestRealtime_RTN17_Integration_HostFallback_Timeout(t *testing.T) { err = ablytest.Wait(ablytest.ConnWaiter(realtime, realtime.Connect, ably.ConnectionEventConnected), nil) assert.NoError(t, err) - assert.Equal(t, "sandbox-a-fallback.ably-realtime.com", realtime.Rest().ActiveRealtimeHost()) + assert.Equal(t, 2, connAttempts) + assert.Equal(t, fallbackHost, realtime.Rest().ActiveRealtimeHost()) } func checkUnique(ch chan string, typ string, n int) error {