From b7fbcead0a52d4dbe6aa818012091943514851f7 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 14 Oct 2024 11:45:43 +0200 Subject: [PATCH 01/31] chore: wip add epoll --- .../graphql_datasource/graphql_datasource.go | 18 ++ .../graphql_datasource_test.go | 8 + .../graphql_subscription_client.go | 190 ++++++++++++++++-- .../graphql_subscription_client_test.go | 146 ++++++++++++++ .../graphql_datasource/graphql_tws_handler.go | 107 +++++++++- .../graphql_datasource/graphql_ws_handler.go | 87 +++++++- v2/pkg/engine/resolve/datasource.go | 6 + v2/pkg/engine/resolve/resolve.go | 14 +- 8 files changed, 537 insertions(+), 39 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go index ef450bc52..334e0f298 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go @@ -1734,6 +1734,8 @@ type GraphQLSubscriptionClient interface { // Subscribe to the origin source. The implementation must not block the calling goroutine. Subscribe(ctx *resolve.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error UniqueRequestID(ctx *resolve.Context, options GraphQLSubscriptionOptions, hash *xxhash.Digest) (err error) + SubscribeAsync(ctx *resolve.Context, id uint64, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error + Unsubscribe(id uint64) } type GraphQLSubscriptionOptions struct { @@ -1759,6 +1761,22 @@ type SubscriptionSource struct { client GraphQLSubscriptionClient } +func (s *SubscriptionSource) AsyncStart(ctx *resolve.Context, id uint64, input []byte, updater resolve.SubscriptionUpdater) error { + var options GraphQLSubscriptionOptions + err := json.Unmarshal(input, &options) + if err != nil { + return err + } + if options.Body.Query == "" { + return resolve.ErrUnableToResolve + } + return s.client.Subscribe(ctx, options, updater) +} + +func (s *SubscriptionSource) AsyncStop(id uint64) { + s.client.Unsubscribe(id) +} + // Start the subscription. The updater is called on new events. Start needs to be called in a separate goroutine. func (s *SubscriptionSource) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error { var options GraphQLSubscriptionOptions diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go index 7bf460484..227e7cea7 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go @@ -8834,6 +8834,14 @@ var errSubscriptionClientFail = errors.New("subscription client fail error") type FailingSubscriptionClient struct{} +func (f *FailingSubscriptionClient) SubscribeAsync(ctx *resolve.Context, id uint64, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { + return errSubscriptionClientFail +} + +func (f *FailingSubscriptionClient) Unsubscribe(id uint64) { + +} + func (f *FailingSubscriptionClient) Subscribe(ctx *resolve.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { return errSubscriptionClientFail } diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index 6909ab5c5..e7e76a5fe 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -4,17 +4,20 @@ import ( "context" "fmt" "math" + "net" "net/http" "net/textproto" + "strings" "sync" + "syscall" "time" "github.com/buger/jsonparser" "github.com/cespare/xxhash/v2" + ws "github.com/gorilla/websocket" "github.com/jensneuse/abstractlogger" - "nhooyr.io/websocket" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/epoller" ) const ackWaitTimeout = 30 * time.Second @@ -31,6 +34,44 @@ type subscriptionClient struct { onWsConnectionInitCallback *OnWsConnectionInitCallback readTimeout time.Duration + + epoll epoller.Poller + + connections map[int]ConnectionHandler + connectionsMu sync.RWMutex + + triggers map[uint64]int +} + +func (c *subscriptionClient) SubscribeAsync(reqCtx *resolve.Context, id uint64, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { + if options.UseSSE { + return c.subscribeSSE(reqCtx, options, updater) + } + + if strings.HasPrefix(options.URL, "https") { + options.URL = "wss" + options.URL[5:] + } else if strings.HasPrefix(options.URL, "http") { + options.URL = "ws" + options.URL[4:] + } + + return c.asyncSubscribeWS(reqCtx, id, options, updater) +} + +func (c *subscriptionClient) Unsubscribe(id uint64) { + c.connectionsMu.Lock() + defer c.connectionsMu.Unlock() + fd, ok := c.triggers[id] + if !ok { + return + } + delete(c.triggers, id) + handler, ok := c.connections[fd] + if !ok { + return + } + handler.ClientClose() + delete(c.connections, fd) + _ = c.epoll.Remove(handler.NetConn()) } type InvalidWsSubprotocolError struct { @@ -92,7 +133,11 @@ func NewGraphQLSubscriptionClient(httpClient, streamingClient *http.Client, engi for _, option := range options { option(op) } - return &subscriptionClient{ + epoll, err := epoller.NewPoller(1024, time.Second) + if err != nil { + fmt.Printf("failed to create epoll: %v\n", err) + } + client := &subscriptionClient{ httpClient: httpClient, streamingClient: streamingClient, engineCtx: engineCtx, @@ -104,7 +149,11 @@ func NewGraphQLSubscriptionClient(httpClient, streamingClient *http.Client, engi }, }, onWsConnectionInitCallback: op.onWsConnectionInitCallback, + epoll: epoll, + connections: make(map[int]ConnectionHandler), + triggers: make(map[uint64]int), } + return client } // Subscribe initiates a new GraphQL Subscription with the origin @@ -116,6 +165,12 @@ func (c *subscriptionClient) Subscribe(reqCtx *resolve.Context, options GraphQLS return c.subscribeSSE(reqCtx, options, updater) } + if strings.HasPrefix(options.URL, "https") { + options.URL = "wss" + options.URL[5:] + } else if strings.HasPrefix(options.URL, "http") { + options.URL = "ws" + options.URL[4:] + } + return c.subscribeWS(reqCtx, options, updater) } @@ -179,6 +234,45 @@ func (c *subscriptionClient) subscribeWS(reqCtx *resolve.Context, options GraphQ return nil } +func (c *subscriptionClient) asyncSubscribeWS(reqCtx *resolve.Context, id uint64, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { + if c.httpClient == nil { + return fmt.Errorf("http client is nil") + } + + sub := Subscription{ + ctx: reqCtx.Context(), + options: options, + updater: updater, + } + + handler, err := c.newWSConnectionHandler(reqCtx.Context(), options) + if err != nil { + return err + } + + netConn := handler.NetConn() + if err := c.epoll.Add(netConn); err != nil { + return err + } + + c.connectionsMu.Lock() + fd := socketFd(netConn) + c.connections[fd] = handler + c.triggers[id] = fd + count := len(c.connections) + c.connectionsMu.Unlock() + + if count == 1 { + go c.runEpoll(c.engineCtx) + } + + fmt.Printf("added connection to epoll\n") + + handler.Subscribe(sub) + + return nil +} + // generateHandlerIDHash generates a Hash based on: URL and Headers to uniquely identify Upgrade Requests func (c *subscriptionClient) requestHash(ctx *resolve.Context, options GraphQLSubscriptionOptions, xxh *xxhash.Digest) (err error) { if _, err = xxh.WriteString(options.URL); err != nil { @@ -251,17 +345,16 @@ func (c *subscriptionClient) newWSConnectionHandler(reqCtx context.Context, opti subProtocols = []string{options.WsSubProtocol} } - conn, upgradeResponse, err := websocket.Dial(reqCtx, options.URL, &websocket.DialOptions{ - HTTPClient: c.httpClient, - HTTPHeader: options.Header, - CompressionMode: websocket.CompressionDisabled, - Subprotocols: subProtocols, - }) + dialer := ws.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: time.Second * 10, + Subprotocols: subProtocols, + } + + conn, upgradeResponse, err := dialer.DialContext(reqCtx, options.URL, options.Header) if err != nil { return nil, err } - // Disable the maximum message size limit. Don't use MaxInt64 since - // the nhooyr.io/websocket doesn't handle it correctly on 32 bit systems. conn.SetReadLimit(math.MaxInt32) if upgradeResponse.StatusCode != http.StatusSwitchingProtocols { return nil, fmt.Errorf("upgrade unsuccessful") @@ -287,7 +380,7 @@ func (c *subscriptionClient) newWSConnectionHandler(reqCtx context.Context, opti } // init + ack - err = conn.Write(reqCtx, websocket.MessageText, connectionInitMessage) + err = conn.WriteMessage(ws.TextMessage, connectionInitMessage) if err != nil { return nil, err } @@ -300,7 +393,7 @@ func (c *subscriptionClient) newWSConnectionHandler(reqCtx context.Context, opti } } - if err := waitForAck(reqCtx, conn); err != nil { + if err := waitForAck(conn); err != nil { return nil, err } @@ -340,6 +433,11 @@ func (c *subscriptionClient) getConnectionInitMessage(ctx context.Context, url s type ConnectionHandler interface { StartBlocking(sub Subscription) + NetConn() net.Conn + ReadMessage() (done bool) + ServerClose() + ClientClose() + Subscribe(sub Subscription) } type Subscription struct { @@ -348,7 +446,7 @@ type Subscription struct { updater resolve.SubscriptionUpdater } -func waitForAck(ctx context.Context, conn *websocket.Conn) error { +func waitForAck(conn *ws.Conn) error { timer := time.NewTimer(ackWaitTimeout) for { select { @@ -357,11 +455,11 @@ func waitForAck(ctx context.Context, conn *websocket.Conn) error { default: } - msgType, msg, err := conn.Read(ctx) + msgType, msg, err := conn.ReadMessage() if err != nil { return err } - if msgType != websocket.MessageText { + if msgType != ws.TextMessage { return fmt.Errorf("unexpected message type") } @@ -374,11 +472,10 @@ func waitForAck(ctx context.Context, conn *websocket.Conn) error { case messageTypeConnectionKeepAlive: continue case messageTypePing: - err := conn.Write(ctx, websocket.MessageText, []byte(pongMessage)) + err := conn.WriteMessage(ws.TextMessage, []byte(pongMessage)) if err != nil { return fmt.Errorf("failed to send pong message: %w", err) } - continue case messageTypeConnectionAck: return nil @@ -387,3 +484,60 @@ func waitForAck(ctx context.Context, conn *websocket.Conn) error { } } } + +func (c *subscriptionClient) runEpoll(ctx context.Context) { + for { + if ctx.Err() != nil { + return + } + connections, err := c.epoll.Wait(50) + if err != nil { + c.log.Error("epoll.Wait", abstractlogger.Error(err)) + return + } + fmt.Printf("ePoll - time: %v, connections: %d\n", time.Now(), len(connections)) + c.connectionsMu.RLock() + for _, conn := range connections { + id := socketFd(conn) + handler, ok := c.connections[id] + if !ok { + continue + } + c.handleConnection(id, handler, conn) + } + c.connectionsMu.RUnlock() + } +} + +func (c *subscriptionClient) handleConnection(id int, handler ConnectionHandler, conn net.Conn) { + fmt.Printf("handling connection %d\n", id) + done := handler.ReadMessage() + if done { + fmt.Printf("connection %d done\n", id) + c.connectionsMu.Lock() + delete(c.connections, id) + c.connectionsMu.Unlock() + handler.ServerClose() + _ = c.epoll.Remove(conn) + return + } + fmt.Printf("no new messages for connection %d, returning to ePoll\n", id) +} + +func socketFd(conn net.Conn) int { + if con, ok := conn.(syscall.Conn); ok { + raw, err := con.SyscallConn() + if err != nil { + return 0 + } + sfd := 0 + _ = raw.Control(func(fd uintptr) { + sfd = int(fd) + }) + return sfd + } + if con, ok := conn.(epoller.ConnImpl); ok { + return con.GetFD() + } + return 0 +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index b93304763..01e33149b 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -477,3 +477,149 @@ func TestSubprotocolNegotiationWithConfiguredGraphQLTransportWS(t *testing.T) { }, time.Second, time.Millisecond*10, "server did not close") serverCancel() } + +func TestSubscribeAsync(t *testing.T) { + serverDone := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + assert.NoError(t, err) + ctx := context.Background() + msgType, data, err := conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"connection_init"}`, string(data)) + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) + assert.NoError(t, err) + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) + assert.NoError(t, err) + + time.Sleep(time.Millisecond) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) + assert.NoError(t, err) + + time.Sleep(time.Millisecond) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) + assert.NoError(t, err) + + time.Sleep(time.Millisecond) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) + close(serverDone) + })) + defer server.Close() + ctx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + + client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, + WithReadTimeout(time.Second), + WithLogger(logger()), + ).(*subscriptionClient) + updater := &testSubscriptionUpdater{} + + err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ + URL: server.URL, + Body: GraphQLBody{ + Query: `subscription {messageAdded(roomName: "room"){text}}`, + }, + WsSubProtocol: ProtocolGraphQLTWS, + }, updater) + assert.NoError(t, err) + + updater.AwaitUpdates(t, time.Second*5, 3) + assert.Equal(t, 3, len(updater.updates)) + assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) + assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) + assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) + client.Unsubscribe(1) + clientCancel() + assert.Eventuallyf(t, func() bool { + <-serverDone + return true + }, time.Second, time.Millisecond*10, "server did not close") + serverCancel() +} + +func TestSubscribeAsyncServerTimeout(t *testing.T) { + serverDone := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + assert.NoError(t, err) + ctx := context.Background() + msgType, data, err := conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"connection_init"}`, string(data)) + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) + assert.NoError(t, err) + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) + assert.NoError(t, err) + + time.Sleep(time.Second * 1) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) + assert.NoError(t, err) + + time.Sleep(time.Millisecond) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) + assert.NoError(t, err) + + time.Sleep(time.Millisecond) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) + close(serverDone) + })) + defer server.Close() + ctx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + + client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, + WithReadTimeout(time.Second), + WithLogger(logger()), + ).(*subscriptionClient) + updater := &testSubscriptionUpdater{} + + err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ + URL: server.URL, + Body: GraphQLBody{ + Query: `subscription {messageAdded(roomName: "room"){text}}`, + }, + WsSubProtocol: ProtocolGraphQLTWS, + }, updater) + assert.NoError(t, err) + + updater.AwaitUpdates(t, time.Second*5, 3) + assert.Equal(t, 3, len(updater.updates)) + assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) + assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) + assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) + client.Unsubscribe(1) + clientCancel() + assert.Eventuallyf(t, func() bool { + <-serverDone + return true + }, time.Second, time.Millisecond*10, "server did not close") + serverCancel() +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go index fc5259c64..d80e8cb85 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go @@ -5,22 +5,24 @@ import ( "encoding/json" "errors" "fmt" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "io" "net" "strconv" + "strings" "time" + ws "github.com/gorilla/websocket" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "github.com/buger/jsonparser" log "github.com/jensneuse/abstractlogger" - "nhooyr.io/websocket" ) // gqlTWSConnectionHandler is responsible for handling a connection to an origin // it is responsible for managing all subscriptions using the underlying WebSocket connection // if all Subscriptions are complete or cancelled/unsubscribed the handler will terminate type gqlTWSConnectionHandler struct { - conn *websocket.Conn + conn *ws.Conn ctx context.Context log log.Logger subscribeCh chan Subscription @@ -29,7 +31,90 @@ type gqlTWSConnectionHandler struct { readTimeout time.Duration } -func newGQLTWSConnectionHandler(ctx context.Context, conn *websocket.Conn, rt time.Duration, l log.Logger) *gqlTWSConnectionHandler { +func (h *gqlTWSConnectionHandler) ServerClose() { + fmt.Printf("ServerClose\n") + for _, sub := range h.subscriptions { + sub.updater.Done() + } +} + +func (h *gqlTWSConnectionHandler) ClientClose() { + fmt.Printf("ClientClose\n") + for k, v := range h.subscriptions { + v.updater.Done() + delete(h.subscriptions, k) + + req := fmt.Sprintf(completeMessage, k) + err := h.conn.WriteMessage(ws.TextMessage, []byte(req)) + if err != nil { + h.log.Error("failed to write complete message", log.Error(err)) + } + } + _ = h.conn.Close() +} + +func (h *gqlTWSConnectionHandler) Subscribe(sub Subscription) { + h.subscribe(sub) +} + +func (h *gqlTWSConnectionHandler) ReadMessage() (done bool) { + fmt.Printf("ReadMessage\n") + + err := h.conn.SetReadDeadline(time.Now().Add(time.Second * 5)) + if err != nil { + fmt.Printf("SetReadDeadline error: %v\n", err) + return h.isConnectionClosed(err) + } + msgType, data, err := h.conn.ReadMessage() + if err != nil { + fmt.Printf("ReadMessage error: %v\n", err) + return h.isConnectionClosed(err) + } + fmt.Printf("ReadMessage messageType %v, data: %v\n", msgType, string(data)) + if msgType != ws.TextMessage { + return false + } + messageType, err := jsonparser.GetString(data, "type") + if err != nil { + return false + } + switch messageType { + case messageTypePing: + h.handleMessageTypePing() + return false + case messageTypeNext: + h.handleMessageTypeNext(data) + return false + case messageTypeComplete: + h.handleMessageTypeComplete(data) + return true + case messageTypeError: + h.handleMessageTypeError(data) + return false + case messageTypeConnectionKeepAlive: + return false + case messageTypeData, messageTypeConnectionError: + h.log.Error("Invalid subprotocol. The subprotocol should be set to graphql-transport-ws, but currently it is set to graphql-ws") + return true + default: + h.log.Error("unknown message type", log.String("type", messageType)) + return false + } +} + +func (h *gqlTWSConnectionHandler) isConnectionClosed(err error) bool { + if strings.HasSuffix(err.Error(), "use of closed network connection") { + return true + } + fmt.Printf("isConnectionClosed: %v\n", err) + return false +} + +func (h *gqlTWSConnectionHandler) NetConn() net.Conn { + return h.conn.NetConn() +} + +func newGQLTWSConnectionHandler(ctx context.Context, conn *ws.Conn, rt time.Duration, l log.Logger) *gqlTWSConnectionHandler { return &gqlTWSConnectionHandler{ conn: conn, ctx: ctx, @@ -117,7 +202,7 @@ func (h *gqlTWSConnectionHandler) unsubscribeAllAndCloseConn() { for id := range h.subscriptions { h.unsubscribe(id) } - _ = h.conn.Close(websocket.StatusNormalClosure, "") + _ = h.conn.Close() } func (h *gqlTWSConnectionHandler) unsubscribe(subscriptionID string) { @@ -129,7 +214,7 @@ func (h *gqlTWSConnectionHandler) unsubscribe(subscriptionID string) { delete(h.subscriptions, subscriptionID) req := fmt.Sprintf(completeMessage, subscriptionID) - err := h.conn.Write(h.ctx, websocket.MessageText, []byte(req)) + err := h.conn.WriteMessage(ws.TextMessage, []byte(req)) if err != nil { h.log.Error("failed to write complete message", log.Error(err)) } @@ -147,8 +232,10 @@ func (h *gqlTWSConnectionHandler) subscribe(sub Subscription) { subscriptionID := strconv.Itoa(h.nextSubscriptionID) + fmt.Printf("subscribe with subscriptionID: %s\n", subscriptionID) + subscribeRequest := fmt.Sprintf(subscribeMessage, subscriptionID, string(graphQLBody)) - err = h.conn.Write(h.ctx, websocket.MessageText, []byte(subscribeRequest)) + err = h.conn.WriteMessage(ws.TextMessage, []byte(subscribeRequest)) if err != nil { h.log.Error("failed to write subscribe message", log.Error(err)) return @@ -218,7 +305,7 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeError(data []byte) { } func (h *gqlTWSConnectionHandler) handleMessageTypePing() { - err := h.conn.Write(h.ctx, websocket.MessageText, []byte(pongMessage)) + err := h.conn.WriteMessage(ws.TextMessage, []byte(pongMessage)) if err != nil { h.log.Error("failed to write pong message", log.Error(err)) } @@ -252,7 +339,7 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeNext(data []byte) { // we'll block forever on reading until the context of the gqlTWSConnectionHandler stops func (h *gqlTWSConnectionHandler) readBlocking(ctx context.Context, dataCh chan []byte, errCh chan error) { for { - msgType, data, err := h.conn.Read(ctx) + msgType, data, err := h.conn.ReadMessage() if err != nil { select { case errCh <- err: @@ -260,7 +347,7 @@ func (h *gqlTWSConnectionHandler) readBlocking(ctx context.Context, dataCh chan } return } - if msgType != websocket.MessageText { + if msgType != ws.TextMessage { continue } select { diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go index 838bb3995..421477ef8 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go @@ -5,22 +5,24 @@ import ( "encoding/json" "errors" "fmt" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "io" "net" "strconv" + "strings" "time" + ws "github.com/gorilla/websocket" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "github.com/buger/jsonparser" "github.com/jensneuse/abstractlogger" - "nhooyr.io/websocket" ) // gqlWSConnectionHandler is responsible for handling a connection to an origin // it is responsible for managing all subscriptions using the underlying WebSocket connection // if all Subscriptions are complete or cancelled/unsubscribed the handler will terminate type gqlWSConnectionHandler struct { - conn *websocket.Conn + conn *ws.Conn ctx context.Context log abstractlogger.Logger // log slog.Logger @@ -30,7 +32,74 @@ type gqlWSConnectionHandler struct { readTimeout time.Duration } -func newGQLWSConnectionHandler(ctx context.Context, conn *websocket.Conn, readTimeout time.Duration, log abstractlogger.Logger) *gqlWSConnectionHandler { +func (h *gqlWSConnectionHandler) ServerClose() { + for _, sub := range h.subscriptions { + sub.updater.Done() + } +} + +func (h *gqlWSConnectionHandler) ClientClose() { + for k, v := range h.subscriptions { + v.updater.Done() + delete(h.subscriptions, k) + stopRequest := fmt.Sprintf(stopMessage, k) + _ = h.conn.WriteMessage(ws.TextMessage, []byte(stopRequest)) + } + _ = h.conn.Close() +} + +func (h *gqlWSConnectionHandler) Subscribe(sub Subscription) { + h.subscribe(sub) +} + +func (h *gqlWSConnectionHandler) ReadMessage() (done bool) { + for { + err := h.conn.SetReadDeadline(time.Now().Add(time.Second * 5)) + if err != nil { + return h.isConnectionClosed(err) + } + msgType, data, err := h.conn.ReadMessage() + if err != nil { + return h.isConnectionClosed(err) + } + if msgType != ws.TextMessage { + return false + } + messageType, err := jsonparser.GetString(data, "type") + if err != nil { + return false + } + switch messageType { + case messageTypeData: + h.handleMessageTypeData(data) + continue + case messageTypeComplete: + h.handleMessageTypeComplete(data) + return true + case messageTypeConnectionError: + h.handleMessageTypeConnectionError() + return true + case messageTypeError: + h.handleMessageTypeError(data) + continue + default: + return false + } + } +} + +func (h *gqlWSConnectionHandler) isConnectionClosed(err error) bool { + if strings.HasSuffix(err.Error(), "use of closed network connection") { + return true + } + return false +} + +func (h *gqlWSConnectionHandler) NetConn() net.Conn { + return h.conn.NetConn() +} + +func newGQLWSConnectionHandler(ctx context.Context, conn *ws.Conn, readTimeout time.Duration, log abstractlogger.Logger) *gqlWSConnectionHandler { return &gqlWSConnectionHandler{ conn: conn, ctx: ctx, @@ -122,7 +191,7 @@ func (h *gqlWSConnectionHandler) StartBlocking(sub Subscription) { // we'll block forever on reading until the context of the gqlWSConnectionHandler stops func (h *gqlWSConnectionHandler) readBlocking(ctx context.Context, dataCh chan []byte, errCh chan error) { for { - msgType, data, err := h.conn.Read(ctx) + msgType, data, err := h.conn.ReadMessage() if err != nil { select { case errCh <- err: @@ -130,7 +199,7 @@ func (h *gqlWSConnectionHandler) readBlocking(ctx context.Context, dataCh chan [ } return } - if msgType != websocket.MessageText { + if msgType != ws.TextMessage { continue } select { @@ -145,7 +214,7 @@ func (h *gqlWSConnectionHandler) unsubscribeAllAndCloseConn() { for id := range h.subscriptions { h.unsubscribe(id) } - _ = h.conn.Close(websocket.StatusNormalClosure, "") + _ = h.conn.Close() } // subscribe adds a new Subscription to the gqlWSConnectionHandler and sends the startMessage to the origin @@ -160,7 +229,7 @@ func (h *gqlWSConnectionHandler) subscribe(sub Subscription) { subscriptionID := strconv.Itoa(h.nextSubscriptionID) startRequest := fmt.Sprintf(startMessage, subscriptionID, string(graphQLBody)) - err = h.conn.Write(h.ctx, websocket.MessageText, []byte(startRequest)) + err = h.conn.WriteMessage(ws.TextMessage, []byte(startRequest)) if err != nil { return } @@ -255,7 +324,7 @@ func (h *gqlWSConnectionHandler) unsubscribe(subscriptionID string) { sub.updater.Done() delete(h.subscriptions, subscriptionID) stopRequest := fmt.Sprintf(stopMessage, subscriptionID) - _ = h.conn.Write(h.ctx, websocket.MessageText, []byte(stopRequest)) + _ = h.conn.WriteMessage(ws.TextMessage, []byte(stopRequest)) } func (h *gqlWSConnectionHandler) checkActiveSubscriptions() (hasActiveSubscriptions bool) { diff --git a/v2/pkg/engine/resolve/datasource.go b/v2/pkg/engine/resolve/datasource.go index bba3295b1..f1b3a6293 100644 --- a/v2/pkg/engine/resolve/datasource.go +++ b/v2/pkg/engine/resolve/datasource.go @@ -19,3 +19,9 @@ type SubscriptionDataSource interface { Start(ctx *Context, input []byte, updater SubscriptionUpdater) error UniqueRequestID(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) } + +type AsyncSubscriptionDataSource interface { + AsyncStart(ctx *Context, id uint64, input []byte, updater SubscriptionUpdater) error + AsyncStop(id uint64) + UniqueRequestID(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) +} diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index ce00b064a..927087df9 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -6,11 +6,12 @@ import ( "bytes" "context" "fmt" - "golang.org/x/sync/semaphore" "io" "sync" "time" + "golang.org/x/sync/semaphore" + "github.com/buger/jsonparser" "github.com/pkg/errors" "go.uber.org/atomic" @@ -574,7 +575,15 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) fmt.Printf("resolver:trigger:start:%d\n", triggerID) } - err = add.resolve.Trigger.Source.Start(cloneCtx, add.input, updater) + if async, ok := add.resolve.Trigger.Source.(AsyncSubscriptionDataSource); ok { + trig.cancel = func() { + async.AsyncStop(triggerID) + cancel() + } + err = async.AsyncStart(cloneCtx, triggerID, add.input, updater) + } else { + err = add.resolve.Trigger.Source.Start(cloneCtx, add.input, updater) + } if err != nil { if r.options.Debug { fmt.Printf("resolver:trigger:failed:%d\n", triggerID) @@ -718,6 +727,7 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { } func (r *Resolver) shutdownTrigger(id uint64) { + fmt.Printf("resolver:trigger:shutdown:%d\n", id) trig, ok := r.triggers[id] if !ok { return From 0d470571f543bdb836be660d2d7770939bbce95a Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 14 Oct 2024 12:27:32 +0200 Subject: [PATCH 02/31] chore: wip add epoll --- .../graphql_subscription_client.go | 61 +++++-- .../graphql_subscription_client_test.go | 6 - v2/pkg/internal/epoller/conn.go | 27 +++ v2/pkg/internal/epoller/epoll.go | 35 ++++ v2/pkg/internal/epoller/epoll_bsd.go | 158 ++++++++++++++++ v2/pkg/internal/epoller/epoll_linux.go | 130 ++++++++++++++ v2/pkg/internal/epoller/epoll_test.go | 170 ++++++++++++++++++ v2/pkg/internal/epoller/epoll_unsupported.go | 14 ++ v2/pkg/internal/epoller/fd_test.go | 55 ++++++ 9 files changed, 633 insertions(+), 23 deletions(-) create mode 100644 v2/pkg/internal/epoller/conn.go create mode 100644 v2/pkg/internal/epoller/epoll.go create mode 100755 v2/pkg/internal/epoller/epoll_bsd.go create mode 100755 v2/pkg/internal/epoller/epoll_linux.go create mode 100644 v2/pkg/internal/epoller/epoll_test.go create mode 100644 v2/pkg/internal/epoller/epoll_unsupported.go create mode 100644 v2/pkg/internal/epoller/fd_test.go diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index e7e76a5fe..a20d66d82 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -38,7 +38,10 @@ type subscriptionClient struct { epoll epoller.Poller connections map[int]ConnectionHandler - connectionsMu sync.RWMutex + connectionsMu sync.Mutex + + activeConnections map[int]int + activeConnectionsMu sync.Mutex triggers map[uint64]int } @@ -486,42 +489,66 @@ func waitForAck(conn *ws.Conn) error { } func (c *subscriptionClient) runEpoll(ctx context.Context) { + done := ctx.Done() + tick := time.NewTicker(time.Millisecond * 10) for { - if ctx.Err() != nil { - return - } connections, err := c.epoll.Wait(50) if err != nil { c.log.Error("epoll.Wait", abstractlogger.Error(err)) return } fmt.Printf("ePoll - time: %v, connections: %d\n", time.Now(), len(connections)) - c.connectionsMu.RLock() + c.connectionsMu.Lock() for _, conn := range connections { id := socketFd(conn) handler, ok := c.connections[id] if !ok { continue } - c.handleConnection(id, handler, conn) + c.activeConnectionsMu.Lock() + if i, ok := c.activeConnections[id]; ok { + fmt.Printf("connection %d is active, queueing event\n", id) + c.activeConnections[id] = i + 1 + } + c.activeConnectionsMu.Unlock() + go c.handleConnection(id, handler, conn) + } + c.connectionsMu.Unlock() + select { + case <-done: + return + case <-tick.C: + continue } - c.connectionsMu.RUnlock() } } func (c *subscriptionClient) handleConnection(id int, handler ConnectionHandler, conn net.Conn) { fmt.Printf("handling connection %d\n", id) - done := handler.ReadMessage() - if done { - fmt.Printf("connection %d done\n", id) - c.connectionsMu.Lock() - delete(c.connections, id) - c.connectionsMu.Unlock() - handler.ServerClose() - _ = c.epoll.Remove(conn) - return + for { + done := handler.ReadMessage() + if done { + fmt.Printf("connection %d done\n", id) + c.connectionsMu.Lock() + delete(c.connections, id) + c.connectionsMu.Unlock() + handler.ServerClose() + _ = c.epoll.Remove(conn) + return + } + c.activeConnectionsMu.Lock() + if i, ok := c.activeConnections[id]; ok { + if i == 0 { + delete(c.activeConnections, id) + c.activeConnectionsMu.Unlock() + fmt.Printf("handleConnection: event queue empty, returning to ePoll for connection %d\n", id) + return + } + c.activeConnections[id] = i - 1 + fmt.Printf("handleConnection: event queue not empty, processing next event for connection %d\n", id) + } + c.activeConnectionsMu.Unlock() } - fmt.Printf("no new messages for connection %d, returning to ePoll\n", id) } func socketFd(conn net.Conn) int { diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index 01e33149b..7f141ac4c 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -498,18 +498,12 @@ func TestSubscribeAsync(t *testing.T) { err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) assert.NoError(t, err) - time.Sleep(time.Millisecond) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) assert.NoError(t, err) - time.Sleep(time.Millisecond) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) assert.NoError(t, err) - time.Sleep(time.Millisecond) - msgType, data, err = conn.Read(ctx) assert.NoError(t, err) assert.Equal(t, websocket.MessageText, msgType) diff --git a/v2/pkg/internal/epoller/conn.go b/v2/pkg/internal/epoller/conn.go new file mode 100644 index 000000000..74649e64a --- /dev/null +++ b/v2/pkg/internal/epoller/conn.go @@ -0,0 +1,27 @@ +package epoller + +import ( + "net" +) + +// newConnImpl returns a net.Conn with GetFD() method. +func newConnImpl(in net.Conn) ConnImpl { + if ci, ok := in.(ConnImpl); ok { + return ci + } + + return ConnImpl{ + Conn: in, + fd: socketFD(in), + } +} + +// ConnImpl is a net.Conn with GetFD() method. +type ConnImpl struct { + net.Conn + fd int +} + +func (c ConnImpl) GetFD() int { + return c.fd +} diff --git a/v2/pkg/internal/epoller/epoll.go b/v2/pkg/internal/epoller/epoll.go new file mode 100644 index 000000000..efaa8439b --- /dev/null +++ b/v2/pkg/internal/epoller/epoll.go @@ -0,0 +1,35 @@ +package epoller + +import ( + "net" + "syscall" +) + +// Poller is the interface for epoll/kqueue poller, special for network connections. +type Poller interface { + // Add adds the connection to poller. + Add(conn net.Conn) error + // Remove removes the connection from poller and closes it. + Remove(conn net.Conn) error + // Wait waits for at most count events and returns the connections. + Wait(count int) ([]net.Conn, error) + // Close closes the poller. If closeConns is true, it will close all the connections. + Close(closeConns bool) error +} + +func socketFD(conn net.Conn) int { + if con, ok := conn.(syscall.Conn); ok { + raw, err := con.SyscallConn() + if err != nil { + return 0 + } + sfd := 0 + raw.Control(func(fd uintptr) { + sfd = int(fd) + }) + return sfd + } else if con, ok := conn.(ConnImpl); ok { + return con.fd + } + return 0 +} diff --git a/v2/pkg/internal/epoller/epoll_bsd.go b/v2/pkg/internal/epoller/epoll_bsd.go new file mode 100755 index 000000000..a9d23e982 --- /dev/null +++ b/v2/pkg/internal/epoller/epoll_bsd.go @@ -0,0 +1,158 @@ +//go:build darwin || netbsd || freebsd || openbsd || dragonfly +// +build darwin netbsd freebsd openbsd dragonfly + +package epoller + +import ( + "errors" + "net" + "sync" + "syscall" + "time" +) + +var _ Poller = (*Epoll)(nil) + +// Epoll is an epoll based poller. +type Epoll struct { + fd int + ts syscall.Timespec + + connBufferSize int + mu *sync.RWMutex + changes []syscall.Kevent_t + conns map[int]net.Conn + connbuf []net.Conn +} + +// NewPoller creates a new poller instance. +func NewPoller(connBufferSize int, pollTimeout time.Duration) (*Epoll, error) { + return newPollerWithBuffer(connBufferSize, pollTimeout) +} + +// newPollerWithBuffer creates a new poller instance with buffer size. +func newPollerWithBuffer(count int, pollTimeout time.Duration) (*Epoll, error) { + p, err := syscall.Kqueue() + if err != nil { + panic(err) + } + _, err = syscall.Kevent(p, []syscall.Kevent_t{{ + Ident: 0, + Filter: syscall.EVFILT_USER, + Flags: syscall.EV_ADD | syscall.EV_CLEAR, + }}, nil, nil) + if err != nil { + panic(err) + } + + return &Epoll{ + fd: p, + ts: syscall.NsecToTimespec(pollTimeout.Nanoseconds()), + connBufferSize: count, + mu: &sync.RWMutex{}, + conns: make(map[int]net.Conn), + connbuf: make([]net.Conn, count), + }, nil +} + +// Close closes the poller. +func (e *Epoll) Close(closeConns bool) error { + e.mu.Lock() + defer e.mu.Unlock() + + if closeConns { + for _, conn := range e.conns { + conn.Close() + } + } + + e.conns = nil + e.changes = nil + e.connbuf = e.connbuf[:0] + + return syscall.Close(e.fd) +} + +// Add adds a network connection to the poller. +func (e *Epoll) Add(conn net.Conn) error { + conn = newConnImpl(conn) + fd := socketFD(conn) + if e := syscall.SetNonblock(int(fd), true); e != nil { + return errors.New("udev: unix.SetNonblock failed") + } + + e.mu.Lock() + defer e.mu.Unlock() + + e.changes = append(e.changes, + syscall.Kevent_t{ + Ident: uint64(fd), Flags: syscall.EV_ADD | syscall.EV_EOF, Filter: syscall.EVFILT_READ, + }, + ) + + e.conns[fd] = conn + + return nil +} + +// Remove removes a connection from the poller. +// If close is true, the connection will be closed. +func (e *Epoll) Remove(conn net.Conn) error { + fd := socketFD(conn) + + e.mu.Lock() + defer e.mu.Unlock() + + if len(e.changes) <= 1 { + e.changes = nil + } else { + changes := make([]syscall.Kevent_t, 0, len(e.changes)-1) + ident := uint64(fd) + for _, ke := range e.changes { + if ke.Ident != ident { + changes = append(changes, ke) + } + } + e.changes = changes + } + + delete(e.conns, fd) + + return nil +} + +// Wait waits for events and returns the connections. +func (e *Epoll) Wait(count int) ([]net.Conn, error) { + events := make([]syscall.Kevent_t, count) + + e.mu.RLock() + changes := e.changes + e.mu.RUnlock() + +retry: + n, err := syscall.Kevent(e.fd, changes, events, &e.ts) + if err != nil { + if err == syscall.EINTR { + goto retry + } + return nil, err + } + + var conns []net.Conn + if e.connBufferSize == 0 { + conns = make([]net.Conn, 0, n) + } else { + conns = e.connbuf[:0] + } + + e.mu.RLock() + for i := 0; i < n; i++ { + conn := e.conns[int(events[i].Ident)] + if conn != nil { + conns = append(conns, conn) + } + } + e.mu.RUnlock() + + return conns, nil +} diff --git a/v2/pkg/internal/epoller/epoll_linux.go b/v2/pkg/internal/epoller/epoll_linux.go new file mode 100755 index 000000000..ee32c6c57 --- /dev/null +++ b/v2/pkg/internal/epoller/epoll_linux.go @@ -0,0 +1,130 @@ +//go:build linux +// +build linux + +package epoller + +import ( + "errors" + "net" + "sync" + "syscall" + "time" + + "golang.org/x/sys/unix" +) + +var _ Poller = (*Epoll)(nil) + +// Epoll is an epoll based poller. +type Epoll struct { + fd int + + connBufferSize int + lock *sync.RWMutex + conns map[int]net.Conn + connbuf []net.Conn + + timeoutMsec int +} + +// NewPoller creates a new epoll poller. +func NewPoller(connBufferSize int, pollTimeout time.Duration) (*Epoll, error) { + return newPollerWithBuffer(connBufferSize, pollTimeout) +} + +// newPollerWithBuffer creates a new epoll poller with a buffer. +func newPollerWithBuffer(count int, pollTimeout time.Duration) (*Epoll, error) { + fd, err := unix.EpollCreate1(0) + if err != nil { + return nil, err + } + return &Epoll{ + fd: fd, + connBufferSize: count, + lock: &sync.RWMutex{}, + conns: make(map[int]net.Conn), + connbuf: make([]net.Conn, count), + timeoutMsec: int(pollTimeout.Milliseconds()), + }, nil +} + +// Close closes the poller. If closeConns is true, it will close all the connections. +func (e *Epoll) Close(closeConns bool) error { + e.lock.Lock() + defer e.lock.Unlock() + + if closeConns { + for _, conn := range e.conns { + conn.Close() + } + } + + e.conns = nil + e.connbuf = e.connbuf[:0] + + return unix.Close(e.fd) +} + +// Add adds a connection to the poller. +func (e *Epoll) Add(conn net.Conn) error { + conn = newConnImpl(conn) + fd := socketFD(conn) + if e := syscall.SetNonblock(int(fd), true); e != nil { + return errors.New("udev: unix.SetNonblock failed") + } + + e.lock.Lock() + defer e.lock.Unlock() + + err := unix.EpollCtl(e.fd, syscall.EPOLL_CTL_ADD, fd, &unix.EpollEvent{Events: unix.POLLIN | unix.POLLHUP, Fd: int32(fd)}) + if err != nil { + return err + } + e.conns[fd] = conn + return nil +} + +// Remove removes a connection from the poller. +func (e *Epoll) Remove(conn net.Conn) error { + fd := socketFD(conn) + err := unix.EpollCtl(e.fd, syscall.EPOLL_CTL_DEL, fd, nil) + if err != nil { + return err + } + e.lock.Lock() + defer e.lock.Unlock() + delete(e.conns, fd) + + return nil +} + +// Wait waits for at most count events and returns the connections. +func (e *Epoll) Wait(count int) ([]net.Conn, error) { + events := make([]unix.EpollEvent, count) + +retry: + n, err := unix.EpollWait(e.fd, events, e.timeoutMsec) + if err != nil { + if err == unix.EINTR { + goto retry + } + return nil, err + } + + var conns []net.Conn + if e.connBufferSize == 0 { + conns = make([]net.Conn, 0, n) + } else { + conns = e.connbuf[:0] + } + e.lock.RLock() + for i := 0; i < n; i++ { + conn := e.conns[int(events[i].Fd)] + if conn != nil { + conns = append(conns, conn) + } + } + e.lock.RUnlock() + + return conns, nil +} diff --git a/v2/pkg/internal/epoller/epoll_test.go b/v2/pkg/internal/epoller/epoll_test.go new file mode 100644 index 000000000..dd5336306 --- /dev/null +++ b/v2/pkg/internal/epoller/epoll_test.go @@ -0,0 +1,170 @@ +package epoller + +import ( + "errors" + "io" + "log" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestPoller(t *testing.T) { + // connections + num := 10 + // msg per connection + msgPerConn := 10 + + poller, err := NewPoller(0, time.Second) + require.NoError(t, err) + + // start server + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + + poller.Add(conn) + } + }() + + // create num connections and send msgPerConn messages per connection + for i := 0; i < num; i++ { + go func() { + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Error(err) + return + } + time.Sleep(time.Second) + for i := 0; i < msgPerConn; i++ { + n, err := conn.Write([]byte("hello world")) + if err != nil { + t.Error(err) + } + if n != len("hello world") { + t.Errorf("expect to write %d bytes but got %d bytes", len("hello world"), n) + } + } + conn.Close() + }() + } + + // read those num * msgPerConn messages, and each message (hello world) contains 11 bytes. + done := make(chan struct{}) + errs := make(chan error) + var total int + var count int + + expected := num * msgPerConn * len("hello world") + go func(errs chan error) { + for { + conns, err := poller.Wait(128) + if err != nil { + t.Log(err) + errs <- err // fatal errors (i.e t.Fatal()) must be reported in the main test goroutine + return + } + if len(conns) == 0 { + continue + } + count++ + buf := make([]byte, 1024) + for _, conn := range conns { + n, err := conn.Read(buf) + if err != nil { + if err == io.EOF || errors.Is(err, net.ErrClosed) { + poller.Remove(conn) + conn.Close() + } else { + t.Error(err) + } + } + total += n + } + + if total == expected { + break + } + } + + t.Logf("read all %d bytes, count: %d", total, count) + close(done) + }(errs) + + select { + case <-done: + case <-time.After(2 * time.Second): + case err := <-errs: + t.Fatal(err) + } + + if total != expected { + t.Fatalf("epoller does not work. expect %d bytes but got %d bytes", expected, total) + } +} + +type netPoller struct { + Poller Poller + WriteReq chan uint64 +} + +func TestPoller_growstack(t *testing.T) { + var nps []netPoller + for i := 0; i < 2; i++ { + poller, err := NewPoller(128, time.Second) + if err != nil { + t.Fatal(err) + } + if err != nil { + t.Fatal(err) + } + // the following line cause goroutine stack grow and copy local variables to new allocated stack and switch to new stack + // but runtime.adjustpointers will check whether pointers bigger than runtime.minLegalPointer(4096) or throw a panic + // fatal error: invalid pointer found on stack (runtime/stack.go:599@go1.14.3) + // since NewEpoller return A pointer created by CreateIoCompletionPort may less than 4096 + np := netPoller{ + Poller: poller, + WriteReq: make(chan uint64, 1000000), + } + + nps = append(nps, np) + } + + poller := nps[0].Poller + // start server + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + log.Fatal(err) + } + defer ln.Close() + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + + poller.Add(conn) + } + }() + + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Error(err) + return + } + time.Sleep(200 * time.Millisecond) + for i := 0; i < 100; i++ { + conn.Write([]byte("hello world")) + } + conn.Close() +} diff --git a/v2/pkg/internal/epoller/epoll_unsupported.go b/v2/pkg/internal/epoller/epoll_unsupported.go new file mode 100644 index 000000000..a840ef7d6 --- /dev/null +++ b/v2/pkg/internal/epoller/epoll_unsupported.go @@ -0,0 +1,14 @@ +//go:build windows +// +build windows + +package epoller + +import ( + "errors" + "time" +) + +// NewPoller creates a new epoll poller. +func NewPoller(connBufferSize int, _ time.Duration) (Poller, error) { + return nil, errors.New("epoll is not supported on windows") +} diff --git a/v2/pkg/internal/epoller/fd_test.go b/v2/pkg/internal/epoller/fd_test.go new file mode 100644 index 000000000..2595f04f8 --- /dev/null +++ b/v2/pkg/internal/epoller/fd_test.go @@ -0,0 +1,55 @@ +package epoller + +import ( + "net" + "reflect" + "runtime" + "syscall" + "testing" +) + +func reflectSocketFDAsUint(conn net.Conn) uint64 { + tcpConn := reflect.Indirect(reflect.ValueOf(conn)).FieldByName("conn") + fdVal := tcpConn.FieldByName("fd") + pfdVal := reflect.Indirect(fdVal).FieldByName("pfd") + + return pfdVal.FieldByName("Sysfd").Uint() +} + +func rawSocketFD(conn net.Conn) uint64 { + if con, ok := conn.(syscall.Conn); ok { + raw, err := con.SyscallConn() + if err != nil { + return 0 + } + sfd := uint64(0) + raw.Control(func(fd uintptr) { + sfd = uint64(fd) + }) + return sfd + } + return 0 +} + +func BenchmarkSocketFdReflect(b *testing.B) { + var con, _ = net.Dial(`udp`, "8.8.8.8:53") + fd := uint64(0) + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + fd = reflectSocketFDAsUint(con) + } + runtime.KeepAlive(fd) +} + +func BenchmarkSocketFdRaw(b *testing.B) { + con, _ := net.Dial(`udp`, "8.8.8.8:53") + fd := uint64(0) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + fd = rawSocketFD(con) + } + runtime.KeepAlive(fd) +} From 5fae3c800b3dc543561bb5273d669d814b9366e6 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 17 Oct 2024 09:01:00 +0200 Subject: [PATCH 03/31] chore: finish epoll implementation for subgraph requests --- v2/go.mod | 1 + v2/go.sum | 2 + .../graphql_datasource/graphql_datasource.go | 2 +- .../graphql_subscription_client.go | 215 ++- .../graphql_subscription_client_test.go | 1214 +++++++++++++++-- .../graphql_datasource/graphql_tws_handler.go | 138 +- .../graphql_tws_handler_test.go | 94 +- .../graphql_datasource/graphql_ws_handler.go | 68 +- .../graphql_ws_handler_test.go | 2 +- v2/pkg/engine/resolve/resolve.go | 2 +- 10 files changed, 1371 insertions(+), 367 deletions(-) diff --git a/v2/go.mod b/v2/go.mod index 327cfb450..7834335fd 100644 --- a/v2/go.mod +++ b/v2/go.mod @@ -37,6 +37,7 @@ require ( require ( github.com/agnivade/levenshtein v1.1.1 // indirect + github.com/coder/websocket v1.8.12 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect diff --git a/v2/go.sum b/v2/go.sum index 4704c405e..01a1844eb 100644 --- a/v2/go.sum +++ b/v2/go.sum @@ -13,6 +13,8 @@ github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMU github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= +github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go index 334e0f298..92ec8e9b7 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go @@ -1770,7 +1770,7 @@ func (s *SubscriptionSource) AsyncStart(ctx *resolve.Context, id uint64, input [ if options.Body.Query == "" { return resolve.ErrUnableToResolve } - return s.client.Subscribe(ctx, options, updater) + return s.client.SubscribeAsync(ctx, id, options, updater) } func (s *SubscriptionSource) AsyncStop(id uint64) { diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index a20d66d82..385f0ef95 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -12,9 +12,11 @@ import ( "syscall" "time" + "github.com/gobwas/ws/wsutil" + "github.com/gorilla/websocket" + "github.com/buger/jsonparser" "github.com/cespare/xxhash/v2" - ws "github.com/gorilla/websocket" "github.com/jensneuse/abstractlogger" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/epoller" @@ -26,8 +28,11 @@ const ackWaitTimeout = 30 * time.Second // It takes care of de-duplicating connections to the same origin under certain circumstances // If Hash(URL,Body,Headers) result in the same result, an existing connection is re-used type subscriptionClient struct { - streamingClient *http.Client - httpClient *http.Client + streamingClient *http.Client + httpClient *http.Client + + useHttpClientWithSkipRoundTrip bool + engineCtx context.Context log abstractlogger.Logger hashPool sync.Pool @@ -35,12 +40,13 @@ type subscriptionClient struct { readTimeout time.Duration - epoll epoller.Poller + epoll epoller.Poller + stopEpollSignal chan struct{} connections map[int]ConnectionHandler connectionsMu sync.Mutex - activeConnections map[int]int + activeConnections map[int]struct{} activeConnectionsMu sync.Mutex triggers map[uint64]int @@ -75,6 +81,9 @@ func (c *subscriptionClient) Unsubscribe(id uint64) { handler.ClientClose() delete(c.connections, fd) _ = c.epoll.Remove(handler.NetConn()) + if len(c.connections) == 0 { + close(c.stopEpollSignal) + } } type InvalidWsSubprotocolError struct { @@ -105,10 +114,17 @@ func WithReadTimeout(timeout time.Duration) Options { } } +func UseHttpClientWithSkipRoundTrip() Options { + return func(options *opts) { + options.useHttpClientWithSkipRoundTrip = true + } +} + type opts struct { - readTimeout time.Duration - log abstractlogger.Logger - onWsConnectionInitCallback *OnWsConnectionInitCallback + readTimeout time.Duration + log abstractlogger.Logger + onWsConnectionInitCallback *OnWsConnectionInitCallback + useHttpClientWithSkipRoundTrip bool } // GraphQLSubscriptionClientFactory abstracts the way of creating a new GraphQLSubscriptionClient. @@ -136,10 +152,8 @@ func NewGraphQLSubscriptionClient(httpClient, streamingClient *http.Client, engi for _, option := range options { option(op) } - epoll, err := epoller.NewPoller(1024, time.Second) - if err != nil { - fmt.Printf("failed to create epoll: %v\n", err) - } + // ignore error is ok, it means that epoll is not supported, which is handled gracefully by the client + epoll, _ := epoller.NewPoller(1024, time.Millisecond*100) client := &subscriptionClient{ httpClient: httpClient, streamingClient: streamingClient, @@ -151,10 +165,12 @@ func NewGraphQLSubscriptionClient(httpClient, streamingClient *http.Client, engi return xxhash.New() }, }, - onWsConnectionInitCallback: op.onWsConnectionInitCallback, - epoll: epoll, - connections: make(map[int]ConnectionHandler), - triggers: make(map[uint64]int), + onWsConnectionInitCallback: op.onWsConnectionInitCallback, + epoll: epoll, + connections: make(map[int]ConnectionHandler), + activeConnections: make(map[int]struct{}), + triggers: make(map[uint64]int), + useHttpClientWithSkipRoundTrip: op.useHttpClientWithSkipRoundTrip, } return client } @@ -168,12 +184,6 @@ func (c *subscriptionClient) Subscribe(reqCtx *resolve.Context, options GraphQLS return c.subscribeSSE(reqCtx, options, updater) } - if strings.HasPrefix(options.URL, "https") { - options.URL = "wss" + options.URL[5:] - } else if strings.HasPrefix(options.URL, "http") { - options.URL = "ws" + options.URL[4:] - } - return c.subscribeWS(reqCtx, options, updater) } @@ -266,11 +276,10 @@ func (c *subscriptionClient) asyncSubscribeWS(reqCtx *resolve.Context, id uint64 c.connectionsMu.Unlock() if count == 1 { + c.stopEpollSignal = make(chan struct{}) go c.runEpoll(c.engineCtx) } - fmt.Printf("added connection to epoll\n") - handler.Subscribe(sub) return nil @@ -342,27 +351,88 @@ func (c *subscriptionClient) requestHash(ctx *resolve.Context, options GraphQLSu return nil } +type UpgradeRequestError struct { + URL string + StatusCode int +} + +func (u *UpgradeRequestError) Error() string { + return fmt.Sprintf("failed to upgrade connection to %s, status code: %d", u.URL, u.StatusCode) +} + func (c *subscriptionClient) newWSConnectionHandler(reqCtx context.Context, options GraphQLSubscriptionOptions) (ConnectionHandler, error) { + + var ( + upgradeRequestHeader http.Header + subgraphHttpURL string + upgradeRequestURL string + ) + subProtocols := []string{ProtocolGraphQLWS, ProtocolGraphQLTWS} if options.WsSubProtocol != "" && options.WsSubProtocol != "auto" { subProtocols = []string{options.WsSubProtocol} } - dialer := ws.Dialer{ + if strings.HasPrefix(options.URL, "https") { + upgradeRequestURL = "wss" + options.URL[5:] + subgraphHttpURL = options.URL + } else if strings.HasPrefix(options.URL, "http") { + upgradeRequestURL = "ws" + options.URL[4:] + subgraphHttpURL = options.URL + } else if strings.HasPrefix(options.URL, "wss") { + upgradeRequestURL = options.URL + subgraphHttpURL = "https" + options.URL[3:] + } else if strings.HasPrefix(options.URL, "ws") { + upgradeRequestURL = options.URL + subgraphHttpURL = "http" + options.URL[2:] + } + + if c.useHttpClientWithSkipRoundTrip { + // gorilla websocket does not support using the http.Client directly + // but we need to use our existing client, or the transport more specifically + // to be able to forward headers in the upgrade request + // + // as a workaround we create a "dummy" request which we run through the http.Client with the context + // we set the "SkipRoundTrip" header to true to signal the http.Client to not perform the request + // but only to modify the request Headers + req, err := http.NewRequestWithContext(reqCtx, "GET", options.URL, nil) + if err != nil { + return nil, err + } + if strings.HasPrefix(options.URL, "ws") { + req.URL.Scheme = "http" + } else { + req.URL.Scheme = "https" + } + if options.Header != nil { + req.Header = options.Header + } + req.Header.Set("SkipRoundTrip", "true") + _, _ = c.httpClient.Do(req) + req.Header.Del("SkipRoundTrip") + upgradeRequestHeader = req.Header + subgraphHttpURL = req.URL.String() + } else { + upgradeRequestHeader = options.Header + } + + dialer := websocket.Dialer{ Proxy: http.ProxyFromEnvironment, HandshakeTimeout: time.Second * 10, Subprotocols: subProtocols, } - conn, upgradeResponse, err := dialer.DialContext(reqCtx, options.URL, options.Header) + conn, upgradeResponse, err := dialer.DialContext(reqCtx, upgradeRequestURL, upgradeRequestHeader) if err != nil { + if upgradeResponse != nil && upgradeResponse.StatusCode != http.StatusSwitchingProtocols { + return nil, &UpgradeRequestError{ + URL: subgraphHttpURL, + StatusCode: upgradeResponse.StatusCode, + } + } return nil, err } conn.SetReadLimit(math.MaxInt32) - if upgradeResponse.StatusCode != http.StatusSwitchingProtocols { - return nil, fmt.Errorf("upgrade unsuccessful") - } - connectionInitMessage, err := c.getConnectionInitMessage(reqCtx, options.URL, options.Header) if err != nil { return nil, err @@ -382,8 +452,10 @@ func (c *subscriptionClient) newWSConnectionHandler(reqCtx context.Context, opti } } + netConn := conn.NetConn() + // init + ack - err = conn.WriteMessage(ws.TextMessage, connectionInitMessage) + err = wsutil.WriteClientText(netConn, connectionInitMessage) if err != nil { return nil, err } @@ -396,15 +468,15 @@ func (c *subscriptionClient) newWSConnectionHandler(reqCtx context.Context, opti } } - if err := waitForAck(conn); err != nil { + if err := waitForAck(netConn); err != nil { return nil, err } switch wsSubProtocol { case ProtocolGraphQLWS: - return newGQLWSConnectionHandler(c.engineCtx, conn, c.readTimeout, c.log), nil + return newGQLWSConnectionHandler(c.engineCtx, netConn, c.readTimeout, c.log), nil case ProtocolGraphQLTWS: - return newGQLTWSConnectionHandler(c.engineCtx, conn, c.readTimeout, c.log), nil + return newGQLTWSConnectionHandler(c.engineCtx, netConn, c.readTimeout, c.log), nil default: return nil, NewInvalidWsSubprotocolError(wsSubProtocol) } @@ -437,7 +509,7 @@ func (c *subscriptionClient) getConnectionInitMessage(ctx context.Context, url s type ConnectionHandler interface { StartBlocking(sub Subscription) NetConn() net.Conn - ReadMessage() (done bool) + ReadMessage() (done, timeout bool) ServerClose() ClientClose() Subscribe(sub Subscription) @@ -449,7 +521,7 @@ type Subscription struct { updater resolve.SubscriptionUpdater } -func waitForAck(conn *ws.Conn) error { +func waitForAck(conn net.Conn) error { timer := time.NewTimer(ackWaitTimeout) for { select { @@ -458,15 +530,12 @@ func waitForAck(conn *ws.Conn) error { default: } - msgType, msg, err := conn.ReadMessage() + data, err := wsutil.ReadServerText(conn) if err != nil { - return err - } - if msgType != ws.TextMessage { - return fmt.Errorf("unexpected message type") + return fmt.Errorf("failed to read message: %w", err) } - respType, err := jsonparser.GetString(msg, "type") + respType, err := jsonparser.GetString(data, "type") if err != nil { return err } @@ -475,7 +544,7 @@ func waitForAck(conn *ws.Conn) error { case messageTypeConnectionKeepAlive: continue case messageTypePing: - err := conn.WriteMessage(ws.TextMessage, []byte(pongMessage)) + err := wsutil.WriteClientText(conn, []byte(pongMessage)) if err != nil { return fmt.Errorf("failed to send pong message: %w", err) } @@ -490,14 +559,13 @@ func waitForAck(conn *ws.Conn) error { func (c *subscriptionClient) runEpoll(ctx context.Context) { done := ctx.Done() - tick := time.NewTicker(time.Millisecond * 10) + tick := time.NewTicker(time.Millisecond * 50) for { connections, err := c.epoll.Wait(50) if err != nil { c.log.Error("epoll.Wait", abstractlogger.Error(err)) return } - fmt.Printf("ePoll - time: %v, connections: %d\n", time.Now(), len(connections)) c.connectionsMu.Lock() for _, conn := range connections { id := socketFd(conn) @@ -506,48 +574,55 @@ func (c *subscriptionClient) runEpoll(ctx context.Context) { continue } c.activeConnectionsMu.Lock() - if i, ok := c.activeConnections[id]; ok { - fmt.Printf("connection %d is active, queueing event\n", id) - c.activeConnections[id] = i + 1 + _, active := c.activeConnections[id] + if !active { + c.activeConnections[id] = struct{}{} } c.activeConnectionsMu.Unlock() + if active { + continue + } go c.handleConnection(id, handler, conn) } c.connectionsMu.Unlock() + + if len(connections) == 50 { + // we have more connections to process, + continue + } + select { case <-done: return case <-tick.C: continue + case <-c.stopEpollSignal: + return } } } func (c *subscriptionClient) handleConnection(id int, handler ConnectionHandler, conn net.Conn) { - fmt.Printf("handling connection %d\n", id) - for { - done := handler.ReadMessage() - if done { - fmt.Printf("connection %d done\n", id) - c.connectionsMu.Lock() - delete(c.connections, id) - c.connectionsMu.Unlock() - handler.ServerClose() - _ = c.epoll.Remove(conn) - return - } + done, timeout := handler.ReadMessage() + if timeout { c.activeConnectionsMu.Lock() - if i, ok := c.activeConnections[id]; ok { - if i == 0 { - delete(c.activeConnections, id) - c.activeConnectionsMu.Unlock() - fmt.Printf("handleConnection: event queue empty, returning to ePoll for connection %d\n", id) - return - } - c.activeConnections[id] = i - 1 - fmt.Printf("handleConnection: event queue not empty, processing next event for connection %d\n", id) - } + delete(c.activeConnections, id) c.activeConnectionsMu.Unlock() + return + } + + if done { + c.activeConnectionsMu.Lock() + delete(c.activeConnections, id) + c.activeConnectionsMu.Unlock() + + c.connectionsMu.Lock() + delete(c.connections, id) + c.connectionsMu.Unlock() + + handler.ServerClose() + _ = c.epoll.Remove(conn) + return } } diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index 7f141ac4c..1f385d2fb 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -115,8 +115,14 @@ func TestWebsocketSubscriptionClientWithServerDisconnect(t *testing.T) { err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) assert.NoError(t, err) + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"stop","id":"1"}`, string(data)) + _, _, err = conn.Read(ctx) assert.Error(t, err) + close(serverDone) })) defer server.Close() @@ -478,142 +484,1122 @@ func TestSubprotocolNegotiationWithConfiguredGraphQLTransportWS(t *testing.T) { serverCancel() } -func TestSubscribeAsync(t *testing.T) { - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) +func TestAsyncSubscribe(t *testing.T) { + t.Parallel() + t.Run("subscribe async", func(t *testing.T) { + t.Parallel() + serverDone := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + assert.NoError(t, err) + ctx := context.Background() + msgType, data, err := conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"connection_init"}`, string(data)) + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) + assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) + time.Sleep(time.Second * 1) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) + assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() + time.Sleep(time.Second * 1) - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Second), - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) + assert.NoError(t, err) - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) + time.Sleep(time.Second * 1) - updater.AwaitUpdates(t, time.Second*5, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() -} + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) + assert.NoError(t, err) -func TestSubscribeAsyncServerTimeout(t *testing.T) { - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) + close(serverDone) + })) + defer server.Close() + ctx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + + client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, + WithReadTimeout(time.Second), + WithLogger(logger()), + ).(*subscriptionClient) + updater := &testSubscriptionUpdater{} + + err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ + URL: server.URL, + Body: GraphQLBody{ + Query: `subscription {messageAdded(roomName: "room"){text}}`, + }, + WsSubProtocol: ProtocolGraphQLTWS, + }, updater) assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) + + updater.AwaitUpdates(t, time.Second*10, 3) + assert.Equal(t, 3, len(updater.updates)) + assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) + assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) + assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) + client.Unsubscribe(1) + clientCancel() + assert.Eventuallyf(t, func() bool { + <-serverDone + return true + }, time.Second, time.Millisecond*10, "server did not close") + serverCancel() + }) + t.Run("server timeout", func(t *testing.T) { + t.Parallel() + serverDone := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + assert.NoError(t, err) + ctx := context.Background() + msgType, data, err := conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"connection_init"}`, string(data)) + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) + assert.NoError(t, err) + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) + assert.NoError(t, err) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) + assert.NoError(t, err) + + time.Sleep(time.Second * 2) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) + assert.NoError(t, err) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) + close(serverDone) + })) + defer server.Close() + ctx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + + client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, + WithReadTimeout(time.Second), + WithLogger(logger()), + ).(*subscriptionClient) + updater := &testSubscriptionUpdater{} + + err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ + URL: server.URL, + Body: GraphQLBody{ + Query: `subscription {messageAdded(roomName: "room"){text}}`, + }, + WsSubProtocol: ProtocolGraphQLTWS, + }, updater) assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) + updater.AwaitUpdates(t, time.Second*10, 3) + assert.Equal(t, 3, len(updater.updates)) + assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) + assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) + assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) + client.Unsubscribe(1) + clientCancel() + assert.Eventuallyf(t, func() bool { + <-serverDone + return true + }, time.Second, time.Millisecond*10, "server did not close") + serverCancel() + }) + t.Run("server complete", func(t *testing.T) { + t.Parallel() + serverDone := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + assert.NoError(t, err) + ctx := context.Background() + msgType, data, err := conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"connection_init"}`, string(data)) + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) + assert.NoError(t, err) + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) + assert.NoError(t, err) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"complete","id":"1"}`)) + assert.NoError(t, err) + close(serverDone) + })) + defer server.Close() + ctx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + + client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, + WithReadTimeout(time.Second), + WithLogger(logger()), + ).(*subscriptionClient) + updater := &testSubscriptionUpdater{} + + err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ + URL: server.URL, + Body: GraphQLBody{ + Query: `subscription {messageAdded(roomName: "room"){text}}`, + }, + WsSubProtocol: ProtocolGraphQLTWS, + }, updater) assert.NoError(t, err) - time.Sleep(time.Second * 1) + updater.AwaitUpdates(t, time.Second*10, 1) + assert.Equal(t, 1, len(updater.updates)) + assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) + client.Unsubscribe(1) + clientCancel() + assert.Eventuallyf(t, func() bool { + <-serverDone + return true + }, time.Second, time.Millisecond*10, "server did not close") + serverCancel() + }) + t.Run("server ka", func(t *testing.T) { + t.Parallel() + serverDone := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + assert.NoError(t, err) + ctx := context.Background() + msgType, data, err := conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"connection_init"}`, string(data)) + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) + assert.NoError(t, err) + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"ka"}`)) + assert.NoError(t, err) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) + assert.NoError(t, err) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"ka"}`)) + assert.NoError(t, err) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"complete","id":"1"}`)) + assert.NoError(t, err) + close(serverDone) + })) + defer server.Close() + ctx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + + client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, + WithReadTimeout(time.Second), + WithLogger(logger()), + ).(*subscriptionClient) + updater := &testSubscriptionUpdater{} + + err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ + URL: server.URL, + Body: GraphQLBody{ + Query: `subscription {messageAdded(roomName: "room"){text}}`, + }, + WsSubProtocol: ProtocolGraphQLTWS, + }, updater) assert.NoError(t, err) - time.Sleep(time.Millisecond) + updater.AwaitUpdates(t, time.Second*10, 1) + assert.Equal(t, 1, len(updater.updates)) + assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) + client.Unsubscribe(1) + clientCancel() + assert.Eventuallyf(t, func() bool { + <-serverDone + return true + }, time.Second, time.Millisecond*10, "server did not close") + serverCancel() + }) + t.Run("long timeout", func(t *testing.T) { + t.Parallel() + serverDone := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) + defer conn.Close(websocket.StatusNormalClosure, "done") + + ctx := context.Background() + msgType, data, err := conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"connection_init"}`, string(data)) + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) + assert.NoError(t, err) + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) + assert.NoError(t, err) + + time.Sleep(time.Second * 2) + + close(serverDone) + })) + defer server.Close() + ctx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + + client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, + WithReadTimeout(time.Second), + WithLogger(logger()), + ).(*subscriptionClient) + updater := &testSubscriptionUpdater{} + + err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ + URL: server.URL, + Body: GraphQLBody{ + Query: `subscription {messageAdded(roomName: "room"){text}}`, + }, + WsSubProtocol: ProtocolGraphQLTWS, + }, updater) assert.NoError(t, err) - time.Sleep(time.Millisecond) + updater.AwaitUpdates(t, time.Second*10, 1) + assert.Equal(t, 1, len(updater.updates)) + assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) + assert.Eventuallyf(t, func() bool { + <-serverDone + return true + }, time.Second*5, time.Millisecond*10, "server did not close") + time.Sleep(time.Second) + client.connectionsMu.Lock() + defer client.connectionsMu.Unlock() + assert.Equal(t, 0, len(client.connections)) + }) + t.Run("forever timeout", func(t *testing.T) { + t.Parallel() + globalCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) + defer conn.Close(websocket.StatusNormalClosure, "done") + + ctx := context.Background() + msgType, data, err := conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"connection_init"}`, string(data)) + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) + assert.NoError(t, err) + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) + assert.NoError(t, err) + <-globalCtx.Done() + })) + defer server.Close() + ctx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + + client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, + WithReadTimeout(time.Second), + WithLogger(logger()), + ).(*subscriptionClient) + updater := &testSubscriptionUpdater{} + + err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ + URL: server.URL, + Body: GraphQLBody{ + Query: `subscription {messageAdded(roomName: "room"){text}}`, + }, + WsSubProtocol: ProtocolGraphQLTWS, + }, updater) assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Second), - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} + updater.AwaitUpdates(t, time.Second*3, 1) + assert.Equal(t, 1, len(updater.updates)) + assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) + time.Sleep(time.Second * 2) + client.activeConnectionsMu.Lock() + defer client.activeConnectionsMu.Unlock() + assert.Equal(t, 0, len(client.activeConnections)) + }) + t.Run("graphql-ws", func(t *testing.T) { + t.Parallel() + t.Run("happy path", func(t *testing.T) { + serverDone := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + assert.NoError(t, err) + ctx := context.Background() + msgType, data, err := conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"connection_init"}`, string(data)) + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) + assert.NoError(t, err) - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) + time.Sleep(time.Second * 1) - updater.AwaitUpdates(t, time.Second*5, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"data","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) + assert.NoError(t, err) + + time.Sleep(time.Second * 1) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"data","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) + assert.NoError(t, err) + + time.Sleep(time.Second * 1) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"data","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) + assert.NoError(t, err) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"stop","id":"1"}`, string(data)) + close(serverDone) + })) + defer server.Close() + ctx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + + client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, + WithReadTimeout(time.Second), + WithLogger(logger()), + ).(*subscriptionClient) + updater := &testSubscriptionUpdater{} + + err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ + URL: server.URL, + Body: GraphQLBody{ + Query: `subscription {messageAdded(roomName: "room"){text}}`, + }, + WsSubProtocol: ProtocolGraphQLWS, + }, updater) + assert.NoError(t, err) + + updater.AwaitUpdates(t, time.Second*10, 3) + assert.Equal(t, 3, len(updater.updates)) + assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) + assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) + assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) + client.Unsubscribe(1) + clientCancel() + assert.Eventuallyf(t, func() bool { + <-serverDone + return true + }, time.Second, time.Millisecond*10, "server did not close") + serverCancel() + }) + t.Run("connection error", func(t *testing.T) { + t.Parallel() + serverDone := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + assert.NoError(t, err) + ctx := context.Background() + msgType, data, err := conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"connection_init"}`, string(data)) + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) + assert.NoError(t, err) + + time.Sleep(time.Second * 1) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"connection_error"}`)) + assert.NoError(t, err) + + _ = conn.Close(websocket.StatusNormalClosure, "done") + + close(serverDone) + })) + defer server.Close() + ctx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + + client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, + WithReadTimeout(time.Second), + WithLogger(logger()), + ).(*subscriptionClient) + updater := &testSubscriptionUpdater{} + + err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ + URL: server.URL, + Body: GraphQLBody{ + Query: `subscription {messageAdded(roomName: "room"){text}}`, + }, + WsSubProtocol: ProtocolGraphQLWS, + }, updater) + assert.NoError(t, err) + + updater.AwaitUpdates(t, time.Second*5, 1) + assert.Equal(t, 1, len(updater.updates)) + assert.Equal(t, `{"errors":[{"message":"connection error"}]}`, updater.updates[0]) + client.Unsubscribe(1) + clientCancel() + assert.Eventuallyf(t, func() bool { + <-serverDone + return true + }, time.Second, time.Millisecond*10, "server did not close") + serverCancel() + client.connectionsMu.Lock() + defer client.connectionsMu.Unlock() + assert.Equal(t, 0, len(client.connections)) + }) + t.Run("error object", func(t *testing.T) { + t.Parallel() + serverDone := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + assert.NoError(t, err) + ctx := context.Background() + msgType, data, err := conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"connection_init"}`, string(data)) + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) + assert.NoError(t, err) + + time.Sleep(time.Second * 1) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"error","payload":{"message":"ws error"}}`)) + assert.NoError(t, err) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"stop","id":"1"}`, string(data)) + close(serverDone) + })) + defer server.Close() + ctx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + + client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, + WithReadTimeout(time.Second), + WithLogger(logger()), + ).(*subscriptionClient) + updater := &testSubscriptionUpdater{} + + err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ + URL: server.URL, + Body: GraphQLBody{ + Query: `subscription {messageAdded(roomName: "room"){text}}`, + }, + WsSubProtocol: ProtocolGraphQLWS, + }, updater) + assert.NoError(t, err) + + updater.AwaitUpdates(t, time.Second*5, 1) + assert.Equal(t, 1, len(updater.updates)) + assert.Equal(t, `{"errors":[{"message":"ws error"}]}`, updater.updates[0]) + client.Unsubscribe(1) + clientCancel() + assert.Eventuallyf(t, func() bool { + <-serverDone + return true + }, time.Second, time.Millisecond*10, "server did not close") + serverCancel() + }) + t.Run("error array", func(t *testing.T) { + t.Parallel() + serverDone := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + assert.NoError(t, err) + ctx := context.Background() + msgType, data, err := conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"connection_init"}`, string(data)) + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) + assert.NoError(t, err) + + time.Sleep(time.Second * 1) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"error","payload":[{"message":"ws error"}]}`)) + assert.NoError(t, err) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"stop","id":"1"}`, string(data)) + close(serverDone) + })) + defer server.Close() + ctx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + + client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, + WithReadTimeout(time.Second), + WithLogger(logger()), + ).(*subscriptionClient) + updater := &testSubscriptionUpdater{} + + err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ + URL: server.URL, + Body: GraphQLBody{ + Query: `subscription {messageAdded(roomName: "room"){text}}`, + }, + WsSubProtocol: ProtocolGraphQLWS, + }, updater) + assert.NoError(t, err) + + updater.AwaitUpdates(t, time.Second*5, 1) + assert.Equal(t, 1, len(updater.updates)) + assert.Equal(t, `{"errors":[{"message":"ws error"}]}`, updater.updates[0]) + client.Unsubscribe(1) + clientCancel() + assert.Eventuallyf(t, func() bool { + <-serverDone + return true + }, time.Second, time.Millisecond*10, "server did not close") + serverCancel() + }) + }) + t.Run("graphql-transport-ws", func(t *testing.T) { + t.Parallel() + t.Run("happy path", func(t *testing.T) { + t.Parallel() + serverDone := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + assert.NoError(t, err) + ctx := context.Background() + msgType, data, err := conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"connection_init"}`, string(data)) + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) + assert.NoError(t, err) + + time.Sleep(time.Second * 1) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) + assert.NoError(t, err) + + time.Sleep(time.Second * 1) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) + assert.NoError(t, err) + + time.Sleep(time.Second * 1) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) + assert.NoError(t, err) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) + close(serverDone) + })) + defer server.Close() + ctx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + + client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, + WithReadTimeout(time.Second), + WithLogger(logger()), + ).(*subscriptionClient) + updater := &testSubscriptionUpdater{} + + err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ + URL: server.URL, + Body: GraphQLBody{ + Query: `subscription {messageAdded(roomName: "room"){text}}`, + }, + WsSubProtocol: ProtocolGraphQLTWS, + }, updater) + assert.NoError(t, err) + + updater.AwaitUpdates(t, time.Second*10, 3) + assert.Equal(t, 3, len(updater.updates)) + assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) + assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) + assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) + client.Unsubscribe(1) + clientCancel() + assert.Eventuallyf(t, func() bool { + <-serverDone + return true + }, time.Second, time.Millisecond*10, "server did not close") + serverCancel() + }) + t.Run("ping", func(t *testing.T) { + t.Parallel() + serverDone := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + assert.NoError(t, err) + ctx := context.Background() + msgType, data, err := conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"connection_init"}`, string(data)) + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) + assert.NoError(t, err) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) + assert.NoError(t, err) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"ping"}`)) + assert.NoError(t, err) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"pong"}`, string(data)) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) + assert.NoError(t, err) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) + assert.NoError(t, err) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) + close(serverDone) + })) + defer server.Close() + ctx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + + client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, + WithReadTimeout(time.Second), + WithLogger(logger()), + ).(*subscriptionClient) + updater := &testSubscriptionUpdater{} + + err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ + URL: server.URL, + Body: GraphQLBody{ + Query: `subscription {messageAdded(roomName: "room"){text}}`, + }, + WsSubProtocol: ProtocolGraphQLTWS, + }, updater) + assert.NoError(t, err) + + updater.AwaitUpdates(t, time.Second*10, 3) + assert.Equal(t, 3, len(updater.updates)) + assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) + assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) + assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) + client.Unsubscribe(1) + clientCancel() + assert.Eventuallyf(t, func() bool { + <-serverDone + return true + }, time.Second, time.Millisecond*10, "server did not close") + serverCancel() + }) + t.Run("ka", func(t *testing.T) { + t.Parallel() + serverDone := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + assert.NoError(t, err) + ctx := context.Background() + msgType, data, err := conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"connection_init"}`, string(data)) + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) + assert.NoError(t, err) + + time.Sleep(time.Second * 1) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) + assert.NoError(t, err) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"ka"}`)) + assert.NoError(t, err) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) + assert.NoError(t, err) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) + assert.NoError(t, err) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) + close(serverDone) + })) + defer server.Close() + ctx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + + client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, + WithReadTimeout(time.Second), + WithLogger(logger()), + ).(*subscriptionClient) + updater := &testSubscriptionUpdater{} + + err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ + URL: server.URL, + Body: GraphQLBody{ + Query: `subscription {messageAdded(roomName: "room"){text}}`, + }, + WsSubProtocol: ProtocolGraphQLTWS, + }, updater) + assert.NoError(t, err) + + updater.AwaitUpdates(t, time.Second*10, 3) + assert.Equal(t, 3, len(updater.updates)) + assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) + assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) + assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) + client.Unsubscribe(1) + clientCancel() + assert.Eventuallyf(t, func() bool { + <-serverDone + return true + }, time.Second, time.Millisecond*10, "server did not close") + serverCancel() + }) + t.Run("error object", func(t *testing.T) { + t.Parallel() + serverDone := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + assert.NoError(t, err) + ctx := context.Background() + msgType, data, err := conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"connection_init"}`, string(data)) + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) + assert.NoError(t, err) + + time.Sleep(time.Second * 1) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) + assert.NoError(t, err) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"error","payload":{"message":"ws error"}}`)) + assert.NoError(t, err) + + _ = conn.Close(websocket.StatusNormalClosure, "done") + + close(serverDone) + })) + defer server.Close() + ctx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + + client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, + WithReadTimeout(time.Second), + WithLogger(logger()), + ).(*subscriptionClient) + updater := &testSubscriptionUpdater{} + + err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ + URL: server.URL, + Body: GraphQLBody{ + Query: `subscription {messageAdded(roomName: "room"){text}}`, + }, + WsSubProtocol: ProtocolGraphQLTWS, + }, updater) + assert.NoError(t, err) + + updater.AwaitUpdates(t, time.Second*5, 2) + assert.Equal(t, 2, len(updater.updates)) + assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) + assert.Equal(t, `{"errors":[{"message":"ws error"}]}`, updater.updates[1]) + client.Unsubscribe(1) + clientCancel() + assert.Eventuallyf(t, func() bool { + <-serverDone + return true + }, time.Second, time.Millisecond*10, "server did not close") + serverCancel() + }) + t.Run("error array", func(t *testing.T) { + t.Parallel() + serverDone := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + assert.NoError(t, err) + ctx := context.Background() + msgType, data, err := conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"connection_init"}`, string(data)) + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) + assert.NoError(t, err) + + time.Sleep(time.Second * 1) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) + assert.NoError(t, err) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"error","payload":[{"message":"ws error"}]}`)) + assert.NoError(t, err) + + _ = conn.Close(websocket.StatusNormalClosure, "done") + + close(serverDone) + })) + defer server.Close() + ctx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + + client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, + WithReadTimeout(time.Second), + WithLogger(logger()), + ).(*subscriptionClient) + updater := &testSubscriptionUpdater{} + + err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ + URL: server.URL, + Body: GraphQLBody{ + Query: `subscription {messageAdded(roomName: "room"){text}}`, + }, + WsSubProtocol: ProtocolGraphQLTWS, + }, updater) + assert.NoError(t, err) + + updater.AwaitUpdates(t, time.Second*5, 2) + assert.Equal(t, 2, len(updater.updates)) + assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) + assert.Equal(t, `{"errors":[{"message":"ws error"}]}`, updater.updates[1]) + client.Unsubscribe(1) + clientCancel() + assert.Eventuallyf(t, func() bool { + <-serverDone + return true + }, time.Second, time.Millisecond*10, "server did not close") + serverCancel() + }) + t.Run("data error", func(t *testing.T) { + t.Parallel() + serverDone := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + assert.NoError(t, err) + ctx := context.Background() + msgType, data, err := conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"connection_init"}`, string(data)) + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) + assert.NoError(t, err) + + time.Sleep(time.Second * 1) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) + assert.NoError(t, err) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"data","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) + assert.NoError(t, err) + + close(serverDone) + })) + defer server.Close() + ctx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + + client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, + WithReadTimeout(time.Second), + WithLogger(logger()), + ).(*subscriptionClient) + updater := &testSubscriptionUpdater{} + + err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ + URL: server.URL, + Body: GraphQLBody{ + Query: `subscription {messageAdded(roomName: "room"){text}}`, + }, + WsSubProtocol: ProtocolGraphQLTWS, + }, updater) + assert.NoError(t, err) + updater.AwaitUpdates(t, time.Second*5, 1) + assert.Equal(t, 1, len(updater.updates)) + assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) + client.Unsubscribe(1) + clientCancel() + assert.Eventuallyf(t, func() bool { + <-serverDone + return true + }, time.Second, time.Millisecond*10, "server did not close") + serverCancel() + }) + t.Run("connection error", func(t *testing.T) { + t.Parallel() + serverDone := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + assert.NoError(t, err) + ctx := context.Background() + msgType, data, err := conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"connection_init"}`, string(data)) + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) + assert.NoError(t, err) + + time.Sleep(time.Second * 1) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) + assert.NoError(t, err) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"connection_error"}`)) + assert.NoError(t, err) + + close(serverDone) + })) + defer server.Close() + ctx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + + client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, + WithReadTimeout(time.Second), + WithLogger(logger()), + ).(*subscriptionClient) + updater := &testSubscriptionUpdater{} + + err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ + URL: server.URL, + Body: GraphQLBody{ + Query: `subscription {messageAdded(roomName: "room"){text}}`, + }, + WsSubProtocol: ProtocolGraphQLTWS, + }, updater) + assert.NoError(t, err) + updater.AwaitUpdates(t, time.Second*5, 1) + assert.Equal(t, 1, len(updater.updates)) + assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) + client.Unsubscribe(1) + clientCancel() + assert.Eventuallyf(t, func() bool { + <-serverDone + return true + }, time.Second, time.Millisecond*10, "server did not close") + serverCancel() + }) + }) } diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go index d80e8cb85..d7c0559a1 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go @@ -1,6 +1,7 @@ package graphql_datasource import ( + "bufio" "context" "encoding/json" "errors" @@ -11,7 +12,7 @@ import ( "strings" "time" - ws "github.com/gorilla/websocket" + "github.com/gobwas/ws/wsutil" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "github.com/buger/jsonparser" @@ -22,7 +23,7 @@ import ( // it is responsible for managing all subscriptions using the underlying WebSocket connection // if all Subscriptions are complete or cancelled/unsubscribed the handler will terminate type gqlTWSConnectionHandler struct { - conn *ws.Conn + conn net.Conn ctx context.Context log log.Logger subscribeCh chan Subscription @@ -32,20 +33,19 @@ type gqlTWSConnectionHandler struct { } func (h *gqlTWSConnectionHandler) ServerClose() { - fmt.Printf("ServerClose\n") for _, sub := range h.subscriptions { sub.updater.Done() } + _ = h.conn.Close() } func (h *gqlTWSConnectionHandler) ClientClose() { - fmt.Printf("ClientClose\n") for k, v := range h.subscriptions { v.updater.Done() delete(h.subscriptions, k) req := fmt.Sprintf(completeMessage, k) - err := h.conn.WriteMessage(ws.TextMessage, []byte(req)) + err := wsutil.WriteClientText(h.conn, []byte(req)) if err != nil { h.log.Error("failed to write complete message", log.Error(err)) } @@ -57,64 +57,78 @@ func (h *gqlTWSConnectionHandler) Subscribe(sub Subscription) { h.subscribe(sub) } -func (h *gqlTWSConnectionHandler) ReadMessage() (done bool) { - fmt.Printf("ReadMessage\n") +func (h *gqlTWSConnectionHandler) ReadMessage() (done, timeout bool) { - err := h.conn.SetReadDeadline(time.Now().Add(time.Second * 5)) - if err != nil { - fmt.Printf("SetReadDeadline error: %v\n", err) - return h.isConnectionClosed(err) - } - msgType, data, err := h.conn.ReadMessage() - if err != nil { - fmt.Printf("ReadMessage error: %v\n", err) - return h.isConnectionClosed(err) + r := bufio.NewReader(h.conn) + wr := bufio.NewWriter(h.conn) + rwr := bufio.NewReadWriter(r, wr) + + for { + + err := h.conn.SetReadDeadline(time.Now().Add(time.Second)) + if err != nil { + return h.handleConnectionError(err) + } + + data, err := wsutil.ReadServerText(rwr) + if err != nil { + return h.handleConnectionError(err) + } + + messageType, err := jsonparser.GetString(data, "type") + if err != nil { + return false, false + } + switch messageType { + case messageTypePing: + h.handleMessageTypePing() + continue + case messageTypeNext: + h.handleMessageTypeNext(data) + continue + case messageTypeComplete: + h.handleMessageTypeComplete(data) + return true, false + case messageTypeError: + h.handleMessageTypeError(data) + continue + case messageTypeConnectionKeepAlive: + continue + case messageTypeData, messageTypeConnectionError: + h.log.Error("Invalid subprotocol. The subprotocol should be set to graphql-transport-ws, but currently it is set to graphql-ws") + return true, false + default: + h.log.Error("unknown message type", log.String("type", messageType)) + return false, false + } } - fmt.Printf("ReadMessage messageType %v, data: %v\n", msgType, string(data)) - if msgType != ws.TextMessage { - return false +} + +func (h *gqlTWSConnectionHandler) handleConnectionError(err error) (closed, timeout bool) { + if errors.Is(err, context.DeadlineExceeded) { + return false, true } - messageType, err := jsonparser.GetString(data, "type") - if err != nil { - return false + netOpErr := &net.OpError{} + if errors.As(err, &netOpErr) { + if netOpErr.Timeout() { + return false, true + } + return true, false } - switch messageType { - case messageTypePing: - h.handleMessageTypePing() - return false - case messageTypeNext: - h.handleMessageTypeNext(data) - return false - case messageTypeComplete: - h.handleMessageTypeComplete(data) - return true - case messageTypeError: - h.handleMessageTypeError(data) - return false - case messageTypeConnectionKeepAlive: - return false - case messageTypeData, messageTypeConnectionError: - h.log.Error("Invalid subprotocol. The subprotocol should be set to graphql-transport-ws, but currently it is set to graphql-ws") - return true - default: - h.log.Error("unknown message type", log.String("type", messageType)) - return false + if errors.As(err, &wsutil.ClosedError{}) { + return true, false } -} - -func (h *gqlTWSConnectionHandler) isConnectionClosed(err error) bool { if strings.HasSuffix(err.Error(), "use of closed network connection") { - return true + return true, false } - fmt.Printf("isConnectionClosed: %v\n", err) - return false + return false, false } func (h *gqlTWSConnectionHandler) NetConn() net.Conn { - return h.conn.NetConn() + return h.conn } -func newGQLTWSConnectionHandler(ctx context.Context, conn *ws.Conn, rt time.Duration, l log.Logger) *gqlTWSConnectionHandler { +func newGQLTWSConnectionHandler(ctx context.Context, conn net.Conn, rt time.Duration, l log.Logger) *gqlTWSConnectionHandler { return &gqlTWSConnectionHandler{ conn: conn, ctx: ctx, @@ -214,7 +228,7 @@ func (h *gqlTWSConnectionHandler) unsubscribe(subscriptionID string) { delete(h.subscriptions, subscriptionID) req := fmt.Sprintf(completeMessage, subscriptionID) - err := h.conn.WriteMessage(ws.TextMessage, []byte(req)) + err := wsutil.WriteClientText(h.conn, []byte(req)) if err != nil { h.log.Error("failed to write complete message", log.Error(err)) } @@ -231,11 +245,8 @@ func (h *gqlTWSConnectionHandler) subscribe(sub Subscription) { h.nextSubscriptionID++ subscriptionID := strconv.Itoa(h.nextSubscriptionID) - - fmt.Printf("subscribe with subscriptionID: %s\n", subscriptionID) - subscribeRequest := fmt.Sprintf(subscribeMessage, subscriptionID, string(graphQLBody)) - err = h.conn.WriteMessage(ws.TextMessage, []byte(subscribeRequest)) + err = wsutil.WriteClientText(h.conn, []byte(subscribeRequest)) if err != nil { h.log.Error("failed to write subscribe message", log.Error(err)) return @@ -299,13 +310,21 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeError(data []byte) { return } sub.updater.Update(response) + case jsonparser.Object: + response := []byte(`{"errors":[]}`) + response, err = jsonparser.Set(response, value, "errors", "[0]") + if err != nil { + sub.updater.Update([]byte(internalError)) + return + } + sub.updater.Update(response) default: sub.updater.Update([]byte(internalError)) } } func (h *gqlTWSConnectionHandler) handleMessageTypePing() { - err := h.conn.WriteMessage(ws.TextMessage, []byte(pongMessage)) + err := wsutil.WriteClientText(h.conn, []byte(pongMessage)) if err != nil { h.log.Error("failed to write pong message", log.Error(err)) } @@ -339,7 +358,7 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeNext(data []byte) { // we'll block forever on reading until the context of the gqlTWSConnectionHandler stops func (h *gqlTWSConnectionHandler) readBlocking(ctx context.Context, dataCh chan []byte, errCh chan error) { for { - msgType, data, err := h.conn.ReadMessage() + data, err := wsutil.ReadServerText(h.conn) if err != nil { select { case errCh <- err: @@ -347,9 +366,6 @@ func (h *gqlTWSConnectionHandler) readBlocking(ctx context.Context, dataCh chan } return } - if msgType != ws.TextMessage { - continue - } select { case dataCh <- data: case <-ctx.Done(): diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go index 1cdb941dc..81bc59306 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go @@ -9,8 +9,6 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "github.com/wundergraph/graphql-go-tools/v2/pkg/testing/flags" @@ -240,96 +238,6 @@ func TestWebsocketSubscriptionClientError_GQLTWS(t *testing.T) { }, time.Second, time.Millisecond*10, "server did not close") } -func TestWebSocketSubscriptionClientInitIncludePing_GQLTWS(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - assertion := require.New(t) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"graphql-transport-ws"}, - }) - assertion.NoError(err) - - // write "ping" every second - go func() { - for { - err := conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"ping"}`)) - if err != nil { - break - } - time.Sleep(time.Second) - } - }() - - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assertion.NoError(err) - - assertion.Equal(websocket.MessageText, msgType) - assertion.Equal(`{"type":"connection_init"}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assertion.NoError(err) - - msgType, data, err = conn.Read(ctx) - assertion.NoError(err) - assertion.Equal(websocket.MessageText, msgType) - assertion.Equal(`{"type":"pong"}`, string(data)) - - msgType, data, err = conn.Read(ctx) - assertion.NoError(err) - assertion.Equal(websocket.MessageText, msgType) - assertion.Equal(`{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assertion.NoError(err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assertion.NoError(err) - - msgType, data, err = conn.Read(ctx) - assertion.NoError(err) - assertion.Equal(websocket.MessageText, msgType) - assertion.Equal(`{"id":"1","type":"complete"}`, string(data)) - close(serverDone) - })) - - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - serverCtx, serverCancel := context.WithCancel(context.Background()) - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - assertion.NoError(err) - }() - - updater.AwaitUpdates(t, time.Second, 2) - assertion.Equal(2, len(updater.updates)) - assertion.Equal(`{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assertion.Equal(`{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - - clientCancel() - assertion.Eventuallyf(func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() -} - func TestWebsocketSubscriptionClient_GQLTWS_Upstream_Dies(t *testing.T) { if flags.IsWindows { t.Skip("skipping test on windows") @@ -399,7 +307,7 @@ func TestWebsocketSubscriptionClient_GQLTWS_Upstream_Dies(t *testing.T) { // Kill the upstream here. We should get an End-of-File error. assert.NoError(t, wrappedListener.underlyingConnection.Close()) updater.AwaitUpdates(t, time.Second, 2) - assert.Equal(t, `{"errors":[{"message":"failed to get reader: failed to read frame header: EOF"}]}`, updater.updates[1]) + assert.Equal(t, `{"errors":[{"message":"EOF"}]}`, updater.updates[1]) clientCancel() serverCancel() diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go index 421477ef8..ec63d3df5 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go @@ -1,6 +1,7 @@ package graphql_datasource import ( + "bufio" "context" "encoding/json" "errors" @@ -11,7 +12,7 @@ import ( "strings" "time" - ws "github.com/gorilla/websocket" + "github.com/gobwas/ws/wsutil" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "github.com/buger/jsonparser" @@ -22,7 +23,7 @@ import ( // it is responsible for managing all subscriptions using the underlying WebSocket connection // if all Subscriptions are complete or cancelled/unsubscribed the handler will terminate type gqlWSConnectionHandler struct { - conn *ws.Conn + conn net.Conn ctx context.Context log abstractlogger.Logger // log slog.Logger @@ -36,6 +37,7 @@ func (h *gqlWSConnectionHandler) ServerClose() { for _, sub := range h.subscriptions { sub.updater.Done() } + _ = h.conn.Close() } func (h *gqlWSConnectionHandler) ClientClose() { @@ -43,7 +45,7 @@ func (h *gqlWSConnectionHandler) ClientClose() { v.updater.Done() delete(h.subscriptions, k) stopRequest := fmt.Sprintf(stopMessage, k) - _ = h.conn.WriteMessage(ws.TextMessage, []byte(stopRequest)) + _ = wsutil.WriteClientText(h.conn, []byte(stopRequest)) } _ = h.conn.Close() } @@ -52,54 +54,71 @@ func (h *gqlWSConnectionHandler) Subscribe(sub Subscription) { h.subscribe(sub) } -func (h *gqlWSConnectionHandler) ReadMessage() (done bool) { +func (h *gqlWSConnectionHandler) ReadMessage() (done, timeout bool) { + + r := bufio.NewReader(h.conn) + wr := bufio.NewWriter(h.conn) + rwr := bufio.NewReadWriter(r, wr) + for { - err := h.conn.SetReadDeadline(time.Now().Add(time.Second * 5)) + err := h.conn.SetReadDeadline(time.Now().Add(time.Second)) if err != nil { - return h.isConnectionClosed(err) + return h.handleConnectionError(err) } - msgType, data, err := h.conn.ReadMessage() + data, err := wsutil.ReadServerText(rwr) if err != nil { - return h.isConnectionClosed(err) - } - if msgType != ws.TextMessage { - return false + return h.handleConnectionError(err) } messageType, err := jsonparser.GetString(data, "type") if err != nil { - return false + return false, false } switch messageType { + case messageTypeConnectionKeepAlive: + continue case messageTypeData: h.handleMessageTypeData(data) continue case messageTypeComplete: h.handleMessageTypeComplete(data) - return true + return true, false case messageTypeConnectionError: h.handleMessageTypeConnectionError() - return true + return true, false case messageTypeError: h.handleMessageTypeError(data) continue default: - return false + return true, false } } } -func (h *gqlWSConnectionHandler) isConnectionClosed(err error) bool { +func (h *gqlWSConnectionHandler) handleConnectionError(err error) (closed, timeout bool) { + if errors.Is(err, context.DeadlineExceeded) { + return false, true + } + netOpErr := &net.OpError{} + if errors.As(err, &netOpErr) { + if netOpErr.Timeout() { + return false, true + } + return true, false + } + if errors.As(err, &wsutil.ClosedError{}) { + return true, false + } if strings.HasSuffix(err.Error(), "use of closed network connection") { - return true + return true, false } - return false + return false, false } func (h *gqlWSConnectionHandler) NetConn() net.Conn { - return h.conn.NetConn() + return h.conn } -func newGQLWSConnectionHandler(ctx context.Context, conn *ws.Conn, readTimeout time.Duration, log abstractlogger.Logger) *gqlWSConnectionHandler { +func newGQLWSConnectionHandler(ctx context.Context, conn net.Conn, readTimeout time.Duration, log abstractlogger.Logger) *gqlWSConnectionHandler { return &gqlWSConnectionHandler{ conn: conn, ctx: ctx, @@ -191,7 +210,7 @@ func (h *gqlWSConnectionHandler) StartBlocking(sub Subscription) { // we'll block forever on reading until the context of the gqlWSConnectionHandler stops func (h *gqlWSConnectionHandler) readBlocking(ctx context.Context, dataCh chan []byte, errCh chan error) { for { - msgType, data, err := h.conn.ReadMessage() + data, err := wsutil.ReadServerText(h.conn) if err != nil { select { case errCh <- err: @@ -199,9 +218,6 @@ func (h *gqlWSConnectionHandler) readBlocking(ctx context.Context, dataCh chan [ } return } - if msgType != ws.TextMessage { - continue - } select { case dataCh <- data: case <-ctx.Done(): @@ -229,7 +245,7 @@ func (h *gqlWSConnectionHandler) subscribe(sub Subscription) { subscriptionID := strconv.Itoa(h.nextSubscriptionID) startRequest := fmt.Sprintf(startMessage, subscriptionID, string(graphQLBody)) - err = h.conn.WriteMessage(ws.TextMessage, []byte(startRequest)) + err = wsutil.WriteClientText(h.conn, []byte(startRequest)) if err != nil { return } @@ -324,7 +340,7 @@ func (h *gqlWSConnectionHandler) unsubscribe(subscriptionID string) { sub.updater.Done() delete(h.subscriptions, subscriptionID) stopRequest := fmt.Sprintf(stopMessage, subscriptionID) - _ = h.conn.WriteMessage(ws.TextMessage, []byte(stopRequest)) + _ = wsutil.WriteClientText(h.conn, []byte(stopRequest)) } func (h *gqlWSConnectionHandler) checkActiveSubscriptions() (hasActiveSubscriptions bool) { diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go index 8255642f9..23242c19e 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go @@ -343,7 +343,7 @@ func TestWebsocketSubscriptionClient_GQLWS_Upstream_Dies(t *testing.T) { // Kill the upstream here. We should get an End-of-File error. assert.NoError(t, wrappedListener.underlyingConnection.Close()) updater.AwaitUpdates(t, time.Second, 2) - assert.Equal(t, `{"errors":[{"message":"failed to get reader: failed to read frame header: EOF"}]}`, updater.updates[1]) + assert.Equal(t, `{"errors":[{"message":"EOF"}]}`, updater.updates[1]) serverCancel() clientCancel() diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 927087df9..020767d7e 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -577,8 +577,8 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) if async, ok := add.resolve.Trigger.Source.(AsyncSubscriptionDataSource); ok { trig.cancel = func() { - async.AsyncStop(triggerID) cancel() + async.AsyncStop(triggerID) } err = async.AsyncStart(cloneCtx, triggerID, add.input, updater) } else { From bbdd192aa5477dd94cd98c551452574d9e653c48 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 17 Oct 2024 19:32:45 +0200 Subject: [PATCH 04/31] chore: add error handling to subscribe --- .../graphql_subscription_client.go | 9 ++++++--- .../graphql_datasource/graphql_tws_handler.go | 13 ++++++------- .../graphql_datasource/graphql_ws_handler.go | 11 ++++++----- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index 385f0ef95..424351f9a 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -263,6 +263,11 @@ func (c *subscriptionClient) asyncSubscribeWS(reqCtx *resolve.Context, id uint64 return err } + err = handler.Subscribe(sub) + if err != nil { + return err + } + netConn := handler.NetConn() if err := c.epoll.Add(netConn); err != nil { return err @@ -280,8 +285,6 @@ func (c *subscriptionClient) asyncSubscribeWS(reqCtx *resolve.Context, id uint64 go c.runEpoll(c.engineCtx) } - handler.Subscribe(sub) - return nil } @@ -512,7 +515,7 @@ type ConnectionHandler interface { ReadMessage() (done, timeout bool) ServerClose() ClientClose() - Subscribe(sub Subscription) + Subscribe(sub Subscription) error } type Subscription struct { diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go index d7c0559a1..4a15d241e 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go @@ -53,8 +53,8 @@ func (h *gqlTWSConnectionHandler) ClientClose() { _ = h.conn.Close() } -func (h *gqlTWSConnectionHandler) Subscribe(sub Subscription) { - h.subscribe(sub) +func (h *gqlTWSConnectionHandler) Subscribe(sub Subscription) error { + return h.subscribe(sub) } func (h *gqlTWSConnectionHandler) ReadMessage() (done, timeout bool) { @@ -235,11 +235,10 @@ func (h *gqlTWSConnectionHandler) unsubscribe(subscriptionID string) { } // subscribe adds a new Subscription to the gqlTWSConnectionHandler and sends the subscribeMessage to the origin -func (h *gqlTWSConnectionHandler) subscribe(sub Subscription) { +func (h *gqlTWSConnectionHandler) subscribe(sub Subscription) error { graphQLBody, err := json.Marshal(sub.options.Body) if err != nil { - h.log.Error("failed to marshal GraphQL body", log.Error(err)) - return + return err } h.nextSubscriptionID++ @@ -248,11 +247,11 @@ func (h *gqlTWSConnectionHandler) subscribe(sub Subscription) { subscribeRequest := fmt.Sprintf(subscribeMessage, subscriptionID, string(graphQLBody)) err = wsutil.WriteClientText(h.conn, []byte(subscribeRequest)) if err != nil { - h.log.Error("failed to write subscribe message", log.Error(err)) - return + return err } h.subscriptions[subscriptionID] = sub + return nil } func (h *gqlTWSConnectionHandler) broadcastErrorMessage(err error) { diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go index ec63d3df5..a05b5173a 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go @@ -50,8 +50,8 @@ func (h *gqlWSConnectionHandler) ClientClose() { _ = h.conn.Close() } -func (h *gqlWSConnectionHandler) Subscribe(sub Subscription) { - h.subscribe(sub) +func (h *gqlWSConnectionHandler) Subscribe(sub Subscription) error { + return h.subscribe(sub) } func (h *gqlWSConnectionHandler) ReadMessage() (done, timeout bool) { @@ -234,10 +234,10 @@ func (h *gqlWSConnectionHandler) unsubscribeAllAndCloseConn() { } // subscribe adds a new Subscription to the gqlWSConnectionHandler and sends the startMessage to the origin -func (h *gqlWSConnectionHandler) subscribe(sub Subscription) { +func (h *gqlWSConnectionHandler) subscribe(sub Subscription) error { graphQLBody, err := json.Marshal(sub.options.Body) if err != nil { - return + return err } h.nextSubscriptionID++ @@ -247,10 +247,11 @@ func (h *gqlWSConnectionHandler) subscribe(sub Subscription) { startRequest := fmt.Sprintf(startMessage, subscriptionID, string(graphQLBody)) err = wsutil.WriteClientText(h.conn, []byte(startRequest)) if err != nil { - return + return err } h.subscriptions[subscriptionID] = sub + return nil } func (h *gqlWSConnectionHandler) handleMessageTypeData(data []byte) { From 9586289281bf227b29c4fa9ca83899223030e4f3 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 17 Oct 2024 22:30:55 +0200 Subject: [PATCH 05/31] chore: cleanup subscriptions impl --- .../graphql_datasource/graphql_sse_handler.go | 61 +++--- .../graphql_subscription_client.go | 75 +++----- .../graphql_datasource/graphql_tws_handler.go | 180 ++++++++---------- .../graphql_datasource/graphql_ws_handler.go | 174 +++++++---------- 4 files changed, 211 insertions(+), 279 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler.go index 4a98ea89c..dcf687cca 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler.go @@ -25,33 +25,34 @@ var ( ) type gqlSSEConnectionHandler struct { - conn *http.Client - ctx context.Context - log log.Logger - options GraphQLSubscriptionOptions + conn *http.Client + requestContext, engineContext context.Context + log log.Logger + options GraphQLSubscriptionOptions + updater resolve.SubscriptionUpdater } -func newSSEConnectionHandler(ctx *resolve.Context, conn *http.Client, opts GraphQLSubscriptionOptions, l log.Logger) *gqlSSEConnectionHandler { +func newSSEConnectionHandler(requestContext, engineContext context.Context, conn *http.Client, updater resolve.SubscriptionUpdater, options GraphQLSubscriptionOptions, l log.Logger) *gqlSSEConnectionHandler { return &gqlSSEConnectionHandler{ - conn: conn, - ctx: ctx.Context(), - log: l, - options: opts, + conn: conn, + requestContext: requestContext, + engineContext: engineContext, + log: l, + updater: updater, + options: options, } } -func (h *gqlSSEConnectionHandler) StartBlocking(sub Subscription) { - reqCtx := sub.ctx - +func (h *gqlSSEConnectionHandler) StartBlocking() { dataCh := make(chan []byte) errCh := make(chan []byte) defer func() { close(dataCh) close(errCh) - sub.updater.Done() + h.updater.Done() }() - go h.subscribe(reqCtx, sub, dataCh, errCh) + go h.subscribe(dataCh, errCh) ticker := time.NewTicker(resolve.HearbeatInterval) defer ticker.Stop() @@ -59,31 +60,33 @@ func (h *gqlSSEConnectionHandler) StartBlocking(sub Subscription) { for { select { case <-ticker.C: - sub.updater.Heartbeat() + h.updater.Heartbeat() case data := <-dataCh: ticker.Reset(resolve.HearbeatInterval) - sub.updater.Update(data) + h.updater.Update(data) case data := <-errCh: ticker.Reset(resolve.HearbeatInterval) - sub.updater.Update(data) + h.updater.Update(data) + return + case <-h.requestContext.Done(): return - case <-reqCtx.Done(): + case <-h.engineContext.Done(): return } } } -func (h *gqlSSEConnectionHandler) subscribe(ctx context.Context, sub Subscription, dataCh, errCh chan []byte) { - resp, err := h.performSubscriptionRequest(ctx) +func (h *gqlSSEConnectionHandler) subscribe(dataCh, errCh chan []byte) { + resp, err := h.performSubscriptionRequest() if err != nil { h.log.Error("failed to perform subscription request", log.Error(err)) - if ctx.Err() != nil { + if h.requestContext.Err() != nil { // request context was canceled do not send an error as channel will be closed return } - sub.updater.Update([]byte(internalError)) + h.updater.Update([]byte(internalError)) return } @@ -94,8 +97,12 @@ func (h *gqlSSEConnectionHandler) subscribe(ctx context.Context, sub Subscriptio reader := sse.NewEventStreamReader(resp.Body, math.MaxInt) for { - if ctx.Err() != nil { + select { + case <-h.requestContext.Done(): + return + case <-h.engineContext.Done(): return + default: } msg, err := reader.ReadEvent() @@ -126,7 +133,7 @@ func (h *gqlSSEConnectionHandler) subscribe(ctx context.Context, sub Subscriptio continue } - if ctx.Err() != nil { + if h.requestContext.Err() != nil { // request context was canceled do not send an error as channel will be closed return } @@ -205,16 +212,16 @@ func trim(data []byte) []byte { return data } -func (h *gqlSSEConnectionHandler) performSubscriptionRequest(ctx context.Context) (*http.Response, error) { +func (h *gqlSSEConnectionHandler) performSubscriptionRequest() (*http.Response, error) { var req *http.Request var err error // default to GET requests when SSEMethodPost is not enabled in the SubscriptionConfiguration if h.options.SSEMethodPost { - req, err = h.buildPOSTRequest(ctx) + req, err = h.buildPOSTRequest(h.requestContext) } else { - req, err = h.buildGETRequest(ctx) + req, err = h.buildGETRequest(h.requestContext) } if err != nil { diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index 424351f9a..ca2501f1c 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -52,9 +52,9 @@ type subscriptionClient struct { triggers map[uint64]int } -func (c *subscriptionClient) SubscribeAsync(reqCtx *resolve.Context, id uint64, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { +func (c *subscriptionClient) SubscribeAsync(ctx *resolve.Context, id uint64, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { if options.UseSSE { - return c.subscribeSSE(reqCtx, options, updater) + return c.subscribeSSE(ctx.Context(), c.engineCtx, options, updater) } if strings.HasPrefix(options.URL, "https") { @@ -63,7 +63,7 @@ func (c *subscriptionClient) SubscribeAsync(reqCtx *resolve.Context, id uint64, options.URL = "ws" + options.URL[4:] } - return c.asyncSubscribeWS(reqCtx, id, options, updater) + return c.asyncSubscribeWS(ctx.Context(), c.engineCtx, id, options, updater) } func (c *subscriptionClient) Unsubscribe(id uint64) { @@ -179,12 +179,12 @@ func NewGraphQLSubscriptionClient(httpClient, streamingClient *http.Client, engi // If an existing WS connection with the same ID (Hash) exists, it is being re-used // If connection protocol is SSE, a new connection is always created // If no connection exists, the client initiates a new one -func (c *subscriptionClient) Subscribe(reqCtx *resolve.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { +func (c *subscriptionClient) Subscribe(ctx *resolve.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { if options.UseSSE { - return c.subscribeSSE(reqCtx, options, updater) + return c.subscribeSSE(ctx.Context(), c.engineCtx, options, updater) } - return c.subscribeWS(reqCtx, options, updater) + return c.subscribeWS(ctx.Context(), c.engineCtx, options, updater) } var ( @@ -208,62 +208,49 @@ func (c *subscriptionClient) UniqueRequestID(ctx *resolve.Context, options Graph return c.requestHash(ctx, options, hash) } -func (c *subscriptionClient) subscribeSSE(reqCtx *resolve.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { +func (c *subscriptionClient) subscribeSSE(requestContext, engineContext context.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { if c.streamingClient == nil { return fmt.Errorf("streaming http client is nil") } - sub := Subscription{ - ctx: reqCtx.Context(), - options: options, - updater: updater, - } - - handler := newSSEConnectionHandler(reqCtx, c.streamingClient, options, c.log) + handler := newSSEConnectionHandler(requestContext, engineContext, c.streamingClient, updater, options, c.log) - go handler.StartBlocking(sub) + go handler.StartBlocking() return nil } -func (c *subscriptionClient) subscribeWS(reqCtx *resolve.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { +func (c *subscriptionClient) subscribeWS(requestContext, engineContext context.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { if c.httpClient == nil { return fmt.Errorf("http client is nil") } - sub := Subscription{ - ctx: reqCtx.Context(), - options: options, - updater: updater, - } - - handler, err := c.newWSConnectionHandler(reqCtx.Context(), options) + handler, err := c.newWSConnectionHandler(requestContext, engineContext, options, updater) if err != nil { return err } - go handler.StartBlocking(sub) + go func() { + err := handler.StartBlocking() + if err != nil { + c.log.Error("subscriptionClient.subscribeWS", abstractlogger.Error(err)) + } + }() return nil } -func (c *subscriptionClient) asyncSubscribeWS(reqCtx *resolve.Context, id uint64, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { +func (c *subscriptionClient) asyncSubscribeWS(requestContext, engineContext context.Context, id uint64, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { if c.httpClient == nil { return fmt.Errorf("http client is nil") } - sub := Subscription{ - ctx: reqCtx.Context(), - options: options, - updater: updater, - } - - handler, err := c.newWSConnectionHandler(reqCtx.Context(), options) + handler, err := c.newWSConnectionHandler(requestContext, engineContext, options, updater) if err != nil { return err } - err = handler.Subscribe(sub) + err = handler.Subscribe() if err != nil { return err } @@ -363,7 +350,7 @@ func (u *UpgradeRequestError) Error() string { return fmt.Sprintf("failed to upgrade connection to %s, status code: %d", u.URL, u.StatusCode) } -func (c *subscriptionClient) newWSConnectionHandler(reqCtx context.Context, options GraphQLSubscriptionOptions) (ConnectionHandler, error) { +func (c *subscriptionClient) newWSConnectionHandler(requestContext, engineContext context.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) (ConnectionHandler, error) { var ( upgradeRequestHeader http.Header @@ -398,7 +385,7 @@ func (c *subscriptionClient) newWSConnectionHandler(reqCtx context.Context, opti // as a workaround we create a "dummy" request which we run through the http.Client with the context // we set the "SkipRoundTrip" header to true to signal the http.Client to not perform the request // but only to modify the request Headers - req, err := http.NewRequestWithContext(reqCtx, "GET", options.URL, nil) + req, err := http.NewRequestWithContext(requestContext, "GET", options.URL, nil) if err != nil { return nil, err } @@ -425,7 +412,7 @@ func (c *subscriptionClient) newWSConnectionHandler(reqCtx context.Context, opti Subprotocols: subProtocols, } - conn, upgradeResponse, err := dialer.DialContext(reqCtx, upgradeRequestURL, upgradeRequestHeader) + conn, upgradeResponse, err := dialer.DialContext(requestContext, upgradeRequestURL, upgradeRequestHeader) if err != nil { if upgradeResponse != nil && upgradeResponse.StatusCode != http.StatusSwitchingProtocols { return nil, &UpgradeRequestError{ @@ -436,7 +423,7 @@ func (c *subscriptionClient) newWSConnectionHandler(reqCtx context.Context, opti return nil, err } conn.SetReadLimit(math.MaxInt32) - connectionInitMessage, err := c.getConnectionInitMessage(reqCtx, options.URL, options.Header) + connectionInitMessage, err := c.getConnectionInitMessage(requestContext, options.URL, options.Header) if err != nil { return nil, err } @@ -477,9 +464,9 @@ func (c *subscriptionClient) newWSConnectionHandler(reqCtx context.Context, opti switch wsSubProtocol { case ProtocolGraphQLWS: - return newGQLWSConnectionHandler(c.engineCtx, netConn, c.readTimeout, c.log), nil + return newGQLWSConnectionHandler(requestContext, engineContext, netConn, options, updater, c.log), nil case ProtocolGraphQLTWS: - return newGQLTWSConnectionHandler(c.engineCtx, netConn, c.readTimeout, c.log), nil + return newGQLTWSConnectionHandler(requestContext, engineContext, netConn, options, updater, c.log), nil default: return nil, NewInvalidWsSubprotocolError(wsSubProtocol) } @@ -510,18 +497,12 @@ func (c *subscriptionClient) getConnectionInitMessage(ctx context.Context, url s } type ConnectionHandler interface { - StartBlocking(sub Subscription) + StartBlocking() error NetConn() net.Conn ReadMessage() (done, timeout bool) ServerClose() ClientClose() - Subscribe(sub Subscription) error -} - -type Subscription struct { - ctx context.Context - options GraphQLSubscriptionOptions - updater resolve.SubscriptionUpdater + Subscribe() error } func waitForAck(conn net.Conn) error { diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go index 4a15d241e..35e403229 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go @@ -6,9 +6,7 @@ import ( "encoding/json" "errors" "fmt" - "io" "net" - "strconv" "strings" "time" @@ -23,38 +21,30 @@ import ( // it is responsible for managing all subscriptions using the underlying WebSocket connection // if all Subscriptions are complete or cancelled/unsubscribed the handler will terminate type gqlTWSConnectionHandler struct { - conn net.Conn - ctx context.Context - log log.Logger - subscribeCh chan Subscription - nextSubscriptionID int - subscriptions map[string]Subscription - readTimeout time.Duration + conn net.Conn + requestContext, engineContext context.Context + log log.Logger + options GraphQLSubscriptionOptions + updater resolve.SubscriptionUpdater } func (h *gqlTWSConnectionHandler) ServerClose() { - for _, sub := range h.subscriptions { - sub.updater.Done() - } + h.updater.Done() _ = h.conn.Close() } func (h *gqlTWSConnectionHandler) ClientClose() { - for k, v := range h.subscriptions { - v.updater.Done() - delete(h.subscriptions, k) - - req := fmt.Sprintf(completeMessage, k) - err := wsutil.WriteClientText(h.conn, []byte(req)) - if err != nil { - h.log.Error("failed to write complete message", log.Error(err)) - } + h.updater.Done() + req := fmt.Sprintf(completeMessage, "1") + err := wsutil.WriteClientText(h.conn, []byte(req)) + if err != nil { + h.log.Error("failed to write complete message", log.Error(err)) } _ = h.conn.Close() } -func (h *gqlTWSConnectionHandler) Subscribe(sub Subscription) error { - return h.subscribe(sub) +func (h *gqlTWSConnectionHandler) Subscribe() error { + return h.subscribe() } func (h *gqlTWSConnectionHandler) ReadMessage() (done, timeout bool) { @@ -128,28 +118,31 @@ func (h *gqlTWSConnectionHandler) NetConn() net.Conn { return h.conn } -func newGQLTWSConnectionHandler(ctx context.Context, conn net.Conn, rt time.Duration, l log.Logger) *gqlTWSConnectionHandler { +func newGQLTWSConnectionHandler(requestContext, engineContext context.Context, conn net.Conn, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater, l log.Logger) *gqlTWSConnectionHandler { return &gqlTWSConnectionHandler{ - conn: conn, - ctx: ctx, - log: l, - nextSubscriptionID: 0, - subscriptions: map[string]Subscription{}, - readTimeout: rt, + conn: conn, + requestContext: requestContext, + engineContext: engineContext, + log: l, + updater: updater, + options: options, } } -func (h *gqlTWSConnectionHandler) StartBlocking(sub Subscription) { - readCtx, cancel := context.WithCancel(h.ctx) +func (h *gqlTWSConnectionHandler) StartBlocking() error { + readCtx, cancel := context.WithCancel(h.requestContext) dataCh := make(chan []byte) errCh := make(chan error) defer func() { - h.unsubscribeAllAndCloseConn() cancel() + h.unsubscribeAllAndCloseConn() }() - h.subscribe(sub) + err := h.subscribe() + if err != nil { + return err + } go h.readBlocking(readCtx, dataCh, errCh) @@ -157,31 +150,17 @@ func (h *gqlTWSConnectionHandler) StartBlocking(sub Subscription) { defer ticker.Stop() for { - err := h.ctx.Err() - if err != nil { - if !errors.Is(err, context.Canceled) && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { - h.log.Error("gqlWSConnectionHandler.StartBlocking", log.Error(err)) - } - h.broadcastErrorMessage(err) - return - } - - hasActiveSubscriptions := h.hasActiveSubscriptions() - if !hasActiveSubscriptions { - return - } - select { + case <-h.engineContext.Done(): + return h.engineContext.Err() case <-readCtx.Done(): - return - case <-time.After(h.readTimeout): - continue + return readCtx.Err() case err := <-errCh: h.log.Error("gqlWSConnectionHandler.StartBlocking", log.Error(err)) h.broadcastErrorMessage(err) - return + return err case <-ticker.C: - sub.updater.Heartbeat() + h.updater.Heartbeat() case data := <-dataCh: ticker.Reset(resolve.HearbeatInterval) messageType, err := jsonparser.GetString(data, "type") @@ -192,10 +171,13 @@ func (h *gqlTWSConnectionHandler) StartBlocking(sub Subscription) { switch messageType { case messageTypePing: h.handleMessageTypePing() + continue case messageTypeNext: h.handleMessageTypeNext(data) + continue case messageTypeComplete: h.handleMessageTypeComplete(data) + return nil case messageTypeError: h.handleMessageTypeError(data) continue @@ -203,7 +185,7 @@ func (h *gqlTWSConnectionHandler) StartBlocking(sub Subscription) { continue case messageTypeData, messageTypeConnectionError: h.log.Error("Invalid subprotocol. The subprotocol should be set to graphql-transport-ws, but currently it is set to graphql-ws") - return + return errors.New("invalid subprotocol") default: h.log.Error("unknown message type", log.String("type", messageType)) continue @@ -213,21 +195,13 @@ func (h *gqlTWSConnectionHandler) StartBlocking(sub Subscription) { } func (h *gqlTWSConnectionHandler) unsubscribeAllAndCloseConn() { - for id := range h.subscriptions { - h.unsubscribe(id) - } + h.unsubscribe() _ = h.conn.Close() } -func (h *gqlTWSConnectionHandler) unsubscribe(subscriptionID string) { - sub, ok := h.subscriptions[subscriptionID] - if !ok { - return - } - sub.updater.Done() - delete(h.subscriptions, subscriptionID) - - req := fmt.Sprintf(completeMessage, subscriptionID) +func (h *gqlTWSConnectionHandler) unsubscribe() { + h.updater.Done() + req := fmt.Sprintf(completeMessage, "1") err := wsutil.WriteClientText(h.conn, []byte(req)) if err != nil { h.log.Error("failed to write complete message", log.Error(err)) @@ -235,30 +209,22 @@ func (h *gqlTWSConnectionHandler) unsubscribe(subscriptionID string) { } // subscribe adds a new Subscription to the gqlTWSConnectionHandler and sends the subscribeMessage to the origin -func (h *gqlTWSConnectionHandler) subscribe(sub Subscription) error { - graphQLBody, err := json.Marshal(sub.options.Body) +func (h *gqlTWSConnectionHandler) subscribe() error { + graphQLBody, err := json.Marshal(h.options.Body) if err != nil { return err } - - h.nextSubscriptionID++ - - subscriptionID := strconv.Itoa(h.nextSubscriptionID) - subscribeRequest := fmt.Sprintf(subscribeMessage, subscriptionID, string(graphQLBody)) + subscribeRequest := fmt.Sprintf(subscribeMessage, "1", string(graphQLBody)) err = wsutil.WriteClientText(h.conn, []byte(subscribeRequest)) if err != nil { return err } - - h.subscriptions[subscriptionID] = sub return nil } func (h *gqlTWSConnectionHandler) broadcastErrorMessage(err error) { errMsg := fmt.Sprintf(errorMessageTemplate, err) - for _, sub := range h.subscriptions { - sub.updater.Update([]byte(errMsg)) - } + h.updater.Update([]byte(errMsg)) } func (h *gqlTWSConnectionHandler) handleMessageTypeComplete(data []byte) { @@ -266,12 +232,10 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeComplete(data []byte) { if err != nil { return } - sub, ok := h.subscriptions[id] - if !ok { + if id != "1" { return } - sub.updater.Done() - delete(h.subscriptions, id) + h.updater.Done() } func (h *gqlTWSConnectionHandler) handleMessageTypeError(data []byte) { @@ -279,11 +243,9 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeError(data []byte) { if err != nil { return } - sub, ok := h.subscriptions[id] - if !ok { + if id != "1" { return } - value, valueType, _, err := jsonparser.Get(data, "payload") if err != nil { h.log.Error( @@ -291,7 +253,7 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeError(data []byte) { log.Error(err), log.ByteString("raw message", data), ) - sub.updater.Update([]byte(internalError)) + h.updater.Update([]byte(internalError)) return } @@ -305,20 +267,20 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeError(data []byte) { log.Error(err), log.ByteString("raw message", value), ) - sub.updater.Update([]byte(internalError)) + h.updater.Update([]byte(internalError)) return } - sub.updater.Update(response) + h.updater.Update(response) case jsonparser.Object: response := []byte(`{"errors":[]}`) response, err = jsonparser.Set(response, value, "errors", "[0]") if err != nil { - sub.updater.Update([]byte(internalError)) + h.updater.Update([]byte(internalError)) return } - sub.updater.Update(response) + h.updater.Update(response) default: - sub.updater.Update([]byte(internalError)) + h.updater.Update([]byte(internalError)) } } @@ -334,31 +296,48 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeNext(data []byte) { if err != nil { return } - sub, ok := h.subscriptions[id] - if !ok { + if id != "1" { return } - value, _, _, err := jsonparser.Get(data, "payload") if err != nil { h.log.Error( "failed to get payload from next message", log.Error(err), ) - sub.updater.Update([]byte(internalError)) + h.updater.Update([]byte(internalError)) return } - sub.updater.Update(value) + h.updater.Update(value) } // readBlocking is a dedicated loop running in a separate goroutine // because the library "nhooyr.io/websocket" doesn't allow reading with a context with Timeout // we'll block forever on reading until the context of the gqlTWSConnectionHandler stops func (h *gqlTWSConnectionHandler) readBlocking(ctx context.Context, dataCh chan []byte, errCh chan error) { + netOpErr := &net.OpError{} for { + err := h.conn.SetReadDeadline(time.Now().Add(time.Second)) + if err != nil { + select { + case errCh <- err: + case <-ctx.Done(): + } + return + } data, err := wsutil.ReadServerText(h.conn) if err != nil { + if errors.As(err, &netOpErr) { + if netOpErr.Timeout() { + select { + case <-ctx.Done(): + return + default: + continue + } + } + } select { case errCh <- err: case <-ctx.Done(): @@ -372,12 +351,3 @@ func (h *gqlTWSConnectionHandler) readBlocking(ctx context.Context, dataCh chan } } } - -func (h *gqlTWSConnectionHandler) hasActiveSubscriptions() (hasActiveSubscriptions bool) { - for id, sub := range h.subscriptions { - if sub.ctx.Err() != nil { - h.unsubscribe(id) - } - } - return len(h.subscriptions) != 0 -} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go index a05b5173a..7fd709992 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go @@ -8,7 +8,6 @@ import ( "fmt" "io" "net" - "strconv" "strings" "time" @@ -23,35 +22,27 @@ import ( // it is responsible for managing all subscriptions using the underlying WebSocket connection // if all Subscriptions are complete or cancelled/unsubscribed the handler will terminate type gqlWSConnectionHandler struct { - conn net.Conn - ctx context.Context - log abstractlogger.Logger - // log slog.Logger - subscribeCh chan Subscription - nextSubscriptionID int - subscriptions map[string]Subscription - readTimeout time.Duration + conn net.Conn + requestContext, engineContext context.Context + log abstractlogger.Logger + options GraphQLSubscriptionOptions + updater resolve.SubscriptionUpdater } func (h *gqlWSConnectionHandler) ServerClose() { - for _, sub := range h.subscriptions { - sub.updater.Done() - } + h.updater.Done() _ = h.conn.Close() } func (h *gqlWSConnectionHandler) ClientClose() { - for k, v := range h.subscriptions { - v.updater.Done() - delete(h.subscriptions, k) - stopRequest := fmt.Sprintf(stopMessage, k) - _ = wsutil.WriteClientText(h.conn, []byte(stopRequest)) - } + h.updater.Done() + stopRequest := fmt.Sprintf(stopMessage, "1") + _ = wsutil.WriteClientText(h.conn, []byte(stopRequest)) _ = h.conn.Close() } -func (h *gqlWSConnectionHandler) Subscribe(sub Subscription) error { - return h.subscribe(sub) +func (h *gqlWSConnectionHandler) Subscribe() error { + return h.subscribe() } func (h *gqlWSConnectionHandler) ReadMessage() (done, timeout bool) { @@ -118,30 +109,33 @@ func (h *gqlWSConnectionHandler) NetConn() net.Conn { return h.conn } -func newGQLWSConnectionHandler(ctx context.Context, conn net.Conn, readTimeout time.Duration, log abstractlogger.Logger) *gqlWSConnectionHandler { +func newGQLWSConnectionHandler(requestContext, engineContext context.Context, conn net.Conn, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater, log abstractlogger.Logger) *gqlWSConnectionHandler { return &gqlWSConnectionHandler{ - conn: conn, - ctx: ctx, - log: log, - nextSubscriptionID: 0, - subscriptions: map[string]Subscription{}, - readTimeout: readTimeout, + conn: conn, + requestContext: requestContext, + engineContext: engineContext, + log: log, + updater: updater, + options: options, } } // StartBlocking starts the single threaded event loop of the handler // if the global context returns or the websocket connection is terminated, it will stop -func (h *gqlWSConnectionHandler) StartBlocking(sub Subscription) { +func (h *gqlWSConnectionHandler) StartBlocking() error { dataCh := make(chan []byte) errCh := make(chan error) - readCtx, cancel := context.WithCancel(h.ctx) + readCtx, cancel := context.WithCancel(h.requestContext) defer func() { - h.unsubscribeAllAndCloseConn() cancel() + h.unsubscribeAllAndCloseConn() }() - h.subscribe(sub) + err := h.subscribe() + if err != nil { + return err + } go h.readBlocking(readCtx, dataCh, errCh) @@ -149,31 +143,19 @@ func (h *gqlWSConnectionHandler) StartBlocking(sub Subscription) { defer ticker.Stop() for { - err := h.ctx.Err() - if err != nil { - if !errors.Is(err, context.Canceled) && !errors.Is(err, io.EOF) { - h.log.Error("gqlWSConnectionHandler.StartBlocking", abstractlogger.Error(err)) - } - h.broadcastErrorMessage(err) - return - } - hasActiveSubscriptions := h.checkActiveSubscriptions() - if !hasActiveSubscriptions { - return - } select { + case <-h.engineContext.Done(): + return h.engineContext.Err() case <-readCtx.Done(): - return - case <-time.After(h.readTimeout): - continue - case err = <-errCh: + return readCtx.Err() + case err := <-errCh: if !errors.Is(err, context.Canceled) && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { h.log.Error("gqlWSConnectionHandler.StartBlocking", abstractlogger.Error(err)) } h.broadcastErrorMessage(err) - return + return err case <-ticker.C: - sub.updater.Heartbeat() + h.updater.Heartbeat() case data := <-dataCh: ticker.Reset(resolve.HearbeatInterval) @@ -184,11 +166,13 @@ func (h *gqlWSConnectionHandler) StartBlocking(sub Subscription) { switch messageType { case messageTypeData: h.handleMessageTypeData(data) + continue case messageTypeComplete: h.handleMessageTypeComplete(data) + return nil case messageTypeConnectionError: h.handleMessageTypeConnectionError() - return + return nil case messageTypeError: h.handleMessageTypeError(data) continue @@ -196,7 +180,7 @@ func (h *gqlWSConnectionHandler) StartBlocking(sub Subscription) { continue case messageTypePing, messageTypeNext: h.log.Error("Invalid subprotocol. The subprotocol should be set to graphql-ws, but currently it is set to graphql-transport-ws") - return + return errors.New("invalid subprotocol") default: h.log.Error("unknown message type", abstractlogger.String("type", messageType)) continue @@ -209,9 +193,28 @@ func (h *gqlWSConnectionHandler) StartBlocking(sub Subscription) { // because the library "nhooyr.io/websocket" doesn't allow reading with a context with Timeout // we'll block forever on reading until the context of the gqlWSConnectionHandler stops func (h *gqlWSConnectionHandler) readBlocking(ctx context.Context, dataCh chan []byte, errCh chan error) { + netOpErr := &net.OpError{} for { + err := h.conn.SetReadDeadline(time.Now().Add(time.Second)) + if err != nil { + select { + case errCh <- err: + case <-ctx.Done(): + } + return + } data, err := wsutil.ReadServerText(h.conn) if err != nil { + if errors.As(err, &netOpErr) { + if netOpErr.Timeout() { + select { + case <-ctx.Done(): + return + default: + continue + } + } + } select { case errCh <- err: case <-ctx.Done(): @@ -227,30 +230,23 @@ func (h *gqlWSConnectionHandler) readBlocking(ctx context.Context, dataCh chan [ } func (h *gqlWSConnectionHandler) unsubscribeAllAndCloseConn() { - for id := range h.subscriptions { - h.unsubscribe(id) - } + h.unsubscribe() _ = h.conn.Close() } // subscribe adds a new Subscription to the gqlWSConnectionHandler and sends the startMessage to the origin -func (h *gqlWSConnectionHandler) subscribe(sub Subscription) error { - graphQLBody, err := json.Marshal(sub.options.Body) +func (h *gqlWSConnectionHandler) subscribe() error { + graphQLBody, err := json.Marshal(h.options.Body) if err != nil { return err } - h.nextSubscriptionID++ - - subscriptionID := strconv.Itoa(h.nextSubscriptionID) - - startRequest := fmt.Sprintf(startMessage, subscriptionID, string(graphQLBody)) + startRequest := fmt.Sprintf(startMessage, "1", string(graphQLBody)) err = wsutil.WriteClientText(h.conn, []byte(startRequest)) if err != nil { return err } - h.subscriptions[subscriptionID] = sub return nil } @@ -259,8 +255,7 @@ func (h *gqlWSConnectionHandler) handleMessageTypeData(data []byte) { if err != nil { return } - sub, ok := h.subscriptions[id] - if !ok { + if id != "1" { return } payload, _, _, err := jsonparser.Get(data, "payload") @@ -268,20 +263,16 @@ func (h *gqlWSConnectionHandler) handleMessageTypeData(data []byte) { return } - sub.updater.Update(payload) + h.updater.Update(payload) } func (h *gqlWSConnectionHandler) handleMessageTypeConnectionError() { - for _, sub := range h.subscriptions { - sub.updater.Update([]byte(connectionError)) - } + h.updater.Update([]byte(connectionError)) } func (h *gqlWSConnectionHandler) broadcastErrorMessage(err error) { errMsg := fmt.Sprintf(errorMessageTemplate, err) - for _, sub := range h.subscriptions { - sub.updater.Update([]byte(errMsg)) - } + h.updater.Update([]byte(errMsg)) } func (h *gqlWSConnectionHandler) handleMessageTypeComplete(data []byte) { @@ -289,12 +280,10 @@ func (h *gqlWSConnectionHandler) handleMessageTypeComplete(data []byte) { if err != nil { return } - sub, ok := h.subscriptions[id] - if !ok { + if id != "1" { return } - sub.updater.Done() - delete(h.subscriptions, id) + h.updater.Done() } func (h *gqlWSConnectionHandler) handleMessageTypeError(data []byte) { @@ -302,13 +291,12 @@ func (h *gqlWSConnectionHandler) handleMessageTypeError(data []byte) { if err != nil { return } - sub, ok := h.subscriptions[id] - if !ok { + if id != "1" { return } value, valueType, _, err := jsonparser.Get(data, "payload") if err != nil { - sub.updater.Update([]byte(internalError)) + h.updater.Update([]byte(internalError)) return } switch valueType { @@ -316,39 +304,25 @@ func (h *gqlWSConnectionHandler) handleMessageTypeError(data []byte) { response := []byte(`{}`) response, err = jsonparser.Set(response, value, "errors") if err != nil { - sub.updater.Update([]byte(internalError)) + h.updater.Update([]byte(internalError)) return } - sub.updater.Update(response) + h.updater.Update(response) case jsonparser.Object: response := []byte(`{"errors":[]}`) response, err = jsonparser.Set(response, value, "errors", "[0]") if err != nil { - sub.updater.Update([]byte(internalError)) + h.updater.Update([]byte(internalError)) return } - sub.updater.Update(response) + h.updater.Update(response) default: - sub.updater.Update([]byte(internalError)) + h.updater.Update([]byte(internalError)) } } -func (h *gqlWSConnectionHandler) unsubscribe(subscriptionID string) { - sub, ok := h.subscriptions[subscriptionID] - if !ok { - return - } - sub.updater.Done() - delete(h.subscriptions, subscriptionID) - stopRequest := fmt.Sprintf(stopMessage, subscriptionID) +func (h *gqlWSConnectionHandler) unsubscribe() { + h.updater.Done() + stopRequest := fmt.Sprintf(stopMessage, "1") _ = wsutil.WriteClientText(h.conn, []byte(stopRequest)) } - -func (h *gqlWSConnectionHandler) checkActiveSubscriptions() (hasActiveSubscriptions bool) { - for id, sub := range h.subscriptions { - if sub.ctx.Err() != nil { - h.unsubscribe(id) - } - } - return len(h.subscriptions) != 0 -} From 9d854ba5b1182f4e730ae25f2e6e42a7e8fd0e48 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 17 Oct 2024 22:38:56 +0200 Subject: [PATCH 06/31] chore: fix lint --- v2/pkg/internal/epoller/README.md | 8 ++++++++ v2/pkg/internal/epoller/epoll.go | 2 +- v2/pkg/internal/epoller/epoll_test.go | 14 +++++++------- v2/pkg/internal/epoller/fd_test.go | 2 +- 4 files changed, 17 insertions(+), 9 deletions(-) create mode 100644 v2/pkg/internal/epoller/README.md diff --git a/v2/pkg/internal/epoller/README.md b/v2/pkg/internal/epoller/README.md new file mode 100644 index 000000000..77b5c2548 --- /dev/null +++ b/v2/pkg/internal/epoller/README.md @@ -0,0 +1,8 @@ +# Epoller + +epoll implementation for connections in Linux, MacOS. + +Its target is implementing a simple epoll lib for network connections, so you should see it only contains few methods about net.Conn: + +This is a copy of [https://github.com/smallnest/epoller](https://github.com/smallnest/epoller) (v1.2.0) to remove Windows support and avoid the need for CGO. +On Windows, we handle websocket messages in a separate goroutine, without epoll. \ No newline at end of file diff --git a/v2/pkg/internal/epoller/epoll.go b/v2/pkg/internal/epoller/epoll.go index efaa8439b..1eea7eeba 100644 --- a/v2/pkg/internal/epoller/epoll.go +++ b/v2/pkg/internal/epoller/epoll.go @@ -24,7 +24,7 @@ func socketFD(conn net.Conn) int { return 0 } sfd := 0 - raw.Control(func(fd uintptr) { + raw.Control(func(fd uintptr) { // nolint: errcheck sfd = int(fd) }) return sfd diff --git a/v2/pkg/internal/epoller/epoll_test.go b/v2/pkg/internal/epoller/epoll_test.go index dd5336306..4230151af 100644 --- a/v2/pkg/internal/epoller/epoll_test.go +++ b/v2/pkg/internal/epoller/epoll_test.go @@ -32,7 +32,7 @@ func TestPoller(t *testing.T) { return } - poller.Add(conn) + poller.Add(conn) // nolint: errcheck } }() @@ -54,7 +54,7 @@ func TestPoller(t *testing.T) { t.Errorf("expect to write %d bytes but got %d bytes", len("hello world"), n) } } - conn.Close() + conn.Close() // nolint: errcheck }() } @@ -82,8 +82,8 @@ func TestPoller(t *testing.T) { n, err := conn.Read(buf) if err != nil { if err == io.EOF || errors.Is(err, net.ErrClosed) { - poller.Remove(conn) - conn.Close() + poller.Remove(conn) // nolint: errcheck + conn.Close() // nolint: errcheck } else { t.Error(err) } @@ -153,7 +153,7 @@ func TestPoller_growstack(t *testing.T) { return } - poller.Add(conn) + poller.Add(conn) // nolint: errcheck } }() @@ -164,7 +164,7 @@ func TestPoller_growstack(t *testing.T) { } time.Sleep(200 * time.Millisecond) for i := 0; i < 100; i++ { - conn.Write([]byte("hello world")) + conn.Write([]byte("hello world")) // nolint: errcheck } - conn.Close() + conn.Close() // nolint: errcheck } diff --git a/v2/pkg/internal/epoller/fd_test.go b/v2/pkg/internal/epoller/fd_test.go index 2595f04f8..db5be542b 100644 --- a/v2/pkg/internal/epoller/fd_test.go +++ b/v2/pkg/internal/epoller/fd_test.go @@ -23,7 +23,7 @@ func rawSocketFD(conn net.Conn) uint64 { return 0 } sfd := uint64(0) - raw.Control(func(fd uintptr) { + raw.Control(func(fd uintptr) { // nolint: errcheck sfd = uint64(fd) }) return sfd From b0bd97a47ab95156b43816683eefe337c0c90e6b Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 17 Oct 2024 22:45:24 +0200 Subject: [PATCH 07/31] chore: skip epoll if not supported --- .../graphql_datasource/graphql_subscription_client.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index ca2501f1c..0bb32487e 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -56,6 +56,9 @@ func (c *subscriptionClient) SubscribeAsync(ctx *resolve.Context, id uint64, opt if options.UseSSE { return c.subscribeSSE(ctx.Context(), c.engineCtx, options, updater) } + if c.epoll == nil { + return c.subscribeWS(ctx.Context(), c.engineCtx, options, updater) + } if strings.HasPrefix(options.URL, "https") { options.URL = "wss" + options.URL[5:] From 406ab3807ace49431f70d1f1e0ddec6b68f2cc15 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 17 Oct 2024 22:52:02 +0200 Subject: [PATCH 08/31] chore: skip epoll tests if not supported --- .../graphql_datasource/graphql_subscription_client.go | 3 --- .../graphql_datasource/graphql_subscription_client_test.go | 4 ++++ 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index 0bb32487e..ca2501f1c 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -56,9 +56,6 @@ func (c *subscriptionClient) SubscribeAsync(ctx *resolve.Context, id uint64, opt if options.UseSSE { return c.subscribeSSE(ctx.Context(), c.engineCtx, options, updater) } - if c.epoll == nil { - return c.subscribeWS(ctx.Context(), c.engineCtx, options, updater) - } if strings.HasPrefix(options.URL, "https") { options.URL = "wss" + options.URL[5:] diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index 1f385d2fb..95ab8c0f2 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "runtime" "testing" "time" @@ -485,6 +486,9 @@ func TestSubprotocolNegotiationWithConfiguredGraphQLTransportWS(t *testing.T) { } func TestAsyncSubscribe(t *testing.T) { + if runtime.GOOS == "windows" { + t.SkipNow() + } t.Parallel() t.Run("subscribe async", func(t *testing.T) { t.Parallel() From 05944f0b5d67ec74b1de705ea9f6e4016afb8395 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 17 Oct 2024 22:57:22 +0200 Subject: [PATCH 09/31] chore: skip epoll tests if not supported --- v2/pkg/internal/epoller/epoll_test.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/v2/pkg/internal/epoller/epoll_test.go b/v2/pkg/internal/epoller/epoll_test.go index 4230151af..c4e094594 100644 --- a/v2/pkg/internal/epoller/epoll_test.go +++ b/v2/pkg/internal/epoller/epoll_test.go @@ -5,6 +5,7 @@ import ( "io" "log" "net" + "runtime" "testing" "time" @@ -12,6 +13,11 @@ import ( ) func TestPoller(t *testing.T) { + + if runtime.GOOS == "windows" { + t.SkipNow() + } + // connections num := 10 // msg per connection From 071349cd86b2cc59aca4a650163208e5fe236e9c Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 17 Oct 2024 23:01:04 +0200 Subject: [PATCH 10/31] chore: skip epoll tests if not supported --- v2/pkg/internal/epoller/epoll_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/v2/pkg/internal/epoller/epoll_test.go b/v2/pkg/internal/epoller/epoll_test.go index c4e094594..7f08dae21 100644 --- a/v2/pkg/internal/epoller/epoll_test.go +++ b/v2/pkg/internal/epoller/epoll_test.go @@ -124,6 +124,9 @@ type netPoller struct { } func TestPoller_growstack(t *testing.T) { + if runtime.GOOS == "windows" { + t.SkipNow() + } var nps []netPoller for i := 0; i < 2; i++ { poller, err := NewPoller(128, time.Second) From d9b985cc14a7f3026a6798a05bf0a199f69fd10e Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 17 Oct 2024 23:23:44 +0200 Subject: [PATCH 11/31] chore: skip epoll tests if not supported --- .../graphql_datasource/graphql_subscription_client.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index ca2501f1c..5269c2f9a 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -255,6 +255,15 @@ func (c *subscriptionClient) asyncSubscribeWS(requestContext, engineContext cont return err } + if c.epoll == nil { + go func() { + err := handler.StartBlocking() + if err != nil { + c.log.Error("subscriptionClient.asyncSubscribeWS", abstractlogger.Error(err)) + } + }() + } + netConn := handler.NetConn() if err := c.epoll.Add(netConn); err != nil { return err From cf48a53e5ade8e0ecc869990974aa17ea5aad95e Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 17 Oct 2024 23:37:40 +0200 Subject: [PATCH 12/31] chore: skip epoll tests if not supported --- .../datasource/graphql_datasource/graphql_subscription_client.go | 1 + 1 file changed, 1 insertion(+) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index 5269c2f9a..0bdad2b9c 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -262,6 +262,7 @@ func (c *subscriptionClient) asyncSubscribeWS(requestContext, engineContext cont c.log.Error("subscriptionClient.asyncSubscribeWS", abstractlogger.Error(err)) } }() + return nil } netConn := handler.NetConn() From 7bb424fa6449c4e8b0804c822a15f817ccc00dab Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Fri, 18 Oct 2024 09:35:13 +0200 Subject: [PATCH 13/31] chore: fix race --- .../graphql_subscription_client.go | 31 +++++-------------- v2/pkg/engine/resolve/resolve.go | 1 - 2 files changed, 8 insertions(+), 24 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index 0bdad2b9c..c2d1c2a7c 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -40,8 +40,7 @@ type subscriptionClient struct { readTimeout time.Duration - epoll epoller.Poller - stopEpollSignal chan struct{} + epoll epoller.Poller connections map[int]ConnectionHandler connectionsMu sync.Mutex @@ -81,9 +80,6 @@ func (c *subscriptionClient) Unsubscribe(id uint64) { handler.ClientClose() delete(c.connections, fd) _ = c.epoll.Remove(handler.NetConn()) - if len(c.connections) == 0 { - close(c.stopEpollSignal) - } } type InvalidWsSubprotocolError struct { @@ -172,6 +168,9 @@ func NewGraphQLSubscriptionClient(httpClient, streamingClient *http.Client, engi triggers: make(map[uint64]int), useHttpClientWithSkipRoundTrip: op.useHttpClientWithSkipRoundTrip, } + if epoll != nil { + go client.runEpoll(engineCtx) + } return client } @@ -250,21 +249,15 @@ func (c *subscriptionClient) asyncSubscribeWS(requestContext, engineContext cont return err } + if c.epoll == nil { + return handler.StartBlocking() + } + err = handler.Subscribe() if err != nil { return err } - if c.epoll == nil { - go func() { - err := handler.StartBlocking() - if err != nil { - c.log.Error("subscriptionClient.asyncSubscribeWS", abstractlogger.Error(err)) - } - }() - return nil - } - netConn := handler.NetConn() if err := c.epoll.Add(netConn); err != nil { return err @@ -274,14 +267,8 @@ func (c *subscriptionClient) asyncSubscribeWS(requestContext, engineContext cont fd := socketFd(netConn) c.connections[fd] = handler c.triggers[id] = fd - count := len(c.connections) c.connectionsMu.Unlock() - if count == 1 { - c.stopEpollSignal = make(chan struct{}) - go c.runEpoll(c.engineCtx) - } - return nil } @@ -590,8 +577,6 @@ func (c *subscriptionClient) runEpoll(ctx context.Context) { return case <-tick.C: continue - case <-c.stopEpollSignal: - return } } } diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 020767d7e..f7fe7684e 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -574,7 +574,6 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) if r.options.Debug { fmt.Printf("resolver:trigger:start:%d\n", triggerID) } - if async, ok := add.resolve.Trigger.Source.(AsyncSubscriptionDataSource); ok { trig.cancel = func() { cancel() From 7609a0d7024502fdb8cf56ad3661dc7018b8c5a2 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Fri, 18 Oct 2024 13:33:19 +0200 Subject: [PATCH 14/31] chore: add tests for epoll disabled --- .../graphql_subscription_client.go | 71 ++++++-- .../graphql_subscription_client_test.go | 155 ++++++++++++++++++ 2 files changed, 215 insertions(+), 11 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index c2d1c2a7c..dd81082bc 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -2,6 +2,7 @@ package graphql_datasource import ( "context" + "errors" "fmt" "math" "net" @@ -40,7 +41,8 @@ type subscriptionClient struct { readTimeout time.Duration - epoll epoller.Poller + epoll epoller.Poller + epollConfig EpollConfiguration connections map[int]ConnectionHandler connectionsMu sync.Mutex @@ -116,11 +118,37 @@ func UseHttpClientWithSkipRoundTrip() Options { } } +type EpollConfiguration struct { + Disable bool + BufferSize int + Interval time.Duration + Wait time.Duration +} + +func (e *EpollConfiguration) ApplyDefaults() { + if e.BufferSize == 0 { + e.BufferSize = 1024 + } + if e.Interval == 0 { + e.Interval = time.Millisecond * 100 + } + if e.Wait == 0 { + e.Wait = time.Millisecond * 100 + } +} + +func WithEpollConfiguration(config EpollConfiguration) Options { + return func(options *opts) { + options.epollConfiguration = config + } +} + type opts struct { readTimeout time.Duration log abstractlogger.Logger onWsConnectionInitCallback *OnWsConnectionInitCallback useHttpClientWithSkipRoundTrip bool + epollConfiguration EpollConfiguration } // GraphQLSubscriptionClientFactory abstracts the way of creating a new GraphQLSubscriptionClient. @@ -148,8 +176,7 @@ func NewGraphQLSubscriptionClient(httpClient, streamingClient *http.Client, engi for _, option := range options { option(op) } - // ignore error is ok, it means that epoll is not supported, which is handled gracefully by the client - epoll, _ := epoller.NewPoller(1024, time.Millisecond*100) + op.epollConfiguration.ApplyDefaults() client := &subscriptionClient{ httpClient: httpClient, streamingClient: streamingClient, @@ -162,14 +189,19 @@ func NewGraphQLSubscriptionClient(httpClient, streamingClient *http.Client, engi }, }, onWsConnectionInitCallback: op.onWsConnectionInitCallback, - epoll: epoll, connections: make(map[int]ConnectionHandler), activeConnections: make(map[int]struct{}), triggers: make(map[uint64]int), useHttpClientWithSkipRoundTrip: op.useHttpClientWithSkipRoundTrip, - } - if epoll != nil { - go client.runEpoll(engineCtx) + epollConfig: op.epollConfiguration, + } + if !op.epollConfiguration.Disable { + // ignore error is ok, it means that epoll is not supported, which is handled gracefully by the client + epoll, _ := epoller.NewPoller(op.epollConfiguration.BufferSize, op.epollConfiguration.Interval) + if epoll != nil { + client.epoll = epoll + go client.runEpoll(engineCtx) + } } return client } @@ -250,7 +282,13 @@ func (c *subscriptionClient) asyncSubscribeWS(requestContext, engineContext cont } if c.epoll == nil { - return handler.StartBlocking() + go func() { + err := handler.StartBlocking() + if err != nil && !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { + c.log.Error("subscriptionClient.asyncSubscribeWS", abstractlogger.Error(err)) + } + }() + return nil } err = handler.Subscribe() @@ -539,8 +577,19 @@ func waitForAck(conn net.Conn) error { } func (c *subscriptionClient) runEpoll(ctx context.Context) { - done := ctx.Done() - tick := time.NewTicker(time.Millisecond * 50) + var ( + ticker <-chan time.Time + done = ctx.Done() + ) + if c.epollConfig.Wait > 0 { + tick := time.NewTicker(time.Millisecond * 50) + defer tick.Stop() + ticker = tick.C + } else { + tick := make(chan time.Time) + close(tick) + ticker = tick + } for { connections, err := c.epoll.Wait(50) if err != nil { @@ -575,7 +624,7 @@ func (c *subscriptionClient) runEpoll(ctx context.Context) { select { case <-done: return - case <-tick.C: + case <-ticker: continue } } diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index 95ab8c0f2..bf9acf221 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -6,6 +6,7 @@ import ( "net/http" "net/http/httptest" "runtime" + "sync" "testing" "time" @@ -1208,6 +1209,160 @@ func TestAsyncSubscribe(t *testing.T) { }, time.Second, time.Millisecond*10, "server did not close") serverCancel() }) + t.Run("happy path no epoll", func(t *testing.T) { + t.Parallel() + serverDone := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + assert.NoError(t, err) + ctx := context.Background() + msgType, data, err := conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"connection_init"}`, string(data)) + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) + assert.NoError(t, err) + + time.Sleep(time.Second * 1) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) + assert.NoError(t, err) + + time.Sleep(time.Second * 1) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) + assert.NoError(t, err) + + time.Sleep(time.Second * 1) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) + assert.NoError(t, err) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) + close(serverDone) + })) + defer server.Close() + ctx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + + client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, + WithReadTimeout(time.Second), + WithLogger(logger()), + WithEpollConfiguration(EpollConfiguration{ + Disable: true, + }), + ).(*subscriptionClient) + updater := &testSubscriptionUpdater{} + + err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ + URL: server.URL, + Body: GraphQLBody{ + Query: `subscription {messageAdded(roomName: "room"){text}}`, + }, + WsSubProtocol: ProtocolGraphQLTWS, + }, updater) + assert.NoError(t, err) + + updater.AwaitUpdates(t, time.Second*10, 3) + assert.Equal(t, 3, len(updater.updates)) + assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) + assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) + assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) + client.Unsubscribe(1) + clientCancel() + assert.Eventuallyf(t, func() bool { + <-serverDone + return true + }, time.Second, time.Millisecond*10, "server did not close") + serverCancel() + }) + t.Run("happy path no epoll two clients", func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + assert.NoError(t, err) + ctx := context.Background() + msgType, data, err := conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"connection_init"}`, string(data)) + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) + assert.NoError(t, err) + + time.Sleep(time.Second * 1) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) + assert.NoError(t, err) + + time.Sleep(time.Second * 1) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) + assert.NoError(t, err) + + time.Sleep(time.Second * 1) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) + assert.NoError(t, err) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) + })) + defer server.Close() + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + + client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, + WithReadTimeout(time.Second), + WithLogger(logger()), + WithEpollConfiguration(EpollConfiguration{ + Disable: true, + }), + ).(*subscriptionClient) + wg := &sync.WaitGroup{} + wg.Add(2) + for i := 0; i < 2; i++ { + go func(i int) { + ctx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + updater := &testSubscriptionUpdater{} + err := client.SubscribeAsync(resolve.NewContext(ctx), uint64(i), GraphQLSubscriptionOptions{ + URL: server.URL, + Body: GraphQLBody{ + Query: `subscription {messageAdded(roomName: "room"){text}}`, + }, + WsSubProtocol: ProtocolGraphQLTWS, + }, updater) + assert.NoError(t, err) + + updater.AwaitUpdates(t, time.Second*10, 3) + assert.Equal(t, 3, len(updater.updates)) + assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) + assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) + assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) + client.Unsubscribe(uint64(i)) + clientCancel() + wg.Done() + }(i) + } + wg.Wait() + }) t.Run("ping", func(t *testing.T) { t.Parallel() serverDone := make(chan struct{}) From 7b40eb8fc88127ce992b7fb6b71a48b963546f65 Mon Sep 17 00:00:00 2001 From: starptech Date: Sun, 20 Oct 2024 12:52:12 +0200 Subject: [PATCH 15/31] fix: revert to previous library, move epoll lib to engine --- .../graphql_subscription_client.go | 4 +- .../graphql_subscription_client_test.go | 2 +- .../graphql_datasource/graphql_tws_handler.go | 4 +- .../graphql_tws_handler_test.go | 2 +- .../graphql_datasource/graphql_ws_handler.go | 4 +- .../graphql_ws_handler_test.go | 2 +- v2/go.mod | 3 +- .../graphql_subscription_client.go | 111 ++++-------------- .../graphql_subscription_client_test.go | 2 +- .../graphql_datasource/graphql_tws_handler.go | 2 +- .../graphql_tws_handler_test.go | 2 +- .../graphql_datasource/graphql_ws_handler.go | 2 +- .../graphql_ws_handler_test.go | 2 +- v2/pkg/{internal => }/epoller/README.md | 0 v2/pkg/{internal => }/epoller/conn.go | 2 +- v2/pkg/{internal => }/epoller/epoll.go | 2 +- v2/pkg/{internal => }/epoller/epoll_bsd.go | 4 +- v2/pkg/{internal => }/epoller/epoll_linux.go | 0 v2/pkg/{internal => }/epoller/epoll_test.go | 0 .../epoller/epoll_unsupported.go | 0 v2/pkg/{internal => }/epoller/fd_test.go | 0 21 files changed, 43 insertions(+), 107 deletions(-) rename v2/pkg/{internal => }/epoller/README.md (100%) rename v2/pkg/{internal => }/epoller/conn.go (94%) rename v2/pkg/{internal => }/epoller/epoll.go (95%) rename v2/pkg/{internal => }/epoller/epoll_bsd.go (98%) rename v2/pkg/{internal => }/epoller/epoll_linux.go (100%) rename v2/pkg/{internal => }/epoller/epoll_test.go (100%) rename v2/pkg/{internal => }/epoller/epoll_unsupported.go (100%) rename v2/pkg/{internal => }/epoller/fd_test.go (100%) diff --git a/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index 2b8a955f2..f242fb42d 100644 --- a/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -10,8 +10,8 @@ import ( "github.com/buger/jsonparser" "github.com/cespare/xxhash/v2" + "github.com/coder/websocket" "github.com/jensneuse/abstractlogger" - "nhooyr.io/websocket" ) const ackWaitTimeout = 30 * time.Second @@ -217,7 +217,7 @@ func (c *SubscriptionClient) newWSConnectionHandler(reqCtx context.Context, opti return nil, err } // Disable the maximum message size limit. Don't use MaxInt64 since - // the nhooyr.io/websocket doesn't handle it correctly on 32 bit systems. + // the github.com/coder/websocket doesn't handle it correctly on 32 bit systems. conn.SetReadLimit(math.MaxInt32) if upgradeResponse.StatusCode != http.StatusSwitchingProtocols { return nil, fmt.Errorf("upgrade unsuccessful") diff --git a/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index d0cfdb09e..59de83ad5 100644 --- a/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -14,11 +14,11 @@ import ( "github.com/stretchr/testify/require" "github.com/buger/jsonparser" + "github.com/coder/websocket" ll "github.com/jensneuse/abstractlogger" "github.com/stretchr/testify/assert" "go.uber.org/atomic" "go.uber.org/zap" - "nhooyr.io/websocket" ) func logger() ll.Logger { diff --git a/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go b/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go index d4883a994..0645702a0 100644 --- a/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go +++ b/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go @@ -8,8 +8,8 @@ import ( "time" "github.com/buger/jsonparser" + "github.com/coder/websocket" log "github.com/jensneuse/abstractlogger" - "nhooyr.io/websocket" ) // gqlTWSConnectionHandler is responsible for handling a connection to an origin @@ -241,7 +241,7 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeNext(data []byte) { } // readBlocking is a dedicated loop running in a separate goroutine -// because the library "nhooyr.io/websocket" doesn't allow reading with a context with Timeout +// because the library "github.com/coder/websocket" doesn't allow reading with a context with Timeout // we'll block forever on reading until the context of the gqlTWSConnectionHandler stops func (h *gqlTWSConnectionHandler) readBlocking(ctx context.Context, dataCh chan []byte, errCh chan error) { for { diff --git a/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go b/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go index 997908556..b00e855d9 100644 --- a/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go +++ b/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go @@ -9,8 +9,8 @@ import ( "github.com/stretchr/testify/require" + "github.com/coder/websocket" "github.com/stretchr/testify/assert" - "nhooyr.io/websocket" ) func TestWebsocketSubscriptionClient_GQLTWS(t *testing.T) { diff --git a/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go b/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go index a84ff7bae..3ffb40687 100644 --- a/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go +++ b/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go @@ -8,8 +8,8 @@ import ( "time" "github.com/buger/jsonparser" + "github.com/coder/websocket" "github.com/jensneuse/abstractlogger" - "nhooyr.io/websocket" ) // gqlWSConnectionHandler is responsible for handling a connection to an origin @@ -97,7 +97,7 @@ func (h *gqlWSConnectionHandler) StartBlocking(sub Subscription) { } // readBlocking is a dedicated loop running in a separate goroutine -// because the library "nhooyr.io/websocket" doesn't allow reading with a context with Timeout +// because the library "github.com/coder/websocket" doesn't allow reading with a context with Timeout // we'll block forever on reading until the context of the gqlWSConnectionHandler stops func (h *gqlWSConnectionHandler) readBlocking(ctx context.Context, dataCh chan []byte, errCh chan error) { for { diff --git a/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go b/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go index 0f8737143..1fcb9c71e 100644 --- a/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go +++ b/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go @@ -10,8 +10,8 @@ import ( "github.com/stretchr/testify/require" + "github.com/coder/websocket" "github.com/stretchr/testify/assert" - "nhooyr.io/websocket" ) func TestWebSocketSubscriptionClientInitIncludeKA_GQLWS(t *testing.T) { diff --git a/v2/go.mod b/v2/go.mod index 7834335fd..60a390e61 100644 --- a/v2/go.mod +++ b/v2/go.mod @@ -7,6 +7,7 @@ require ( github.com/alitto/pond v1.8.3 github.com/buger/jsonparser v1.1.1 github.com/cespare/xxhash/v2 v2.2.0 + github.com/coder/websocket v1.8.12 github.com/davecgh/go-spew v1.1.1 github.com/goccy/go-json v0.10.2 github.com/golang/mock v1.6.0 @@ -32,12 +33,10 @@ require ( golang.org/x/sync v0.7.0 gonum.org/v1/gonum v0.14.0 gopkg.in/yaml.v2 v2.4.0 - nhooyr.io/websocket v1.8.11 ) require ( github.com/agnivade/levenshtein v1.1.1 // indirect - github.com/coder/websocket v1.8.12 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index dd81082bc..fec7277de 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -4,23 +4,23 @@ import ( "context" "errors" "fmt" + "github.com/gobwas/ws/wsutil" "math" "net" "net/http" + "net/http/httptrace" "net/textproto" "strings" "sync" - "syscall" "time" - "github.com/gobwas/ws/wsutil" - "github.com/gorilla/websocket" + "github.com/coder/websocket" "github.com/buger/jsonparser" "github.com/cespare/xxhash/v2" "github.com/jensneuse/abstractlogger" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" - "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/epoller" + "github.com/wundergraph/graphql-go-tools/v2/pkg/epoller" ) const ackWaitTimeout = 30 * time.Second @@ -302,7 +302,7 @@ func (c *subscriptionClient) asyncSubscribeWS(requestContext, engineContext cont } c.connectionsMu.Lock() - fd := socketFd(netConn) + fd := epoller.SocketFD(netConn) c.connections[fd] = handler c.triggers[id] = fd c.connectionsMu.Unlock() @@ -386,78 +386,35 @@ func (u *UpgradeRequestError) Error() string { } func (c *subscriptionClient) newWSConnectionHandler(requestContext, engineContext context.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) (ConnectionHandler, error) { - - var ( - upgradeRequestHeader http.Header - subgraphHttpURL string - upgradeRequestURL string - ) - subProtocols := []string{ProtocolGraphQLWS, ProtocolGraphQLTWS} if options.WsSubProtocol != "" && options.WsSubProtocol != "auto" { subProtocols = []string{options.WsSubProtocol} } - if strings.HasPrefix(options.URL, "https") { - upgradeRequestURL = "wss" + options.URL[5:] - subgraphHttpURL = options.URL - } else if strings.HasPrefix(options.URL, "http") { - upgradeRequestURL = "ws" + options.URL[4:] - subgraphHttpURL = options.URL - } else if strings.HasPrefix(options.URL, "wss") { - upgradeRequestURL = options.URL - subgraphHttpURL = "https" + options.URL[3:] - } else if strings.HasPrefix(options.URL, "ws") { - upgradeRequestURL = options.URL - subgraphHttpURL = "http" + options.URL[2:] - } - - if c.useHttpClientWithSkipRoundTrip { - // gorilla websocket does not support using the http.Client directly - // but we need to use our existing client, or the transport more specifically - // to be able to forward headers in the upgrade request - // - // as a workaround we create a "dummy" request which we run through the http.Client with the context - // we set the "SkipRoundTrip" header to true to signal the http.Client to not perform the request - // but only to modify the request Headers - req, err := http.NewRequestWithContext(requestContext, "GET", options.URL, nil) - if err != nil { - return nil, err - } - if strings.HasPrefix(options.URL, "ws") { - req.URL.Scheme = "http" - } else { - req.URL.Scheme = "https" - } - if options.Header != nil { - req.Header = options.Header - } - req.Header.Set("SkipRoundTrip", "true") - _, _ = c.httpClient.Do(req) - req.Header.Del("SkipRoundTrip") - upgradeRequestHeader = req.Header - subgraphHttpURL = req.URL.String() - } else { - upgradeRequestHeader = options.Header - } + var netConn net.Conn - dialer := websocket.Dialer{ - Proxy: http.ProxyFromEnvironment, - HandshakeTimeout: time.Second * 10, - Subprotocols: subProtocols, + clientTrace := &httptrace.ClientTrace{ + GotConn: func(info httptrace.GotConnInfo) { + netConn = info.Conn + }, } - - conn, upgradeResponse, err := dialer.DialContext(requestContext, upgradeRequestURL, upgradeRequestHeader) + clientTraceCtx := httptrace.WithClientTrace(requestContext, clientTrace) + conn, upgradeResponse, err := websocket.Dial(clientTraceCtx, options.URL, &websocket.DialOptions{ + HTTPClient: c.httpClient, + HTTPHeader: options.Header, + CompressionMode: websocket.CompressionDisabled, + Subprotocols: subProtocols, + }) if err != nil { - if upgradeResponse != nil && upgradeResponse.StatusCode != http.StatusSwitchingProtocols { - return nil, &UpgradeRequestError{ - URL: subgraphHttpURL, - StatusCode: upgradeResponse.StatusCode, - } - } return nil, err } + // Disable the maximum message size limit. Don't use MaxInt64 since + // the github.com/coder/websocket doesn't handle it correctly on 32-bit systems. conn.SetReadLimit(math.MaxInt32) + if upgradeResponse.StatusCode != http.StatusSwitchingProtocols { + return nil, fmt.Errorf("upgrade unsuccessful") + } + connectionInitMessage, err := c.getConnectionInitMessage(requestContext, options.URL, options.Header) if err != nil { return nil, err @@ -477,8 +434,6 @@ func (c *subscriptionClient) newWSConnectionHandler(requestContext, engineContex } } - netConn := conn.NetConn() - // init + ack err = wsutil.WriteClientText(netConn, connectionInitMessage) if err != nil { @@ -598,7 +553,7 @@ func (c *subscriptionClient) runEpoll(ctx context.Context) { } c.connectionsMu.Lock() for _, conn := range connections { - id := socketFd(conn) + id := epoller.SocketFD(conn) handler, ok := c.connections[id] if !ok { continue @@ -653,21 +608,3 @@ func (c *subscriptionClient) handleConnection(id int, handler ConnectionHandler, return } } - -func socketFd(conn net.Conn) int { - if con, ok := conn.(syscall.Conn); ok { - raw, err := con.SyscallConn() - if err != nil { - return 0 - } - sfd := 0 - _ = raw.Control(func(fd uintptr) { - sfd = int(fd) - }) - return sfd - } - if con, ok := conn.(epoller.ConnImpl); ok { - return con.GetFD() - } - return 0 -} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index bf9acf221..8764ff42b 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -14,11 +14,11 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "github.com/coder/websocket" ll "github.com/jensneuse/abstractlogger" "github.com/stretchr/testify/assert" "go.uber.org/atomic" "go.uber.org/zap" - "nhooyr.io/websocket" ) func logger() ll.Logger { diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go index 35e403229..c508e3717 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go @@ -313,7 +313,7 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeNext(data []byte) { } // readBlocking is a dedicated loop running in a separate goroutine -// because the library "nhooyr.io/websocket" doesn't allow reading with a context with Timeout +// because the library "github.com/coder/websocket" doesn't allow reading with a context with Timeout // we'll block forever on reading until the context of the gqlTWSConnectionHandler stops func (h *gqlTWSConnectionHandler) readBlocking(ctx context.Context, dataCh chan []byte, errCh chan error) { netOpErr := &net.OpError{} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go index 81bc59306..e46175935 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go @@ -12,8 +12,8 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "github.com/wundergraph/graphql-go-tools/v2/pkg/testing/flags" + "github.com/coder/websocket" "github.com/stretchr/testify/assert" - "nhooyr.io/websocket" ) func TestWebsocketSubscriptionClient_GQLTWS(t *testing.T) { diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go index 7fd709992..1b186d669 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go @@ -190,7 +190,7 @@ func (h *gqlWSConnectionHandler) StartBlocking() error { } // readBlocking is a dedicated loop running in a separate goroutine -// because the library "nhooyr.io/websocket" doesn't allow reading with a context with Timeout +// because the library "github.com/coder/websocket" doesn't allow reading with a context with Timeout // we'll block forever on reading until the context of the gqlWSConnectionHandler stops func (h *gqlWSConnectionHandler) readBlocking(ctx context.Context, dataCh chan []byte, errCh chan error) { netOpErr := &net.OpError{} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go index 23242c19e..16d8c6b12 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go @@ -15,8 +15,8 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "github.com/wundergraph/graphql-go-tools/v2/pkg/testing/flags" + "github.com/coder/websocket" "github.com/stretchr/testify/assert" - "nhooyr.io/websocket" ) func TestWebSocketSubscriptionClientInitIncludeKA_GQLWS(t *testing.T) { diff --git a/v2/pkg/internal/epoller/README.md b/v2/pkg/epoller/README.md similarity index 100% rename from v2/pkg/internal/epoller/README.md rename to v2/pkg/epoller/README.md diff --git a/v2/pkg/internal/epoller/conn.go b/v2/pkg/epoller/conn.go similarity index 94% rename from v2/pkg/internal/epoller/conn.go rename to v2/pkg/epoller/conn.go index 74649e64a..e2d2e6d5a 100644 --- a/v2/pkg/internal/epoller/conn.go +++ b/v2/pkg/epoller/conn.go @@ -12,7 +12,7 @@ func newConnImpl(in net.Conn) ConnImpl { return ConnImpl{ Conn: in, - fd: socketFD(in), + fd: SocketFD(in), } } diff --git a/v2/pkg/internal/epoller/epoll.go b/v2/pkg/epoller/epoll.go similarity index 95% rename from v2/pkg/internal/epoller/epoll.go rename to v2/pkg/epoller/epoll.go index 1eea7eeba..7ff678fe8 100644 --- a/v2/pkg/internal/epoller/epoll.go +++ b/v2/pkg/epoller/epoll.go @@ -17,7 +17,7 @@ type Poller interface { Close(closeConns bool) error } -func socketFD(conn net.Conn) int { +func SocketFD(conn net.Conn) int { if con, ok := conn.(syscall.Conn); ok { raw, err := con.SyscallConn() if err != nil { diff --git a/v2/pkg/internal/epoller/epoll_bsd.go b/v2/pkg/epoller/epoll_bsd.go similarity index 98% rename from v2/pkg/internal/epoller/epoll_bsd.go rename to v2/pkg/epoller/epoll_bsd.go index a9d23e982..67369fa4f 100755 --- a/v2/pkg/internal/epoller/epoll_bsd.go +++ b/v2/pkg/epoller/epoll_bsd.go @@ -76,7 +76,7 @@ func (e *Epoll) Close(closeConns bool) error { // Add adds a network connection to the poller. func (e *Epoll) Add(conn net.Conn) error { conn = newConnImpl(conn) - fd := socketFD(conn) + fd := SocketFD(conn) if e := syscall.SetNonblock(int(fd), true); e != nil { return errors.New("udev: unix.SetNonblock failed") } @@ -98,7 +98,7 @@ func (e *Epoll) Add(conn net.Conn) error { // Remove removes a connection from the poller. // If close is true, the connection will be closed. func (e *Epoll) Remove(conn net.Conn) error { - fd := socketFD(conn) + fd := SocketFD(conn) e.mu.Lock() defer e.mu.Unlock() diff --git a/v2/pkg/internal/epoller/epoll_linux.go b/v2/pkg/epoller/epoll_linux.go similarity index 100% rename from v2/pkg/internal/epoller/epoll_linux.go rename to v2/pkg/epoller/epoll_linux.go diff --git a/v2/pkg/internal/epoller/epoll_test.go b/v2/pkg/epoller/epoll_test.go similarity index 100% rename from v2/pkg/internal/epoller/epoll_test.go rename to v2/pkg/epoller/epoll_test.go diff --git a/v2/pkg/internal/epoller/epoll_unsupported.go b/v2/pkg/epoller/epoll_unsupported.go similarity index 100% rename from v2/pkg/internal/epoller/epoll_unsupported.go rename to v2/pkg/epoller/epoll_unsupported.go diff --git a/v2/pkg/internal/epoller/fd_test.go b/v2/pkg/epoller/fd_test.go similarity index 100% rename from v2/pkg/internal/epoller/fd_test.go rename to v2/pkg/epoller/fd_test.go From 8b30a1d0a304579b7c8990a1d770148cde4bb567 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 21 Oct 2024 13:01:43 +0200 Subject: [PATCH 16/31] chore: move heartbeat into resolver, address PR feedback --- .../graphql_datasource_test.go | 51 -------- .../graphql_datasource/graphql_sse_handler.go | 8 -- .../graphql_sse_handler_test.go | 67 ---------- .../graphql_subscription_client.go | 59 +++------ .../graphql_datasource/graphql_tws_handler.go | 6 - .../graphql_datasource/graphql_ws_handler.go | 7 - v2/pkg/engine/resolve/resolve.go | 121 ++++++------------ 7 files changed, 60 insertions(+), 259 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go index 227e7cea7..49a1fb272 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go @@ -9037,31 +9037,6 @@ func TestSubscriptionSource_Start(t *testing.T) { assert.Len(t, updater.updates, 1) assert.Equal(t, `{"data":{"messageAdded":{"text":"hello world!","createdBy":"myuser"}}}`, updater.updates[0]) }) - - t.Run("should successfully send heartbeat", func(t *testing.T) { - ctx := resolve.NewContext(context.Background()) - ctx.ExecutionOptions.SendHeartbeat = true - defer ctx.Context().Done() - - updater := &testSubscriptionUpdater{} - - source := newSubscriptionSource(ctx.Context()) - chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`) - - err := source.Start(ctx, chatSubscriptionOptions, updater) - require.NoError(t, err) - - username := "myuser" - message := "hello world!" - go sendChatMessage(t, username, message) - updater.AwaitUpdates(t, time.Second, 1) - assert.Len(t, updater.updates, 1) - assert.Equal(t, `{"data":{"messageAdded":{"text":"hello world!","createdBy":"myuser"}}}`, updater.updates[0]) - - updater.AwaitUpdates(t, 7*time.Second, 2) - assert.Len(t, updater.updates, 2) - assert.Equal(t, `{}`, updater.updates[1]) - }) } func TestSubscription_GTWS_SubProtocol(t *testing.T) { @@ -9170,32 +9145,6 @@ func TestSubscription_GTWS_SubProtocol(t *testing.T) { assert.Len(t, updater.updates, 1) assert.Equal(t, `{"data":{"messageAdded":{"text":"hello world!","createdBy":"myuser"}}}`, updater.updates[0]) }) - - t.Run("should successfully send heartbeat", func(t *testing.T) { - ctx := resolve.NewContext(context.Background()) - ctx.ExecutionOptions.SendHeartbeat = true - defer ctx.Context().Done() - - updater := &testSubscriptionUpdater{} - - source := newSubscriptionSource(ctx.Context()) - chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`) - - err := source.Start(ctx, chatSubscriptionOptions, updater) - require.NoError(t, err) - - username := "myuser" - message := "hello world!" - go sendChatMessage(t, username, message) - - updater.AwaitUpdates(t, time.Second, 1) - assert.Len(t, updater.updates, 1) - assert.Equal(t, `{"data":{"messageAdded":{"text":"hello world!","createdBy":"myuser"}}}`, updater.updates[0]) - - updater.AwaitUpdates(t, 7*time.Second, 2) - assert.Len(t, updater.updates, 2) - assert.Equal(t, `{}`, updater.updates[1]) - }) } type runTestOnTestDefinitionOptions func(planConfig *plan.Configuration) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler.go index dcf687cca..fce400267 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler.go @@ -8,7 +8,6 @@ import ( "io" "math" "net/http" - "time" "github.com/buger/jsonparser" log "github.com/jensneuse/abstractlogger" @@ -54,18 +53,11 @@ func (h *gqlSSEConnectionHandler) StartBlocking() { go h.subscribe(dataCh, errCh) - ticker := time.NewTicker(resolve.HearbeatInterval) - defer ticker.Stop() - for { select { - case <-ticker.C: - h.updater.Heartbeat() case data := <-dataCh: - ticker.Reset(resolve.HearbeatInterval) h.updater.Update(data) case data := <-errCh: - ticker.Reset(resolve.HearbeatInterval) h.updater.Update(data) return case <-h.requestContext.Done(): diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler_test.go index ea3c9ab8f..af110b440 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler_test.go @@ -81,73 +81,6 @@ func TestGraphQLSubscriptionClientSubscribe_SSE(t *testing.T) { serverCancel() } -func TestGraphQLSubscriptionClientSubscribe_Heartbeat(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - urlQuery := r.URL.Query() - assert.Equal(t, "subscription {messageAdded(roomName: \"room\"){text}}", urlQuery.Get("query")) - - // Make sure that the writer supports flushing. - flusher, ok := w.(http.Flusher) - require.True(t, ok) - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") - - _, _ = fmt.Fprintf(w, "data: %s\n\n", `{"data":{"messageAdded":{"text":"first"}}}`) - flusher.Flush() - - _, _ = fmt.Fprintf(w, "data: %s\n\n", `{"data":{"messageAdded":{"text":"second"}}}`) - flusher.Flush() - - close(serverDone) - })) - defer server.Close() - - serverCtx, serverCancel := context.WithCancel(context.Background()) - - ctx, clientCancel := context.WithCancel(context.Background()) - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ) - - updater := &testSubscriptionUpdater{} - - go func() { - rCtx := resolve.NewContext(ctx) - rCtx.ExecutionOptions.SendHeartbeat = true - err := client.Subscribe(rCtx, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - UseSSE: true, - }, updater) - assert.NoError(t, err) - }() - - updater.AwaitUpdates(t, 15*time.Second, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{}`, updater.updates[2]) - - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() -} - func TestGraphQLSubscriptionClientSubscribe_SSE_RequestAbort(t *testing.T) { if flags.IsWindows { t.Skip("skipping test on windows") diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index fec7277de..9af53d162 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "github.com/gobwas/ws/wsutil" "math" "net" "net/http" @@ -14,6 +13,8 @@ import ( "sync" "time" + "github.com/gobwas/ws/wsutil" + "github.com/coder/websocket" "github.com/buger/jsonparser" @@ -112,17 +113,10 @@ func WithReadTimeout(timeout time.Duration) Options { } } -func UseHttpClientWithSkipRoundTrip() Options { - return func(options *opts) { - options.useHttpClientWithSkipRoundTrip = true - } -} - type EpollConfiguration struct { Disable bool BufferSize int Interval time.Duration - Wait time.Duration } func (e *EpollConfiguration) ApplyDefaults() { @@ -132,9 +126,6 @@ func (e *EpollConfiguration) ApplyDefaults() { if e.Interval == 0 { e.Interval = time.Millisecond * 100 } - if e.Wait == 0 { - e.Wait = time.Millisecond * 100 - } } func WithEpollConfiguration(config EpollConfiguration) Options { @@ -144,11 +135,10 @@ func WithEpollConfiguration(config EpollConfiguration) Options { } type opts struct { - readTimeout time.Duration - log abstractlogger.Logger - onWsConnectionInitCallback *OnWsConnectionInitCallback - useHttpClientWithSkipRoundTrip bool - epollConfiguration EpollConfiguration + readTimeout time.Duration + log abstractlogger.Logger + onWsConnectionInitCallback *OnWsConnectionInitCallback + epollConfiguration EpollConfiguration } // GraphQLSubscriptionClientFactory abstracts the way of creating a new GraphQLSubscriptionClient. @@ -188,12 +178,11 @@ func NewGraphQLSubscriptionClient(httpClient, streamingClient *http.Client, engi return xxhash.New() }, }, - onWsConnectionInitCallback: op.onWsConnectionInitCallback, - connections: make(map[int]ConnectionHandler), - activeConnections: make(map[int]struct{}), - triggers: make(map[uint64]int), - useHttpClientWithSkipRoundTrip: op.useHttpClientWithSkipRoundTrip, - epollConfig: op.epollConfiguration, + onWsConnectionInitCallback: op.onWsConnectionInitCallback, + connections: make(map[int]ConnectionHandler), + activeConnections: make(map[int]struct{}), + triggers: make(map[uint64]int), + epollConfig: op.epollConfiguration, } if !op.epollConfiguration.Disable { // ignore error is ok, it means that epoll is not supported, which is handled gracefully by the client @@ -406,6 +395,12 @@ func (c *subscriptionClient) newWSConnectionHandler(requestContext, engineContex Subprotocols: subProtocols, }) if err != nil { + if upgradeResponse != nil && upgradeResponse.StatusCode != 101 { + return nil, &UpgradeRequestError{ + URL: options.URL, + StatusCode: upgradeResponse.StatusCode, + } + } return nil, err } // Disable the maximum message size limit. Don't use MaxInt64 since @@ -533,18 +528,8 @@ func waitForAck(conn net.Conn) error { func (c *subscriptionClient) runEpoll(ctx context.Context) { var ( - ticker <-chan time.Time - done = ctx.Done() + done = ctx.Done() ) - if c.epollConfig.Wait > 0 { - tick := time.NewTicker(time.Millisecond * 50) - defer tick.Stop() - ticker = tick.C - } else { - tick := make(chan time.Time) - close(tick) - ticker = tick - } for { connections, err := c.epoll.Wait(50) if err != nil { @@ -571,16 +556,10 @@ func (c *subscriptionClient) runEpoll(ctx context.Context) { } c.connectionsMu.Unlock() - if len(connections) == 50 { - // we have more connections to process, - continue - } - select { case <-done: return - case <-ticker: - continue + default: } } } diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go index c508e3717..4da307612 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go @@ -146,9 +146,6 @@ func (h *gqlTWSConnectionHandler) StartBlocking() error { go h.readBlocking(readCtx, dataCh, errCh) - ticker := time.NewTicker(resolve.HearbeatInterval) - defer ticker.Stop() - for { select { case <-h.engineContext.Done(): @@ -159,10 +156,7 @@ func (h *gqlTWSConnectionHandler) StartBlocking() error { h.log.Error("gqlWSConnectionHandler.StartBlocking", log.Error(err)) h.broadcastErrorMessage(err) return err - case <-ticker.C: - h.updater.Heartbeat() case data := <-dataCh: - ticker.Reset(resolve.HearbeatInterval) messageType, err := jsonparser.GetString(data, "type") if err != nil { continue diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go index 1b186d669..dad8f6b0e 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go @@ -139,9 +139,6 @@ func (h *gqlWSConnectionHandler) StartBlocking() error { go h.readBlocking(readCtx, dataCh, errCh) - ticker := time.NewTicker(resolve.HearbeatInterval) - defer ticker.Stop() - for { select { case <-h.engineContext.Done(): @@ -154,11 +151,7 @@ func (h *gqlWSConnectionHandler) StartBlocking() error { } h.broadcastErrorMessage(err) return err - case <-ticker.C: - h.updater.Heartbeat() case data := <-dataCh: - ticker.Reset(resolve.HearbeatInterval) - messageType, err := jsonparser.GetString(data, "type") if err != nil { continue diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index f7fe7684e..4594a0cf0 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -51,10 +51,11 @@ type Resolver struct { bufPool sync.Pool maxConcurrency chan struct{} - triggers map[uint64]*trigger - events chan subscriptionEvent - triggerUpdateSem *semaphore.Weighted - triggerUpdateBuf *bytes.Buffer + triggers map[uint64]*trigger + heartbeatSubscriptions map[*Context]*sub + events chan subscriptionEvent + triggerUpdateSem *semaphore.Weighted + triggerUpdateBuf *bytes.Buffer connectionIDs atomic.Int64 @@ -170,6 +171,7 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { propagateSubgraphStatusCodes: options.PropagateSubgraphStatusCodes, events: make(chan subscriptionEvent), triggers: make(map[uint64]*trigger), + heartbeatSubscriptions: make(map[*Context]*sub), reporter: options.Reporter, asyncErrorWriter: options.AsyncErrorWriter, triggerUpdateBuf: bytes.NewBuffer(make([]byte, 0, 1024)), @@ -292,33 +294,16 @@ type trigger struct { initialized bool } -func (t *trigger) hasPendingUpdates() bool { - for _, s := range t.subscriptions { - s.mux.Lock() - hasUpdates := s.pendingUpdates != 0 - s.mux.Unlock() - if hasUpdates { - return true - } - } - return false -} - type sub struct { - mux sync.Mutex - resolve *GraphQLSubscription - writer SubscriptionResponseWriter - id SubscriptionIdentifier - pendingUpdates int - completed chan struct{} - sendHeartbeat bool + mux sync.Mutex + resolve *GraphQLSubscription + writer SubscriptionResponseWriter + id SubscriptionIdentifier + completed chan struct{} + lastWrite time.Time } func (r *Resolver) executeSubscriptionUpdate(ctx *Context, sub *sub, sharedInput []byte) { - sub.mux.Lock() - sub.pendingUpdates++ - sub.mux.Unlock() - if r.options.Debug { fmt.Printf("resolver:trigger:subscription:update:%d\n", sub.id.SubscriptionID) } @@ -331,7 +316,6 @@ func (r *Resolver) executeSubscriptionUpdate(ctx *Context, sub *sub, sharedInput if err := t.resolvable.InitSubscription(ctx, input, sub.resolve.Trigger.PostProcessing); err != nil { sub.mux.Lock() r.asyncErrorWriter.WriteError(ctx, err, sub.resolve.Response, sub.writer) - sub.pendingUpdates-- sub.mux.Unlock() if r.options.Debug { fmt.Printf("resolver:trigger:subscription:init:failed:%d\n", sub.id.SubscriptionID) @@ -345,7 +329,6 @@ func (r *Resolver) executeSubscriptionUpdate(ctx *Context, sub *sub, sharedInput if err := t.loader.LoadGraphQLResponseData(ctx, sub.resolve.Response, t.resolvable); err != nil { sub.mux.Lock() r.asyncErrorWriter.WriteError(ctx, err, sub.resolve.Response, sub.writer) - sub.pendingUpdates-- sub.mux.Unlock() if r.options.Debug { fmt.Printf("resolver:trigger:subscription:load:failed:%d\n", sub.id.SubscriptionID) @@ -357,12 +340,10 @@ func (r *Resolver) executeSubscriptionUpdate(ctx *Context, sub *sub, sharedInput } sub.mux.Lock() - sub.pendingUpdates-- - sub.mux.Unlock() - - sub.mux.Lock() - sub.pendingUpdates-- - defer sub.mux.Unlock() + defer func() { + sub.lastWrite = time.Now() + sub.mux.Unlock() + }() if err := t.resolvable.Resolve(ctx.ctx, sub.resolve.Response.Data, sub.resolve.Response.Fetches, sub.writer); err != nil { r.asyncErrorWriter.WriteError(ctx, err, sub.resolve.Response, sub.writer) @@ -395,6 +376,8 @@ func (r *Resolver) executeSubscriptionUpdate(ctx *Context, sub *sub, sharedInput func (r *Resolver) handleEvents() { done := r.ctx.Done() + heatbeat := time.NewTicker(HearbeatInterval) + defer heatbeat.Stop() for { select { case <-done: @@ -402,6 +385,8 @@ func (r *Resolver) handleEvents() { return case event := <-r.events: r.handleEvent(event) + case <-heatbeat.C: + r.handleHeartbeat(multipartHeartbeat) } } } @@ -422,30 +407,29 @@ func (r *Resolver) handleEvent(event subscriptionEvent) { r.handleTriggerInitialized(event.triggerID) case subscriptionEventKindTriggerShutdown: r.handleTriggerShutdown(event) - case subscriptionEventKindHeartbeat: - r.handleHeartbeat(event.triggerID, event.data) case subscriptionEventKindUnknown: panic("unknown event") } } -func (r *Resolver) handleHeartbeat(id uint64, data []byte) { - trig, ok := r.triggers[id] - if !ok { - return - } +func (r *Resolver) handleHeartbeat(data []byte) { if r.options.Debug { - fmt.Printf("resolver:heartbeat:%d\n", id) + fmt.Printf("resolver:heartbeat:%d\n", len(r.heartbeatSubscriptions)) } - for c, s := range trig.subscriptions { - c, s := c, s - // Only send heartbeats to subscriptions who have enabled it - if !s.sendHeartbeat { + now := time.Now() + for c, s := range r.heartbeatSubscriptions { + // check if the last write to the subscription was more than heartbeat interval ago + s.mux.Lock() + skipHeartbeat := now.Sub(s.lastWrite) < HearbeatInterval + s.mux.Unlock() + if skipHeartbeat { continue } + if err := r.triggerUpdateSem.Acquire(r.ctx, 1); err != nil { return } + go func() { defer r.triggerUpdateSem.Release(1) @@ -529,11 +513,14 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) fmt.Printf("resolver:trigger:subscription:add:%d:%d\n", triggerID, add.id.SubscriptionID) } s := &sub{ - resolve: add.resolve, - writer: add.writer, - id: add.id, - completed: add.completed, - sendHeartbeat: add.ctx.ExecutionOptions.SendHeartbeat, + resolve: add.resolve, + writer: add.writer, + id: add.id, + completed: add.completed, + lastWrite: time.Now(), + } + if add.ctx.ExecutionOptions.SendHeartbeat { + r.heartbeatSubscriptions[add.ctx] = s } trig, ok := r.triggers[triggerID] if ok { @@ -644,7 +631,7 @@ func (r *Resolver) handleRemoveSubscription(id SubscriptionIdentifier) { if ctx.Context().Err() == nil { s.writer.Complete() } - + delete(r.heartbeatSubscriptions, ctx) delete(trig.subscriptions, ctx) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:removed:%d:%d\n", trig.id, id.SubscriptionID) @@ -739,6 +726,7 @@ func (r *Resolver) shutdownTrigger(id uint64) { if s.completed != nil { close(s.completed) } + delete(r.heartbeatSubscriptions, c) delete(trig.subscriptions, c) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:done:%d:%d\n", trig.id, s.id.SubscriptionID) @@ -941,28 +929,6 @@ type subscriptionUpdater struct { ctx context.Context } -func (s *subscriptionUpdater) Heartbeat() { - if s.debug { - fmt.Printf("resolver:subscription_updater:heartbeat:%d\n", s.triggerID) - } - if s.done { - return - } - - select { - case <-s.ctx.Done(): - return - case s.ch <- subscriptionEvent{ - triggerID: s.triggerID, - kind: subscriptionEventKindHeartbeat, - data: multipartHeartbeat, - // Currently, the only heartbeat we support is for multipart subscriptions. If we need to support future types - // of subscriptions, we can evaluate then how we can save on the subscription level what kind of heartbeat it - // requires - }: - } -} - func (s *subscriptionUpdater) Update(data []byte) { if s.debug { fmt.Printf("resolver:subscription_updater:update:%d\n", s.triggerID) @@ -1027,16 +993,11 @@ const ( subscriptionEventKindRemoveClient subscriptionEventKindTriggerInitialized subscriptionEventKindTriggerShutdown - subscriptionEventKindHeartbeat ) type SubscriptionUpdater interface { // Update sends an update to the client. It is not guaranteed that the update is sent immediately. Update(data []byte) - // Heartbeat sends a heartbeat to the client. It is not guaranteed that the update is sent immediately. When calling, - // clients should reset their heartbeat timer after an Update call to make sure that we don't send needless heartbeats - // downstream - Heartbeat() // Done also takes care of cleaning up the trigger and all subscriptions. No more updates should be sent after calling Done. Done() } From 8bb9a0520008229983c464f6c533127ca5e77976 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 21 Oct 2024 13:07:30 +0200 Subject: [PATCH 17/31] chore: cleanup --- v2/pkg/epoller/epoll_linux.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/v2/pkg/epoller/epoll_linux.go b/v2/pkg/epoller/epoll_linux.go index ee32c6c57..a042a6ae6 100755 --- a/v2/pkg/epoller/epoll_linux.go +++ b/v2/pkg/epoller/epoll_linux.go @@ -68,7 +68,7 @@ func (e *Epoll) Close(closeConns bool) error { // Add adds a connection to the poller. func (e *Epoll) Add(conn net.Conn) error { conn = newConnImpl(conn) - fd := socketFD(conn) + fd := SocketFD(conn) if e := syscall.SetNonblock(int(fd), true); e != nil { return errors.New("udev: unix.SetNonblock failed") } @@ -86,7 +86,7 @@ func (e *Epoll) Add(conn net.Conn) error { // Remove removes a connection from the poller. func (e *Epoll) Remove(conn net.Conn) error { - fd := socketFD(conn) + fd := SocketFD(conn) err := unix.EpollCtl(e.fd, syscall.EPOLL_CTL_DEL, fd, nil) if err != nil { return err From 8d19b96e80d0296129e7cc2bd48e323db28128ec Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 21 Oct 2024 13:38:37 +0200 Subject: [PATCH 18/31] chore: fix lint --- v2/pkg/engine/resolve/resolve.go | 1 + 1 file changed, 1 insertion(+) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 4594a0cf0..4ebdc2a94 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -419,6 +419,7 @@ func (r *Resolver) handleHeartbeat(data []byte) { now := time.Now() for c, s := range r.heartbeatSubscriptions { // check if the last write to the subscription was more than heartbeat interval ago + c, s := c, s s.mux.Lock() skipHeartbeat := now.Sub(s.lastWrite) < HearbeatInterval s.mux.Unlock() From e968003b5597acc42c5583817178937382b7d487 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 21 Oct 2024 14:04:32 +0200 Subject: [PATCH 19/31] chore: fix tests --- .../graphql_datasource/graphql_subscription_client.go | 3 +++ .../graphql_datasource/graphql_tws_handler_test.go | 6 ++---- .../graphql_datasource/graphql_ws_handler_test.go | 6 ++---- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index 9af53d162..714f300f6 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -253,6 +253,9 @@ func (c *subscriptionClient) subscribeWS(requestContext, engineContext context.C go func() { err := handler.StartBlocking() if err != nil { + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + return + } c.log.Error("subscriptionClient.subscribeWS", abstractlogger.Error(err)) } }() diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go index e46175935..7640b485c 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go @@ -68,7 +68,6 @@ func TestWebsocketSubscriptionClient_GQLTWS(t *testing.T) { updater := &testSubscriptionUpdater{} go func() { rCtx := resolve.NewContext(ctx) - rCtx.ExecutionOptions.SendHeartbeat = true err := client.Subscribe(rCtx, GraphQLSubscriptionOptions{ URL: server.URL, Body: GraphQLBody{ @@ -77,12 +76,11 @@ func TestWebsocketSubscriptionClient_GQLTWS(t *testing.T) { }, updater) assert.NoError(t, err) }() - updater.AwaitUpdates(t, 10*time.Second, 4) - assert.Equal(t, 4, len(updater.updates)) + updater.AwaitUpdates(t, time.Second*5, 3) + assert.Equal(t, 3, len(updater.updates)) assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - assert.Equal(t, `{}`, updater.updates[3]) clientCancel() assert.Eventuallyf(t, func() bool { diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go index 16d8c6b12..ec81c0295 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go @@ -143,7 +143,6 @@ func TestWebsocketSubscriptionClient_GQLWS(t *testing.T) { updater := &testSubscriptionUpdater{} go func() { rCtx := resolve.NewContext(ctx) - rCtx.ExecutionOptions.SendHeartbeat = true err := client.Subscribe(rCtx, GraphQLSubscriptionOptions{ URL: server.URL, Body: GraphQLBody{ @@ -152,12 +151,11 @@ func TestWebsocketSubscriptionClient_GQLWS(t *testing.T) { }, updater) assert.NoError(t, err) }() - updater.AwaitUpdates(t, 10*time.Second, 4) - assert.Equal(t, 4, len(updater.updates)) + updater.AwaitUpdates(t, time.Second*5, 3) + assert.Equal(t, 3, len(updater.updates)) assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - assert.Equal(t, `{}`, updater.updates[3]) clientCancel() assert.Eventuallyf(t, func() bool { <-serverDone From 3c47efcf430429b66678f933d3bea6c122266b75 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 21 Oct 2024 14:33:35 +0200 Subject: [PATCH 20/31] chore: fix tests --- .../graphql_datasource/graphql_subscription_client.go | 4 ++-- .../graphql_datasource/graphql_subscription_client_test.go | 2 +- .../datasource/graphql_datasource/graphql_tws_handler.go | 4 ++-- .../datasource/graphql_datasource/graphql_tws_handler_test.go | 2 +- .../datasource/graphql_datasource/graphql_ws_handler.go | 4 ++-- .../datasource/graphql_datasource/graphql_ws_handler_test.go | 2 +- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index f242fb42d..2b8a955f2 100644 --- a/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -10,8 +10,8 @@ import ( "github.com/buger/jsonparser" "github.com/cespare/xxhash/v2" - "github.com/coder/websocket" "github.com/jensneuse/abstractlogger" + "nhooyr.io/websocket" ) const ackWaitTimeout = 30 * time.Second @@ -217,7 +217,7 @@ func (c *SubscriptionClient) newWSConnectionHandler(reqCtx context.Context, opti return nil, err } // Disable the maximum message size limit. Don't use MaxInt64 since - // the github.com/coder/websocket doesn't handle it correctly on 32 bit systems. + // the nhooyr.io/websocket doesn't handle it correctly on 32 bit systems. conn.SetReadLimit(math.MaxInt32) if upgradeResponse.StatusCode != http.StatusSwitchingProtocols { return nil, fmt.Errorf("upgrade unsuccessful") diff --git a/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index 59de83ad5..d0cfdb09e 100644 --- a/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -14,11 +14,11 @@ import ( "github.com/stretchr/testify/require" "github.com/buger/jsonparser" - "github.com/coder/websocket" ll "github.com/jensneuse/abstractlogger" "github.com/stretchr/testify/assert" "go.uber.org/atomic" "go.uber.org/zap" + "nhooyr.io/websocket" ) func logger() ll.Logger { diff --git a/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go b/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go index 0645702a0..d4883a994 100644 --- a/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go +++ b/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go @@ -8,8 +8,8 @@ import ( "time" "github.com/buger/jsonparser" - "github.com/coder/websocket" log "github.com/jensneuse/abstractlogger" + "nhooyr.io/websocket" ) // gqlTWSConnectionHandler is responsible for handling a connection to an origin @@ -241,7 +241,7 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeNext(data []byte) { } // readBlocking is a dedicated loop running in a separate goroutine -// because the library "github.com/coder/websocket" doesn't allow reading with a context with Timeout +// because the library "nhooyr.io/websocket" doesn't allow reading with a context with Timeout // we'll block forever on reading until the context of the gqlTWSConnectionHandler stops func (h *gqlTWSConnectionHandler) readBlocking(ctx context.Context, dataCh chan []byte, errCh chan error) { for { diff --git a/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go b/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go index b00e855d9..997908556 100644 --- a/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go +++ b/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go @@ -9,8 +9,8 @@ import ( "github.com/stretchr/testify/require" - "github.com/coder/websocket" "github.com/stretchr/testify/assert" + "nhooyr.io/websocket" ) func TestWebsocketSubscriptionClient_GQLTWS(t *testing.T) { diff --git a/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go b/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go index 3ffb40687..a84ff7bae 100644 --- a/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go +++ b/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go @@ -8,8 +8,8 @@ import ( "time" "github.com/buger/jsonparser" - "github.com/coder/websocket" "github.com/jensneuse/abstractlogger" + "nhooyr.io/websocket" ) // gqlWSConnectionHandler is responsible for handling a connection to an origin @@ -97,7 +97,7 @@ func (h *gqlWSConnectionHandler) StartBlocking(sub Subscription) { } // readBlocking is a dedicated loop running in a separate goroutine -// because the library "github.com/coder/websocket" doesn't allow reading with a context with Timeout +// because the library "nhooyr.io/websocket" doesn't allow reading with a context with Timeout // we'll block forever on reading until the context of the gqlWSConnectionHandler stops func (h *gqlWSConnectionHandler) readBlocking(ctx context.Context, dataCh chan []byte, errCh chan error) { for { diff --git a/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go b/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go index 1fcb9c71e..0f8737143 100644 --- a/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go +++ b/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go @@ -10,8 +10,8 @@ import ( "github.com/stretchr/testify/require" - "github.com/coder/websocket" "github.com/stretchr/testify/assert" + "nhooyr.io/websocket" ) func TestWebSocketSubscriptionClientInitIncludeKA_GQLWS(t *testing.T) { From 90bf1ec06e6f271a60889d13859943f40b16c6a7 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 21 Oct 2024 15:54:17 +0200 Subject: [PATCH 21/31] chore: fix tests --- .../datasource/graphql_datasource/graphql_tws_handler.go | 8 -------- .../datasource/graphql_datasource/graphql_ws_handler.go | 8 -------- 2 files changed, 16 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go index 4da307612..17e86d580 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go @@ -312,14 +312,6 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeNext(data []byte) { func (h *gqlTWSConnectionHandler) readBlocking(ctx context.Context, dataCh chan []byte, errCh chan error) { netOpErr := &net.OpError{} for { - err := h.conn.SetReadDeadline(time.Now().Add(time.Second)) - if err != nil { - select { - case errCh <- err: - case <-ctx.Done(): - } - return - } data, err := wsutil.ReadServerText(h.conn) if err != nil { if errors.As(err, &netOpErr) { diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go index dad8f6b0e..6879fa2c2 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go @@ -188,14 +188,6 @@ func (h *gqlWSConnectionHandler) StartBlocking() error { func (h *gqlWSConnectionHandler) readBlocking(ctx context.Context, dataCh chan []byte, errCh chan error) { netOpErr := &net.OpError{} for { - err := h.conn.SetReadDeadline(time.Now().Add(time.Second)) - if err != nil { - select { - case errCh <- err: - case <-ctx.Done(): - } - return - } data, err := wsutil.ReadServerText(h.conn) if err != nil { if errors.As(err, &netOpErr) { From 90498eb0175449fd52a941fe38ac31b8455161fc Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 21 Oct 2024 16:15:20 +0200 Subject: [PATCH 22/31] chore: cleanup --- v2/pkg/engine/resolve/resolve.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 4ebdc2a94..2eacf07cf 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -426,11 +426,9 @@ func (r *Resolver) handleHeartbeat(data []byte) { if skipHeartbeat { continue } - if err := r.triggerUpdateSem.Acquire(r.ctx, 1); err != nil { return } - go func() { defer r.triggerUpdateSem.Release(1) @@ -650,7 +648,6 @@ func (r *Resolver) handleRemoveSubscription(id SubscriptionIdentifier) { } func (r *Resolver) handleRemoveClient(id int64) { - if r.options.Debug { fmt.Printf("resolver:trigger:subscription:remove:client:%d\n", id) } @@ -714,7 +711,9 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { } func (r *Resolver) shutdownTrigger(id uint64) { - fmt.Printf("resolver:trigger:shutdown:%d\n", id) + if r.options.Debug { + fmt.Printf("resolver:trigger:shutdown:%d\n", id) + } trig, ok := r.triggers[id] if !ok { return From 44de1f6f6e849ca6843cc356b7c089f58f909969 Mon Sep 17 00:00:00 2001 From: starptech Date: Mon, 21 Oct 2024 21:06:11 +0200 Subject: [PATCH 23/31] chore: use websocket library to close conn, fix race with trigger cancel --- .../graphql_subscription_client.go | 45 +++++++++++++++---- .../graphql_datasource/graphql_tws_handler.go | 37 +++++---------- .../graphql_datasource/graphql_ws_handler.go | 37 +++++---------- v2/pkg/engine/resolve/resolve.go | 18 +++++--- 4 files changed, 69 insertions(+), 68 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index 714f300f6..190204717 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "math" "net" "net/http" @@ -446,15 +447,15 @@ func (c *subscriptionClient) newWSConnectionHandler(requestContext, engineContex } } - if err := waitForAck(netConn); err != nil { + if err := waitForAck(requestContext, conn); err != nil { return nil, err } switch wsSubProtocol { case ProtocolGraphQLWS: - return newGQLWSConnectionHandler(requestContext, engineContext, netConn, options, updater, c.log), nil + return newGQLWSConnectionHandler(requestContext, engineContext, conn, netConn, options, updater, c.log), nil case ProtocolGraphQLTWS: - return newGQLTWSConnectionHandler(requestContext, engineContext, netConn, options, updater, c.log), nil + return newGQLTWSConnectionHandler(requestContext, engineContext, conn, netConn, options, updater, c.log), nil default: return nil, NewInvalidWsSubprotocolError(wsSubProtocol) } @@ -493,7 +494,7 @@ type ConnectionHandler interface { Subscribe() error } -func waitForAck(conn net.Conn) error { +func waitForAck(ctx context.Context, conn *websocket.Conn) error { timer := time.NewTimer(ackWaitTimeout) for { select { @@ -502,12 +503,15 @@ func waitForAck(conn net.Conn) error { default: } - data, err := wsutil.ReadServerText(conn) + msgType, msg, err := conn.Read(ctx) if err != nil { - return fmt.Errorf("failed to read message: %w", err) + return err + } + if msgType != websocket.MessageText { + return fmt.Errorf("unexpected message type") } - respType, err := jsonparser.GetString(data, "type") + respType, err := jsonparser.GetString(msg, "type") if err != nil { return err } @@ -516,10 +520,11 @@ func waitForAck(conn net.Conn) error { case messageTypeConnectionKeepAlive: continue case messageTypePing: - err := wsutil.WriteClientText(conn, []byte(pongMessage)) + err := conn.Write(ctx, websocket.MessageText, []byte(pongMessage)) if err != nil { return fmt.Errorf("failed to send pong message: %w", err) } + continue case messageTypeConnectionAck: return nil @@ -590,3 +595,27 @@ func (c *subscriptionClient) handleConnection(id int, handler ConnectionHandler, return } } + +func handleConnectionError(err error) (closed, timeout bool) { + netOpErr := &net.OpError{} + if errors.As(err, &netOpErr) { + if netOpErr.Timeout() { + return false, true + } + return true, false + } + + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + return true, false + } + if errors.Is(err, context.DeadlineExceeded) { + return false, true + } + if errors.As(err, &wsutil.ClosedError{}) { + return true, false + } + if strings.HasSuffix(err.Error(), "use of closed network connection") { + return true, false + } + return false, false +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go index 17e86d580..a072cea47 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go @@ -6,8 +6,8 @@ import ( "encoding/json" "errors" "fmt" + "github.com/coder/websocket" "net" - "strings" "time" "github.com/gobwas/ws/wsutil" @@ -21,7 +21,9 @@ import ( // it is responsible for managing all subscriptions using the underlying WebSocket connection // if all Subscriptions are complete or cancelled/unsubscribed the handler will terminate type gqlTWSConnectionHandler struct { + // The underlying net.Conn. Only used for epoll. Should not be used to shutdown the connection. conn net.Conn + wsConn *websocket.Conn requestContext, engineContext context.Context log log.Logger options GraphQLSubscriptionOptions @@ -30,7 +32,7 @@ type gqlTWSConnectionHandler struct { func (h *gqlTWSConnectionHandler) ServerClose() { h.updater.Done() - _ = h.conn.Close() + _ = h.wsConn.CloseNow() } func (h *gqlTWSConnectionHandler) ClientClose() { @@ -40,7 +42,7 @@ func (h *gqlTWSConnectionHandler) ClientClose() { if err != nil { h.log.Error("failed to write complete message", log.Error(err)) } - _ = h.conn.Close() + _ = h.wsConn.Close(websocket.StatusNormalClosure, "") } func (h *gqlTWSConnectionHandler) Subscribe() error { @@ -57,12 +59,12 @@ func (h *gqlTWSConnectionHandler) ReadMessage() (done, timeout bool) { err := h.conn.SetReadDeadline(time.Now().Add(time.Second)) if err != nil { - return h.handleConnectionError(err) + return handleConnectionError(err) } data, err := wsutil.ReadServerText(rwr) if err != nil { - return h.handleConnectionError(err) + return handleConnectionError(err) } messageType, err := jsonparser.GetString(data, "type") @@ -94,33 +96,14 @@ func (h *gqlTWSConnectionHandler) ReadMessage() (done, timeout bool) { } } -func (h *gqlTWSConnectionHandler) handleConnectionError(err error) (closed, timeout bool) { - if errors.Is(err, context.DeadlineExceeded) { - return false, true - } - netOpErr := &net.OpError{} - if errors.As(err, &netOpErr) { - if netOpErr.Timeout() { - return false, true - } - return true, false - } - if errors.As(err, &wsutil.ClosedError{}) { - return true, false - } - if strings.HasSuffix(err.Error(), "use of closed network connection") { - return true, false - } - return false, false -} - func (h *gqlTWSConnectionHandler) NetConn() net.Conn { return h.conn } -func newGQLTWSConnectionHandler(requestContext, engineContext context.Context, conn net.Conn, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater, l log.Logger) *gqlTWSConnectionHandler { +func newGQLTWSConnectionHandler(requestContext, engineContext context.Context, wsConn *websocket.Conn, conn net.Conn, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater, l log.Logger) *gqlTWSConnectionHandler { return &gqlTWSConnectionHandler{ conn: conn, + wsConn: wsConn, requestContext: requestContext, engineContext: engineContext, log: l, @@ -190,7 +173,7 @@ func (h *gqlTWSConnectionHandler) StartBlocking() error { func (h *gqlTWSConnectionHandler) unsubscribeAllAndCloseConn() { h.unsubscribe() - _ = h.conn.Close() + _ = h.wsConn.Close(websocket.StatusNormalClosure, "") } func (h *gqlTWSConnectionHandler) unsubscribe() { diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go index 6879fa2c2..e6859031b 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go @@ -6,9 +6,9 @@ import ( "encoding/json" "errors" "fmt" + "github.com/coder/websocket" "io" "net" - "strings" "time" "github.com/gobwas/ws/wsutil" @@ -22,7 +22,9 @@ import ( // it is responsible for managing all subscriptions using the underlying WebSocket connection // if all Subscriptions are complete or cancelled/unsubscribed the handler will terminate type gqlWSConnectionHandler struct { + // The underlying net.Conn. Only used for epoll. Should not be used to shutdown the connection. conn net.Conn + wsConn *websocket.Conn requestContext, engineContext context.Context log abstractlogger.Logger options GraphQLSubscriptionOptions @@ -31,14 +33,14 @@ type gqlWSConnectionHandler struct { func (h *gqlWSConnectionHandler) ServerClose() { h.updater.Done() - _ = h.conn.Close() + _ = h.wsConn.CloseNow() } func (h *gqlWSConnectionHandler) ClientClose() { h.updater.Done() stopRequest := fmt.Sprintf(stopMessage, "1") _ = wsutil.WriteClientText(h.conn, []byte(stopRequest)) - _ = h.conn.Close() + _ = h.wsConn.Close(websocket.StatusNormalClosure, "") } func (h *gqlWSConnectionHandler) Subscribe() error { @@ -54,11 +56,11 @@ func (h *gqlWSConnectionHandler) ReadMessage() (done, timeout bool) { for { err := h.conn.SetReadDeadline(time.Now().Add(time.Second)) if err != nil { - return h.handleConnectionError(err) + return handleConnectionError(err) } data, err := wsutil.ReadServerText(rwr) if err != nil { - return h.handleConnectionError(err) + return handleConnectionError(err) } messageType, err := jsonparser.GetString(data, "type") if err != nil { @@ -85,33 +87,14 @@ func (h *gqlWSConnectionHandler) ReadMessage() (done, timeout bool) { } } -func (h *gqlWSConnectionHandler) handleConnectionError(err error) (closed, timeout bool) { - if errors.Is(err, context.DeadlineExceeded) { - return false, true - } - netOpErr := &net.OpError{} - if errors.As(err, &netOpErr) { - if netOpErr.Timeout() { - return false, true - } - return true, false - } - if errors.As(err, &wsutil.ClosedError{}) { - return true, false - } - if strings.HasSuffix(err.Error(), "use of closed network connection") { - return true, false - } - return false, false -} - func (h *gqlWSConnectionHandler) NetConn() net.Conn { return h.conn } -func newGQLWSConnectionHandler(requestContext, engineContext context.Context, conn net.Conn, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater, log abstractlogger.Logger) *gqlWSConnectionHandler { +func newGQLWSConnectionHandler(requestContext, engineContext context.Context, wsConn *websocket.Conn, conn net.Conn, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater, log abstractlogger.Logger) *gqlWSConnectionHandler { return &gqlWSConnectionHandler{ conn: conn, + wsConn: wsConn, requestContext: requestContext, engineContext: engineContext, log: log, @@ -216,7 +199,7 @@ func (h *gqlWSConnectionHandler) readBlocking(ctx context.Context, dataCh chan [ func (h *gqlWSConnectionHandler) unsubscribeAllAndCloseConn() { h.unsubscribe() - _ = h.conn.Close() + _ = h.wsConn.Close(websocket.StatusNormalClosure, "") } // subscribe adds a new Subscription to the gqlWSConnectionHandler and sends the startMessage to the origin diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 2eacf07cf..459a08567 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -556,16 +556,22 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) r.reporter.SubscriptionCountInc(1) } + var asyncDataSource AsyncSubscriptionDataSource + + if async, ok := add.resolve.Trigger.Source.(AsyncSubscriptionDataSource); ok { + trig.cancel = func() { + cancel() + async.AsyncStop(triggerID) + } + asyncDataSource = async + } + go func() { if r.options.Debug { fmt.Printf("resolver:trigger:start:%d\n", triggerID) } - if async, ok := add.resolve.Trigger.Source.(AsyncSubscriptionDataSource); ok { - trig.cancel = func() { - cancel() - async.AsyncStop(triggerID) - } - err = async.AsyncStart(cloneCtx, triggerID, add.input, updater) + if asyncDataSource != nil { + err = asyncDataSource.AsyncStart(cloneCtx, triggerID, add.input, updater) } else { err = add.resolve.Trigger.Source.Start(cloneCtx, add.input, updater) } From b7035aca1cea5f9a886b2840cdd9dac833dcf930 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Tue, 22 Oct 2024 10:38:40 +0200 Subject: [PATCH 24/31] chore: use http client for upgrades --- .../graphql_subscription_client.go | 183 ++++++++++++------ .../graphql_subscription_client_test.go | 159 ++++++++++++++- .../graphql_datasource/graphql_tws_handler.go | 55 +++--- .../graphql_datasource/graphql_ws_handler.go | 28 ++- 4 files changed, 309 insertions(+), 116 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index 190204717..e9bd2b167 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -1,11 +1,14 @@ package graphql_datasource import ( + "bufio" "context" + "crypto/rand" + "crypto/sha1" + "encoding/base64" "errors" "fmt" "io" - "math" "net" "net/http" "net/http/httptrace" @@ -14,12 +17,9 @@ import ( "sync" "time" - "github.com/gobwas/ws/wsutil" - - "github.com/coder/websocket" - "github.com/buger/jsonparser" "github.com/cespare/xxhash/v2" + "github.com/gobwas/ws/wsutil" "github.com/jensneuse/abstractlogger" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "github.com/wundergraph/graphql-go-tools/v2/pkg/epoller" @@ -379,40 +379,11 @@ func (u *UpgradeRequestError) Error() string { } func (c *subscriptionClient) newWSConnectionHandler(requestContext, engineContext context.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) (ConnectionHandler, error) { - subProtocols := []string{ProtocolGraphQLWS, ProtocolGraphQLTWS} - if options.WsSubProtocol != "" && options.WsSubProtocol != "auto" { - subProtocols = []string{options.WsSubProtocol} - } - var netConn net.Conn - - clientTrace := &httptrace.ClientTrace{ - GotConn: func(info httptrace.GotConnInfo) { - netConn = info.Conn - }, - } - clientTraceCtx := httptrace.WithClientTrace(requestContext, clientTrace) - conn, upgradeResponse, err := websocket.Dial(clientTraceCtx, options.URL, &websocket.DialOptions{ - HTTPClient: c.httpClient, - HTTPHeader: options.Header, - CompressionMode: websocket.CompressionDisabled, - Subprotocols: subProtocols, - }) + conn, subProtocol, err := c.dial(requestContext, options) if err != nil { - if upgradeResponse != nil && upgradeResponse.StatusCode != 101 { - return nil, &UpgradeRequestError{ - URL: options.URL, - StatusCode: upgradeResponse.StatusCode, - } - } return nil, err } - // Disable the maximum message size limit. Don't use MaxInt64 since - // the github.com/coder/websocket doesn't handle it correctly on 32-bit systems. - conn.SetReadLimit(math.MaxInt32) - if upgradeResponse.StatusCode != http.StatusSwitchingProtocols { - return nil, fmt.Errorf("upgrade unsuccessful") - } connectionInitMessage, err := c.getConnectionInitMessage(requestContext, options.URL, options.Header) if err != nil { @@ -434,31 +405,107 @@ func (c *subscriptionClient) newWSConnectionHandler(requestContext, engineContex } // init + ack - err = wsutil.WriteClientText(netConn, connectionInitMessage) + err = wsutil.WriteClientText(conn, connectionInitMessage) if err != nil { return nil, err } - wsSubProtocol := subProtocols[0] - if options.WsSubProtocol == "" || options.WsSubProtocol == "auto" { - wsSubProtocol = conn.Subprotocol() - if wsSubProtocol == "" { - wsSubProtocol = ProtocolGraphQLWS - } - } - - if err := waitForAck(requestContext, conn); err != nil { + if err := waitForAck(conn); err != nil { return nil, err } - switch wsSubProtocol { + switch subProtocol { case ProtocolGraphQLWS: - return newGQLWSConnectionHandler(requestContext, engineContext, conn, netConn, options, updater, c.log), nil + return newGQLWSConnectionHandler(requestContext, engineContext, conn, options, updater, c.log), nil case ProtocolGraphQLTWS: - return newGQLTWSConnectionHandler(requestContext, engineContext, conn, netConn, options, updater, c.log), nil + return newGQLTWSConnectionHandler(requestContext, engineContext, conn, options, updater, c.log), nil default: - return nil, NewInvalidWsSubprotocolError(wsSubProtocol) + return nil, NewInvalidWsSubprotocolError(subProtocol) + } +} + +func (c *subscriptionClient) dial(ctx context.Context, options GraphQLSubscriptionOptions) (conn net.Conn, subProtocol string, err error) { + subProtocols := []string{ProtocolGraphQLWS, ProtocolGraphQLTWS} + if options.WsSubProtocol != "" && options.WsSubProtocol != "auto" { + subProtocols = []string{options.WsSubProtocol} + } + + clientTrace := &httptrace.ClientTrace{ + GotConn: func(info httptrace.GotConnInfo) { + conn = info.Conn + }, + } + clientTraceCtx := httptrace.WithClientTrace(ctx, clientTrace) + u := options.URL + if strings.HasPrefix(options.URL, "wss") { + u = "https" + options.URL[3:] + } else if strings.HasPrefix(options.URL, "ws") { + u = "http" + options.URL[2:] + } + req, err := http.NewRequestWithContext(clientTraceCtx, http.MethodGet, u, nil) + if err != nil { + return nil, "", err + } + req.Proto = "HTTP/1.1" + req.ProtoMajor = 1 + req.ProtoMinor = 1 + if options.Header != nil { + req.Header = options.Header + } + req.Header.Set("Sec-WebSocket-Protocol", strings.Join(subProtocols, ",")) + req.Header.Set("Sec-WebSocket-Version", "13") + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + + challengeKey, err := generateChallengeKey() + if err != nil { + return nil, "", err + } + + req.Header.Set("Sec-WebSocket-Key", challengeKey) + + upgradeResponse, err := c.httpClient.Do(req) + if err != nil { + return nil, "", err + } + if upgradeResponse.StatusCode != http.StatusSwitchingProtocols { + return nil, "", &UpgradeRequestError{ + URL: u, + StatusCode: upgradeResponse.StatusCode, + } + } + + accept := computeAcceptKey(challengeKey) + if upgradeResponse.Header.Get("Sec-WebSocket-Accept") != accept { + return nil, "", fmt.Errorf("invalid Sec-WebSocket-Accept") + } + + subProtocol = subProtocols[0] + if options.WsSubProtocol == "" || options.WsSubProtocol == "auto" { + subProtocol = upgradeResponse.Header.Get("Sec-WebSocket-Protocol") + if subProtocol == "" { + subProtocol = ProtocolGraphQLWS + } + } + + return conn, subProtocol, nil +} + +func generateChallengeKey() (string, error) { + p := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, p); err != nil { + return "", err } + return base64.StdEncoding.EncodeToString(p), nil +} + +var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +func computeAcceptKey(challengeKey string) string { + h := sha1.New() //#nosec G401 -- (CWE-326) https://datatracker.ietf.org/doc/html/rfc6455#page-54 + h.Write([]byte(challengeKey)) + h.Write(keyGUID) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) } func (c *subscriptionClient) getConnectionInitMessage(ctx context.Context, url string, header http.Header) ([]byte, error) { @@ -494,7 +541,7 @@ type ConnectionHandler interface { Subscribe() error } -func waitForAck(ctx context.Context, conn *websocket.Conn) error { +func waitForAck(conn net.Conn) error { timer := time.NewTimer(ackWaitTimeout) for { select { @@ -502,29 +549,22 @@ func waitForAck(ctx context.Context, conn *websocket.Conn) error { return fmt.Errorf("timeout while waiting for connection_ack") default: } - - msgType, msg, err := conn.Read(ctx) + msg, err := wsutil.ReadServerText(conn) if err != nil { return err } - if msgType != websocket.MessageText { - return fmt.Errorf("unexpected message type") - } - respType, err := jsonparser.GetString(msg, "type") if err != nil { return err } - switch respType { case messageTypeConnectionKeepAlive: continue case messageTypePing: - err := conn.Write(ctx, websocket.MessageText, []byte(pongMessage)) + err = wsutil.WriteClientText(conn, []byte(pongMessage)) if err != nil { return fmt.Errorf("failed to send pong message: %w", err) } - continue case messageTypeConnectionAck: return nil @@ -619,3 +659,28 @@ func handleConnectionError(err error) (closed, timeout bool) { } return false, false } + +var ( + readWriterPool = &ReadWriterPool{} +) + +type ReadWriterPool struct { + pool sync.Pool +} + +func (r *ReadWriterPool) Get(rw io.ReadWriter) *bufio.ReadWriter { + v := r.pool.Get() + if v == nil { + return bufio.NewReadWriter(bufio.NewReader(rw), bufio.NewWriter(rw)) + } + rwr := v.(*bufio.ReadWriter) + rwr.Reader.Reset(rw) + rwr.Writer.Reset(rw) + return rwr +} + +func (r *ReadWriterPool) Put(rw *bufio.ReadWriter) { + rw.Reader.Reset(nil) + rw.Writer.Reset(nil) + r.pool.Put(rw) +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index 8764ff42b..db4f62c3f 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -10,14 +10,13 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" - - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" - "github.com/coder/websocket" ll "github.com/jensneuse/abstractlogger" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/atomic" + "go.uber.org/goleak" "go.uber.org/zap" ) @@ -423,6 +422,9 @@ func TestSubprotocolNegotiationWithConfiguredGraphQLTransportWS(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, nil) assert.NoError(t, err) + defer func() { + _ = conn.Close(websocket.StatusNormalClosure, "done") + }() ctx := context.Background() msgType, data, err := conn.Read(ctx) assert.NoError(t, err) @@ -486,6 +488,84 @@ func TestSubprotocolNegotiationWithConfiguredGraphQLTransportWS(t *testing.T) { serverCancel() } +func TestWebSocketClientLeaks(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + assert.NoError(t, err) + defer func() { + _ = conn.Close(websocket.StatusNormalClosure, "done") + }() + ctx := context.Background() + msgType, data, err := conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"connection_init"}`, string(data)) + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) + assert.NoError(t, err) + + time.Sleep(time.Second * 1) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) + assert.NoError(t, err) + + time.Sleep(time.Second * 1) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) + assert.NoError(t, err) + + time.Sleep(time.Second * 1) + + err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) + assert.NoError(t, err) + + msgType, data, err = conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) + })) + defer server.Close() + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + + client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, + WithReadTimeout(time.Second), + WithLogger(logger()), + ).(*subscriptionClient) + wg := &sync.WaitGroup{} + wg.Add(2) + for i := 0; i < 2; i++ { + go func(i int) { + ctx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + updater := &testSubscriptionUpdater{} + err := client.SubscribeAsync(resolve.NewContext(ctx), uint64(i), GraphQLSubscriptionOptions{ + URL: server.URL, + Body: GraphQLBody{ + Query: `subscription {messageAdded(roomName: "room"){text}}`, + }, + WsSubProtocol: ProtocolGraphQLTWS, + }, updater) + assert.NoError(t, err) + + updater.AwaitUpdates(t, time.Second*10, 3) + assert.Equal(t, 3, len(updater.updates)) + assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) + assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) + assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) + client.Unsubscribe(uint64(i)) + clientCancel() + wg.Done() + }(i) + } + wg.Wait() +} + func TestAsyncSubscribe(t *testing.T) { if runtime.GOOS == "windows" { t.SkipNow() @@ -497,6 +577,9 @@ func TestAsyncSubscribe(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, nil) assert.NoError(t, err) + defer func() { + _ = conn.Close(websocket.StatusNormalClosure, "done") + }() ctx := context.Background() msgType, data, err := conn.Read(ctx) assert.NoError(t, err) @@ -571,6 +654,9 @@ func TestAsyncSubscribe(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, nil) assert.NoError(t, err) + defer func() { + _ = conn.Close(websocket.StatusNormalClosure, "done") + }() ctx := context.Background() msgType, data, err := conn.Read(ctx) assert.NoError(t, err) @@ -640,6 +726,9 @@ func TestAsyncSubscribe(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, nil) assert.NoError(t, err) + defer func() { + _ = conn.Close(websocket.StatusNormalClosure, "done") + }() ctx := context.Background() msgType, data, err := conn.Read(ctx) assert.NoError(t, err) @@ -697,6 +786,9 @@ func TestAsyncSubscribe(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, nil) assert.NoError(t, err) + defer func() { + _ = conn.Close(websocket.StatusNormalClosure, "done") + }() ctx := context.Background() msgType, data, err := conn.Read(ctx) assert.NoError(t, err) @@ -760,9 +852,9 @@ func TestAsyncSubscribe(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, nil) assert.NoError(t, err) - - defer conn.Close(websocket.StatusNormalClosure, "done") - + defer func() { + _ = conn.Close(websocket.StatusNormalClosure, "done") + }() ctx := context.Background() msgType, data, err := conn.Read(ctx) assert.NoError(t, err) @@ -822,9 +914,9 @@ func TestAsyncSubscribe(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, nil) assert.NoError(t, err) - - defer conn.Close(websocket.StatusNormalClosure, "done") - + defer func() { + _ = conn.Close(websocket.StatusNormalClosure, "done") + }() ctx := context.Background() msgType, data, err := conn.Read(ctx) assert.NoError(t, err) @@ -877,6 +969,9 @@ func TestAsyncSubscribe(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, nil) assert.NoError(t, err) + defer func() { + _ = conn.Close(websocket.StatusNormalClosure, "done") + }() ctx := context.Background() msgType, data, err := conn.Read(ctx) assert.NoError(t, err) @@ -909,6 +1004,9 @@ func TestAsyncSubscribe(t *testing.T) { assert.NoError(t, err) assert.Equal(t, websocket.MessageText, msgType) assert.Equal(t, `{"type":"stop","id":"1"}`, string(data)) + + ctx = conn.CloseRead(ctx) + <-ctx.Done() close(serverDone) })) defer server.Close() @@ -951,6 +1049,9 @@ func TestAsyncSubscribe(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, nil) assert.NoError(t, err) + defer func() { + _ = conn.Close(websocket.StatusNormalClosure, "done") + }() ctx := context.Background() msgType, data, err := conn.Read(ctx) assert.NoError(t, err) @@ -1014,6 +1115,9 @@ func TestAsyncSubscribe(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, nil) assert.NoError(t, err) + defer func() { + _ = conn.Close(websocket.StatusNormalClosure, "done") + }() ctx := context.Background() msgType, data, err := conn.Read(ctx) assert.NoError(t, err) @@ -1076,6 +1180,9 @@ func TestAsyncSubscribe(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, nil) assert.NoError(t, err) + defer func() { + _ = conn.Close(websocket.StatusNormalClosure, "done") + }() ctx := context.Background() msgType, data, err := conn.Read(ctx) assert.NoError(t, err) @@ -1141,6 +1248,9 @@ func TestAsyncSubscribe(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, nil) assert.NoError(t, err) + defer func() { + _ = conn.Close(websocket.StatusNormalClosure, "done") + }() ctx := context.Background() msgType, data, err := conn.Read(ctx) assert.NoError(t, err) @@ -1173,6 +1283,11 @@ func TestAsyncSubscribe(t *testing.T) { assert.NoError(t, err) assert.Equal(t, websocket.MessageText, msgType) assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) + + ctx = conn.CloseRead(ctx) + <-ctx.Done() + close(serverDone) + close(serverDone) })) defer server.Close() @@ -1215,6 +1330,9 @@ func TestAsyncSubscribe(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, nil) assert.NoError(t, err) + defer func() { + _ = conn.Close(websocket.StatusNormalClosure, "done") + }() ctx := context.Background() msgType, data, err := conn.Read(ctx) assert.NoError(t, err) @@ -1291,6 +1409,9 @@ func TestAsyncSubscribe(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, nil) assert.NoError(t, err) + defer func() { + _ = conn.Close(websocket.StatusNormalClosure, "done") + }() ctx := context.Background() msgType, data, err := conn.Read(ctx) assert.NoError(t, err) @@ -1369,6 +1490,9 @@ func TestAsyncSubscribe(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, nil) assert.NoError(t, err) + defer func() { + _ = conn.Close(websocket.StatusNormalClosure, "done") + }() ctx := context.Background() msgType, data, err := conn.Read(ctx) assert.NoError(t, err) @@ -1445,6 +1569,9 @@ func TestAsyncSubscribe(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, nil) assert.NoError(t, err) + defer func() { + _ = conn.Close(websocket.StatusNormalClosure, "done") + }() ctx := context.Background() msgType, data, err := conn.Read(ctx) assert.NoError(t, err) @@ -1518,6 +1645,9 @@ func TestAsyncSubscribe(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, nil) assert.NoError(t, err) + defer func() { + _ = conn.Close(websocket.StatusNormalClosure, "done") + }() ctx := context.Background() msgType, data, err := conn.Read(ctx) assert.NoError(t, err) @@ -1582,6 +1712,9 @@ func TestAsyncSubscribe(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, nil) assert.NoError(t, err) + defer func() { + _ = conn.Close(websocket.StatusNormalClosure, "done") + }() ctx := context.Background() msgType, data, err := conn.Read(ctx) assert.NoError(t, err) @@ -1646,6 +1779,9 @@ func TestAsyncSubscribe(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, nil) assert.NoError(t, err) + defer func() { + _ = conn.Close(websocket.StatusNormalClosure, "done") + }() ctx := context.Background() msgType, data, err := conn.Read(ctx) assert.NoError(t, err) @@ -1706,6 +1842,9 @@ func TestAsyncSubscribe(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, nil) assert.NoError(t, err) + defer func() { + _ = conn.Close(websocket.StatusNormalClosure, "done") + }() ctx := context.Background() msgType, data, err := conn.Read(ctx) assert.NoError(t, err) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go index a072cea47..02b3af285 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go @@ -1,20 +1,19 @@ package graphql_datasource import ( - "bufio" "context" "encoding/json" "errors" "fmt" - "github.com/coder/websocket" "net" "time" + "github.com/gobwas/ws" "github.com/gobwas/ws/wsutil" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "github.com/buger/jsonparser" - log "github.com/jensneuse/abstractlogger" + "github.com/jensneuse/abstractlogger" ) // gqlTWSConnectionHandler is responsible for handling a connection to an origin @@ -23,26 +22,23 @@ import ( type gqlTWSConnectionHandler struct { // The underlying net.Conn. Only used for epoll. Should not be used to shutdown the connection. conn net.Conn - wsConn *websocket.Conn requestContext, engineContext context.Context - log log.Logger + log abstractlogger.Logger options GraphQLSubscriptionOptions updater resolve.SubscriptionUpdater } func (h *gqlTWSConnectionHandler) ServerClose() { h.updater.Done() - _ = h.wsConn.CloseNow() + _ = ws.WriteFrame(h.conn, ws.MaskFrame(ws.NewCloseFrame(ws.NewCloseFrameBody(ws.StatusNormalClosure, "Normal Closure")))) + _ = h.conn.Close() } func (h *gqlTWSConnectionHandler) ClientClose() { h.updater.Done() - req := fmt.Sprintf(completeMessage, "1") - err := wsutil.WriteClientText(h.conn, []byte(req)) - if err != nil { - h.log.Error("failed to write complete message", log.Error(err)) - } - _ = h.wsConn.Close(websocket.StatusNormalClosure, "") + _ = wsutil.WriteClientText(h.conn, []byte(`{"id":"1","type":"complete"}`)) + _ = ws.WriteFrame(h.conn, ws.MaskFrame(ws.NewCloseFrame(ws.NewCloseFrameBody(ws.StatusNormalClosure, "Normal Closure")))) + _ = h.conn.Close() } func (h *gqlTWSConnectionHandler) Subscribe() error { @@ -51,18 +47,15 @@ func (h *gqlTWSConnectionHandler) Subscribe() error { func (h *gqlTWSConnectionHandler) ReadMessage() (done, timeout bool) { - r := bufio.NewReader(h.conn) - wr := bufio.NewWriter(h.conn) - rwr := bufio.NewReadWriter(r, wr) + rw := readWriterPool.Get(h.conn) + defer readWriterPool.Put(rw) for { - err := h.conn.SetReadDeadline(time.Now().Add(time.Second)) if err != nil { return handleConnectionError(err) } - - data, err := wsutil.ReadServerText(rwr) + data, err := wsutil.ReadServerText(rw) if err != nil { return handleConnectionError(err) } @@ -90,7 +83,7 @@ func (h *gqlTWSConnectionHandler) ReadMessage() (done, timeout bool) { h.log.Error("Invalid subprotocol. The subprotocol should be set to graphql-transport-ws, but currently it is set to graphql-ws") return true, false default: - h.log.Error("unknown message type", log.String("type", messageType)) + h.log.Error("unknown message type", abstractlogger.String("type", messageType)) return false, false } } @@ -100,10 +93,9 @@ func (h *gqlTWSConnectionHandler) NetConn() net.Conn { return h.conn } -func newGQLTWSConnectionHandler(requestContext, engineContext context.Context, wsConn *websocket.Conn, conn net.Conn, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater, l log.Logger) *gqlTWSConnectionHandler { +func newGQLTWSConnectionHandler(requestContext, engineContext context.Context, conn net.Conn, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater, l abstractlogger.Logger) *gqlTWSConnectionHandler { return &gqlTWSConnectionHandler{ conn: conn, - wsConn: wsConn, requestContext: requestContext, engineContext: engineContext, log: l, @@ -136,7 +128,7 @@ func (h *gqlTWSConnectionHandler) StartBlocking() error { case <-readCtx.Done(): return readCtx.Err() case err := <-errCh: - h.log.Error("gqlWSConnectionHandler.StartBlocking", log.Error(err)) + h.log.Error("gqlWSConnectionHandler.StartBlocking", abstractlogger.Error(err)) h.broadcastErrorMessage(err) return err case data := <-dataCh: @@ -164,7 +156,7 @@ func (h *gqlTWSConnectionHandler) StartBlocking() error { h.log.Error("Invalid subprotocol. The subprotocol should be set to graphql-transport-ws, but currently it is set to graphql-ws") return errors.New("invalid subprotocol") default: - h.log.Error("unknown message type", log.String("type", messageType)) + h.log.Error("unknown message type", abstractlogger.String("type", messageType)) continue } } @@ -173,7 +165,8 @@ func (h *gqlTWSConnectionHandler) StartBlocking() error { func (h *gqlTWSConnectionHandler) unsubscribeAllAndCloseConn() { h.unsubscribe() - _ = h.wsConn.Close(websocket.StatusNormalClosure, "") + _ = ws.WriteFrame(h.conn, ws.MaskFrame(ws.NewCloseFrame(ws.NewCloseFrameBody(ws.StatusNormalClosure, "Normal Closure")))) + _ = h.conn.Close() } func (h *gqlTWSConnectionHandler) unsubscribe() { @@ -181,7 +174,7 @@ func (h *gqlTWSConnectionHandler) unsubscribe() { req := fmt.Sprintf(completeMessage, "1") err := wsutil.WriteClientText(h.conn, []byte(req)) if err != nil { - h.log.Error("failed to write complete message", log.Error(err)) + h.log.Error("failed to write complete message", abstractlogger.Error(err)) } } @@ -227,8 +220,8 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeError(data []byte) { if err != nil { h.log.Error( "failed to get payload from error message", - log.Error(err), - log.ByteString("raw message", data), + abstractlogger.Error(err), + abstractlogger.ByteString("raw message", data), ) h.updater.Update([]byte(internalError)) return @@ -241,8 +234,8 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeError(data []byte) { if err != nil { h.log.Error( "failed to set errors response", - log.Error(err), - log.ByteString("raw message", value), + abstractlogger.Error(err), + abstractlogger.ByteString("raw message", value), ) h.updater.Update([]byte(internalError)) return @@ -264,7 +257,7 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeError(data []byte) { func (h *gqlTWSConnectionHandler) handleMessageTypePing() { err := wsutil.WriteClientText(h.conn, []byte(pongMessage)) if err != nil { - h.log.Error("failed to write pong message", log.Error(err)) + h.log.Error("failed to write pong message", abstractlogger.Error(err)) } } @@ -280,7 +273,7 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeNext(data []byte) { if err != nil { h.log.Error( "failed to get payload from next message", - log.Error(err), + abstractlogger.Error(err), ) h.updater.Update([]byte(internalError)) return diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go index e6859031b..d4d97b389 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go @@ -1,16 +1,15 @@ package graphql_datasource import ( - "bufio" "context" "encoding/json" "errors" "fmt" - "github.com/coder/websocket" "io" "net" "time" + "github.com/gobwas/ws" "github.com/gobwas/ws/wsutil" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" @@ -24,7 +23,6 @@ import ( type gqlWSConnectionHandler struct { // The underlying net.Conn. Only used for epoll. Should not be used to shutdown the connection. conn net.Conn - wsConn *websocket.Conn requestContext, engineContext context.Context log abstractlogger.Logger options GraphQLSubscriptionOptions @@ -33,14 +31,15 @@ type gqlWSConnectionHandler struct { func (h *gqlWSConnectionHandler) ServerClose() { h.updater.Done() - _ = h.wsConn.CloseNow() + _ = ws.WriteFrame(h.conn, ws.MaskFrame(ws.NewCloseFrame(ws.NewCloseFrameBody(ws.StatusNormalClosure, "Normal Closure")))) + _ = h.conn.Close() } func (h *gqlWSConnectionHandler) ClientClose() { h.updater.Done() - stopRequest := fmt.Sprintf(stopMessage, "1") - _ = wsutil.WriteClientText(h.conn, []byte(stopRequest)) - _ = h.wsConn.Close(websocket.StatusNormalClosure, "") + _ = wsutil.WriteClientText(h.conn, []byte(`{"type":"stop","id":"1"}`)) + _ = ws.WriteFrame(h.conn, ws.MaskFrame(ws.NewCloseFrame(ws.NewCloseFrameBody(ws.StatusNormalClosure, "Normal Closure")))) + _ = h.conn.Close() } func (h *gqlWSConnectionHandler) Subscribe() error { @@ -49,16 +48,15 @@ func (h *gqlWSConnectionHandler) Subscribe() error { func (h *gqlWSConnectionHandler) ReadMessage() (done, timeout bool) { - r := bufio.NewReader(h.conn) - wr := bufio.NewWriter(h.conn) - rwr := bufio.NewReadWriter(r, wr) + rw := readWriterPool.Get(h.conn) + defer readWriterPool.Put(rw) for { err := h.conn.SetReadDeadline(time.Now().Add(time.Second)) if err != nil { return handleConnectionError(err) } - data, err := wsutil.ReadServerText(rwr) + data, err := wsutil.ReadServerText(h.conn) if err != nil { return handleConnectionError(err) } @@ -91,10 +89,9 @@ func (h *gqlWSConnectionHandler) NetConn() net.Conn { return h.conn } -func newGQLWSConnectionHandler(requestContext, engineContext context.Context, wsConn *websocket.Conn, conn net.Conn, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater, log abstractlogger.Logger) *gqlWSConnectionHandler { +func newGQLWSConnectionHandler(requestContext, engineContext context.Context, conn net.Conn, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater, log abstractlogger.Logger) *gqlWSConnectionHandler { return &gqlWSConnectionHandler{ conn: conn, - wsConn: wsConn, requestContext: requestContext, engineContext: engineContext, log: log, @@ -199,7 +196,8 @@ func (h *gqlWSConnectionHandler) readBlocking(ctx context.Context, dataCh chan [ func (h *gqlWSConnectionHandler) unsubscribeAllAndCloseConn() { h.unsubscribe() - _ = h.wsConn.Close(websocket.StatusNormalClosure, "") + _ = ws.WriteFrame(h.conn, ws.MaskFrame(ws.NewCloseFrame(ws.NewCloseFrameBody(ws.StatusNormalClosure, "Normal Closure")))) + _ = h.conn.Close() } // subscribe adds a new Subscription to the gqlWSConnectionHandler and sends the startMessage to the origin @@ -208,13 +206,11 @@ func (h *gqlWSConnectionHandler) subscribe() error { if err != nil { return err } - startRequest := fmt.Sprintf(startMessage, "1", string(graphQLBody)) err = wsutil.WriteClientText(h.conn, []byte(startRequest)) if err != nil { return err } - return nil } From 2281d9d62c9bcb5ba98c928c1bf28aabddeac0d0 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Tue, 22 Oct 2024 23:09:41 +0200 Subject: [PATCH 25/31] chore: improve heap usage and epoll unsubscribe --- .../graphql_subscription_client.go | 52 +++++++++++-------- .../graphql_datasource/graphql_tws_handler.go | 1 - .../graphql_datasource/graphql_ws_handler.go | 2 +- v2/pkg/epoller/epoll_bsd.go | 7 ++- v2/pkg/epoller/epoll_linux.go | 6 ++- 5 files changed, 43 insertions(+), 25 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index e9bd2b167..4cadbfeba 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -52,7 +52,8 @@ type subscriptionClient struct { activeConnections map[int]struct{} activeConnectionsMu sync.Mutex - triggers map[uint64]int + triggers map[uint64]int + asyncUnsubscribeTrigger chan uint64 } func (c *subscriptionClient) SubscribeAsync(ctx *resolve.Context, id uint64, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { @@ -70,20 +71,7 @@ func (c *subscriptionClient) SubscribeAsync(ctx *resolve.Context, id uint64, opt } func (c *subscriptionClient) Unsubscribe(id uint64) { - c.connectionsMu.Lock() - defer c.connectionsMu.Unlock() - fd, ok := c.triggers[id] - if !ok { - return - } - delete(c.triggers, id) - handler, ok := c.connections[fd] - if !ok { - return - } - handler.ClientClose() - delete(c.connections, fd) - _ = c.epoll.Remove(handler.NetConn()) + c.asyncUnsubscribeTrigger <- id } type InvalidWsSubprotocolError struct { @@ -183,6 +171,7 @@ func NewGraphQLSubscriptionClient(httpClient, streamingClient *http.Client, engi connections: make(map[int]ConnectionHandler), activeConnections: make(map[int]struct{}), triggers: make(map[uint64]int), + asyncUnsubscribeTrigger: make(chan uint64, op.epollConfiguration.BufferSize), epollConfig: op.epollConfiguration, } if !op.epollConfiguration.Disable { @@ -575,14 +564,12 @@ func waitForAck(conn net.Conn) error { } func (c *subscriptionClient) runEpoll(ctx context.Context) { - var ( - done = ctx.Done() - ) + done := ctx.Done() for { - connections, err := c.epoll.Wait(50) + connections, err := c.epoll.Wait(c.epollConfig.BufferSize) if err != nil { c.log.Error("epoll.Wait", abstractlogger.Error(err)) - return + continue } c.connectionsMu.Lock() for _, conn := range connections { @@ -602,6 +589,7 @@ func (c *subscriptionClient) runEpoll(ctx context.Context) { } go c.handleConnection(id, handler, conn) } + c.handlePendingUnsubscribe() c.connectionsMu.Unlock() select { @@ -612,6 +600,28 @@ func (c *subscriptionClient) runEpoll(ctx context.Context) { } } +func (c *subscriptionClient) handlePendingUnsubscribe() { + for { + select { + case id := <-c.asyncUnsubscribeTrigger: + fd, ok := c.triggers[id] + if !ok { + continue + } + delete(c.triggers, id) + handler, ok := c.connections[fd] + if !ok { + continue + } + delete(c.connections, fd) + _ = c.epoll.Remove(handler.NetConn()) + handler.ClientClose() + default: + return + } + } +} + func (c *subscriptionClient) handleConnection(id int, handler ConnectionHandler, conn net.Conn) { done, timeout := handler.ReadMessage() if timeout { @@ -630,8 +640,8 @@ func (c *subscriptionClient) handleConnection(id int, handler ConnectionHandler, delete(c.connections, id) c.connectionsMu.Unlock() - handler.ServerClose() _ = c.epoll.Remove(conn) + handler.ServerClose() return } } diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go index 02b3af285..21637d2fb 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go @@ -59,7 +59,6 @@ func (h *gqlTWSConnectionHandler) ReadMessage() (done, timeout bool) { if err != nil { return handleConnectionError(err) } - messageType, err := jsonparser.GetString(data, "type") if err != nil { return false, false diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go index d4d97b389..8a5a92c83 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go @@ -56,7 +56,7 @@ func (h *gqlWSConnectionHandler) ReadMessage() (done, timeout bool) { if err != nil { return handleConnectionError(err) } - data, err := wsutil.ReadServerText(h.conn) + data, err := wsutil.ReadServerText(rw) if err != nil { return handleConnectionError(err) } diff --git a/v2/pkg/epoller/epoll_bsd.go b/v2/pkg/epoller/epoll_bsd.go index 67369fa4f..3f6092645 100755 --- a/v2/pkg/epoller/epoll_bsd.go +++ b/v2/pkg/epoller/epoll_bsd.go @@ -23,6 +23,8 @@ type Epoll struct { changes []syscall.Kevent_t conns map[int]net.Conn connbuf []net.Conn + + events []syscall.Kevent_t } // NewPoller creates a new poller instance. @@ -123,7 +125,10 @@ func (e *Epoll) Remove(conn net.Conn) error { // Wait waits for events and returns the connections. func (e *Epoll) Wait(count int) ([]net.Conn, error) { - events := make([]syscall.Kevent_t, count) + if e.events == nil { + e.events = make([]syscall.Kevent_t, count) + } + events := e.events[:count] e.mu.RLock() changes := e.changes diff --git a/v2/pkg/epoller/epoll_linux.go b/v2/pkg/epoller/epoll_linux.go index a042a6ae6..3607d5856 100755 --- a/v2/pkg/epoller/epoll_linux.go +++ b/v2/pkg/epoller/epoll_linux.go @@ -25,6 +25,7 @@ type Epoll struct { connbuf []net.Conn timeoutMsec int + events []unix.EpollEvent } // NewPoller creates a new epoll poller. @@ -100,7 +101,10 @@ func (e *Epoll) Remove(conn net.Conn) error { // Wait waits for at most count events and returns the connections. func (e *Epoll) Wait(count int) ([]net.Conn, error) { - events := make([]unix.EpollEvent, count) + if e.events == nil { + e.events = make([]unix.EpollEvent, count) + } + events := e.events[:count] retry: n, err := unix.EpollWait(e.fd, events, e.timeoutMsec) From b97e73b678f68d8dd9813a0e615baa4effbf3f6c Mon Sep 17 00:00:00 2001 From: starptech Date: Wed, 23 Oct 2024 16:00:25 +0200 Subject: [PATCH 26/31] feat: limit concurrency, decrease read deadline when reading messages --- .../graphql_subscription_client.go | 71 ++++++++++++------- .../graphql_datasource/graphql_ws_handler.go | 2 +- 2 files changed, 46 insertions(+), 27 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index 4cadbfeba..bca746a35 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -15,6 +15,7 @@ import ( "net/textproto" "strings" "sync" + "syscall" "time" "github.com/buger/jsonparser" @@ -47,10 +48,9 @@ type subscriptionClient struct { epollConfig EpollConfiguration connections map[int]ConnectionHandler - connectionsMu sync.Mutex + connectionsMu sync.RWMutex - activeConnections map[int]struct{} - activeConnectionsMu sync.Mutex + activeConnections map[int]struct{} triggers map[uint64]int asyncUnsubscribeTrigger chan uint64 @@ -110,7 +110,7 @@ type EpollConfiguration struct { func (e *EpollConfiguration) ApplyDefaults() { if e.BufferSize == 0 { - e.BufferSize = 1024 + e.BufferSize = 2048 } if e.Interval == 0 { e.Interval = time.Millisecond * 100 @@ -565,37 +565,37 @@ func waitForAck(conn net.Conn) error { func (c *subscriptionClient) runEpoll(ctx context.Context) { done := ctx.Done() + wg := sync.WaitGroup{} + for { connections, err := c.epoll.Wait(c.epollConfig.BufferSize) if err != nil { c.log.Error("epoll.Wait", abstractlogger.Error(err)) continue } - c.connectionsMu.Lock() - for _, conn := range connections { + c.connectionsMu.RLock() + for _, cc := range connections { + conn := cc id := epoller.SocketFD(conn) handler, ok := c.connections[id] if !ok { continue } - c.activeConnectionsMu.Lock() - _, active := c.activeConnections[id] - if !active { - c.activeConnections[id] = struct{}{} - } - c.activeConnectionsMu.Unlock() - if active { - continue - } - go c.handleConnection(id, handler, conn) + wg.Add(1) + go func() { + defer wg.Done() + c.handleConnection(id, handler, conn) + }() } c.handlePendingUnsubscribe() - c.connectionsMu.Unlock() + c.connectionsMu.RUnlock() select { case <-done: + c.log.Debug("epoll done due to context done") return default: + wg.Wait() } } } @@ -604,16 +604,25 @@ func (c *subscriptionClient) handlePendingUnsubscribe() { for { select { case id := <-c.asyncUnsubscribeTrigger: + c.connectionsMu.Lock() + fd, ok := c.triggers[id] if !ok { + c.connectionsMu.Unlock() continue } + delete(c.triggers, id) + handler, ok := c.connections[fd] if !ok { + c.connectionsMu.Unlock() continue } + delete(c.connections, fd) + c.connectionsMu.Unlock() + _ = c.epoll.Remove(handler.NetConn()) handler.ClientClose() default: @@ -625,17 +634,10 @@ func (c *subscriptionClient) handlePendingUnsubscribe() { func (c *subscriptionClient) handleConnection(id int, handler ConnectionHandler, conn net.Conn) { done, timeout := handler.ReadMessage() if timeout { - c.activeConnectionsMu.Lock() - delete(c.activeConnections, id) - c.activeConnectionsMu.Unlock() return } if done { - c.activeConnectionsMu.Lock() - delete(c.activeConnections, id) - c.activeConnectionsMu.Unlock() - c.connectionsMu.Lock() delete(c.connections, id) c.connectionsMu.Unlock() @@ -655,18 +657,35 @@ func handleConnectionError(err error) (closed, timeout bool) { return true, false } + // Check if we have errors during reading from the connection if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { return true, false } + + // Check if we have a context error if errors.Is(err, context.DeadlineExceeded) { return false, true } - if errors.As(err, &wsutil.ClosedError{}) { + + // Check if the error is a connection reset by peer + if errors.Is(err, syscall.ECONNRESET) { return true, false } - if strings.HasSuffix(err.Error(), "use of closed network connection") { + if errors.Is(err, syscall.EPIPE) { + return true, false + } + + // Check if the error is a closed network connection. Introduced in go 1.16. + // This replaces the string match of "use of closed network connection" + if errors.Is(err, net.ErrClosed) { return true, false } + + // Check if the error is closed websocket connection + if errors.As(err, &wsutil.ClosedError{}) { + return true, false + } + return false, false } diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go index 8a5a92c83..d1964174e 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go @@ -52,7 +52,7 @@ func (h *gqlWSConnectionHandler) ReadMessage() (done, timeout bool) { defer readWriterPool.Put(rw) for { - err := h.conn.SetReadDeadline(time.Now().Add(time.Second)) + err := h.conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) if err != nil { return handleConnectionError(err) } From d006b22e8e285f7721a20e35984c581d1ac786b3 Mon Sep 17 00:00:00 2001 From: starptech Date: Wed, 23 Oct 2024 17:16:04 +0200 Subject: [PATCH 27/31] fix: dead lock, test --- .../graphql_datasource/graphql_subscription_client.go | 6 ++---- .../graphql_datasource/graphql_subscription_client_test.go | 5 ++--- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index bca746a35..e6b5339a6 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -50,8 +50,6 @@ type subscriptionClient struct { connections map[int]ConnectionHandler connectionsMu sync.RWMutex - activeConnections map[int]struct{} - triggers map[uint64]int asyncUnsubscribeTrigger chan uint64 } @@ -169,7 +167,6 @@ func NewGraphQLSubscriptionClient(httpClient, streamingClient *http.Client, engi }, onWsConnectionInitCallback: op.onWsConnectionInitCallback, connections: make(map[int]ConnectionHandler), - activeConnections: make(map[int]struct{}), triggers: make(map[uint64]int), asyncUnsubscribeTrigger: make(chan uint64, op.epollConfiguration.BufferSize), epollConfig: op.epollConfiguration, @@ -587,9 +584,10 @@ func (c *subscriptionClient) runEpoll(ctx context.Context) { c.handleConnection(id, handler, conn) }() } - c.handlePendingUnsubscribe() c.connectionsMu.RUnlock() + c.handlePendingUnsubscribe() + select { case <-done: c.log.Debug("epoll done due to context done") diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index db4f62c3f..548dbee52 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -958,9 +958,6 @@ func TestAsyncSubscribe(t *testing.T) { assert.Equal(t, 1, len(updater.updates)) assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) time.Sleep(time.Second * 2) - client.activeConnectionsMu.Lock() - defer client.activeConnectionsMu.Unlock() - assert.Equal(t, 0, len(client.activeConnections)) }) t.Run("graphql-ws", func(t *testing.T) { t.Parallel() @@ -1036,7 +1033,9 @@ func TestAsyncSubscribe(t *testing.T) { assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) client.Unsubscribe(1) + clientCancel() + assert.Eventuallyf(t, func() bool { <-serverDone return true From 1ed429e92a2cb666f3dd739115b43ff08212b3c0 Mon Sep 17 00:00:00 2001 From: starptech Date: Wed, 23 Oct 2024 17:49:42 +0200 Subject: [PATCH 28/31] fix: tests --- .../graphql_subscription_client.go | 48 +++++++++---------- .../graphql_subscription_client_test.go | 38 +++++++++++++-- 2 files changed, 57 insertions(+), 29 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index e6b5339a6..cd7226484 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -565,34 +565,34 @@ func (c *subscriptionClient) runEpoll(ctx context.Context) { wg := sync.WaitGroup{} for { - connections, err := c.epoll.Wait(c.epollConfig.BufferSize) - if err != nil { - c.log.Error("epoll.Wait", abstractlogger.Error(err)) - continue - } - c.connectionsMu.RLock() - for _, cc := range connections { - conn := cc - id := epoller.SocketFD(conn) - handler, ok := c.connections[id] - if !ok { - continue - } - wg.Add(1) - go func() { - defer wg.Done() - c.handleConnection(id, handler, conn) - }() - } - c.connectionsMu.RUnlock() - - c.handlePendingUnsubscribe() - select { case <-done: - c.log.Debug("epoll done due to context done") + c.log.Debug("epoll context done", abstractlogger.Error(ctx.Err())) return default: + connections, err := c.epoll.Wait(c.epollConfig.BufferSize) + if err != nil { + c.log.Error("epoll.Wait", abstractlogger.Error(err)) + continue + } + c.connectionsMu.RLock() + for _, cc := range connections { + conn := cc + id := epoller.SocketFD(conn) + handler, ok := c.connections[id] + if !ok { + continue + } + wg.Add(1) + go func() { + defer wg.Done() + c.handleConnection(id, handler, conn) + }() + } + c.connectionsMu.RUnlock() + + c.handlePendingUnsubscribe() + wg.Wait() } } diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index 548dbee52..abed5757a 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -3,6 +3,7 @@ package graphql_datasource import ( "context" "encoding/json" + "go.uber.org/goleak" "net/http" "net/http/httptest" "runtime" @@ -16,7 +17,6 @@ import ( "github.com/stretchr/testify/require" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/atomic" - "go.uber.org/goleak" "go.uber.org/zap" ) @@ -490,7 +490,13 @@ func TestSubprotocolNegotiationWithConfiguredGraphQLTransportWS(t *testing.T) { func TestWebSocketClientLeaks(t *testing.T) { defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + + wg := &sync.WaitGroup{} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wg.Add(1) + defer wg.Done() + conn, err := websocket.Accept(w, r, nil) assert.NoError(t, err) defer func() { @@ -537,7 +543,7 @@ func TestWebSocketClientLeaks(t *testing.T) { WithReadTimeout(time.Second), WithLogger(logger()), ).(*subscriptionClient) - wg := &sync.WaitGroup{} + wg.Add(2) for i := 0; i < 2; i++ { go func(i int) { @@ -560,9 +566,13 @@ func TestWebSocketClientLeaks(t *testing.T) { assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) client.Unsubscribe(uint64(i)) clientCancel() + + time.Sleep(200 * time.Millisecond) + wg.Done() }(i) } + wg.Wait() } @@ -641,7 +651,11 @@ func TestAsyncSubscribe(t *testing.T) { assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) client.Unsubscribe(1) + clientCancel() + + time.Sleep(200 * time.Millisecond) + assert.Eventuallyf(t, func() bool { <-serverDone return true @@ -675,7 +689,7 @@ func TestAsyncSubscribe(t *testing.T) { err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) assert.NoError(t, err) - time.Sleep(time.Second * 2) + time.Sleep(time.Second * 4) err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) assert.NoError(t, err) @@ -713,7 +727,11 @@ func TestAsyncSubscribe(t *testing.T) { assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) client.Unsubscribe(1) + clientCancel() + + time.Sleep(200 * time.Millisecond) + assert.Eventuallyf(t, func() bool { <-serverDone return true @@ -1286,8 +1304,6 @@ func TestAsyncSubscribe(t *testing.T) { ctx = conn.CloseRead(ctx) <-ctx.Done() close(serverDone) - - close(serverDone) })) defer server.Close() ctx, clientCancel := context.WithCancel(context.Background()) @@ -1316,7 +1332,11 @@ func TestAsyncSubscribe(t *testing.T) { assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) client.Unsubscribe(1) + clientCancel() + + time.Sleep(200 * time.Millisecond) + assert.Eventuallyf(t, func() bool { <-serverDone return true @@ -1555,7 +1575,11 @@ func TestAsyncSubscribe(t *testing.T) { assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) client.Unsubscribe(1) + clientCancel() + + time.Sleep(200 * time.Millisecond) + assert.Eventuallyf(t, func() bool { <-serverDone return true @@ -1631,7 +1655,11 @@ func TestAsyncSubscribe(t *testing.T) { assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) client.Unsubscribe(1) + clientCancel() + + time.Sleep(200 * time.Millisecond) + assert.Eventuallyf(t, func() bool { <-serverDone return true From b71bef3fad3086050623112480734fcf8967c44e Mon Sep 17 00:00:00 2001 From: starptech Date: Wed, 23 Oct 2024 22:25:22 +0200 Subject: [PATCH 29/31] chore: unsubscribe in seperate routine --- .../graphql_subscription_client.go | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index cd7226484..e96bce368 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -564,6 +564,8 @@ func (c *subscriptionClient) runEpoll(ctx context.Context) { done := ctx.Done() wg := sync.WaitGroup{} + go c.handlePendingUnsubscribe(ctx) + for { select { case <-done: @@ -575,6 +577,7 @@ func (c *subscriptionClient) runEpoll(ctx context.Context) { c.log.Error("epoll.Wait", abstractlogger.Error(err)) continue } + c.connectionsMu.RLock() for _, cc := range connections { conn := cc @@ -591,17 +594,19 @@ func (c *subscriptionClient) runEpoll(ctx context.Context) { } c.connectionsMu.RUnlock() - c.handlePendingUnsubscribe() - wg.Wait() } } } -func (c *subscriptionClient) handlePendingUnsubscribe() { +func (c *subscriptionClient) handlePendingUnsubscribe(ctx context.Context) { + for { select { - case id := <-c.asyncUnsubscribeTrigger: + case <-ctx.Done(): + c.log.Debug("handlePendingUnsubscribe context done", abstractlogger.Error(ctx.Err())) + return + case id, ok := <-c.asyncUnsubscribeTrigger: c.connectionsMu.Lock() fd, ok := c.triggers[id] @@ -623,10 +628,9 @@ func (c *subscriptionClient) handlePendingUnsubscribe() { _ = c.epoll.Remove(handler.NetConn()) handler.ClientClose() - default: - return } } + } func (c *subscriptionClient) handleConnection(id int, handler ConnectionHandler, conn net.Conn) { From 07cc7c4a69822d21896e118b2babecc0b98c8d26 Mon Sep 17 00:00:00 2001 From: starptech Date: Wed, 23 Oct 2024 22:31:52 +0200 Subject: [PATCH 30/31] fix: lint --- .../graphql_datasource/graphql_subscription_client.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index e96bce368..f6847e281 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -607,6 +607,10 @@ func (c *subscriptionClient) handlePendingUnsubscribe(ctx context.Context) { c.log.Debug("handlePendingUnsubscribe context done", abstractlogger.Error(ctx.Err())) return case id, ok := <-c.asyncUnsubscribeTrigger: + if !ok { + return + } + c.connectionsMu.Lock() fd, ok := c.triggers[id] From 264dacb635613e31f0aa012e441e622744001e88 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 24 Oct 2024 09:41:26 +0200 Subject: [PATCH 31/31] chore: handle unsubscribe async batched --- .../graphql_subscription_client.go | 156 ++++++++++-------- .../graphql_subscription_client_test.go | 14 +- .../graphql_datasource/graphql_tws_handler.go | 18 +- .../graphql_datasource/graphql_ws_handler.go | 18 +- 4 files changed, 122 insertions(+), 84 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index e6b5339a6..a6e17c043 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -47,11 +47,12 @@ type subscriptionClient struct { epoll epoller.Poller epollConfig EpollConfiguration - connections map[int]ConnectionHandler - connectionsMu sync.RWMutex + connections map[int]*connection + connectionsMu sync.Mutex - triggers map[uint64]int - asyncUnsubscribeTrigger chan uint64 + triggers map[uint64]int + clientUnsubscribe chan uint64 + serverUnsubscribe chan int } func (c *subscriptionClient) SubscribeAsync(ctx *resolve.Context, id uint64, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { @@ -69,7 +70,7 @@ func (c *subscriptionClient) SubscribeAsync(ctx *resolve.Context, id uint64, opt } func (c *subscriptionClient) Unsubscribe(id uint64) { - c.asyncUnsubscribeTrigger <- id + c.clientUnsubscribe <- id } type InvalidWsSubprotocolError struct { @@ -108,7 +109,7 @@ type EpollConfiguration struct { func (e *EpollConfiguration) ApplyDefaults() { if e.BufferSize == 0 { - e.BufferSize = 2048 + e.BufferSize = 1024 * 2 } if e.Interval == 0 { e.Interval = time.Millisecond * 100 @@ -166,9 +167,10 @@ func NewGraphQLSubscriptionClient(httpClient, streamingClient *http.Client, engi }, }, onWsConnectionInitCallback: op.onWsConnectionInitCallback, - connections: make(map[int]ConnectionHandler), + connections: make(map[int]*connection), triggers: make(map[uint64]int), - asyncUnsubscribeTrigger: make(chan uint64, op.epollConfiguration.BufferSize), + clientUnsubscribe: make(chan uint64, op.epollConfiguration.BufferSize), + serverUnsubscribe: make(chan int, op.epollConfiguration.BufferSize), epollConfig: op.epollConfiguration, } if !op.epollConfiguration.Disable { @@ -182,6 +184,13 @@ func NewGraphQLSubscriptionClient(httpClient, streamingClient *http.Client, engi return client } +type connection struct { + id uint64 + fd int + conn net.Conn + handler ConnectionHandler +} + // Subscribe initiates a new GraphQL Subscription with the origin // If an existing WS connection with the same ID (Hash) exists, it is being re-used // If connection protocol is SSE, a new connection is always created @@ -232,13 +241,13 @@ func (c *subscriptionClient) subscribeWS(requestContext, engineContext context.C return fmt.Errorf("http client is nil") } - handler, err := c.newWSConnectionHandler(requestContext, engineContext, options, updater) + conn, err := c.newWSConnectionHandler(requestContext, engineContext, options, updater) if err != nil { return err } go func() { - err := handler.StartBlocking() + err := conn.handler.StartBlocking() if err != nil { if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { return @@ -255,14 +264,14 @@ func (c *subscriptionClient) asyncSubscribeWS(requestContext, engineContext cont return fmt.Errorf("http client is nil") } - handler, err := c.newWSConnectionHandler(requestContext, engineContext, options, updater) + conn, err := c.newWSConnectionHandler(requestContext, engineContext, options, updater) if err != nil { return err } if c.epoll == nil { go func() { - err := handler.StartBlocking() + err := conn.handler.StartBlocking() if err != nil && !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { c.log.Error("subscriptionClient.asyncSubscribeWS", abstractlogger.Error(err)) } @@ -270,19 +279,20 @@ func (c *subscriptionClient) asyncSubscribeWS(requestContext, engineContext cont return nil } - err = handler.Subscribe() + err = conn.handler.Subscribe() if err != nil { return err } - netConn := handler.NetConn() - if err := c.epoll.Add(netConn); err != nil { + if err := c.epoll.Add(conn.conn); err != nil { return err } c.connectionsMu.Lock() - fd := epoller.SocketFD(netConn) - c.connections[fd] = handler + fd := epoller.SocketFD(conn.conn) + conn.id = id + conn.fd = fd + c.connections[fd] = conn c.triggers[id] = fd c.connectionsMu.Unlock() @@ -364,7 +374,7 @@ func (u *UpgradeRequestError) Error() string { return fmt.Sprintf("failed to upgrade connection to %s, status code: %d", u.URL, u.StatusCode) } -func (c *subscriptionClient) newWSConnectionHandler(requestContext, engineContext context.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) (ConnectionHandler, error) { +func (c *subscriptionClient) newWSConnectionHandler(requestContext, engineContext context.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) (*connection, error) { conn, subProtocol, err := c.dial(requestContext, options) if err != nil { @@ -520,8 +530,7 @@ func (c *subscriptionClient) getConnectionInitMessage(ctx context.Context, url s type ConnectionHandler interface { StartBlocking() error - NetConn() net.Conn - ReadMessage() (done, timeout bool) + ReadMessage() (done bool) ServerClose() ClientClose() Subscribe() error @@ -563,128 +572,141 @@ func waitForAck(conn net.Conn) error { func (c *subscriptionClient) runEpoll(ctx context.Context) { done := ctx.Done() wg := sync.WaitGroup{} - for { connections, err := c.epoll.Wait(c.epollConfig.BufferSize) if err != nil { c.log.Error("epoll.Wait", abstractlogger.Error(err)) continue } - c.connectionsMu.RLock() - for _, cc := range connections { - conn := cc - id := epoller.SocketFD(conn) - handler, ok := c.connections[id] + c.connectionsMu.Lock() + hasWork := false + for i := range connections { + id := epoller.SocketFD(connections[i]) + conn, ok := c.connections[id] if !ok { continue } + hasWork = true wg.Add(1) go func() { defer wg.Done() - c.handleConnection(id, handler, conn) + c.handleConnection(conn) }() } - c.connectionsMu.RUnlock() - - c.handlePendingUnsubscribe() - + c.connectionsMu.Unlock() + if hasWork { + wg.Wait() + } + c.handlePendingClientUnsubscribe() + c.handlePendingServerUnsubscribe() select { case <-done: c.log.Debug("epoll done due to context done") return default: - wg.Wait() + continue } } } -func (c *subscriptionClient) handlePendingUnsubscribe() { +func (c *subscriptionClient) handlePendingClientUnsubscribe() { + c.connectionsMu.Lock() + defer c.connectionsMu.Unlock() + ctx, cancel := context.WithTimeout(context.Background(), c.epollConfig.Interval) + defer cancel() for { select { - case id := <-c.asyncUnsubscribeTrigger: - c.connectionsMu.Lock() - + case id := <-c.clientUnsubscribe: fd, ok := c.triggers[id] if !ok { - c.connectionsMu.Unlock() continue } - delete(c.triggers, id) - - handler, ok := c.connections[fd] + conn, ok := c.connections[fd] if !ok { - c.connectionsMu.Unlock() continue } - delete(c.connections, fd) - c.connectionsMu.Unlock() - - _ = c.epoll.Remove(handler.NetConn()) - handler.ClientClose() + _ = c.epoll.Remove(conn.conn) + go conn.handler.ClientClose() + case <-ctx.Done(): + return default: return } } } -func (c *subscriptionClient) handleConnection(id int, handler ConnectionHandler, conn net.Conn) { - done, timeout := handler.ReadMessage() - if timeout { - return +func (c *subscriptionClient) handlePendingServerUnsubscribe() { + c.connectionsMu.Lock() + defer c.connectionsMu.Unlock() + ctx, cancel := context.WithTimeout(context.Background(), c.epollConfig.Interval) + defer cancel() + for { + select { + case id := <-c.serverUnsubscribe: + conn, ok := c.connections[id] + if !ok { + continue + } + delete(c.connections, id) + delete(c.triggers, conn.id) + _ = c.epoll.Remove(conn.conn) + go conn.handler.ServerClose() + case <-ctx.Done(): + return + default: + return + } } +} +func (c *subscriptionClient) handleConnection(conn *connection) { + done := conn.handler.ReadMessage() if done { - c.connectionsMu.Lock() - delete(c.connections, id) - c.connectionsMu.Unlock() - - _ = c.epoll.Remove(conn) - handler.ServerClose() - return + c.serverUnsubscribe <- conn.fd } } -func handleConnectionError(err error) (closed, timeout bool) { +func handleConnectionError(err error) (closed bool) { netOpErr := &net.OpError{} if errors.As(err, &netOpErr) { if netOpErr.Timeout() { - return false, true + return false } - return true, false + return true } // Check if we have errors during reading from the connection if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { - return true, false + return true } // Check if we have a context error if errors.Is(err, context.DeadlineExceeded) { - return false, true + return false } // Check if the error is a connection reset by peer if errors.Is(err, syscall.ECONNRESET) { - return true, false + return true } if errors.Is(err, syscall.EPIPE) { - return true, false + return true } // Check if the error is a closed network connection. Introduced in go 1.16. // This replaces the string match of "use of closed network connection" if errors.Is(err, net.ErrClosed) { - return true, false + return true } // Check if the error is closed websocket connection if errors.As(err, &wsutil.ClosedError{}) { - return true, false + return true } - return false, false + return false } var ( diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index 548dbee52..2b7194ec2 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -489,12 +489,17 @@ func TestSubprotocolNegotiationWithConfiguredGraphQLTransportWS(t *testing.T) { } func TestWebSocketClientLeaks(t *testing.T) { - defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + defer goleak.VerifyNone(t, + goleak.IgnoreCurrent(), // ignore the test itself + ) + serverDone := &sync.WaitGroup{} + serverDone.Add(2) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, nil) assert.NoError(t, err) defer func() { _ = conn.Close(websocket.StatusNormalClosure, "done") + serverDone.Done() }() ctx := context.Background() msgType, data, err := conn.Read(ctx) @@ -564,6 +569,9 @@ func TestWebSocketClientLeaks(t *testing.T) { }(i) } wg.Wait() + serverCancel() + time.Sleep(time.Second) + serverDone.Wait() } func TestAsyncSubscribe(t *testing.T) { @@ -656,6 +664,7 @@ func TestAsyncSubscribe(t *testing.T) { assert.NoError(t, err) defer func() { _ = conn.Close(websocket.StatusNormalClosure, "done") + close(serverDone) }() ctx := context.Background() msgType, data, err := conn.Read(ctx) @@ -684,7 +693,6 @@ func TestAsyncSubscribe(t *testing.T) { assert.NoError(t, err) assert.Equal(t, websocket.MessageText, msgType) assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - close(serverDone) })) defer server.Close() ctx, clientCancel := context.WithCancel(context.Background()) @@ -717,7 +725,7 @@ func TestAsyncSubscribe(t *testing.T) { assert.Eventuallyf(t, func() bool { <-serverDone return true - }, time.Second, time.Millisecond*10, "server did not close") + }, time.Second*5, time.Millisecond*10, "server did not close") serverCancel() }) t.Run("server complete", func(t *testing.T) { diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go index 21637d2fb..3f8008861 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go @@ -45,7 +45,7 @@ func (h *gqlTWSConnectionHandler) Subscribe() error { return h.subscribe() } -func (h *gqlTWSConnectionHandler) ReadMessage() (done, timeout bool) { +func (h *gqlTWSConnectionHandler) ReadMessage() (done bool) { rw := readWriterPool.Get(h.conn) defer readWriterPool.Put(rw) @@ -61,7 +61,7 @@ func (h *gqlTWSConnectionHandler) ReadMessage() (done, timeout bool) { } messageType, err := jsonparser.GetString(data, "type") if err != nil { - return false, false + return false } switch messageType { case messageTypePing: @@ -72,7 +72,7 @@ func (h *gqlTWSConnectionHandler) ReadMessage() (done, timeout bool) { continue case messageTypeComplete: h.handleMessageTypeComplete(data) - return true, false + return true case messageTypeError: h.handleMessageTypeError(data) continue @@ -80,10 +80,10 @@ func (h *gqlTWSConnectionHandler) ReadMessage() (done, timeout bool) { continue case messageTypeData, messageTypeConnectionError: h.log.Error("Invalid subprotocol. The subprotocol should be set to graphql-transport-ws, but currently it is set to graphql-ws") - return true, false + return true default: h.log.Error("unknown message type", abstractlogger.String("type", messageType)) - return false, false + return false } } } @@ -92,8 +92,8 @@ func (h *gqlTWSConnectionHandler) NetConn() net.Conn { return h.conn } -func newGQLTWSConnectionHandler(requestContext, engineContext context.Context, conn net.Conn, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater, l abstractlogger.Logger) *gqlTWSConnectionHandler { - return &gqlTWSConnectionHandler{ +func newGQLTWSConnectionHandler(requestContext, engineContext context.Context, conn net.Conn, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater, l abstractlogger.Logger) *connection { + handler := &gqlTWSConnectionHandler{ conn: conn, requestContext: requestContext, engineContext: engineContext, @@ -101,6 +101,10 @@ func newGQLTWSConnectionHandler(requestContext, engineContext context.Context, c updater: updater, options: options, } + return &connection{ + handler: handler, + conn: conn, + } } func (h *gqlTWSConnectionHandler) StartBlocking() error { diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go index d1964174e..fe5e77d15 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go @@ -46,7 +46,7 @@ func (h *gqlWSConnectionHandler) Subscribe() error { return h.subscribe() } -func (h *gqlWSConnectionHandler) ReadMessage() (done, timeout bool) { +func (h *gqlWSConnectionHandler) ReadMessage() (done bool) { rw := readWriterPool.Get(h.conn) defer readWriterPool.Put(rw) @@ -62,7 +62,7 @@ func (h *gqlWSConnectionHandler) ReadMessage() (done, timeout bool) { } messageType, err := jsonparser.GetString(data, "type") if err != nil { - return false, false + return false } switch messageType { case messageTypeConnectionKeepAlive: @@ -72,15 +72,15 @@ func (h *gqlWSConnectionHandler) ReadMessage() (done, timeout bool) { continue case messageTypeComplete: h.handleMessageTypeComplete(data) - return true, false + return true case messageTypeConnectionError: h.handleMessageTypeConnectionError() - return true, false + return true case messageTypeError: h.handleMessageTypeError(data) continue default: - return true, false + return true } } } @@ -89,8 +89,8 @@ func (h *gqlWSConnectionHandler) NetConn() net.Conn { return h.conn } -func newGQLWSConnectionHandler(requestContext, engineContext context.Context, conn net.Conn, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater, log abstractlogger.Logger) *gqlWSConnectionHandler { - return &gqlWSConnectionHandler{ +func newGQLWSConnectionHandler(requestContext, engineContext context.Context, conn net.Conn, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater, log abstractlogger.Logger) *connection { + handler := &gqlWSConnectionHandler{ conn: conn, requestContext: requestContext, engineContext: engineContext, @@ -98,6 +98,10 @@ func newGQLWSConnectionHandler(requestContext, engineContext context.Context, co updater: updater, options: options, } + return &connection{ + handler: handler, + conn: conn, + } } // StartBlocking starts the single threaded event loop of the handler