Skip to content

Commit

Permalink
chore: fix race
Browse files Browse the repository at this point in the history
  • Loading branch information
jensneuse committed Oct 17, 2024
1 parent 64e9734 commit a1de22c
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 132 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ func (c *subscriptionClient) asyncSubscribeWS(reqCtx *resolve.Context, id uint64
return err
}

handler.Subscribe(sub)

netConn := handler.NetConn()
if err := c.epoll.Add(netConn); err != nil {
return err
Expand All @@ -280,8 +282,6 @@ func (c *subscriptionClient) asyncSubscribeWS(reqCtx *resolve.Context, id uint64
go c.runEpoll(c.engineCtx)
}

handler.Subscribe(sub)

return nil
}

Expand Down
118 changes: 49 additions & 69 deletions v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"fmt"
"io"
"net"
"strconv"
"strings"
"time"

Expand All @@ -23,32 +22,26 @@ import (
// it is responsible for managing all subscriptions using the underlying WebSocket connection
// if all Subscriptions are complete or cancelled/unsubscribed the handler will terminate
type gqlTWSConnectionHandler struct {
conn net.Conn
ctx context.Context
log log.Logger
subscribeCh chan Subscription
nextSubscriptionID int
subscriptions map[string]Subscription
readTimeout time.Duration
conn net.Conn
ctx context.Context
log log.Logger
subscribeCh chan Subscription
subscription *Subscription
readTimeout time.Duration
}

func (h *gqlTWSConnectionHandler) ServerClose() {
for _, sub := range h.subscriptions {
sub.updater.Done()
if h.subscription != nil {
h.subscription.updater.Done()
}
_ = h.conn.Close()
}

func (h *gqlTWSConnectionHandler) ClientClose() {
for k, v := range h.subscriptions {
v.updater.Done()
delete(h.subscriptions, k)

req := fmt.Sprintf(completeMessage, k)
err := wsutil.WriteClientText(h.conn, []byte(req))
if err != nil {
h.log.Error("failed to write complete message", log.Error(err))
}
if h.subscription != nil {
h.subscription.updater.Done()
stopRequest := fmt.Sprintf(completeMessage, "1")
_ = wsutil.WriteClientText(h.conn, []byte(stopRequest))
}
_ = h.conn.Close()
}
Expand Down Expand Up @@ -130,12 +123,10 @@ func (h *gqlTWSConnectionHandler) NetConn() net.Conn {

func newGQLTWSConnectionHandler(ctx context.Context, conn net.Conn, rt time.Duration, l log.Logger) *gqlTWSConnectionHandler {
return &gqlTWSConnectionHandler{
conn: conn,
ctx: ctx,
log: l,
nextSubscriptionID: 0,
subscriptions: map[string]Subscription{},
readTimeout: rt,
conn: conn,
ctx: ctx,
log: l,
readTimeout: rt,
}
}

Expand All @@ -162,7 +153,7 @@ func (h *gqlTWSConnectionHandler) StartBlocking(sub Subscription) {
if !errors.Is(err, context.Canceled) && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
h.log.Error("gqlWSConnectionHandler.StartBlocking", log.Error(err))
}
h.broadcastErrorMessage(err)
h.publishErrorMessage(err)
return
}

Expand All @@ -178,7 +169,7 @@ func (h *gqlTWSConnectionHandler) StartBlocking(sub Subscription) {
continue
case err := <-errCh:
h.log.Error("gqlWSConnectionHandler.StartBlocking", log.Error(err))
h.broadcastErrorMessage(err)
h.publishErrorMessage(err)
return
case <-ticker.C:
sub.updater.Heartbeat()
Expand Down Expand Up @@ -213,21 +204,16 @@ func (h *gqlTWSConnectionHandler) StartBlocking(sub Subscription) {
}

func (h *gqlTWSConnectionHandler) unsubscribeAllAndCloseConn() {
for id := range h.subscriptions {
h.unsubscribe(id)
}
h.unsubscribe()
_ = h.conn.Close()
}

func (h *gqlTWSConnectionHandler) unsubscribe(subscriptionID string) {
sub, ok := h.subscriptions[subscriptionID]
if !ok {
func (h *gqlTWSConnectionHandler) unsubscribe() {
if h.subscription == nil {
return
}
sub.updater.Done()
delete(h.subscriptions, subscriptionID)

req := fmt.Sprintf(completeMessage, subscriptionID)
h.subscription.updater.Done()
req := fmt.Sprintf(completeMessage, "1")
err := wsutil.WriteClientText(h.conn, []byte(req))
if err != nil {
h.log.Error("failed to write complete message", log.Error(err))
Expand All @@ -242,57 +228,54 @@ func (h *gqlTWSConnectionHandler) subscribe(sub Subscription) {
return
}

h.nextSubscriptionID++

subscriptionID := strconv.Itoa(h.nextSubscriptionID)
subscribeRequest := fmt.Sprintf(subscribeMessage, subscriptionID, string(graphQLBody))
subscribeRequest := fmt.Sprintf(subscribeMessage, "1", string(graphQLBody))
err = wsutil.WriteClientText(h.conn, []byte(subscribeRequest))
if err != nil {
h.log.Error("failed to write subscribe message", log.Error(err))
return
}

h.subscriptions[subscriptionID] = sub
h.subscription = &sub
}

func (h *gqlTWSConnectionHandler) broadcastErrorMessage(err error) {
func (h *gqlTWSConnectionHandler) publishErrorMessage(err error) {
errMsg := fmt.Sprintf(errorMessageTemplate, err)
for _, sub := range h.subscriptions {
sub.updater.Update([]byte(errMsg))
}
h.subscription.updater.Update([]byte(errMsg))
}

func (h *gqlTWSConnectionHandler) handleMessageTypeComplete(data []byte) {
id, err := jsonparser.GetString(data, "id")
if err != nil {
return
}
sub, ok := h.subscriptions[id]
if !ok {
if id != "1" {
return
}
if h.subscription == nil {
return
}
sub.updater.Done()
delete(h.subscriptions, id)
h.subscription.updater.Done()
}

func (h *gqlTWSConnectionHandler) handleMessageTypeError(data []byte) {
id, err := jsonparser.GetString(data, "id")
if err != nil {
return
}
sub, ok := h.subscriptions[id]
if !ok {
if id != "1" {
return
}
if h.subscription == nil {
return
}

value, valueType, _, err := jsonparser.Get(data, "payload")
if err != nil {
h.log.Error(
"failed to get payload from error message",
log.Error(err),
log.ByteString("raw message", data),
)
sub.updater.Update([]byte(internalError))
h.subscription.updater.Update([]byte(internalError))
return
}

Expand All @@ -306,20 +289,20 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeError(data []byte) {
log.Error(err),
log.ByteString("raw message", value),
)
sub.updater.Update([]byte(internalError))
h.subscription.updater.Update([]byte(internalError))
return
}
sub.updater.Update(response)
h.subscription.updater.Update(response)
case jsonparser.Object:
response := []byte(`{"errors":[]}`)
response, err = jsonparser.Set(response, value, "errors", "[0]")
if err != nil {
sub.updater.Update([]byte(internalError))
h.subscription.updater.Update([]byte(internalError))
return
}
sub.updater.Update(response)
h.subscription.updater.Update(response)
default:
sub.updater.Update([]byte(internalError))
h.subscription.updater.Update([]byte(internalError))
}
}

Expand All @@ -335,8 +318,10 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeNext(data []byte) {
if err != nil {
return
}
sub, ok := h.subscriptions[id]
if !ok {
if id != "1" {
return
}
if h.subscription == nil {
return
}

Expand All @@ -346,11 +331,11 @@ func (h *gqlTWSConnectionHandler) handleMessageTypeNext(data []byte) {
"failed to get payload from next message",
log.Error(err),
)
sub.updater.Update([]byte(internalError))
h.subscription.updater.Update([]byte(internalError))
return
}

sub.updater.Update(value)
h.subscription.updater.Update(value)
}

// readBlocking is a dedicated loop running in a separate goroutine
Expand All @@ -375,10 +360,5 @@ func (h *gqlTWSConnectionHandler) readBlocking(ctx context.Context, dataCh chan
}

func (h *gqlTWSConnectionHandler) hasActiveSubscriptions() (hasActiveSubscriptions bool) {
for id, sub := range h.subscriptions {
if sub.ctx.Err() != nil {
h.unsubscribe(id)
}
}
return len(h.subscriptions) != 0
return h.subscription != nil
}
Loading

0 comments on commit a1de22c

Please sign in to comment.