Skip to content

Commit

Permalink
fix: race condiion in dispatcher with fast resubscribe (#414)
Browse files Browse the repository at this point in the history
  • Loading branch information
abelanger5 authored Apr 23, 2024
1 parent 629917d commit 1786c0e
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 53 deletions.
137 changes: 118 additions & 19 deletions internal/services/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/go-co-op/gocron/v2"
"github.com/google/uuid"
"github.com/hashicorp/go-multierror"
"github.com/rs/zerolog"

"github.com/hatchet-dev/hatchet/internal/datautils"
Expand Down Expand Up @@ -38,10 +39,71 @@ type DispatcherImpl struct {
dv datautils.DataDecoderValidator
repo repository.EngineRepository
dispatcherId string
workers sync.Map
workers *workers
a *hatcheterrors.Wrapped
}

var ErrWorkerNotFound = fmt.Errorf("worker not found")

type workers struct {
innerMap sync.Map
}

func (w *workers) Range(f func(key, value interface{}) bool) {
w.innerMap.Range(f)
}

func (w *workers) Add(workerId, sessionId string, worker *subscribedWorker) {
actual, _ := w.innerMap.LoadOrStore(workerId, &sync.Map{})

actual.(*sync.Map).Store(sessionId, worker)
}

func (w *workers) GetForSession(workerId, sessionId string) (*subscribedWorker, error) {
actual, ok := w.innerMap.Load(workerId)
if !ok {
return nil, ErrWorkerNotFound
}

worker, ok := actual.(*sync.Map).Load(sessionId)
if !ok {
return nil, ErrWorkerNotFound
}

return worker.(*subscribedWorker), nil
}

func (w *workers) Get(workerId string) ([]*subscribedWorker, error) {
actual, ok := w.innerMap.Load(workerId)

if !ok {
return nil, ErrWorkerNotFound
}

workers := []*subscribedWorker{}

actual.(*sync.Map).Range(func(key, value interface{}) bool {
workers = append(workers, value.(*subscribedWorker))
return true
})

return workers, nil
}

func (w *workers) DeleteForSession(workerId, sessionId string) {
actual, ok := w.innerMap.Load(workerId)

if !ok {
return
}

actual.(*sync.Map).Delete(sessionId)
}

func (w *workers) Delete(workerId string) {
w.innerMap.Delete(workerId)
}

type DispatcherOpt func(*DispatcherOpts)

type DispatcherOpts struct {
Expand Down Expand Up @@ -135,7 +197,7 @@ func New(fs ...DispatcherOpt) (*DispatcherImpl, error) {
dv: opts.dv,
repo: opts.repo,
dispatcherId: opts.dispatcherId,
workers: sync.Map{},
workers: &workers{},
s: s,
a: a,
}, nil
Expand Down Expand Up @@ -206,9 +268,13 @@ func (d *DispatcherImpl) Start() (func() error, error) {
d.l.Debug().Msg("draining existing connections")

d.workers.Range(func(key, value interface{}) bool {
w := value.(subscribedWorker)
value.(*sync.Map).Range(func(key, value interface{}) bool {
w := value.(*subscribedWorker)

w.finished <- true

w.finished <- true
return true
})

return true
})
Expand Down Expand Up @@ -275,7 +341,7 @@ func (d *DispatcherImpl) handleGroupKeyActionAssignedTask(ctx context.Context, t
}

// get the worker for this task
w, err := d.GetWorker(payload.WorkerId)
workers, err := d.workers.Get(payload.WorkerId)

if err != nil {
return fmt.Errorf("could not get worker: %w", err)
Expand All @@ -302,13 +368,24 @@ func (d *DispatcherImpl) handleGroupKeyActionAssignedTask(ctx context.Context, t
return fmt.Errorf("could not get group key run for engine: %w", err)
}

err = w.StartGroupKeyAction(ctx, metadata.TenantId, sqlcGroupKeyRun)
var multiErr error
var success bool

if err != nil {
return fmt.Errorf("could not send group key action to worker: %w", err)
for _, w := range workers {
err = w.StartGroupKeyAction(ctx, metadata.TenantId, sqlcGroupKeyRun)

if err != nil {
multiErr = multierror.Append(multiErr, fmt.Errorf("could not send group key action to worker: %w", err))
} else {
success = true
}
}

if success {
return nil
}

return nil
return multiErr
}

func (d *DispatcherImpl) handleStepRunAssignedTask(ctx context.Context, task *msgqueue.Message) error {
Expand All @@ -331,7 +408,7 @@ func (d *DispatcherImpl) handleStepRunAssignedTask(ctx context.Context, task *ms
}

// get the worker for this task
w, err := d.GetWorker(payload.WorkerId)
workers, err := d.workers.Get(payload.WorkerId)

if err != nil {
return fmt.Errorf("could not get worker: %w", err)
Expand All @@ -346,13 +423,24 @@ func (d *DispatcherImpl) handleStepRunAssignedTask(ctx context.Context, task *ms

servertel.WithStepRunModel(span, stepRun)

err = w.StartStepRun(ctx, metadata.TenantId, stepRun)
var multiErr error
var success bool

if err != nil {
return fmt.Errorf("could not send step action to worker: %w", err)
for _, w := range workers {
err = w.StartStepRun(ctx, metadata.TenantId, stepRun)

if err != nil {
multiErr = multierror.Append(multiErr, fmt.Errorf("could not send step action to worker: %w", err))
} else {
success = true
}
}

if success {
return nil
}

return nil
return multiErr
}

func (d *DispatcherImpl) handleStepRunCancelled(ctx context.Context, task *msgqueue.Message) error {
Expand All @@ -375,7 +463,7 @@ func (d *DispatcherImpl) handleStepRunCancelled(ctx context.Context, task *msgqu
}

// get the worker for this task
w, err := d.GetWorker(payload.WorkerId)
workers, err := d.workers.Get(payload.WorkerId)

if err != nil && !errors.Is(err, ErrWorkerNotFound) {
return fmt.Errorf("could not get worker: %w", err)
Expand All @@ -394,13 +482,24 @@ func (d *DispatcherImpl) handleStepRunCancelled(ctx context.Context, task *msgqu

servertel.WithStepRunModel(span, stepRun)

err = w.CancelStepRun(ctx, metadata.TenantId, stepRun)
var multiErr error
var success bool

if err != nil {
return fmt.Errorf("could not send job to worker: %w", err)
for _, w := range workers {
err = w.CancelStepRun(ctx, metadata.TenantId, stepRun)

if err != nil {
multiErr = multierror.Append(multiErr, fmt.Errorf("could not send job to worker: %w", err))
} else {
success = true
}
}

if success {
return nil
}

return nil
return multiErr
}

func (d *DispatcherImpl) runUpdateHeartbeat(ctx context.Context) func() {
Expand Down
41 changes: 7 additions & 34 deletions internal/services/dispatcher/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"sync"
"time"

"github.com/google/uuid"
"github.com/rs/zerolog"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand All @@ -25,23 +26,6 @@ import (
"github.com/hatchet-dev/hatchet/internal/telemetry"
)

var ErrWorkerNotFound = fmt.Errorf("worker not found")

func (d *DispatcherImpl) GetWorker(workerId string) (*subscribedWorker, error) {
workerInt, ok := d.workers.Load(workerId)
if !ok {
return nil, ErrWorkerNotFound
}

worker, ok := workerInt.(subscribedWorker)

if !ok {
return nil, fmt.Errorf("failed to cast worker with id %s to subscribedWorker", workerId)
}

return &worker, nil
}

type subscribedWorker struct {
// stream is the server side of the RPC stream
stream contracts.Dispatcher_ListenServer
Expand Down Expand Up @@ -172,6 +156,7 @@ func (s *DispatcherImpl) Register(ctx context.Context, request *contracts.Worker
func (s *DispatcherImpl) Listen(request *contracts.WorkerListenRequest, stream contracts.Dispatcher_ListenServer) error {
tenant := stream.Context().Value("tenant").(*dbsqlc.Tenant)
tenantId := sqlchelpers.UUIDToStr(tenant.ID)
sessionId := uuid.New().String()

s.l.Debug().Msgf("Received subscribe request from ID: %s", request.WorkerId)

Expand Down Expand Up @@ -200,7 +185,7 @@ func (s *DispatcherImpl) Listen(request *contracts.WorkerListenRequest, stream c

fin := make(chan bool)

s.workers.Store(request.WorkerId, subscribedWorker{stream: stream, finished: fin})
s.workers.Add(request.WorkerId, sessionId, &subscribedWorker{stream: stream, finished: fin})

defer func() {
// non-blocking send
Expand All @@ -209,7 +194,7 @@ func (s *DispatcherImpl) Listen(request *contracts.WorkerListenRequest, stream c
default:
}

s.workers.Delete(request.WorkerId)
s.workers.DeleteForSession(request.WorkerId, sessionId)
}()

// update the worker with a last heartbeat time every 5 seconds as long as the worker is connected
Expand Down Expand Up @@ -265,6 +250,7 @@ func (s *DispatcherImpl) Listen(request *contracts.WorkerListenRequest, stream c
func (s *DispatcherImpl) ListenV2(request *contracts.WorkerListenRequest, stream contracts.Dispatcher_ListenV2Server) error {
tenant := stream.Context().Value("tenant").(*dbsqlc.Tenant)
tenantId := sqlchelpers.UUIDToStr(tenant.ID)
sessionId := uuid.New().String()

ctx := stream.Context()

Expand Down Expand Up @@ -293,7 +279,7 @@ func (s *DispatcherImpl) ListenV2(request *contracts.WorkerListenRequest, stream

fin := make(chan bool)

s.workers.Store(request.WorkerId, subscribedWorker{stream: stream, finished: fin})
s.workers.Add(request.WorkerId, sessionId, &subscribedWorker{stream: stream, finished: fin})

defer func() {
// non-blocking send
Expand All @@ -302,20 +288,7 @@ func (s *DispatcherImpl) ListenV2(request *contracts.WorkerListenRequest, stream
default:
}

s.workers.Delete(request.WorkerId)

inactive := db.WorkerStatusInactive

updateCtx, updateCtxCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer updateCtxCancel()

_, err := s.repo.Worker().UpdateWorker(updateCtx, tenantId, request.WorkerId, &repository.UpdateWorkerOpts{
Status: &inactive,
})

if err != nil {
s.l.Error().Err(err).Msgf("could not update worker %s status to inactive", request.WorkerId)
}
s.workers.DeleteForSession(request.WorkerId, sessionId)
}()

// Keep the connection alive for sending messages
Expand Down

0 comments on commit 1786c0e

Please sign in to comment.