diff --git a/internal/context.go b/internal/context.go index de8cae9c..18480626 100644 --- a/internal/context.go +++ b/internal/context.go @@ -16,6 +16,7 @@ var ( // logging metadata for a single request type data struct { userID string + deviceID string since int64 next int64 numRooms int @@ -37,13 +38,14 @@ func RequestContext(ctx context.Context) context.Context { } // add the user ID to this request context. Need to have called RequestContext first. -func SetRequestContextUserID(ctx context.Context, userID string) { +func SetRequestContextUserID(ctx context.Context, userID, deviceID string) { d := ctx.Value(ctxData) if d == nil { return } da := d.(*data) da.userID = userID + da.deviceID = deviceID if hub := sentry.GetHubFromContext(ctx); hub != nil { sentry.ConfigureScope(func(scope *sentry.Scope) { scope.SetUser(sentry.User{Username: userID}) @@ -79,6 +81,9 @@ func DecorateLogger(ctx context.Context, l *zerolog.Event) *zerolog.Event { if da.userID != "" { l = l.Str("u", da.userID) } + if da.deviceID != "" { + l = l.Str("dev", da.deviceID) + } if da.since >= 0 { l = l.Int64("p", da.since) } diff --git a/internal/pool.go b/internal/pool.go new file mode 100644 index 00000000..27c6bb0c --- /dev/null +++ b/internal/pool.go @@ -0,0 +1,67 @@ +package internal + +type WorkerPool struct { + N int + ch chan func() +} + +// Create a new worker pool of size N. Up to N work can be done concurrently. +// The size of N depends on the expected frequency of work and contention for +// shared resources. Large values of N allow more frequent work at the cost of +// more contention for shared resources like cpu, memory and fds. Small values +// of N allow less frequent work but control the amount of shared resource contention. +// Ideally this value will be derived from whatever shared resource constraints you +// are hitting up against, rather than set to a fixed value. For example, if you have +// a database connection limit of 100, then setting N to some fraction of the limit is +// preferred to setting this to an arbitrary number < 100. If more than N work is requested, +// eventually WorkerPool.Queue will block until some work is done. +// +// The larger N is, the larger the up front memory costs are due to the implementation of WorkerPool. +func NewWorkerPool(n int) *WorkerPool { + return &WorkerPool{ + N: n, + // If we have N workers, we can process N work concurrently. + // If we have >N work, we need to apply backpressure to stop us + // making more and more work which takes up more and more memory. + // By setting the channel size to N, we ensure that backpressure is + // being applied on the producer, stopping it from creating more work, + // and hence bounding memory consumption. Work is still being produced + // upstream on the homeserver, but we will consume it when we're ready + // rather than gobble it all at once. + // + // Note: we aren't forced to set this to N, it just serves as a useful + // metric which scales on the number of workers. The amount of in-flight + // work is N, so it makes sense to allow up to N work to be queued up before + // applying backpressure. If the channel buffer is < N then the channel can + // become the bottleneck in the case where we have lots of instantaneous work + // to do. If the channel buffer is too large, we needlessly consume memory as + // make() will allocate a backing array of whatever size you give it up front (sad face) + ch: make(chan func(), n), + } +} + +// Start the workers. Only call this once. +func (wp *WorkerPool) Start() { + for i := 0; i < wp.N; i++ { + go wp.worker() + } +} + +// Stop the worker pool. Only really useful for tests as a worker pool should be started once +// and persist for the lifetime of the process, else it causes needless goroutine churn. +// Only call this once. +func (wp *WorkerPool) Stop() { + close(wp.ch) +} + +// Queue some work on the pool. May or may not block until some work is processed. +func (wp *WorkerPool) Queue(fn func()) { + wp.ch <- fn +} + +// worker impl +func (wp *WorkerPool) worker() { + for fn := range wp.ch { + fn() + } +} diff --git a/internal/pool_test.go b/internal/pool_test.go new file mode 100644 index 00000000..34222077 --- /dev/null +++ b/internal/pool_test.go @@ -0,0 +1,186 @@ +package internal + +import ( + "sync" + "testing" + "time" +) + +// Test basic functions of WorkerPool +func TestWorkerPool(t *testing.T) { + wp := NewWorkerPool(2) + wp.Start() + defer wp.Stop() + + // we should process this concurrently as N=2 so it should take 1s not 2s + var wg sync.WaitGroup + wg.Add(2) + start := time.Now() + wp.Queue(func() { + time.Sleep(time.Second) + wg.Done() + }) + wp.Queue(func() { + time.Sleep(time.Second) + wg.Done() + }) + wg.Wait() + took := time.Since(start) + if took > 2*time.Second { + t.Fatalf("took %v for queued work, it should have been faster than 2s", took) + } +} + +func TestWorkerPoolDoesWorkPriorToStart(t *testing.T) { + wp := NewWorkerPool(2) + + // return channel to use to see when work is done + ch := make(chan int, 2) + wp.Queue(func() { + ch <- 1 + }) + wp.Queue(func() { + ch <- 2 + }) + + // the work should not be done yet + time.Sleep(100 * time.Millisecond) + if len(ch) > 0 { + t.Fatalf("Queued work was done before Start()") + } + + // the work should be starting now + wp.Start() + defer wp.Stop() + + sum := 0 + for { + select { + case <-time.After(time.Second): + t.Fatalf("timed out waiting for work to be done") + case val := <-ch: + sum += val + } + if sum == 3 { // 2 + 1 + break + } + } +} + +type workerState struct { + id int + state int // not running, queued, running, finished + unblock *sync.WaitGroup // decrement to unblock this worker +} + +func TestWorkerPoolBackpressure(t *testing.T) { + // this test assumes backpressure starts at n*2+1 due to a chan buffer of size n, and n in-flight work. + n := 2 + wp := NewWorkerPool(n) + wp.Start() + defer wp.Stop() + + var mu sync.Mutex + stateNotRunning := 0 + stateQueued := 1 + stateRunning := 2 + stateFinished := 3 + size := (2 * n) + 1 + running := make([]*workerState, size) + + go func() { + // we test backpressure by scheduling (n*2)+1 work and ensuring that we see the following running states: + // [2,2,1,1,0] <-- 2 running, 2 queued, 1 blocked <-- THIS IS BACKPRESSURE + // [3,2,2,1,1] <-- 1 finished, 2 running, 2 queued + // [3,3,2,2,1] <-- 2 finished, 2 running , 1 queued + // [3,3,3,2,2] <-- 3 finished, 2 running + for i := 0; i < size; i++ { + // set initial state of this piece of work + wg := &sync.WaitGroup{} + wg.Add(1) + state := &workerState{ + id: i, + state: stateNotRunning, + unblock: wg, + } + mu.Lock() + running[i] = state + mu.Unlock() + + // queue the work on the pool. The final piece of work will block here and remain in + // stateNotRunning and not transition to stateQueued until the first piece of work is done. + wp.Queue(func() { + mu.Lock() + if running[state.id].state != stateQueued { + // we ran work in the worker faster than the code underneath .Queue, so let it catch up + mu.Unlock() + time.Sleep(10 * time.Millisecond) + mu.Lock() + } + running[state.id].state = stateRunning + mu.Unlock() + + running[state.id].unblock.Wait() + mu.Lock() + running[state.id].state = stateFinished + mu.Unlock() + }) + + // mark this work as queued + mu.Lock() + running[i].state = stateQueued + mu.Unlock() + } + }() + + // wait for the workers to be doing work and assert the states of each task + time.Sleep(time.Second) + + assertStates(t, &mu, running, []int{ + stateRunning, stateRunning, stateQueued, stateQueued, stateNotRunning, + }) + + // now let the first task complete + running[0].unblock.Done() + // wait for the pool to grab more work + time.Sleep(100 * time.Millisecond) + // assert new states + assertStates(t, &mu, running, []int{ + stateFinished, stateRunning, stateRunning, stateQueued, stateQueued, + }) + + // now let the second task complete + running[1].unblock.Done() + // wait for the pool to grab more work + time.Sleep(100 * time.Millisecond) + // assert new states + assertStates(t, &mu, running, []int{ + stateFinished, stateFinished, stateRunning, stateRunning, stateQueued, + }) + + // now let the third task complete + running[2].unblock.Done() + // wait for the pool to grab more work + time.Sleep(100 * time.Millisecond) + // assert new states + assertStates(t, &mu, running, []int{ + stateFinished, stateFinished, stateFinished, stateRunning, stateRunning, + }) + +} + +func assertStates(t *testing.T, mu *sync.Mutex, running []*workerState, wantStates []int) { + t.Helper() + mu.Lock() + defer mu.Unlock() + if len(running) != len(wantStates) { + t.Fatalf("assertStates: bad wantStates length, got %d want %d", len(wantStates), len(running)) + } + for i := range running { + state := running[i] + wantVal := wantStates[i] + if state.state != wantVal { + t.Errorf("work[%d] got state %d want %d", i, state.state, wantVal) + } + } +} diff --git a/pubsub/v2.go b/pubsub/v2.go index 2ed379f4..7dfb01e0 100644 --- a/pubsub/v2.go +++ b/pubsub/v2.go @@ -91,9 +91,7 @@ type V2InitialSyncComplete struct { func (*V2InitialSyncComplete) Type() string { return "V2InitialSyncComplete" } type V2DeviceData struct { - UserID string - DeviceID string - Pos int64 + UserIDToDeviceIDs map[string][]string } func (*V2DeviceData) Type() string { return "V2DeviceData" } diff --git a/state/accumulator.go b/state/accumulator.go index 0105786f..acab4555 100644 --- a/state/accumulator.go +++ b/state/accumulator.go @@ -293,34 +293,20 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia // - Else it creates a new room state snapshot if the timeline contains state events (as this now represents the current state) // - It adds entries to the membership log for membership events. func (a *Accumulator) Accumulate(txn *sqlx.Tx, roomID string, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) { - // Insert the events. Check for duplicates which can happen in the real world when joining - // Matrix HQ on Synapse. - dedupedEvents := make([]Event, 0, len(timeline)) - seenEvents := make(map[string]struct{}) - for i := range timeline { - e := Event{ - JSON: timeline[i], - RoomID: roomID, - } - if err := e.ensureFieldsSetOnEvent(); err != nil { - return 0, nil, fmt.Errorf("event malformed: %s", err) - } - if _, ok := seenEvents[e.ID]; ok { - logger.Warn().Str("event_id", e.ID).Str("room_id", roomID).Msg( - "Accumulator.Accumulate: seen the same event ID twice, ignoring", - ) - continue - } - if i == 0 && prevBatch != "" { - // tag the first timeline event with the prev batch token - e.PrevBatch = sql.NullString{ - String: prevBatch, - Valid: true, - } - } - dedupedEvents = append(dedupedEvents, e) - seenEvents[e.ID] = struct{}{} + // The first stage of accumulating events is mostly around validation around what the upstream HS sends us. For accumulation to work correctly + // we expect: + // - there to be no duplicate events + // - if there are new events, they are always new. + // Both of these assumptions can be false for different reasons + dedupedEvents, err := a.filterAndParseTimelineEvents(txn, roomID, timeline, prevBatch) + if err != nil { + err = fmt.Errorf("filterTimelineEvents: %w", err) + return } + if len(dedupedEvents) == 0 { + return 0, nil, err // nothing to do + } + eventIDToNID, err := a.eventsTable.Insert(txn, dedupedEvents, false) if err != nil { return 0, nil, err @@ -413,6 +399,91 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, roomID string, prevBatch string, return numNew, timelineNIDs, nil } +// filterAndParseTimelineEvents takes a raw timeline array from sync v2 and applies sanity to it: +// - removes duplicate events: this is just a bug which has been seen on Synapse on matrix.org +// - removes old events: this is an edge case when joining rooms over federation, see https://github.com/matrix-org/sliding-sync/issues/192 +// - parses it and returns Event structs. +// - check which events are unknown. If all events are known, filter them all out. +func (a *Accumulator) filterAndParseTimelineEvents(txn *sqlx.Tx, roomID string, timeline []json.RawMessage, prevBatch string) ([]Event, error) { + // Check for duplicates which can happen in the real world when joining + // Matrix HQ on Synapse, as well as when you join rooms for the first time over federation. + dedupedEvents := make([]Event, 0, len(timeline)) + seenEvents := make(map[string]struct{}) + for i := range timeline { + e := Event{ + JSON: timeline[i], + RoomID: roomID, + } + if err := e.ensureFieldsSetOnEvent(); err != nil { + return nil, fmt.Errorf("event malformed: %s", err) + } + if _, ok := seenEvents[e.ID]; ok { + logger.Warn().Str("event_id", e.ID).Str("room_id", roomID).Msg( + "Accumulator.filterAndParseTimelineEvents: seen the same event ID twice, ignoring", + ) + continue + } + if i == 0 && prevBatch != "" { + // tag the first timeline event with the prev batch token + e.PrevBatch = sql.NullString{ + String: prevBatch, + Valid: true, + } + } + dedupedEvents = append(dedupedEvents, e) + seenEvents[e.ID] = struct{}{} + } + + // if we only have a single timeline event we cannot determine if it is old or not, as we rely on already seen events + // being after (higher index) than it. + if len(dedupedEvents) <= 1 { + return dedupedEvents, nil + } + + // Figure out which of these events are unseen and hence brand new live events. + // In some cases, we may have unseen OLD events - see https://github.com/matrix-org/sliding-sync/issues/192 + // in which case we need to drop those events. + dedupedEventIDs := make([]string, 0, len(seenEvents)) + for evID := range seenEvents { + dedupedEventIDs = append(dedupedEventIDs, evID) + } + unknownEventIDs, err := a.eventsTable.SelectUnknownEventIDs(txn, dedupedEventIDs) + if err != nil { + return nil, fmt.Errorf("filterAndParseTimelineEvents: failed to SelectUnknownEventIDs: %w", err) + } + + if len(unknownEventIDs) == 0 { + // every event has been seen already, no work to do + return nil, nil + } + + // In the happy case, we expect to see timeline arrays like this: (SEEN=S, UNSEEN=U) + // [S,S,U,U] -> want last 2 + // [U,U,U] -> want all + // In the backfill edge case, we might see: + // [U,S,S,S] -> want none + // [U,S,S,U] -> want last 1 + // We should never see scenarios like: + // [U,S,S,U,S,S] <- we should only see 1 contiguous block of seen events. + // If we do, we'll just ignore all unseen events less than the highest seen event. + + // The algorithm starts at the end and just looks for the first S event, returning the subslice after that S event (which may be []) + seenIndex := -1 + for i := len(dedupedEvents) - 1; i >= 0; i-- { + _, unseen := unknownEventIDs[dedupedEvents[i].ID] + if !unseen { + seenIndex = i + break + } + } + // seenIndex can be -1 if all are unseen, or len-1 if all are seen, either way if we +1 this slices correctly: + // no seen events s[A,B,C] => s[-1+1:] => [A,B,C] + // C is seen event s[A,B,C] => s[2+1:] => [] + // B is seen event s[A,B,C] => s[1+1:] => [C] + // A is seen event s[A,B,C] => s[0+1:] => [B,C] + return dedupedEvents[seenIndex+1:], nil +} + // Delta returns a list of events of at most `limit` for the room not including `lastEventNID`. // Returns the latest NID of the last event (most recent) func (a *Accumulator) Delta(roomID string, lastEventNID int64, limit int) (eventsJSON []json.RawMessage, latest int64, err error) { diff --git a/state/accumulator_test.go b/state/accumulator_test.go index 250854e8..64ee6c86 100644 --- a/state/accumulator_test.go +++ b/state/accumulator_test.go @@ -11,7 +11,6 @@ import ( "github.com/jmoiron/sqlx" "github.com/matrix-org/sliding-sync/sqlutil" "github.com/matrix-org/sliding-sync/sync2" - "github.com/matrix-org/sliding-sync/testutils" "github.com/tidwall/gjson" ) @@ -417,86 +416,6 @@ func TestAccumulatorDupeEvents(t *testing.T) { } } -// Regression test for corrupt state snapshots. -// This seems to have happened in the wild, whereby the snapshot exhibited 2 things: -// - A message event having a event_replaces_nid. This should be impossible as messages are not state. -// - Duplicate events in the state snapshot. -// -// We can reproduce a message event having a event_replaces_nid by doing the following: -// - Create a room with initial state A,C -// - Accumulate events D, A, B(msg). This should be impossible because we already got A initially but whatever, roll with it, blame state resets or something. -// - This leads to A,B being processed and D ignored if you just take the newest results. -// -// This can then be tested by: -// - Query the current room snapshot. This will include B(msg) when it shouldn't. -func TestAccumulatorMisorderedGraceful(t *testing.T) { - alice := "@alice:localhost" - bob := "@bob:localhost" - - eventA := testutils.NewStateEvent(t, "m.room.member", alice, alice, map[string]interface{}{"membership": "join"}) - eventC := testutils.NewStateEvent(t, "m.room.create", "", alice, map[string]interface{}{}) - eventD := testutils.NewStateEvent( - t, "m.room.member", bob, "join", map[string]interface{}{"membership": "join"}, - ) - eventBMsg := testutils.NewEvent( - t, "m.room.message", bob, map[string]interface{}{"body": "hello"}, - ) - t.Logf("A=member-alice, B=msg, C=create, D=member-bob") - - db, close := connectToDB(t) - defer close() - accumulator := NewAccumulator(db) - roomID := "!TestAccumulatorStateReset:localhost" - // Create a room with initial state A,C - _, err := accumulator.Initialise(roomID, []json.RawMessage{ - eventA, eventC, - }) - if err != nil { - t.Fatalf("failed to Initialise accumulator: %s", err) - } - - // Accumulate events D, A, B(msg). - err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error { - _, _, err = accumulator.Accumulate(txn, roomID, "", []json.RawMessage{eventD, eventA, eventBMsg}) - return err - }) - if err != nil { - t.Fatalf("failed to Accumulate: %s", err) - } - - eventIDs := []string{ - gjson.GetBytes(eventA, "event_id").Str, - gjson.GetBytes(eventBMsg, "event_id").Str, - gjson.GetBytes(eventC, "event_id").Str, - gjson.GetBytes(eventD, "event_id").Str, - } - t.Logf("Events A,B,C,D: %v", eventIDs) - txn := accumulator.db.MustBeginTx(context.Background(), nil) - idsToNIDs, err := accumulator.eventsTable.SelectNIDsByIDs(txn, eventIDs) - if err != nil { - t.Fatalf("Failed to SelectNIDsByIDs: %s", err) - } - if len(idsToNIDs) != len(eventIDs) { - t.Errorf("SelectNIDsByIDs: asked for %v got %v", eventIDs, idsToNIDs) - } - t.Logf("Events: %v", idsToNIDs) - - wantEventNIDs := []int64{ - idsToNIDs[eventIDs[0]], idsToNIDs[eventIDs[2]], idsToNIDs[eventIDs[3]], - } - sort.Slice(wantEventNIDs, func(i, j int) bool { - return wantEventNIDs[i] < wantEventNIDs[j] - }) - // Query the current room snapshot - gotSnapshotEvents := currentSnapshotNIDs(t, accumulator.snapshotTable, roomID) - if len(gotSnapshotEvents) != len(wantEventNIDs) { // events A,C,D - t.Errorf("corrupt snapshot, got %v want %v", gotSnapshotEvents, wantEventNIDs) - } - if !reflect.DeepEqual(wantEventNIDs, gotSnapshotEvents) { - t.Errorf("got %v want %v", gotSnapshotEvents, wantEventNIDs) - } -} - // Regression test for corrupt state snapshots. // This seems to have happened in the wild, whereby the snapshot exhibited 2 things: // - A message event having a event_replaces_nid. This should be impossible as messages are not state. diff --git a/sync2/device_data_ticker.go b/sync2/device_data_ticker.go new file mode 100644 index 00000000..7d77aaea --- /dev/null +++ b/sync2/device_data_ticker.go @@ -0,0 +1,90 @@ +package sync2 + +import ( + "sync" + "time" + + "github.com/matrix-org/sliding-sync/pubsub" +) + +// This struct remembers user+device IDs to notify for then periodically +// emits them all to the caller. Use to rate limit the frequency of device list +// updates. +type DeviceDataTicker struct { + // data structures to periodically notify downstream about device data updates + // The ticker controls the frequency of updates. The done channel is used to stop ticking + // and clean up the goroutine. The notify map contains the values to notify for. + ticker *time.Ticker + done chan struct{} + notifyMap *sync.Map // map of PollerID to bools, unwrapped when notifying + fn func(payload *pubsub.V2DeviceData) +} + +// Create a new device data ticker, which batches calls to Remember and invokes a callback every +// d duration. If d is 0, no batching is performed and the callback is invoked synchronously, which +// is useful for testing. +func NewDeviceDataTicker(d time.Duration) *DeviceDataTicker { + ddt := &DeviceDataTicker{ + done: make(chan struct{}), + notifyMap: &sync.Map{}, + } + if d != 0 { + ddt.ticker = time.NewTicker(d) + } + return ddt +} + +// Stop ticking. +func (t *DeviceDataTicker) Stop() { + if t.ticker != nil { + t.ticker.Stop() + } + close(t.done) +} + +// Set the function which should be called when the tick happens. +func (t *DeviceDataTicker) SetCallback(fn func(payload *pubsub.V2DeviceData)) { + t.fn = fn +} + +// Remember this user/device ID, and emit it later on. +func (t *DeviceDataTicker) Remember(pid PollerID) { + t.notifyMap.Store(pid, true) + if t.ticker == nil { + t.emitUpdate() + } +} + +func (t *DeviceDataTicker) emitUpdate() { + var p pubsub.V2DeviceData + p.UserIDToDeviceIDs = make(map[string][]string) + // populate the pubsub payload + t.notifyMap.Range(func(key, value any) bool { + pid := key.(PollerID) + devices := p.UserIDToDeviceIDs[pid.UserID] + devices = append(devices, pid.DeviceID) + p.UserIDToDeviceIDs[pid.UserID] = devices + // clear the map of this value + t.notifyMap.Delete(key) + return true // keep enumerating + }) + // notify if we have entries + if len(p.UserIDToDeviceIDs) > 0 { + t.fn(&p) + } +} + +// Blocks forever, ticking until Stop() is called. +func (t *DeviceDataTicker) Run() { + if t.ticker == nil { + return + } + for { + select { + case <-t.done: + return + case <-t.ticker.C: + t.emitUpdate() + } + } +} diff --git a/sync2/device_data_ticker_test.go b/sync2/device_data_ticker_test.go new file mode 100644 index 00000000..daa50819 --- /dev/null +++ b/sync2/device_data_ticker_test.go @@ -0,0 +1,125 @@ +package sync2 + +import ( + "reflect" + "sort" + "sync" + "testing" + "time" + + "github.com/matrix-org/sliding-sync/pubsub" +) + +func TestDeviceTickerBasic(t *testing.T) { + duration := time.Millisecond + ticker := NewDeviceDataTicker(duration) + var payloads []*pubsub.V2DeviceData + ticker.SetCallback(func(payload *pubsub.V2DeviceData) { + payloads = append(payloads, payload) + }) + var wg sync.WaitGroup + wg.Add(1) + go func() { + t.Log("starting the ticker") + ticker.Run() + wg.Done() + }() + time.Sleep(duration * 2) // wait until the ticker is consuming + t.Log("remembering a poller") + ticker.Remember(PollerID{ + UserID: "a", + DeviceID: "b", + }) + time.Sleep(duration * 2) + if len(payloads) != 1 { + t.Fatalf("expected 1 callback, got %d", len(payloads)) + } + want := map[string][]string{ + "a": {"b"}, + } + assertPayloadEqual(t, payloads[0].UserIDToDeviceIDs, want) + // check stopping works + payloads = []*pubsub.V2DeviceData{} + ticker.Stop() + wg.Wait() + time.Sleep(duration * 2) + if len(payloads) != 0 { + t.Fatalf("got extra payloads: %+v", payloads) + } +} + +func TestDeviceTickerBatchesCorrectly(t *testing.T) { + duration := 100 * time.Millisecond + ticker := NewDeviceDataTicker(duration) + var payloads []*pubsub.V2DeviceData + ticker.SetCallback(func(payload *pubsub.V2DeviceData) { + payloads = append(payloads, payload) + }) + go ticker.Run() + defer ticker.Stop() + ticker.Remember(PollerID{ + UserID: "a", + DeviceID: "b", + }) + ticker.Remember(PollerID{ + UserID: "a", + DeviceID: "bb", // different device, same user + }) + ticker.Remember(PollerID{ + UserID: "a", + DeviceID: "b", // dupe poller ID + }) + ticker.Remember(PollerID{ + UserID: "x", + DeviceID: "y", // new device and user + }) + time.Sleep(duration * 2) + if len(payloads) != 1 { + t.Fatalf("expected 1 callback, got %d", len(payloads)) + } + want := map[string][]string{ + "a": {"b", "bb"}, + "x": {"y"}, + } + assertPayloadEqual(t, payloads[0].UserIDToDeviceIDs, want) +} + +func TestDeviceTickerForgetsAfterEmitting(t *testing.T) { + duration := time.Millisecond + ticker := NewDeviceDataTicker(duration) + var payloads []*pubsub.V2DeviceData + + ticker.SetCallback(func(payload *pubsub.V2DeviceData) { + payloads = append(payloads, payload) + }) + ticker.Remember(PollerID{ + UserID: "a", + DeviceID: "b", + }) + + go ticker.Run() + defer ticker.Stop() + ticker.Remember(PollerID{ + UserID: "a", + DeviceID: "b", + }) + time.Sleep(10 * duration) + if len(payloads) != 1 { + t.Fatalf("got %d payloads, want 1", len(payloads)) + } +} + +func assertPayloadEqual(t *testing.T, got, want map[string][]string) { + t.Helper() + if len(got) != len(want) { + t.Fatalf("got %+v\nwant %+v\n", got, want) + } + for userID, wantDeviceIDs := range want { + gotDeviceIDs := got[userID] + sort.Strings(wantDeviceIDs) + sort.Strings(gotDeviceIDs) + if !reflect.DeepEqual(gotDeviceIDs, wantDeviceIDs) { + t.Errorf("user %v got devices %v want %v", userID, gotDeviceIDs, wantDeviceIDs) + } + } +} diff --git a/sync2/handler2/handler.go b/sync2/handler2/handler.go index 60e8fece..15d8037e 100644 --- a/sync2/handler2/handler.go +++ b/sync2/handler2/handler.go @@ -7,6 +7,7 @@ import ( "hash/fnv" "os" "sync" + "time" "github.com/jmoiron/sqlx" "github.com/matrix-org/sliding-sync/sqlutil" @@ -43,13 +44,16 @@ type Handler struct { // room_id => fnv_hash([typing user ids]) typingMap map[string]uint64 + deviceDataTicker *sync2.DeviceDataTicker + e2eeWorkerPool *internal.WorkerPool + numPollers prometheus.Gauge subSystem string } func NewHandler( pMap sync2.IPollerMap, v2Store *sync2.Storage, store *state.Storage, - pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool, + pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool, deviceDataUpdateDuration time.Duration, ) (*Handler, error) { h := &Handler{ pMap: pMap, @@ -60,7 +64,9 @@ func NewHandler( Highlight int Notif int }), - typingMap: make(map[string]uint64), + typingMap: make(map[string]uint64), + deviceDataTicker: sync2.NewDeviceDataTicker(deviceDataUpdateDuration), + e2eeWorkerPool: internal.NewWorkerPool(500), // TODO: assign as fraction of db max conns, not hardcoded } if enablePrometheus { @@ -85,6 +91,9 @@ func (h *Handler) Listen() { sentry.CaptureException(err) } }() + h.e2eeWorkerPool.Start() + h.deviceDataTicker.SetCallback(h.OnBulkDeviceDataUpdate) + go h.deviceDataTicker.Run() } func (h *Handler) Teardown() { @@ -94,6 +103,7 @@ func (h *Handler) Teardown() { h.Store.Teardown() h.v2Store.Teardown() h.pMap.Terminate() + h.deviceDataTicker.Stop() if h.numPollers != nil { prometheus.Unregister(h.numPollers) } @@ -192,27 +202,38 @@ func (h *Handler) UpdateDeviceSince(ctx context.Context, userID, deviceID, since } func (h *Handler) OnE2EEData(ctx context.Context, userID, deviceID string, otkCounts map[string]int, fallbackKeyTypes []string, deviceListChanges map[string]int) { - // some of these fields may be set - partialDD := internal.DeviceData{ - UserID: userID, - DeviceID: deviceID, - OTKCounts: otkCounts, - FallbackKeyTypes: fallbackKeyTypes, - DeviceLists: internal.DeviceLists{ - New: deviceListChanges, - }, - } - nextPos, err := h.Store.DeviceDataTable.Upsert(&partialDD) - if err != nil { - logger.Err(err).Str("user", userID).Msg("failed to upsert device data") - internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err) - return - } - h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2DeviceData{ - UserID: userID, - DeviceID: deviceID, - Pos: nextPos, + var wg sync.WaitGroup + wg.Add(1) + h.e2eeWorkerPool.Queue(func() { + defer wg.Done() + // some of these fields may be set + partialDD := internal.DeviceData{ + UserID: userID, + DeviceID: deviceID, + OTKCounts: otkCounts, + FallbackKeyTypes: fallbackKeyTypes, + DeviceLists: internal.DeviceLists{ + New: deviceListChanges, + }, + } + _, err := h.Store.DeviceDataTable.Upsert(&partialDD) + if err != nil { + logger.Err(err).Str("user", userID).Msg("failed to upsert device data") + internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err) + return + } + // remember this to notify on pubsub later + h.deviceDataTicker.Remember(sync2.PollerID{ + UserID: userID, + DeviceID: deviceID, + }) }) + wg.Wait() +} + +// Called periodically by deviceDataTicker, contains many updates +func (h *Handler) OnBulkDeviceDataUpdate(payload *pubsub.V2DeviceData) { + h.v2Pub.Notify(pubsub.ChanV2, payload) } func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) { diff --git a/sync2/handler2/handler_test.go b/sync2/handler2/handler_test.go index 20f064ab..fa315228 100644 --- a/sync2/handler2/handler_test.go +++ b/sync2/handler2/handler_test.go @@ -1,14 +1,15 @@ package handler2_test import ( - "github.com/jmoiron/sqlx" - "github.com/matrix-org/sliding-sync/sqlutil" "os" "reflect" "sync" "testing" "time" + "github.com/jmoiron/sqlx" + "github.com/matrix-org/sliding-sync/sqlutil" + "github.com/matrix-org/sliding-sync/pubsub" "github.com/matrix-org/sliding-sync/state" "github.com/matrix-org/sliding-sync/sync2" @@ -127,7 +128,7 @@ func TestHandlerFreshEnsurePolling(t *testing.T) { pMap := &mockPollerMap{} pub := newMockPub() sub := &mockSub{} - h, err := handler2.NewHandler(pMap, v2Store, store, pub, sub, false) + h, err := handler2.NewHandler(pMap, v2Store, store, pub, sub, false, time.Minute) assertNoError(t, err) alice := "@alice:localhost" deviceID := "ALICE" diff --git a/sync2/poller.go b/sync2/poller.go index 04b8ca0c..f6c99a18 100644 --- a/sync2/poller.go +++ b/sync2/poller.go @@ -25,6 +25,9 @@ type PollerID struct { // alias time.Sleep so tests can monkey patch it out var timeSleep = time.Sleep +// log at most once every duration. Always logs before terminating. +var logInterval = 30 * time.Second + // V2DataReceiver is the receiver for all the v2 sync data the poller gets type V2DataReceiver interface { // Update the since token for this device. Called AFTER all other data in this sync response has been processed. @@ -64,14 +67,17 @@ type IPollerMap interface { // PollerMap is a map of device ID to Poller type PollerMap struct { - v2Client Client - callbacks V2DataReceiver - pollerMu *sync.Mutex - Pollers map[PollerID]*poller - executor chan func() - executorRunning bool - processHistogramVec *prometheus.HistogramVec - timelineSizeHistogramVec *prometheus.HistogramVec + v2Client Client + callbacks V2DataReceiver + pollerMu *sync.Mutex + Pollers map[PollerID]*poller + executor chan func() + executorRunning bool + processHistogramVec *prometheus.HistogramVec + timelineSizeHistogramVec *prometheus.HistogramVec + gappyStateSizeVec *prometheus.HistogramVec + numOutstandingSyncReqsGauge prometheus.Gauge + totalNumPollsCounter prometheus.Counter } // NewPollerMap makes a new PollerMap. Guarantees that the V2DataReceiver will be called on the same @@ -122,7 +128,28 @@ func NewPollerMap(v2Client Client, enablePrometheus bool) *PollerMap { Buckets: []float64{0.0, 1.0, 2.0, 5.0, 10.0, 20.0, 50.0}, }, []string{"limited"}) prometheus.MustRegister(pm.timelineSizeHistogramVec) - + pm.gappyStateSizeVec = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: "sliding_sync", + Subsystem: "poller", + Name: "gappy_state_size", + Help: "Number of events in a state block during a sync v2 gappy sync", + Buckets: []float64{1.0, 10.0, 100.0, 1000.0, 10000.0}, + }, nil) + prometheus.MustRegister(pm.gappyStateSizeVec) + pm.totalNumPollsCounter = prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "sliding_sync", + Subsystem: "poller", + Name: "total_num_polls", + Help: "Total number of poll loops iterated.", + }) + prometheus.MustRegister(pm.totalNumPollsCounter) + pm.numOutstandingSyncReqsGauge = prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: "sliding_sync", + Subsystem: "poller", + Name: "num_outstanding_sync_v2_reqs", + Help: "Number of sync v2 requests that have yet to return a response.", + }) + prometheus.MustRegister(pm.numOutstandingSyncReqsGauge) } return pm } @@ -144,6 +171,15 @@ func (h *PollerMap) Terminate() { if h.timelineSizeHistogramVec != nil { prometheus.Unregister(h.timelineSizeHistogramVec) } + if h.gappyStateSizeVec != nil { + prometheus.Unregister(h.gappyStateSizeVec) + } + if h.totalNumPollsCounter != nil { + prometheus.Unregister(h.totalNumPollsCounter) + } + if h.numOutstandingSyncReqsGauge != nil { + prometheus.Unregister(h.numOutstandingSyncReqsGauge) + } close(h.executor) } @@ -203,6 +239,9 @@ func (h *PollerMap) EnsurePolling(pid PollerID, accessToken, v2since string, isS poller = newPoller(pid, accessToken, h.v2Client, h, logger, !needToWait && !isStartup) poller.processHistogramVec = h.processHistogramVec poller.timelineSizeVec = h.timelineSizeHistogramVec + poller.gappyStateSizeVec = h.gappyStateSizeVec + poller.numOutstandingSyncReqs = h.numOutstandingSyncReqsGauge + poller.totalNumPolls = h.totalNumPollsCounter go poller.Poll(v2since) h.Pollers[pid] = poller @@ -342,9 +381,24 @@ type poller struct { terminated *atomic.Bool wg *sync.WaitGroup - pollHistogramVec *prometheus.HistogramVec - processHistogramVec *prometheus.HistogramVec - timelineSizeVec *prometheus.HistogramVec + // stats about poll response data, for logging purposes + lastLogged time.Time + totalStateCalls int + totalTimelineCalls int + totalReceipts int + totalTyping int + totalInvites int + totalDeviceEvents int + totalAccountData int + totalChangedDeviceLists int + totalLeftDeviceLists int + + pollHistogramVec *prometheus.HistogramVec + processHistogramVec *prometheus.HistogramVec + timelineSizeVec *prometheus.HistogramVec + gappyStateSizeVec *prometheus.HistogramVec + numOutstandingSyncReqs prometheus.Gauge + totalNumPolls prometheus.Counter } func newPoller(pid PollerID, accessToken string, client Client, receiver V2DataReceiver, logger zerolog.Logger, initialToDeviceOnly bool) *poller { @@ -399,7 +453,7 @@ func (p *poller) Poll(since string) { defer func() { panicErr := recover() if panicErr != nil { - logger.Error().Str("user", p.userID).Str("device", p.deviceID).Msg(string(debug.Stack())) + logger.Error().Str("user", p.userID).Str("device", p.deviceID).Msgf("%s. Traceback:\n%s", panicErr, debug.Stack()) internal.GetSentryHubFromContextOrDefault(ctx).RecoverWithContext(ctx, panicErr) } p.receiver.OnTerminated(ctx, p.userID, p.deviceID) @@ -418,6 +472,7 @@ func (p *poller) Poll(since string) { break } } + p.maybeLogStats(true) // always unblock EnsurePolling else we can end up head-of-line blocking other pollers! if state.firstTime { state.firstTime = false @@ -429,6 +484,9 @@ func (p *poller) Poll(since string) { // s (which is assumed to be non-nil). Returns a non-nil error iff the poller loop // should halt. func (p *poller) poll(ctx context.Context, s *pollLoopState) error { + if p.totalNumPolls != nil { + p.totalNumPolls.Inc() + } if s.failCount > 0 { // don't backoff when doing v2 syncs because the response is only in the cache for a short // period of time (on massive accounts on matrix.org) such that if you wait 2,4,8min between @@ -442,7 +500,13 @@ func (p *poller) poll(ctx context.Context, s *pollLoopState) error { } start := time.Now() spanCtx, region := internal.StartSpan(ctx, "DoSyncV2") + if p.numOutstandingSyncReqs != nil { + p.numOutstandingSyncReqs.Inc() + } resp, statusCode, err := p.client.DoSyncV2(spanCtx, p.accessToken, s.since, s.firstTime, p.initialToDeviceOnly) + if p.numOutstandingSyncReqs != nil { + p.numOutstandingSyncReqs.Dec() + } region.End() p.trackRequestDuration(time.Since(start), s.since == "", s.firstTime) if p.terminated.Load() { @@ -488,6 +552,7 @@ func (p *poller) poll(ctx context.Context, s *pollLoopState) error { p.wg.Done() } p.trackProcessDuration(time.Since(start), wasInitial, wasFirst) + p.maybeLogStats(false) return nil } @@ -526,6 +591,7 @@ func (p *poller) parseToDeviceMessages(ctx context.Context, res *SyncResponse) { if len(res.ToDevice.Events) == 0 { return } + p.totalDeviceEvents += len(res.ToDevice.Events) p.receiver.AddToDeviceMessages(ctx, p.userID, p.deviceID, res.ToDevice.Events) } @@ -564,6 +630,8 @@ func (p *poller) parseE2EEData(ctx context.Context, res *SyncResponse) { deviceListChanges := internal.ToDeviceListChangesMap(res.DeviceLists.Changed, res.DeviceLists.Left) if deviceListChanges != nil || changedFallbackTypes != nil || changedOTKCounts != nil { + p.totalChangedDeviceLists += len(res.DeviceLists.Changed) + p.totalLeftDeviceLists += len(res.DeviceLists.Left) p.receiver.OnE2EEData(ctx, p.userID, p.deviceID, changedOTKCounts, changedFallbackTypes, deviceListChanges) } } @@ -574,6 +642,7 @@ func (p *poller) parseGlobalAccountData(ctx context.Context, res *SyncResponse) if len(res.AccountData.Events) == 0 { return } + p.totalAccountData += len(res.AccountData.Events) p.receiver.OnAccountData(ctx, p.userID, AccountDataGlobalRoom, res.AccountData.Events) } @@ -604,6 +673,7 @@ func (p *poller) parseRoomsResponse(ctx context.Context, res *SyncResponse) { }) hub.CaptureMessage(warnMsg) }) + p.trackGappyStateSize(len(prependStateEvents)) roomData.Timeline.Events = append(prependStateEvents, roomData.Timeline.Events...) } } @@ -648,17 +718,39 @@ func (p *poller) parseRoomsResponse(ctx context.Context, res *SyncResponse) { for roomID, roomData := range res.Rooms.Invite { p.receiver.OnInvite(ctx, p.userID, roomID, roomData.InviteState.Events) } - var l *zerolog.Event - if len(res.Rooms.Invite) > 0 || len(res.Rooms.Join) > 0 { - l = p.logger.Info() - } else { - l = p.logger.Debug() + + p.totalReceipts += receiptCalls + p.totalStateCalls += stateCalls + p.totalTimelineCalls += timelineCalls + p.totalTyping += typingCalls + p.totalInvites += len(res.Rooms.Invite) +} + +func (p *poller) maybeLogStats(force bool) { + if !force && time.Since(p.lastLogged) < logInterval { + // only log at most once every logInterval + return } - l.Ints( - "rooms [invite,join,leave]", []int{len(res.Rooms.Invite), len(res.Rooms.Join), len(res.Rooms.Leave)}, + p.lastLogged = time.Now() + p.logger.Info().Ints( + "rooms [timeline,state,typing,receipts,invites]", []int{ + p.totalTimelineCalls, p.totalStateCalls, p.totalTyping, p.totalReceipts, p.totalInvites, + }, ).Ints( - "storage [states,timelines,typing,receipts]", []int{stateCalls, timelineCalls, typingCalls, receiptCalls}, - ).Int("to_device", len(res.ToDevice.Events)).Msg("Poller: accumulated data") + "device [events,changed,left,account]", []int{ + p.totalDeviceEvents, p.totalChangedDeviceLists, p.totalLeftDeviceLists, p.totalAccountData, + }, + ).Msg("Poller: accumulated data") + + p.totalAccountData = 0 + p.totalChangedDeviceLists = 0 + p.totalDeviceEvents = 0 + p.totalInvites = 0 + p.totalLeftDeviceLists = 0 + p.totalReceipts = 0 + p.totalStateCalls = 0 + p.totalTimelineCalls = 0 + p.totalTyping = 0 } func (p *poller) trackTimelineSize(size int, limited bool) { @@ -671,3 +763,10 @@ func (p *poller) trackTimelineSize(size int, limited bool) { } p.timelineSizeVec.WithLabelValues(label).Observe(float64(size)) } + +func (p *poller) trackGappyStateSize(size int) { + if p.gappyStateSizeVec == nil { + return + } + p.gappyStateSizeVec.WithLabelValues().Observe(float64(size)) +} diff --git a/sync3/caches/user.go b/sync3/caches/user.go index efd7aed8..ef160f8b 100644 --- a/sync3/caches/user.go +++ b/sync3/caches/user.go @@ -393,8 +393,16 @@ func (c *UserCache) AnnotateWithTransactionIDs(ctx context.Context, userID strin i int }) for roomID, events := range roomIDToEvents { - for i, ev := range events { - evID := gjson.GetBytes(ev, "event_id").Str + for i, evJSON := range events { + ev := gjson.ParseBytes(evJSON) + evID := ev.Get("event_id").Str + sender := ev.Get("sender").Str + if sender != userID { + // don't ask for txn IDs for events which weren't sent by us. + // If we do, we'll needlessly hit the database, increasing latencies when + // catching up from the live buffer. + continue + } eventIDs = append(eventIDs, evID) eventIDToEvent[evID] = struct { roomID string @@ -405,6 +413,10 @@ func (c *UserCache) AnnotateWithTransactionIDs(ctx context.Context, userID strin } } } + if len(eventIDs) == 0 { + // don't do any work if we have no events + return roomIDToEvents + } eventIDToTxnID := c.txnIDs.TransactionIDForEvents(userID, deviceID, eventIDs) for eventID, txnID := range eventIDToTxnID { data, ok := eventIDToEvent[eventID] diff --git a/sync3/caches/user_test.go b/sync3/caches/user_test.go index a1e5ec8f..5809317b 100644 --- a/sync3/caches/user_test.go +++ b/sync3/caches/user_test.go @@ -83,8 +83,8 @@ func TestAnnotateWithTransactionIDs(t *testing.T) { data: tc.eventIDToTxnIDs, } uc := caches.NewUserCache(userID, nil, nil, fetcher) - got := uc.AnnotateWithTransactionIDs(context.Background(), userID, "DEVICE", convertIDToEventStub(tc.roomIDToEvents)) - want := convertIDTxnToEventStub(tc.wantRoomIDToEvents) + got := uc.AnnotateWithTransactionIDs(context.Background(), userID, "DEVICE", convertIDToEventStub(userID, tc.roomIDToEvents)) + want := convertIDTxnToEventStub(userID, tc.wantRoomIDToEvents) if !reflect.DeepEqual(got, want) { t.Errorf("%s : got %v want %v", tc.name, js(got), js(want)) } @@ -96,27 +96,27 @@ func js(in interface{}) string { return string(b) } -func convertIDToEventStub(roomToEventIDs map[string][]string) map[string][]json.RawMessage { +func convertIDToEventStub(sender string, roomToEventIDs map[string][]string) map[string][]json.RawMessage { result := make(map[string][]json.RawMessage) for roomID, eventIDs := range roomToEventIDs { events := make([]json.RawMessage, len(eventIDs)) for i := range eventIDs { - events[i] = json.RawMessage(fmt.Sprintf(`{"event_id":"%s","type":"x"}`, eventIDs[i])) + events[i] = json.RawMessage(fmt.Sprintf(`{"event_id":"%s","type":"x","sender":"%s"}`, eventIDs[i], sender)) } result[roomID] = events } return result } -func convertIDTxnToEventStub(roomToEventIDs map[string][][2]string) map[string][]json.RawMessage { +func convertIDTxnToEventStub(sender string, roomToEventIDs map[string][][2]string) map[string][]json.RawMessage { result := make(map[string][]json.RawMessage) for roomID, eventIDs := range roomToEventIDs { events := make([]json.RawMessage, len(eventIDs)) for i := range eventIDs { if eventIDs[i][1] == "" { - events[i] = json.RawMessage(fmt.Sprintf(`{"event_id":"%s","type":"x"}`, eventIDs[i][0])) + events[i] = json.RawMessage(fmt.Sprintf(`{"event_id":"%s","type":"x","sender":"%s"}`, eventIDs[i][0], sender)) } else { - events[i] = json.RawMessage(fmt.Sprintf(`{"event_id":"%s","type":"x","unsigned":{"transaction_id":"%s"}}`, eventIDs[i][0], eventIDs[i][1])) + events[i] = json.RawMessage(fmt.Sprintf(`{"event_id":"%s","type":"x","sender":"%s","unsigned":{"transaction_id":"%s"}}`, eventIDs[i][0], sender, eventIDs[i][1])) } } result[roomID] = events diff --git a/sync3/conn.go b/sync3/conn.go index 14ff1689..8010447b 100644 --- a/sync3/conn.go +++ b/sync3/conn.go @@ -30,7 +30,7 @@ type ConnHandler interface { // Callback which is allowed to block as long as the context is active. Return the response // to send back or an error. Errors of type *internal.HandlerError are inspected for the correct // status code to send back. - OnIncomingRequest(ctx context.Context, cid ConnID, req *Request, isInitial bool) (*Response, error) + OnIncomingRequest(ctx context.Context, cid ConnID, req *Request, isInitial bool, start time.Time) (*Response, error) OnUpdate(ctx context.Context, update caches.Update) Destroy() Alive() bool @@ -88,7 +88,7 @@ func (c *Conn) OnUpdate(ctx context.Context, update caches.Update) { // upwards but will NOT be logged to Sentry (neither here nor by the caller). Errors // should be reported to Sentry as close as possible to the point of creating the error, // to provide the best possible Sentry traceback. -func (c *Conn) tryRequest(ctx context.Context, req *Request) (res *Response, err error) { +func (c *Conn) tryRequest(ctx context.Context, req *Request, start time.Time) (res *Response, err error) { // TODO: include useful information from the request in the sentry hub/context // Might be better done in the caller though? defer func() { @@ -116,7 +116,7 @@ func (c *Conn) tryRequest(ctx context.Context, req *Request) (res *Response, err ctx, task := internal.StartTask(ctx, taskType) defer task.End() internal.Logf(ctx, "connstate", "starting user=%v device=%v pos=%v", c.UserID, c.ConnID.DeviceID, req.pos) - return c.handler.OnIncomingRequest(ctx, c.ConnID, req, req.pos == 0) + return c.handler.OnIncomingRequest(ctx, c.ConnID, req, req.pos == 0, start) } func (c *Conn) isOutstanding(pos int64) bool { @@ -132,7 +132,7 @@ func (c *Conn) isOutstanding(pos int64) bool { // If an error is returned, it will be logged by the caller and transmitted to the // client. It will NOT be reported to Sentry---this should happen as close as possible // to the creation of the error (or else Sentry cannot provide a meaningful traceback.) -func (c *Conn) OnIncomingRequest(ctx context.Context, req *Request) (resp *Response, herr *internal.HandlerError) { +func (c *Conn) OnIncomingRequest(ctx context.Context, req *Request, start time.Time) (resp *Response, herr *internal.HandlerError) { c.cancelOutstandingRequestMu.Lock() if c.cancelOutstandingRequest != nil { c.cancelOutstandingRequest() @@ -217,7 +217,7 @@ func (c *Conn) OnIncomingRequest(ctx context.Context, req *Request) (resp *Respo req.SetTimeoutMSecs(1) } - resp, err := c.tryRequest(ctx, req) + resp, err := c.tryRequest(ctx, req, start) if err != nil { herr, ok := err.(*internal.HandlerError) if !ok { diff --git a/sync3/conn_test.go b/sync3/conn_test.go index a0be14a9..c326938c 100644 --- a/sync3/conn_test.go +++ b/sync3/conn_test.go @@ -16,7 +16,7 @@ type connHandlerMock struct { fn func(ctx context.Context, cid ConnID, req *Request, isInitial bool) (*Response, error) } -func (c *connHandlerMock) OnIncomingRequest(ctx context.Context, cid ConnID, req *Request, init bool) (*Response, error) { +func (c *connHandlerMock) OnIncomingRequest(ctx context.Context, cid ConnID, req *Request, init bool, start time.Time) (*Response, error) { return c.fn(ctx, cid, req, init) } func (c *connHandlerMock) UserID() string { @@ -47,7 +47,7 @@ func TestConn(t *testing.T) { // initial request resp, err := c.OnIncomingRequest(ctx, &Request{ pos: 0, - }) + }, time.Now()) assertNoError(t, err) assertPos(t, resp.Pos, 1) assertInt(t, resp.Lists["a"].Count, 101) @@ -55,14 +55,14 @@ func TestConn(t *testing.T) { // happy case, pos=1 resp, err = c.OnIncomingRequest(ctx, &Request{ pos: 1, - }) + }, time.Now()) assertPos(t, resp.Pos, 2) assertInt(t, resp.Lists["a"].Count, 102) assertNoError(t, err) // bogus position returns a 400 _, err = c.OnIncomingRequest(ctx, &Request{ pos: 31415, - }) + }, time.Now()) if err == nil { t.Fatalf("expected error, got none") } @@ -106,7 +106,7 @@ func TestConnBlocking(t *testing.T) { Sort: []string{"hi"}, }, }, - }) + }, time.Now()) }() go func() { defer wg.Done() @@ -118,7 +118,7 @@ func TestConnBlocking(t *testing.T) { Sort: []string{"hi2"}, }, }, - }) + }, time.Now()) }() go func() { wg.Wait() @@ -148,18 +148,18 @@ func TestConnRetries(t *testing.T) { }, }}, nil }}) - resp, err := c.OnIncomingRequest(ctx, &Request{}) + resp, err := c.OnIncomingRequest(ctx, &Request{}, time.Now()) assertPos(t, resp.Pos, 1) assertInt(t, resp.Lists["a"].Count, 20) assertInt(t, callCount, 1) assertNoError(t, err) - resp, err = c.OnIncomingRequest(ctx, &Request{pos: 1}) + resp, err = c.OnIncomingRequest(ctx, &Request{pos: 1}, time.Now()) assertPos(t, resp.Pos, 2) assertInt(t, resp.Lists["a"].Count, 20) assertInt(t, callCount, 2) assertNoError(t, err) // retry! Shouldn't invoke handler again - resp, err = c.OnIncomingRequest(ctx, &Request{pos: 1}) + resp, err = c.OnIncomingRequest(ctx, &Request{pos: 1}, time.Now()) assertPos(t, resp.Pos, 2) assertInt(t, resp.Lists["a"].Count, 20) assertInt(t, callCount, 2) // this doesn't increment @@ -170,7 +170,7 @@ func TestConnRetries(t *testing.T) { "a": { Sort: []string{SortByName}, }, - }}) + }}, time.Now()) assertPos(t, resp.Pos, 2) assertInt(t, resp.Lists["a"].Count, 20) assertInt(t, callCount, 3) // this doesn't increment @@ -191,25 +191,25 @@ func TestConnBufferRes(t *testing.T) { }, }}, nil }}) - resp, err := c.OnIncomingRequest(ctx, &Request{}) + resp, err := c.OnIncomingRequest(ctx, &Request{}, time.Now()) assertNoError(t, err) assertPos(t, resp.Pos, 1) assertInt(t, resp.Lists["a"].Count, 1) assertInt(t, callCount, 1) - resp, err = c.OnIncomingRequest(ctx, &Request{pos: 1}) + resp, err = c.OnIncomingRequest(ctx, &Request{pos: 1}, time.Now()) assertNoError(t, err) assertPos(t, resp.Pos, 2) assertInt(t, resp.Lists["a"].Count, 2) assertInt(t, callCount, 2) // retry with modified request data that shouldn't prompt data to be returned. // should invoke handler again! - resp, err = c.OnIncomingRequest(ctx, &Request{pos: 1, UnsubscribeRooms: []string{"a"}}) + resp, err = c.OnIncomingRequest(ctx, &Request{pos: 1, UnsubscribeRooms: []string{"a"}}, time.Now()) assertNoError(t, err) assertPos(t, resp.Pos, 2) assertInt(t, resp.Lists["a"].Count, 2) assertInt(t, callCount, 3) // this DOES increment, the response is buffered and not returned yet. // retry with same request body, so should NOT invoke handler again and return buffered response - resp, err = c.OnIncomingRequest(ctx, &Request{pos: 2, UnsubscribeRooms: []string{"a"}}) + resp, err = c.OnIncomingRequest(ctx, &Request{pos: 2, UnsubscribeRooms: []string{"a"}}, time.Now()) assertNoError(t, err) assertPos(t, resp.Pos, 3) assertInt(t, resp.Lists["a"].Count, 3) @@ -228,7 +228,7 @@ func TestConnErrors(t *testing.T) { // random errors = 500 errCh <- errors.New("oops") - _, herr := c.OnIncomingRequest(ctx, &Request{}) + _, herr := c.OnIncomingRequest(ctx, &Request{}, time.Now()) if herr.StatusCode != 500 { t.Fatalf("random errors should be status 500, got %d", herr.StatusCode) } @@ -237,7 +237,7 @@ func TestConnErrors(t *testing.T) { StatusCode: 400, Err: errors.New("no way!"), } - _, herr = c.OnIncomingRequest(ctx, &Request{}) + _, herr = c.OnIncomingRequest(ctx, &Request{}, time.Now()) if herr.StatusCode != 400 { t.Fatalf("expected status 400, got %d", herr.StatusCode) } @@ -258,7 +258,7 @@ func TestConnErrorsNoCache(t *testing.T) { } }}) // errors should not be cached - resp, herr := c.OnIncomingRequest(ctx, &Request{}) + resp, herr := c.OnIncomingRequest(ctx, &Request{}, time.Now()) if herr != nil { t.Fatalf("expected no error, got %+v", herr) } @@ -267,12 +267,12 @@ func TestConnErrorsNoCache(t *testing.T) { StatusCode: 400, Err: errors.New("no way!"), } - _, herr = c.OnIncomingRequest(ctx, &Request{pos: resp.PosInt()}) + _, herr = c.OnIncomingRequest(ctx, &Request{pos: resp.PosInt()}, time.Now()) if herr.StatusCode != 400 { t.Fatalf("expected status 400, got %d", herr.StatusCode) } // but doing the exact same request should now work - _, herr = c.OnIncomingRequest(ctx, &Request{pos: resp.PosInt()}) + _, herr = c.OnIncomingRequest(ctx, &Request{pos: resp.PosInt()}, time.Now()) if herr != nil { t.Fatalf("expected no error, got %+v", herr) } @@ -361,7 +361,7 @@ func TestConnBufferRememberInflight(t *testing.T) { var err *internal.HandlerError for i, step := range steps { t.Logf("Executing step %d", i) - resp, err = c.OnIncomingRequest(ctx, step.req) + resp, err = c.OnIncomingRequest(ctx, step.req, time.Now()) if !step.wantErr { assertNoError(t, err) } diff --git a/sync3/handler/connstate.go b/sync3/handler/connstate.go index 6bac6f30..9f3f79ab 100644 --- a/sync3/handler/connstate.go +++ b/sync3/handler/connstate.go @@ -52,12 +52,13 @@ type ConnState struct { joinChecker JoinChecker extensionsHandler extensions.HandlerInterface + setupHistogramVec *prometheus.HistogramVec processHistogramVec *prometheus.HistogramVec } func NewConnState( userID, deviceID string, userCache *caches.UserCache, globalCache *caches.GlobalCache, - ex extensions.HandlerInterface, joinChecker JoinChecker, histVec *prometheus.HistogramVec, + ex extensions.HandlerInterface, joinChecker JoinChecker, setupHistVec *prometheus.HistogramVec, histVec *prometheus.HistogramVec, maxPendingEventUpdates int, ) *ConnState { cs := &ConnState{ @@ -72,6 +73,7 @@ func NewConnState( extensionsHandler: ex, joinChecker: joinChecker, lazyCache: NewLazyCache(), + setupHistogramVec: setupHistVec, processHistogramVec: histVec, } cs.live = &connStateLive{ @@ -160,7 +162,7 @@ func (s *ConnState) load(ctx context.Context, req *sync3.Request) error { } // OnIncomingRequest is guaranteed to be called sequentially (it's protected by a mutex in conn.go) -func (s *ConnState) OnIncomingRequest(ctx context.Context, cid sync3.ConnID, req *sync3.Request, isInitial bool) (*sync3.Response, error) { +func (s *ConnState) OnIncomingRequest(ctx context.Context, cid sync3.ConnID, req *sync3.Request, isInitial bool, start time.Time) (*sync3.Response, error) { if s.anchorLoadPosition <= 0 { // load() needs no ctx so drop it _, region := internal.StartSpan(ctx, "load") @@ -172,6 +174,8 @@ func (s *ConnState) OnIncomingRequest(ctx context.Context, cid sync3.ConnID, req } region.End() } + setupTime := time.Since(start) + s.trackSetupDuration(setupTime, isInitial) return s.onIncomingRequest(ctx, req, isInitial) } @@ -192,6 +196,9 @@ func (s *ConnState) onIncomingRequest(reqCtx context.Context, req *sync3.Request } internal.Logf(reqCtx, "connstate", "list[%v] prev_empty=%v curr=%v", key, l.Prev == nil, listData) } + for roomID, sub := range s.muxedReq.RoomSubscriptions { + internal.Logf(reqCtx, "connstate", "room sub[%v] %v", roomID, sub) + } // work out which rooms we'll return data for and add their relevant subscriptions to the builder // for it to mix together @@ -597,7 +604,7 @@ func (s *ConnState) getInitialRoomData(ctx context.Context, roomSub sync3.RoomSu Initial: true, IsDM: userRoomData.IsDM, JoinedCount: metadata.JoinCount, - InvitedCount: metadata.InviteCount, + InvitedCount: &metadata.InviteCount, PrevBatch: userRoomData.RequestedLatestEvents.PrevBatch, } } @@ -610,6 +617,17 @@ func (s *ConnState) getInitialRoomData(ctx context.Context, roomSub sync3.RoomSu return rooms } +func (s *ConnState) trackSetupDuration(dur time.Duration, isInitial bool) { + if s.setupHistogramVec == nil { + return + } + val := "0" + if isInitial { + val = "1" + } + s.setupHistogramVec.WithLabelValues(val).Observe(float64(dur.Seconds())) +} + func (s *ConnState) trackProcessDuration(dur time.Duration, isInitial bool) { if s.processHistogramVec == nil { return diff --git a/sync3/handler/connstate_live.go b/sync3/handler/connstate_live.go index 6e0a3703..a01a7ef6 100644 --- a/sync3/handler/connstate_live.go +++ b/sync3/handler/connstate_live.go @@ -37,7 +37,7 @@ func (s *connStateLive) onUpdate(up caches.Update) { select { case s.updates <- up: case <-time.After(BufferWaitTime): - logger.Warn().Interface("update", up).Str("user", s.userID).Msg( + logger.Warn().Interface("update", up).Str("user", s.userID).Str("device", s.deviceID).Msg( "cannot send update to connection, buffer exceeded. Destroying connection.", ) s.bufferFull = true @@ -80,36 +80,47 @@ func (s *connStateLive) liveUpdate( internal.Logf(ctx, "liveUpdate", "timed out after %v", timeLeftToWait) return case update := <-s.updates: - internal.Logf(ctx, "liveUpdate", "process live update") - - s.processLiveUpdate(ctx, update, response) - // pass event to extensions AFTER processing - roomIDsToLists := s.lists.ListsByVisibleRoomIDs(s.muxedReq.Lists) - s.extensionsHandler.HandleLiveUpdate(ctx, update, ex, &response.Extensions, extensions.Context{ - IsInitial: false, - RoomIDToTimeline: response.RoomIDsToTimelineEventIDs(), - UserID: s.userID, - DeviceID: s.deviceID, - RoomIDsToLists: roomIDsToLists, - }) + s.processUpdate(ctx, update, response, ex) // if there's more updates and we don't have lots stacked up already, go ahead and process another for len(s.updates) > 0 && response.ListOps() < 50 { update = <-s.updates - s.processLiveUpdate(ctx, update, response) - s.extensionsHandler.HandleLiveUpdate(ctx, update, ex, &response.Extensions, extensions.Context{ - IsInitial: false, - RoomIDToTimeline: response.RoomIDsToTimelineEventIDs(), - UserID: s.userID, - DeviceID: s.deviceID, - RoomIDsToLists: roomIDsToLists, - }) + s.processUpdate(ctx, update, response, ex) } } } + + // If a client constantly changes their request params in every request they make, we will never consume from + // the update channel as the response will always have data already. In an effort to prevent starvation of new + // data, we will process some updates even though we have data already, but only if A) we didn't live stream + // due to natural circumstances, B) it isn't an initial request and C) there is in fact some data there. + numQueuedUpdates := len(s.updates) + if !hasLiveStreamed && !isInitial && numQueuedUpdates > 0 { + for i := 0; i < numQueuedUpdates; i++ { + update := <-s.updates + s.processUpdate(ctx, update, response, ex) + } + log.Debug().Int("num_queued", numQueuedUpdates).Msg("liveUpdate: caught up") + internal.Logf(ctx, "connstate", "liveUpdate caught up %d updates", numQueuedUpdates) + } + log.Trace().Bool("live_streamed", hasLiveStreamed).Msg("liveUpdate: returning") // TODO: op consolidation } +func (s *connStateLive) processUpdate(ctx context.Context, update caches.Update, response *sync3.Response, ex extensions.Request) { + internal.Logf(ctx, "liveUpdate", "process live update %s", update.Type()) + s.processLiveUpdate(ctx, update, response) + // pass event to extensions AFTER processing + roomIDsToLists := s.lists.ListsByVisibleRoomIDs(s.muxedReq.Lists) + s.extensionsHandler.HandleLiveUpdate(ctx, update, ex, &response.Extensions, extensions.Context{ + IsInitial: false, + RoomIDToTimeline: response.RoomIDsToTimelineEventIDs(), + UserID: s.userID, + DeviceID: s.deviceID, + RoomIDsToLists: roomIDsToLists, + }) +} + func (s *connStateLive) processLiveUpdate(ctx context.Context, up caches.Update, response *sync3.Response) bool { internal.AssertWithContext(ctx, "processLiveUpdate: response list length != internal list length", s.lists.Len() == len(response.Lists)) internal.AssertWithContext(ctx, "processLiveUpdate: request list length != internal list length", s.lists.Len() == len(s.muxedReq.Lists)) @@ -208,7 +219,7 @@ func (s *connStateLive) processLiveUpdate(ctx context.Context, up caches.Update, thisRoom.Name = internal.CalculateRoomName(metadata, 5) // TODO: customisable? } if delta.InviteCountChanged { - thisRoom.InvitedCount = roomUpdate.GlobalRoomMetadata().InviteCount + thisRoom.InvitedCount = &roomUpdate.GlobalRoomMetadata().InviteCount } if delta.JoinCountChanged { thisRoom.JoinedCount = roomUpdate.GlobalRoomMetadata().JoinCount diff --git a/sync3/handler/connstate_test.go b/sync3/handler/connstate_test.go index bed2beff..17700ea8 100644 --- a/sync3/handler/connstate_test.go +++ b/sync3/handler/connstate_test.go @@ -107,7 +107,7 @@ func TestConnStateInitial(t *testing.T) { } return result } - cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, 1000) + cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, nil, 1000) if userID != cs.UserID() { t.Fatalf("UserID returned wrong value, got %v want %v", cs.UserID(), userID) } @@ -118,7 +118,7 @@ func TestConnStateInitial(t *testing.T) { {0, 9}, }), }}, - }, false) + }, false, time.Now()) if err != nil { t.Fatalf("OnIncomingRequest returned error : %s", err) } @@ -168,7 +168,7 @@ func TestConnStateInitial(t *testing.T) { {0, 9}, }), }}, - }, false) + }, false, time.Now()) if err != nil { t.Fatalf("OnIncomingRequest returned error : %s", err) } @@ -206,7 +206,7 @@ func TestConnStateInitial(t *testing.T) { {0, 9}, }), }}, - }, false) + }, false, time.Now()) if err != nil { t.Fatalf("OnIncomingRequest returned error : %s", err) } @@ -272,7 +272,7 @@ func TestConnStateMultipleRanges(t *testing.T) { userCache.LazyRoomDataOverride = mockLazyRoomOverride dispatcher.Register(context.Background(), userCache.UserID, userCache) dispatcher.Register(context.Background(), sync3.DispatcherAllUsers, globalCache) - cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, 1000) + cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, nil, 1000) // request first page res, err := cs.OnIncomingRequest(context.Background(), ConnID, &sync3.Request{ @@ -282,7 +282,7 @@ func TestConnStateMultipleRanges(t *testing.T) { {0, 2}, }), }}, - }, false) + }, false, time.Now()) if err != nil { t.Fatalf("OnIncomingRequest returned error : %s", err) } @@ -308,7 +308,7 @@ func TestConnStateMultipleRanges(t *testing.T) { {0, 2}, {4, 6}, }), }}, - }, false) + }, false, time.Now()) if err != nil { t.Fatalf("OnIncomingRequest returned error : %s", err) } @@ -343,7 +343,7 @@ func TestConnStateMultipleRanges(t *testing.T) { {0, 2}, {4, 6}, }), }}, - }, false) + }, false, time.Now()) if err != nil { t.Fatalf("OnIncomingRequest returned error : %s", err) } @@ -383,7 +383,7 @@ func TestConnStateMultipleRanges(t *testing.T) { {0, 2}, {4, 6}, }), }}, - }, false) + }, false, time.Now()) if err != nil { t.Fatalf("OnIncomingRequest returned error : %s", err) } @@ -451,7 +451,7 @@ func TestBumpToOutsideRange(t *testing.T) { userCache.LazyRoomDataOverride = mockLazyRoomOverride dispatcher.Register(context.Background(), userCache.UserID, userCache) dispatcher.Register(context.Background(), sync3.DispatcherAllUsers, globalCache) - cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, 1000) + cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, nil, 1000) // Ask for A,B res, err := cs.OnIncomingRequest(context.Background(), ConnID, &sync3.Request{ Lists: map[string]sync3.RequestList{"a": { @@ -460,7 +460,7 @@ func TestBumpToOutsideRange(t *testing.T) { {0, 1}, }), }}, - }, false) + }, false, time.Now()) if err != nil { t.Fatalf("OnIncomingRequest returned error : %s", err) } @@ -495,7 +495,7 @@ func TestBumpToOutsideRange(t *testing.T) { {0, 1}, }), }}, - }, false) + }, false, time.Now()) if err != nil { t.Fatalf("OnIncomingRequest returned error : %s", err) } @@ -562,7 +562,7 @@ func TestConnStateRoomSubscriptions(t *testing.T) { } dispatcher.Register(context.Background(), userCache.UserID, userCache) dispatcher.Register(context.Background(), sync3.DispatcherAllUsers, globalCache) - cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, 1000) + cs := NewConnState(userID, deviceID, userCache, globalCache, &NopExtensionHandler{}, &NopJoinTracker{}, nil, nil, 1000) // subscribe to room D res, err := cs.OnIncomingRequest(context.Background(), ConnID, &sync3.Request{ RoomSubscriptions: map[string]sync3.RoomSubscription{ @@ -576,7 +576,7 @@ func TestConnStateRoomSubscriptions(t *testing.T) { {0, 1}, }), }}, - }, false) + }, false, time.Now()) if err != nil { t.Fatalf("OnIncomingRequest returned error : %s", err) } @@ -630,7 +630,7 @@ func TestConnStateRoomSubscriptions(t *testing.T) { {0, 1}, }), }}, - }, false) + }, false, time.Now()) if err != nil { t.Fatalf("OnIncomingRequest returned error : %s", err) } @@ -664,7 +664,7 @@ func TestConnStateRoomSubscriptions(t *testing.T) { {0, 1}, }), }}, - }, false) + }, false, time.Now()) if err != nil { t.Fatalf("OnIncomingRequest returned error : %s", err) } diff --git a/sync3/handler/ensure_polling.go b/sync3/handler/ensure_polling.go index 76d7cfcd..430d378d 100644 --- a/sync3/handler/ensure_polling.go +++ b/sync3/handler/ensure_polling.go @@ -2,9 +2,11 @@ package handler import ( "context" + "sync" + "github.com/matrix-org/sliding-sync/internal" "github.com/matrix-org/sliding-sync/sync2" - "sync" + "github.com/prometheus/client_golang/prometheus" "github.com/matrix-org/sliding-sync/pubsub" ) @@ -30,15 +32,27 @@ type EnsurePoller struct { // pendingPolls tracks the status of pollers that we are waiting to start. pendingPolls map[sync2.PollerID]pendingInfo notifier pubsub.Notifier + // the total number of outstanding ensurepolling requests. + numPendingEnsurePolling prometheus.Gauge } -func NewEnsurePoller(notifier pubsub.Notifier) *EnsurePoller { - return &EnsurePoller{ +func NewEnsurePoller(notifier pubsub.Notifier, enablePrometheus bool) *EnsurePoller { + p := &EnsurePoller{ chanName: pubsub.ChanV3, mu: &sync.Mutex{}, pendingPolls: make(map[sync2.PollerID]pendingInfo), notifier: notifier, } + if enablePrometheus { + p.numPendingEnsurePolling = prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: "sliding_sync", + Subsystem: "api", + Name: "num_devices_pending_ensure_polling", + Help: "Number of devices blocked on EnsurePolling returning.", + }) + prometheus.MustRegister(p.numPendingEnsurePolling) + } + return p } // EnsurePolling blocks until the V2InitialSyncComplete response is received for this device. It is @@ -73,6 +87,7 @@ func (p *EnsurePoller) EnsurePolling(ctx context.Context, pid sync2.PollerID, to done: false, ch: ch, } + p.calculateNumOutstanding() // increment total p.mu.Unlock() // ask the pollers to poll for this device p.notifier.Notify(p.chanName, &pubsub.V3EnsurePolling{ @@ -116,6 +131,7 @@ func (p *EnsurePoller) OnInitialSyncComplete(payload *pubsub.V2InitialSyncComple pending.done = true pending.ch = nil p.pendingPolls[pid] = pending + p.calculateNumOutstanding() // decrement total log.Trace().Msg("OnInitialSyncComplete: closing channel") close(ch) } @@ -137,4 +153,21 @@ func (p *EnsurePoller) OnExpiredToken(payload *pubsub.V2ExpiredToken) { func (p *EnsurePoller) Teardown() { p.notifier.Close() + if p.numPendingEnsurePolling != nil { + prometheus.Unregister(p.numPendingEnsurePolling) + } +} + +// must hold p.mu +func (p *EnsurePoller) calculateNumOutstanding() { + if p.numPendingEnsurePolling == nil { + return + } + var total int + for _, pi := range p.pendingPolls { + if !pi.done { + total++ + } + } + p.numPendingEnsurePolling.Set(float64(total)) } diff --git a/sync3/handler/handler.go b/sync3/handler/handler.go index fe135560..84c461c3 100644 --- a/sync3/handler/handler.go +++ b/sync3/handler/handler.go @@ -59,8 +59,10 @@ type SyncLiveHandler struct { GlobalCache *caches.GlobalCache maxPendingEventUpdates int - numConns prometheus.Gauge - histVec *prometheus.HistogramVec + numConns prometheus.Gauge + setupHistVec *prometheus.HistogramVec + histVec *prometheus.HistogramVec + slowReqs prometheus.Counter } func NewSync3Handler( @@ -90,7 +92,7 @@ func NewSync3Handler( } // set up pubsub mechanism to start from this point - sh.EnsurePoller = NewEnsurePoller(pub) + sh.EnsurePoller = NewEnsurePoller(pub, enablePrometheus) sh.V2Sub = pubsub.NewV2Sub(sub, sh) return sh, nil @@ -129,9 +131,15 @@ func (h *SyncLiveHandler) Teardown() { if h.numConns != nil { prometheus.Unregister(h.numConns) } + if h.setupHistVec != nil { + prometheus.Unregister(h.setupHistVec) + } if h.histVec != nil { prometheus.Unregister(h.histVec) } + if h.slowReqs != nil { + prometheus.Unregister(h.slowReqs) + } } func (h *SyncLiveHandler) updateMetrics() { @@ -148,15 +156,30 @@ func (h *SyncLiveHandler) addPrometheusMetrics() { Name: "num_active_conns", Help: "Number of active sliding sync connections.", }) + h.setupHistVec = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: "sliding_sync", + Subsystem: "api", + Name: "setup_duration_secs", + Help: "Time taken in seconds after receiving a request before we start calculating a sliding sync response.", + Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10}, + }, []string{"initial"}) h.histVec = prometheus.NewHistogramVec(prometheus.HistogramOpts{ Namespace: "sliding_sync", Subsystem: "api", Name: "process_duration_secs", - Help: "Time taken in seconds for the sliding sync response to calculated, excludes long polling", + Help: "Time taken in seconds for the sliding sync response to be calculated, excludes long polling", Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10}, }, []string{"initial"}) + h.slowReqs = prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "sliding_sync", + Subsystem: "api", + Name: "slow_requests", + Help: "Counter of slow (>=50s) requests, initial or otherwise.", + }) prometheus.MustRegister(h.numConns) + prometheus.MustRegister(h.setupHistVec) prometheus.MustRegister(h.histVec) + prometheus.MustRegister(h.slowReqs) } func (h *SyncLiveHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -173,9 +196,13 @@ func (h *SyncLiveHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { Err: err, } } - // artificially wait a bit before sending back the error - // this guards against tightlooping when the client hammers the server with invalid requests - time.Sleep(time.Second) + if herr.ErrCode != "M_UNKNOWN_POS" { + // artificially wait a bit before sending back the error + // this guards against tightlooping when the client hammers the server with invalid requests, + // but not for M_UNKNOWN_POS which we expect to send back after expiring a client's connection. + // We want to recover rapidly in that scenario, hence not sleeping. + time.Sleep(time.Second) + } w.WriteHeader(herr.StatusCode) w.Write(herr.JSON()) } @@ -183,6 +210,16 @@ func (h *SyncLiveHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { // Entry point for sync v3 func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error { + start := time.Now() + defer func() { + dur := time.Since(start) + if dur > 50*time.Second { + if h.slowReqs != nil { + h.slowReqs.Add(1.0) + } + internal.DecorateLogger(req.Context(), log.Warn()).Dur("duration", dur).Msg("slow request") + } + }() var requestBody sync3.Request if req.ContentLength != 0 { defer req.Body.Close() @@ -232,7 +269,6 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error return herr } requestBody.SetPos(cpos) - internal.SetRequestContextUserID(req.Context(), conn.UserID) log := hlog.FromRequest(req).With().Str("user", conn.UserID).Int64("pos", cpos).Logger() var timeout int @@ -249,7 +285,7 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error requestBody.SetTimeoutMSecs(timeout) log.Trace().Int("timeout", timeout).Msg("recv") - resp, herr := conn.OnIncomingRequest(req.Context(), &requestBody) + resp, herr := conn.OnIncomingRequest(req.Context(), &requestBody, start) if herr != nil { logErrorOrWarning("failed to OnIncomingRequest", herr) return herr @@ -334,6 +370,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ } } log := hlog.FromRequest(req).With().Str("user", token.UserID).Str("device", token.DeviceID).Logger() + internal.SetRequestContextUserID(req.Context(), token.UserID, token.DeviceID) internal.Logf(taskCtx, "setupConnection", "identified access token as user=%s device=%s", token.UserID, token.DeviceID) // Record the fact that we've recieved a request from this token @@ -361,8 +398,8 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ } pid := sync2.PollerID{UserID: token.UserID, DeviceID: token.DeviceID} - log.Trace().Any("pid", pid).Msg("checking poller exists and is running") - h.EnsurePoller.EnsurePolling(taskCtx, pid, token.AccessTokenHash) + log.Trace().Any("pid", pid).Msg("checking poller exists and is running") + h.EnsurePoller.EnsurePolling(req.Context(), pid, token.AccessTokenHash) log.Trace().Msg("poller exists and is running") // this may take a while so if the client has given up (e.g timed out) by this point, just stop. // We'll be quicker next time as the poller will already exist. @@ -392,7 +429,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ // to check for an existing connection though, as it's possible for the client to call /sync // twice for a new connection. conn, created := h.ConnMap.CreateConn(connID, func() sync3.ConnHandler { - return NewConnState(token.UserID, token.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.histVec, h.maxPendingEventUpdates) + return NewConnState(token.UserID, token.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.setupHistVec, h.histVec, h.maxPendingEventUpdates) }) if created { log.Info().Msg("created new connection") @@ -634,9 +671,14 @@ func (h *SyncLiveHandler) OnUnreadCounts(p *pubsub.V2UnreadCounts) { func (h *SyncLiveHandler) OnDeviceData(p *pubsub.V2DeviceData) { ctx, task := internal.StartTask(context.Background(), "OnDeviceData") defer task.End() - conns := h.ConnMap.Conns(p.UserID, p.DeviceID) - for _, conn := range conns { - conn.OnUpdate(ctx, caches.DeviceDataUpdate{}) + internal.Logf(ctx, "device_data", fmt.Sprintf("%v users to notify", len(p.UserIDToDeviceIDs))) + for userID, deviceIDs := range p.UserIDToDeviceIDs { + for _, deviceID := range deviceIDs { + conns := h.ConnMap.Conns(userID, deviceID) + for _, conn := range conns { + conn.OnUpdate(ctx, caches.DeviceDataUpdate{}) + } + } } } diff --git a/sync3/room.go b/sync3/room.go index 388b06b7..6bf7c35d 100644 --- a/sync3/room.go +++ b/sync3/room.go @@ -2,6 +2,7 @@ package sync3 import ( "encoding/json" + "github.com/matrix-org/sliding-sync/internal" "github.com/matrix-org/sliding-sync/sync3/caches" @@ -17,7 +18,7 @@ type Room struct { Initial bool `json:"initial,omitempty"` IsDM bool `json:"is_dm,omitempty"` JoinedCount int `json:"joined_count,omitempty"` - InvitedCount int `json:"invited_count,omitempty"` + InvitedCount *int `json:"invited_count,omitempty"` PrevBatch string `json:"prev_batch,omitempty"` NumLive int `json:"num_live,omitempty"` } diff --git a/tests-e2e/main_test.go b/tests-e2e/main_test.go index 6d5805fb..52476773 100644 --- a/tests-e2e/main_test.go +++ b/tests-e2e/main_test.go @@ -159,7 +159,7 @@ func MatchRoomInviteState(events []Event, partial bool) m.RoomMatcher { } } if !found { - return fmt.Errorf("MatchRoomInviteState: want event %+v but it does not exist", want) + return fmt.Errorf("MatchRoomInviteState: want event %+v but it does not exist or failed to pass equality checks", want) } } return nil diff --git a/tests-e2e/membership_transitions_test.go b/tests-e2e/membership_transitions_test.go index 9bfb7e1f..2271d78e 100644 --- a/tests-e2e/membership_transitions_test.go +++ b/tests-e2e/membership_transitions_test.go @@ -64,7 +64,22 @@ func TestRoomStateTransitions(t *testing.T) { m.MatchRoomHighlightCount(1), m.MatchRoomInitial(true), m.MatchRoomRequiredState(nil), - // TODO m.MatchRoomInviteState(inviteStrippedState.InviteState.Events), + m.MatchInviteCount(1), + m.MatchJoinCount(1), + MatchRoomInviteState([]Event{ + { + Type: "m.room.create", + StateKey: ptr(""), + // no content as it includes the room version which we don't want to guess/hardcode + }, + { + Type: "m.room.join_rules", + StateKey: ptr(""), + Content: map[string]interface{}{ + "join_rule": "public", + }, + }, + }, true), }, joinRoomID: {}, }), @@ -105,6 +120,8 @@ func TestRoomStateTransitions(t *testing.T) { }, }), m.MatchRoomInitial(true), + m.MatchJoinCount(2), + m.MatchInviteCount(0), m.MatchRoomHighlightCount(0), )) } @@ -467,7 +484,7 @@ func TestMemberCounts(t *testing.T) { m.MatchResponse(t, res, m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{ secondRoomID: { m.MatchRoomInitial(false), - m.MatchInviteCount(0), + m.MatchNoInviteCount(), m.MatchJoinCount(0), // omitempty }, })) @@ -486,7 +503,7 @@ func TestMemberCounts(t *testing.T) { m.MatchResponse(t, res, m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{ secondRoomID: { m.MatchRoomInitial(false), - m.MatchInviteCount(0), + m.MatchNoInviteCount(), m.MatchJoinCount(2), }, })) diff --git a/tests-e2e/num_live_test.go b/tests-e2e/num_live_test.go index 5a47a02f..acc97261 100644 --- a/tests-e2e/num_live_test.go +++ b/tests-e2e/num_live_test.go @@ -1,10 +1,13 @@ package syncv3_test import ( + "fmt" "testing" + "time" "github.com/matrix-org/sliding-sync/sync3" "github.com/matrix-org/sliding-sync/testutils/m" + "github.com/tidwall/gjson" ) func TestNumLive(t *testing.T) { @@ -126,3 +129,70 @@ func TestNumLive(t *testing.T) { }, })) } + +// Test that if you constantly change req params, we still see live traffic. It does this by: +// - Creating 11 rooms. +// - Hitting /sync with a range [0,1] then [0,2] then [0,3]. Each time this causes a new room to be returned. +// - Interleaving each /sync request with genuine events sent into a room. +// - ensuring we see the genuine events by the time we finish. +func TestReqParamStarvation(t *testing.T) { + alice := registerNewUser(t) + bob := registerNewUser(t) + roomID := alice.CreateRoom(t, map[string]interface{}{ + "preset": "public_chat", + }) + numOtherRooms := 10 + for i := 0; i < numOtherRooms; i++ { + bob.CreateRoom(t, map[string]interface{}{ + "preset": "public_chat", + }) + } + bob.JoinRoom(t, roomID, nil) + res := bob.SlidingSyncUntilMembership(t, "", roomID, bob, "join") + + wantEventIDs := make(map[string]bool) + for i := 0; i < numOtherRooms; i++ { + res = bob.SlidingSync(t, sync3.Request{ + Lists: map[string]sync3.RequestList{ + "a": { + Ranges: sync3.SliceRanges{{0, int64(i)}}, // [0,0], [0,1], ... [0,9] + }, + }, + }, WithPos(res.Pos)) + + // mark off any event we see in wantEventIDs + for _, r := range res.Rooms { + for _, ev := range r.Timeline { + gotEventID := gjson.GetBytes(ev, "event_id").Str + wantEventIDs[gotEventID] = false + } + } + + // send an event in the first few syncs to add to wantEventIDs + // We do this for the first few /syncs and don't dictate which response they should arrive + // in, as we do not know and cannot force the proxy to deliver the event in a particular response. + if i < 3 { + eventID := alice.SendEventSynced(t, roomID, Event{ + Type: "m.room.message", + Content: map[string]interface{}{ + "msgtype": "m.text", + "body": fmt.Sprintf("msg %d", i), + }, + }) + wantEventIDs[eventID] = true + } + + // it's possible the proxy won't see this event before the next /sync + // and that is the reason why we don't send it, as opposed to starvation. + // To try to counter this, sleep a bit. This is why we sleep on every cycle and + // why we send the events early on. + time.Sleep(50 * time.Millisecond) + } + + // at this point wantEventIDs should all have false values if we got the events + for evID, unseen := range wantEventIDs { + if unseen { + t.Errorf("failed to see event %v", evID) + } + } +} diff --git a/tests-integration/regressions_test.go b/tests-integration/regressions_test.go new file mode 100644 index 00000000..d24869ce --- /dev/null +++ b/tests-integration/regressions_test.go @@ -0,0 +1,112 @@ +package syncv3 + +import ( + "encoding/json" + "testing" + "time" + + "github.com/matrix-org/sliding-sync/sync2" + "github.com/matrix-org/sliding-sync/sync3" + "github.com/matrix-org/sliding-sync/testutils" + "github.com/matrix-org/sliding-sync/testutils/m" +) + +// catch all file for any kind of regression test which doesn't fall into a unique category + +// Regression test for https://github.com/matrix-org/sliding-sync/issues/192 +// - Bob on his server invites Alice to a room. +// - Alice joins the room first over federation. Proxy does the right thing and sets her membership to join. There is no timeline though due to not having backfilled. +// - Alice's client backfills in the room which pulls in the invite event, but the SS proxy doesn't see it as it's backfill, not /sync. +// - Charlie joins the same room via SS, which makes the SS proxy see 50 timeline events, which includes the invite. +// As the proxy has never seen this invite event before, it assumes it is newer than the join event and inserts it, corrupting state. +// +// Manually confirmed this can happen with 3x Element clients. We need to make sure we drop those earlier events. +// The first join over federation presents itself as a single join event in the timeline, with the create event, etc in state. +func TestBackfillInviteDoesntCorruptState(t *testing.T) { + pqString := testutils.PrepareDBConnectionString() + // setup code + v2 := runTestV2Server(t) + v3 := runTestServer(t, v2, pqString) + defer v2.close() + defer v3.close() + + fedBob := "@bob:over_federation" + charlie := "@charlie:localhost" + charlieToken := "CHARLIE_TOKEN" + joinEvent := testutils.NewJoinEvent(t, alice) + + room := roomEvents{ + roomID: "!TestBackfillInviteDoesntCorruptState:localhost", + events: []json.RawMessage{ + joinEvent, + }, + state: createRoomState(t, fedBob, time.Now()), + } + v2.addAccount(t, alice, aliceToken) + v2.queueResponse(alice, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Join: v2JoinTimeline(room), + }, + }) + + // alice syncs and should see the room. + aliceRes := v3.mustDoV3Request(t, aliceToken, sync3.Request{ + Lists: map[string]sync3.RequestList{ + "a": { + Ranges: sync3.SliceRanges{{0, 20}}, + RoomSubscription: sync3.RoomSubscription{ + TimelineLimit: 5, + }, + }, + }, + }) + m.MatchResponse(t, aliceRes, m.MatchList("a", m.MatchV3Count(1), m.MatchV3Ops(m.MatchV3SyncOp(0, 0, []string{room.roomID})))) + + // Alice's client "backfills" new data in, meaning the next user who joins is going to see a different set of timeline events + dummyMsg := testutils.NewMessageEvent(t, fedBob, "you didn't see this before joining") + charlieJoinEvent := testutils.NewJoinEvent(t, charlie) + backfilledTimelineEvents := append( + room.state, []json.RawMessage{ + dummyMsg, + testutils.NewStateEvent(t, "m.room.member", alice, fedBob, map[string]interface{}{ + "membership": "invite", + }), + joinEvent, + charlieJoinEvent, + }..., + ) + + // now charlie also joins the room, causing a different response from /sync v2 + v2.addAccount(t, charlie, charlieToken) + v2.queueResponse(charlie, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Join: v2JoinTimeline(roomEvents{ + roomID: room.roomID, + events: backfilledTimelineEvents, + }), + }, + }) + + // and now charlie hits SS, which might corrupt membership state for alice. + charlieRes := v3.mustDoV3Request(t, charlieToken, sync3.Request{ + Lists: map[string]sync3.RequestList{ + "a": { + Ranges: sync3.SliceRanges{{0, 20}}, + }, + }, + }) + m.MatchResponse(t, charlieRes, m.MatchList("a", m.MatchV3Count(1), m.MatchV3Ops(m.MatchV3SyncOp(0, 0, []string{room.roomID})))) + + // alice should not see dummyMsg or the invite + aliceRes = v3.mustDoV3RequestWithPos(t, aliceToken, aliceRes.Pos, sync3.Request{}) + m.MatchResponse(t, aliceRes, m.MatchNoV3Ops(), m.LogResponse(t), m.MatchRoomSubscriptionsStrict( + map[string][]m.RoomMatcher{ + room.roomID: { + m.MatchJoinCount(3), // alice, bob, charlie, + m.MatchNoInviteCount(), + m.MatchNumLive(1), + m.MatchRoomTimeline([]json.RawMessage{charlieJoinEvent}), + }, + }, + )) +} diff --git a/tests-integration/room_subscriptions_test.go b/tests-integration/room_subscriptions_test.go index ba89ad90..3b4bbd5c 100644 --- a/tests-integration/room_subscriptions_test.go +++ b/tests-integration/room_subscriptions_test.go @@ -137,12 +137,9 @@ func TestRoomSubscriptionMisorderedTimeline(t *testing.T) { }) m.MatchResponse(t, res, m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{ room.roomID: { - // TODO: this is the correct result, but due to how timeline loading works currently - // it will be returning the last 5 events BEFORE D,E, which isn't ideal but also isn't - // incorrect per se due to the fact that clients don't know when D,E have been processed - // on the server. - // m.MatchRoomTimeline(append(abcInitialEvents, deLiveEvents...)), - m.MatchRoomTimeline(append(roomState[len(roomState)-2:], abcInitialEvents...)), + // we append live events AFTER processing the new timeline limit, so 7 events not 5. + // TODO: ideally we'd just return abcde here. + m.MatchRoomTimeline(append(roomState[len(roomState)-2:], append(abcInitialEvents, deLiveEvents...)...)), }, }), m.LogResponse(t)) diff --git a/testutils/m/match.go b/testutils/m/match.go index e44d0d82..a7b2f635 100644 --- a/testutils/m/match.go +++ b/testutils/m/match.go @@ -48,10 +48,22 @@ func MatchJoinCount(count int) RoomMatcher { } } +func MatchNoInviteCount() RoomMatcher { + return func(r sync3.Room) error { + if r.InvitedCount != nil { + return fmt.Errorf("MatchInviteCount: invited_count is present when it should be missing: val=%v", *r.InvitedCount) + } + return nil + } +} + func MatchInviteCount(count int) RoomMatcher { return func(r sync3.Room) error { - if r.InvitedCount != count { - return fmt.Errorf("MatchInviteCount: got %v want %v", r.InvitedCount, count) + if r.InvitedCount == nil { + return fmt.Errorf("MatchInviteCount: invited_count is missing") + } + if *r.InvitedCount != count { + return fmt.Errorf("MatchInviteCount: got %v want %v", *r.InvitedCount, count) } return nil } diff --git a/v3.go b/v3.go index 74ae2be2..cc72d093 100644 --- a/v3.go +++ b/v3.go @@ -92,8 +92,10 @@ func Setup(destHomeserver, postgresURI, secret string, opts Opts) (*handler2.Han } } bufferSize := 50 + deviceDataUpdateFrequency := time.Second if opts.TestingSynchronousPubsub { bufferSize = 0 + deviceDataUpdateFrequency = 0 // don't batch } if opts.MaxPendingEventUpdates == 0 { opts.MaxPendingEventUpdates = 2000 @@ -102,7 +104,7 @@ func Setup(destHomeserver, postgresURI, secret string, opts Opts) (*handler2.Han pMap := sync2.NewPollerMap(v2Client, opts.AddPrometheusMetrics) // create v2 handler - h2, err := handler2.NewHandler(pMap, storev2, store, pubSub, pubSub, opts.AddPrometheusMetrics) + h2, err := handler2.NewHandler(pMap, storev2, store, pubSub, pubSub, opts.AddPrometheusMetrics, deviceDataUpdateFrequency) if err != nil { panic(err) }