Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: correctly tell clients when the fallback key has been used #390

Merged
merged 2 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,12 @@ go test -p 1 -count 1 $(go list ./... | grep -v tests-e2e) -timeout 120s
Run end-to-end tests:

```shell
# Run each line in a separate terminal windows. Will need to `docker login`
# to ghcr and pull the image.
docker run --rm -e "SYNAPSE_COMPLEMENT_DATABASE=sqlite" -e "SERVER_NAME=synapse" -p 8888:8008 ghcr.io/matrix-org/synapse-service:v1.72.0
# Will need to `docker login` to ghcr and pull the image.
docker run -d --rm -e "SYNAPSE_COMPLEMENT_DATABASE=sqlite" -e "SERVER_NAME=synapse" -p 8888:8008 ghcr.io/matrix-org/synapse-service:v1.94.0

export SYNCV3_SECRET=foobar
export SYNCV3_SERVER=http://localhost:8888
export SYNCV3_DB="user=$(whoami) dbname=syncv3_test sslmode=disable"

(go build ./cmd/syncv3 && dropdb syncv3_test && createdb syncv3_test && cd tests-e2e && ./run-tests.sh -count=1 .)
```
1 change: 1 addition & 0 deletions internal/device_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type DeviceData struct {
OTKCounts MapStringInt `json:"otk"`
// Contains the latest device_unused_fallback_key_types value
// Set whenever this field arrives down the v2 poller, and it replaces what was previously there.
// If this is a nil slice this means no change. If this is an empty slice then this means the fallback key was used up.
FallbackKeyTypes []string `json:"fallback"`

DeviceLists DeviceLists `json:"dl"`
Expand Down
32 changes: 21 additions & 11 deletions sync2/poller.go
Original file line number Diff line number Diff line change
Expand Up @@ -727,20 +727,30 @@ func (p *poller) parseE2EEData(ctx context.Context, res *SyncResponse) error {
}
shouldSetOTKs = true
}
var changedFallbackTypes []string
var changedFallbackTypes []string // nil slice == don't set, empty slice = no fallback key
shouldSetFallbackKeys := false
if len(res.DeviceUnusedFallbackKeyTypes) > 0 {
if len(p.fallbackKeyTypes) != len(res.DeviceUnusedFallbackKeyTypes) {
changedFallbackTypes = res.DeviceUnusedFallbackKeyTypes
} else {
for i := range res.DeviceUnusedFallbackKeyTypes {
if res.DeviceUnusedFallbackKeyTypes[i] != p.fallbackKeyTypes[i] {
changedFallbackTypes = res.DeviceUnusedFallbackKeyTypes
break
}
if len(p.fallbackKeyTypes) != len(res.DeviceUnusedFallbackKeyTypes) {
// length mismatch always causes an update
changedFallbackTypes = res.DeviceUnusedFallbackKeyTypes
shouldSetFallbackKeys = true
} else {
// lengths match, if they are non-zero then compare each element.
// if they are zero, check for nil vs empty slice.
if len(res.DeviceUnusedFallbackKeyTypes) == 0 {
isCurrentNil := res.DeviceUnusedFallbackKeyTypes == nil
isPreviousNil := p.fallbackKeyTypes == nil
if isCurrentNil != isPreviousNil {
shouldSetFallbackKeys = true
changedFallbackTypes = []string{}
}
}
for i := range res.DeviceUnusedFallbackKeyTypes {
if res.DeviceUnusedFallbackKeyTypes[i] != p.fallbackKeyTypes[i] {
changedFallbackTypes = res.DeviceUnusedFallbackKeyTypes
shouldSetFallbackKeys = true
break
}
}
shouldSetFallbackKeys = true
}

deviceListChanges := internal.ToDeviceListChangesMap(res.DeviceLists.Changed, res.DeviceLists.Left)
Expand Down
6 changes: 3 additions & 3 deletions sync3/extensions/e2ee.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func (r *E2EERequest) Name() string {
type E2EEResponse struct {
OTKCounts map[string]int `json:"device_one_time_keys_count,omitempty"`
DeviceLists *E2EEDeviceList `json:"device_lists,omitempty"`
FallbackKeyTypes []string `json:"device_unused_fallback_key_types,omitempty"`
FallbackKeyTypes *[]string `json:"device_unused_fallback_key_types,omitempty"`
}

type E2EEDeviceList struct {
Expand All @@ -37,7 +37,7 @@ func (r *E2EEResponse) HasData(isInitial bool) bool {
if isInitial {
return true // ensure we send OTK counts immediately
}
return r.DeviceLists != nil || len(r.FallbackKeyTypes) > 0 || len(r.OTKCounts) > 0
return r.DeviceLists != nil || r.FallbackKeyTypes != nil || len(r.OTKCounts) > 0
}

func (r *E2EERequest) AppendLive(ctx context.Context, res *Response, extCtx Context, up caches.Update) {
Expand All @@ -63,7 +63,7 @@ func (r *E2EERequest) ProcessInitial(ctx context.Context, res *Response, extCtx
extRes := &E2EEResponse{}
hasUpdates := false
if dd.FallbackKeyTypes != nil && (dd.FallbackKeysChanged() || extCtx.IsInitial) {
extRes.FallbackKeyTypes = dd.FallbackKeyTypes
extRes.FallbackKeyTypes = &dd.FallbackKeyTypes
hasUpdates = true
}
if dd.OTKCounts != nil && (dd.OTKCountChanged() || extCtx.IsInitial) {
Expand Down
287 changes: 287 additions & 0 deletions tests-e2e/encryption_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
package syncv3_test

import (
"encoding/json"
"fmt"
"testing"

"github.com/matrix-org/complement/b"
"github.com/matrix-org/complement/client"
"github.com/matrix-org/sliding-sync/sync3"
"github.com/matrix-org/sliding-sync/sync3/extensions"
"github.com/matrix-org/sliding-sync/testutils/m"
)

func TestEncryptionFallbackKey(t *testing.T) {
alice := registerNewUser(t)
bob := registerNewUser(t)
roomID := alice.MustCreateRoom(t, map[string]interface{}{
"preset": "public_chat",
})
bob.JoinRoom(t, roomID, nil)

// snaffled from rust SDK
keysUploadBody := fmt.Sprintf(`{
"device_keys": {
"algorithms": [
"m.olm.v1.curve25519-aes-sha2",
"m.megolm.v1.aes-sha2"
],
"device_id": "MUPCQIATEC",
"keys": {
"curve25519:MUPCQIATEC": "NroPrV4HHJ/Wj0A0XMrHt7IuThVnwpT6tRZXQXkO4kI",
"ed25519:MUPCQIATEC": "G9zNR/pZb24Rm0FXiQYutSzcbQvii+AZn/4cmi6LOUI"
},
"signatures": {
"%s": {
"ed25519:MUPCQIATEC": "2CHK2tJO/p2OiNWC2jLKsH5t+pHwnomSHOIpAPuEVi2vJZ4BRRsb4tSFYzEx4cUDg3KCYjoQuCymYHpnk1uqDQ"
}
},
"user_id": "%s"
},
"fallback_keys": {
"signed_curve25519:AAAAAAAAAAA": {
"fallback": true,
"key": "s5+eOJYK1s5xPt51BlYEXx8fQ8NqpwAUjE1mVxw05V8",
"signatures": {
"%s": {
"ed25519:MUPCQIATEC": "TLGi0LJEDxgt37gBCpd8huZa72h0UTB8jIEUoTz/rjbCcGQo1xOlvA5rU+RoTkF1KwVtduOMbZcSGg4ZTfBkDQ"
}
}
}
},
"one_time_keys": {
"signed_curve25519:AAAAAAAAAA0": {
"key": "IuCQvr2AaZC70tCG6g1ZardACNe3mcKZ2PjKJ2p49UM",
"signatures": {
"%s": {
"ed25519:MUPCQIATEC": "FXBkzwuLkfriWJ1B2z9wTHvi7WTOZGvs2oSNJ7CycXJYC6k06sa7a+OMQtpMP2RTuIpiYC+wZ3nFoKp1FcCcBQ"
}
}
},
"signed_curve25519:AAAAAAAAAA4": {
"key": "pgeLFCJPLYUtyLPKDPr76xRYgPjjY4/lEUH98tExxCo",
"signatures": {
"%s": {
"ed25519:MUPCQIATEC": "/o44D5qjTdiYORSXmCVYE3Vzvbz2OlIBC58ELe+EAAgIZTJyDxmBJIFotP6CIuFmB/p4lGCd41Fb6T5BnmLvBQ"
}
}
},
"signed_curve25519:AAAAAAAAAA8": {
"key": "gAhoEOtrGTEG+gfAsCU+JS7+wJTlC51+kZ9vLr9BZGA",
"signatures": {
"%s": {
"ed25519:MUPCQIATEC": "DLDj1c2UncqcCrEwSUEf31ni6W+E6D58EEGFIWj++ydBxuiEnHqFMF7AZU8GGcjQBDIH13uNe8xxO7/KeBbUDQ"
}
}
}
}
}`, bob.UserID, bob.UserID, bob.UserID, bob.UserID, bob.UserID, bob.UserID)

bob.MustDo(t, "POST", []string{"_matrix", "client", "v3", "keys", "upload"},
client.WithRawBody([]byte(keysUploadBody)), client.WithContentType("application/json"),
)

res := bob.SlidingSync(t, sync3.Request{
Extensions: extensions.Request{
E2EE: &extensions.E2EERequest{
Core: extensions.Core{
Enabled: &boolTrue,
},
},
},
})
m.MatchResponse(t, res, m.MatchFallbackKeyTypes([]string{"signed_curve25519"}), m.MatchOTKCounts(map[string]int{
"signed_curve25519": 3,
}))

// claim a OTK, it should decrease the count
mustClaimOTK(t, alice, bob)
// claiming OTKs does not wake up the sync loop, so send something to kick it.
alice.MustSendTyping(t, roomID, true, 1000)
res = bob.SlidingSyncUntil(t, res.Pos, sync3.Request{},
// OTK was claimed so change should be included.
// fallback key was not touched so should be missing.
MatchOTKAndFallbackTypes(map[string]int{
"signed_curve25519": 2,
}, nil),
)

mustClaimOTK(t, alice, bob)
alice.MustSendTyping(t, roomID, false, 1000)
res = bob.SlidingSyncUntil(t, res.Pos, sync3.Request{},
// OTK was claimed so change should be included.
// fallback key was not touched so should be missing.
MatchOTKAndFallbackTypes(map[string]int{
"signed_curve25519": 1,
}, nil),
)

mustClaimOTK(t, alice, bob)
alice.MustSendTyping(t, roomID, true, 1000)
res = bob.SlidingSyncUntil(t, res.Pos, sync3.Request{},
// OTK was claimed so change should be included.
// fallback key was not touched so should be missing.
MatchOTKAndFallbackTypes(map[string]int{
"signed_curve25519": 0,
}, nil),
)

mustClaimOTK(t, alice, bob)
alice.MustSendTyping(t, roomID, false, 1000)
res = bob.SlidingSyncUntil(t, res.Pos, sync3.Request{},
// no OTK change here so it shouldn't be included.
// we should be explicitly sent device_unused_fallback_key_types: []
MatchOTKAndFallbackTypes(nil, []string{}),
)

// now re-upload a fallback key, it should be repopulated.
keysUploadBody = fmt.Sprintf(`{
"fallback_keys": {
"signed_curve25519:AAAAAAAAADA": {
"fallback": true,
"key": "N8DKj83RTN7lLZrH6shMqHbVhNrxd96OQseQVFmNgTU",
"signatures": {
"%s": {
"ed25519:MUPCQIATEC": "ZnKsVcNmOLBv0LMGeNpCfCO2am9L223EiyddWPx9wPOtuYt6KZIPox/SFwVmqBwkUdnmeTb6tVgCpZwcH8doDw"
}
}
}
}
}`, bob.UserID)
bob.MustDo(t, "POST", []string{"_matrix", "client", "v3", "keys", "upload"},
client.WithRawBody([]byte(keysUploadBody)), client.WithContentType("application/json"),
)

alice.MustSendTyping(t, roomID, true, 1000)
res = bob.SlidingSyncUntil(t, res.Pos, sync3.Request{},
// no OTK change here so it shouldn't be included.
// we should be explicitly sent device_unused_fallback_key_types: ["signed_curve25519"]
MatchOTKAndFallbackTypes(nil, []string{"signed_curve25519"}),
)

// another claim should remove it
mustClaimOTK(t, alice, bob)

alice.MustSendTyping(t, roomID, false, 1000)
res = bob.SlidingSyncUntil(t, res.Pos, sync3.Request{},
// no OTK change here so it shouldn't be included.
// we should be explicitly sent device_unused_fallback_key_types: []
MatchOTKAndFallbackTypes(nil, []string{}),
)
}

// Regression test to make sure EX uploads a fallback key initially.
// EX relies on device_unused_fallback_key_types: [] being present in the
// sync response before it will upload any fallback keys at all, it doesn't
// automatically do it on first login.
func TestEncryptionFallbackKeyToldIfMissingInitially(t *testing.T) {
alice := registerNewUser(t)
bob := registerNewUser(t)
roomID := alice.MustCreateRoom(t, map[string]interface{}{
"preset": "public_chat",
})
bob.JoinRoom(t, roomID, nil)
res := bob.SlidingSync(t, sync3.Request{
Extensions: extensions.Request{
E2EE: &extensions.E2EERequest{
Core: extensions.Core{
Enabled: &boolTrue,
},
},
},
})
m.MatchResponse(t, res, m.MatchFallbackKeyTypes([]string{}))

// upload a fallback key and do another initial request => should include key
keysUploadBody := fmt.Sprintf(`{
"fallback_keys": {
"signed_curve25519:AAAAAAAAADA": {
"fallback": true,
"key": "N8DKj83RTN7lLZrH6shMqHbVhNrxd96OQseQVFmNgTU",
"signatures": {
"%s": {
"ed25519:MUPCQIATEC": "ZnKsVcNmOLBv0LMGeNpCfCO2am9L223EiyddWPx9wPOtuYt6KZIPox/SFwVmqBwkUdnmeTb6tVgCpZwcH8doDw"
}
}
}
}
}`, bob.UserID)
bob.MustDo(t, "POST", []string{"_matrix", "client", "v3", "keys", "upload"},
client.WithRawBody([]byte(keysUploadBody)), client.WithContentType("application/json"),
)
sentinelEventID := bob.SendEventSynced(t, roomID, b.Event{
Type: "m.room.message",
Content: map[string]interface{}{
"msgtype": "m.text",
"body": "Sentinel",
},
})
bob.SlidingSyncUntilEventID(t, "", roomID, sentinelEventID)
res = bob.SlidingSync(t, sync3.Request{
Extensions: extensions.Request{
E2EE: &extensions.E2EERequest{
Core: extensions.Core{
Enabled: &boolTrue,
},
},
},
})
m.MatchResponse(t, res, m.MatchFallbackKeyTypes([]string{"signed_curve25519"}))

// consume the fallback key and do another initial request => should be []
mustClaimOTK(t, alice, bob)
sentinelEventID = bob.SendEventSynced(t, roomID, b.Event{
Type: "m.room.message",
Content: map[string]interface{}{
"msgtype": "m.text",
"body": "Sentinel 2",
},
})
bob.SlidingSyncUntilEventID(t, "", roomID, sentinelEventID)
res = bob.SlidingSync(t, sync3.Request{
Extensions: extensions.Request{
E2EE: &extensions.E2EERequest{
Core: extensions.Core{
Enabled: &boolTrue,
},
},
},
})
m.MatchResponse(t, res, m.MatchFallbackKeyTypes([]string{}))
}

func MatchOTKAndFallbackTypes(otkCount map[string]int, fallbackKeyTypes []string) m.RespMatcher {
return func(r *sync3.Response) error {
err := m.MatchOTKCounts(otkCount)(r)
if err != nil {
return err
}
// we should explicitly be sent device_unused_fallback_key_types: []
return m.MatchFallbackKeyTypes(fallbackKeyTypes)(r)
}
}

func mustClaimOTK(t *testing.T, claimer, claimee *CSAPI) {
claimRes := claimer.MustDo(t, "POST", []string{"_matrix", "client", "v3", "keys", "claim"}, client.WithJSONBody(t, map[string]any{
"one_time_keys": map[string]any{
claimee.UserID: map[string]any{
claimee.DeviceID: "signed_curve25519",
},
},
}))
var res struct {
Failures map[string]any `json:"failures"`
OTKs map[string]map[string]any `json:"one_time_keys"`
}
if err := json.NewDecoder(claimRes.Body).Decode(&res); err != nil {
t.Fatalf("failed to decode OTK response: %s", err)
}
if len(res.Failures) > 0 {
t.Fatalf("OTK response had failures: %+v", res.Failures)
}
otk := res.OTKs[claimee.UserID][claimee.DeviceID]
if otk == nil {
t.Fatalf("OTK was not claimed for %s|%s", claimee.UserID, claimee.DeviceID)
}
}
8 changes: 7 additions & 1 deletion testutils/m/match.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,13 @@ func MatchFallbackKeyTypes(fallbackKeyTypes []string) RespMatcher {
if res.Extensions.E2EE == nil {
return fmt.Errorf("MatchFallbackKeyTypes: no E2EE extension present")
}
if !reflect.DeepEqual(res.Extensions.E2EE.FallbackKeyTypes, fallbackKeyTypes) {
if res.Extensions.E2EE.FallbackKeyTypes == nil { // not supplied
if fallbackKeyTypes == nil {
return nil
}
return fmt.Errorf("MatchFallbackKeyTypes: FallbackKeyTypes is missing but want %v", fallbackKeyTypes)
}
if !reflect.DeepEqual(*res.Extensions.E2EE.FallbackKeyTypes, fallbackKeyTypes) {
return fmt.Errorf("MatchFallbackKeyTypes: got %v want %v", res.Extensions.E2EE.FallbackKeyTypes, fallbackKeyTypes)
}
return nil
Expand Down
Loading