From a1de22c99598fea2624ff735391791341b467961 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 17 Oct 2024 13:09:07 +0200 Subject: [PATCH] chore: fix race --- .../graphql_subscription_client.go | 4 +- .../graphql_datasource/graphql_tws_handler.go | 118 ++++++++---------- .../graphql_datasource/graphql_ws_handler.go | 114 ++++++++--------- 3 files changed, 104 insertions(+), 132 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index 385f0ef95..41f6b234f 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -263,6 +263,8 @@ func (c *subscriptionClient) asyncSubscribeWS(reqCtx *resolve.Context, id uint64 return err } + handler.Subscribe(sub) + netConn := handler.NetConn() if err := c.epoll.Add(netConn); err != nil { return err @@ -280,8 +282,6 @@ func (c *subscriptionClient) asyncSubscribeWS(reqCtx *resolve.Context, id uint64 go c.runEpoll(c.engineCtx) } - handler.Subscribe(sub) - return nil } diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go index d7c0559a1..e34234d3c 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go @@ -8,7 +8,6 @@ import ( "fmt" "io" "net" - "strconv" "strings" "time" @@ -23,32 +22,26 @@ import ( // it is responsible for managing all subscriptions using the underlying WebSocket connection // if all Subscriptions are complete or cancelled/unsubscribed the handler will terminate type gqlTWSConnectionHandler struct { - conn net.Conn - ctx context.Context - log log.Logger - subscribeCh chan Subscription - nextSubscriptionID int - subscriptions map[string]Subscription - readTimeout time.Duration + conn net.Conn + ctx context.Context + log log.Logger + subscribeCh chan Subscription + subscription *Subscription + readTimeout time.Duration } func (h *gqlTWSConnectionHandler) ServerClose() { - for _, sub := range h.subscriptions { - sub.updater.Done() + if h.subscription != nil { + h.subscription.updater.Done() } _ = h.conn.Close() } func (h *gqlTWSConnectionHandler) ClientClose() { - for k, v := range h.subscriptions { - v.updater.Done() - delete(h.subscriptions, k) - - req := fmt.Sprintf(completeMessage, k) - err := wsutil.WriteClientText(h.conn, []byte(req)) - if err != nil { - h.log.Error("failed to write complete message", log.Error(err)) - } + if h.subscription != nil { + h.subscription.updater.Done() + stopRequest := fmt.Sprintf(completeMessage, "1") + _ = wsutil.WriteClientText(h.conn, []byte(stopRequest)) } _ = h.conn.Close() } @@ -130,12 +123,10 @@ func (h *gqlTWSConnectionHandler) NetConn() net.Conn { func newGQLTWSConnectionHandler(ctx context.Context, conn net.Conn, rt time.Duration, l log.Logger) *gqlTWSConnectionHandler { return &gqlTWSConnectionHandler{ - conn: conn, - ctx: ctx, - log: l, - nextSubscriptionID: 0, - subscriptions: map[string]Subscription{}, - readTimeout: rt, + conn: conn, + ctx: ctx, + log: l, + readTimeout: rt, } } @@ -162,7 +153,7 @@ func (h *gqlTWSConnectionHandler) StartBlocking(sub Subscription) { if !errors.Is(err, context.Canceled) && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { h.log.Error("gqlWSConnectionHandler.StartBlocking", log.Error(err)) } - h.broadcastErrorMessage(err) + h.publishErrorMessage(err) return } @@ -178,7 +169,7 @@ func (h *gqlTWSConnectionHandler) StartBlocking(sub Subscription) { continue case err := <-errCh: h.log.Error("gqlWSConnectionHandler.StartBlocking", log.Error(err)) - h.broadcastErrorMessage(err) + h.publishErrorMessage(err) return case <-ticker.C: sub.updater.Heartbeat() @@ -213,21 +204,16 @@ func (h *gqlTWSConnectionHandler) StartBlocking(sub Subscription) { } func (h *gqlTWSConnectionHandler) unsubscribeAllAndCloseConn() { - for id := range h.subscriptions { - h.unsubscribe(id) - } + h.unsubscribe() _ = h.conn.Close() } -func (h *gqlTWSConnectionHandler) unsubscribe(subscriptionID string) { - sub, ok := h.subscriptions[subscriptionID] - if !ok { +func (h *gqlTWSConnectionHandler) unsubscribe() { + if h.subscription == nil { return } - sub.updater.Done() - delete(h.subscriptions, subscriptionID) - - req := fmt.Sprintf(completeMessage, subscriptionID) + h.subscription.updater.Done() + req := fmt.Sprintf(completeMessage, "1") err := wsutil.WriteClientText(h.conn, []byte(req)) if err != nil { h.log.Error("failed to write complete message", log.Error(err)) @@ -242,24 +228,19 @@ func (h *gqlTWSConnectionHandler) subscribe(sub Subscription) { return } - h.nextSubscriptionID++ - - subscriptionID := strconv.Itoa(h.nextSubscriptionID) - subscribeRequest := fmt.Sprintf(subscribeMessage, subscriptionID, string(graphQLBody)) + subscribeRequest := fmt.Sprintf(subscribeMessage, "1", string(graphQLBody)) err = wsutil.WriteClientText(h.conn, []byte(subscribeRequest)) if err != nil { h.log.Error("failed to write subscribe message", log.Error(err)) return } - h.subscriptions[subscriptionID] = sub + h.subscription = &sub } -func (h *gqlTWSConnectionHandler) broadcastErrorMessage(err error) { +func (h *gqlTWSConnectionHandler) publishErrorMessage(err error) { errMsg := fmt.Sprintf(errorMessageTemplate, err) - for _, sub := range h.subscriptions { - sub.updater.Update([]byte(errMsg)) - } + h.subscription.updater.Update([]byte(errMsg)) } func (h *gqlTWSConnectionHandler) handleMessageTypeComplete(data []byte) { @@ -267,12 +248,13 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeComplete(data []byte) { if err != nil { return } - sub, ok := h.subscriptions[id] - if !ok { + if id != "1" { + return + } + if h.subscription == nil { return } - sub.updater.Done() - delete(h.subscriptions, id) + h.subscription.updater.Done() } func (h *gqlTWSConnectionHandler) handleMessageTypeError(data []byte) { @@ -280,11 +262,12 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeError(data []byte) { if err != nil { return } - sub, ok := h.subscriptions[id] - if !ok { + if id != "1" { + return + } + if h.subscription == nil { return } - value, valueType, _, err := jsonparser.Get(data, "payload") if err != nil { h.log.Error( @@ -292,7 +275,7 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeError(data []byte) { log.Error(err), log.ByteString("raw message", data), ) - sub.updater.Update([]byte(internalError)) + h.subscription.updater.Update([]byte(internalError)) return } @@ -306,20 +289,20 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeError(data []byte) { log.Error(err), log.ByteString("raw message", value), ) - sub.updater.Update([]byte(internalError)) + h.subscription.updater.Update([]byte(internalError)) return } - sub.updater.Update(response) + h.subscription.updater.Update(response) case jsonparser.Object: response := []byte(`{"errors":[]}`) response, err = jsonparser.Set(response, value, "errors", "[0]") if err != nil { - sub.updater.Update([]byte(internalError)) + h.subscription.updater.Update([]byte(internalError)) return } - sub.updater.Update(response) + h.subscription.updater.Update(response) default: - sub.updater.Update([]byte(internalError)) + h.subscription.updater.Update([]byte(internalError)) } } @@ -335,8 +318,10 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeNext(data []byte) { if err != nil { return } - sub, ok := h.subscriptions[id] - if !ok { + if id != "1" { + return + } + if h.subscription == nil { return } @@ -346,11 +331,11 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeNext(data []byte) { "failed to get payload from next message", log.Error(err), ) - sub.updater.Update([]byte(internalError)) + h.subscription.updater.Update([]byte(internalError)) return } - sub.updater.Update(value) + h.subscription.updater.Update(value) } // readBlocking is a dedicated loop running in a separate goroutine @@ -375,10 +360,5 @@ func (h *gqlTWSConnectionHandler) readBlocking(ctx context.Context, dataCh chan } func (h *gqlTWSConnectionHandler) hasActiveSubscriptions() (hasActiveSubscriptions bool) { - for id, sub := range h.subscriptions { - if sub.ctx.Err() != nil { - h.unsubscribe(id) - } - } - return len(h.subscriptions) != 0 + return h.subscription != nil } diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go index ec63d3df5..042b7a416 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go @@ -8,7 +8,6 @@ import ( "fmt" "io" "net" - "strconv" "strings" "time" @@ -27,24 +26,22 @@ type gqlWSConnectionHandler struct { ctx context.Context log abstractlogger.Logger // log slog.Logger - subscribeCh chan Subscription - nextSubscriptionID int - subscriptions map[string]Subscription - readTimeout time.Duration + subscribeCh chan Subscription + subscription *Subscription + readTimeout time.Duration } func (h *gqlWSConnectionHandler) ServerClose() { - for _, sub := range h.subscriptions { - sub.updater.Done() + if h.subscription != nil { + h.subscription.updater.Done() } _ = h.conn.Close() } func (h *gqlWSConnectionHandler) ClientClose() { - for k, v := range h.subscriptions { - v.updater.Done() - delete(h.subscriptions, k) - stopRequest := fmt.Sprintf(stopMessage, k) + if h.subscription != nil { + h.subscription.updater.Done() + stopRequest := fmt.Sprintf(stopMessage, "1") _ = wsutil.WriteClientText(h.conn, []byte(stopRequest)) } _ = h.conn.Close() @@ -120,12 +117,10 @@ func (h *gqlWSConnectionHandler) NetConn() net.Conn { func newGQLWSConnectionHandler(ctx context.Context, conn net.Conn, readTimeout time.Duration, log abstractlogger.Logger) *gqlWSConnectionHandler { return &gqlWSConnectionHandler{ - conn: conn, - ctx: ctx, - log: log, - nextSubscriptionID: 0, - subscriptions: map[string]Subscription{}, - readTimeout: readTimeout, + conn: conn, + ctx: ctx, + log: log, + readTimeout: readTimeout, } } @@ -154,7 +149,7 @@ func (h *gqlWSConnectionHandler) StartBlocking(sub Subscription) { if !errors.Is(err, context.Canceled) && !errors.Is(err, io.EOF) { h.log.Error("gqlWSConnectionHandler.StartBlocking", abstractlogger.Error(err)) } - h.broadcastErrorMessage(err) + h.publishErrorMessage(err) return } hasActiveSubscriptions := h.checkActiveSubscriptions() @@ -170,7 +165,7 @@ func (h *gqlWSConnectionHandler) StartBlocking(sub Subscription) { if !errors.Is(err, context.Canceled) && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { h.log.Error("gqlWSConnectionHandler.StartBlocking", abstractlogger.Error(err)) } - h.broadcastErrorMessage(err) + h.publishErrorMessage(err) return case <-ticker.C: sub.updater.Heartbeat() @@ -227,8 +222,8 @@ func (h *gqlWSConnectionHandler) readBlocking(ctx context.Context, dataCh chan [ } func (h *gqlWSConnectionHandler) unsubscribeAllAndCloseConn() { - for id := range h.subscriptions { - h.unsubscribe(id) + if h.subscription != nil { + h.unsubscribe() } _ = h.conn.Close() } @@ -240,17 +235,13 @@ func (h *gqlWSConnectionHandler) subscribe(sub Subscription) { return } - h.nextSubscriptionID++ - - subscriptionID := strconv.Itoa(h.nextSubscriptionID) - - startRequest := fmt.Sprintf(startMessage, subscriptionID, string(graphQLBody)) + startRequest := fmt.Sprintf(startMessage, "1", string(graphQLBody)) err = wsutil.WriteClientText(h.conn, []byte(startRequest)) if err != nil { return } - h.subscriptions[subscriptionID] = sub + h.subscription = &sub } func (h *gqlWSConnectionHandler) handleMessageTypeData(data []byte) { @@ -258,29 +249,32 @@ func (h *gqlWSConnectionHandler) handleMessageTypeData(data []byte) { if err != nil { return } - sub, ok := h.subscriptions[id] - if !ok { + if id != "1" { + return + } + if h.subscription == nil { return } payload, _, _, err := jsonparser.Get(data, "payload") if err != nil { return } - - sub.updater.Update(payload) + h.subscription.updater.Update(payload) } func (h *gqlWSConnectionHandler) handleMessageTypeConnectionError() { - for _, sub := range h.subscriptions { - sub.updater.Update([]byte(connectionError)) + if h.subscription == nil { + return } + h.subscription.updater.Update([]byte(connectionError)) } -func (h *gqlWSConnectionHandler) broadcastErrorMessage(err error) { - errMsg := fmt.Sprintf(errorMessageTemplate, err) - for _, sub := range h.subscriptions { - sub.updater.Update([]byte(errMsg)) +func (h *gqlWSConnectionHandler) publishErrorMessage(err error) { + if h.subscription == nil { + return } + errMsg := fmt.Sprintf(errorMessageTemplate, err) + h.subscription.updater.Update([]byte(errMsg)) } func (h *gqlWSConnectionHandler) handleMessageTypeComplete(data []byte) { @@ -288,12 +282,14 @@ func (h *gqlWSConnectionHandler) handleMessageTypeComplete(data []byte) { if err != nil { return } - sub, ok := h.subscriptions[id] - if !ok { + if id != "1" { + return + } + if h.subscription == nil { return } - sub.updater.Done() - delete(h.subscriptions, id) + h.subscription.updater.Done() + h.subscription = nil } func (h *gqlWSConnectionHandler) handleMessageTypeError(data []byte) { @@ -301,13 +297,15 @@ func (h *gqlWSConnectionHandler) handleMessageTypeError(data []byte) { if err != nil { return } - sub, ok := h.subscriptions[id] - if !ok { + if id != "1" { + return + } + if h.subscription == nil { return } value, valueType, _, err := jsonparser.Get(data, "payload") if err != nil { - sub.updater.Update([]byte(internalError)) + h.subscription.updater.Update([]byte(internalError)) return } switch valueType { @@ -315,39 +313,33 @@ func (h *gqlWSConnectionHandler) handleMessageTypeError(data []byte) { response := []byte(`{}`) response, err = jsonparser.Set(response, value, "errors") if err != nil { - sub.updater.Update([]byte(internalError)) + h.subscription.updater.Update([]byte(internalError)) return } - sub.updater.Update(response) + h.subscription.updater.Update(response) case jsonparser.Object: response := []byte(`{"errors":[]}`) response, err = jsonparser.Set(response, value, "errors", "[0]") if err != nil { - sub.updater.Update([]byte(internalError)) + h.subscription.updater.Update([]byte(internalError)) return } - sub.updater.Update(response) + h.subscription.updater.Update(response) default: - sub.updater.Update([]byte(internalError)) + h.subscription.updater.Update([]byte(internalError)) } } -func (h *gqlWSConnectionHandler) unsubscribe(subscriptionID string) { - sub, ok := h.subscriptions[subscriptionID] - if !ok { +func (h *gqlWSConnectionHandler) unsubscribe() { + if h.subscription == nil { return } - sub.updater.Done() - delete(h.subscriptions, subscriptionID) - stopRequest := fmt.Sprintf(stopMessage, subscriptionID) + h.subscription.updater.Done() + h.subscription = nil + stopRequest := fmt.Sprintf(stopMessage, "1") _ = wsutil.WriteClientText(h.conn, []byte(stopRequest)) } func (h *gqlWSConnectionHandler) checkActiveSubscriptions() (hasActiveSubscriptions bool) { - for id, sub := range h.subscriptions { - if sub.ctx.Err() != nil { - h.unsubscribe(id) - } - } - return len(h.subscriptions) != 0 + return h.subscription != nil }