Skip to content

Commit

Permalink
Add FailedToDecrypt flag and impl it, along with MustGetEvent. Rust r…
Browse files Browse the repository at this point in the history
…ace fixes
  • Loading branch information
kegsay committed Nov 17, 2023
1 parent 044a859 commit b8832d6
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 18 deletions.
3 changes: 2 additions & 1 deletion internal/api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ type Event struct {
// FFI bindings don't expose state key
Target string
// FFI bindings don't expose type
Membership string
Membership string
FailedToDecrypt bool
}

type Waiter interface {
Expand Down
37 changes: 36 additions & 1 deletion internal/api/js.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/chromedp/chromedp"
"github.com/matrix-org/complement-crypto/internal/chrome"
"github.com/matrix-org/complement/must"
"github.com/tidwall/gjson"
)

const CONSOLE_LOG_CONTROL_STRING = "CC:" // for "complement-crypto"
Expand Down Expand Up @@ -191,7 +192,41 @@ func (c *JSClient) UserID() string {
}

func (c *JSClient) MustGetEvent(t *testing.T, roomID, eventID string) Event {
return Event{}
// serialised output (if encrypted):
// {
// encrypted: { event }
// decrypted: { event }
// }
// else just returns { event }
evSerialised := chrome.MustExecuteInto[string](t, c.ctx, fmt.Sprintf(`
JSON.stringify(window.__client.getRoom("%s")?.getLiveTimeline()?.getEvents().filter((ev) => {
return ev.getId() === "%s";
})[0].toJSON());
`, roomID, eventID))
if !gjson.Valid(evSerialised) {
fatalf(t, "MustGetEvent(%s, %s): invalid event, got %s", roomID, eventID, evSerialised)
}
result := gjson.Parse(evSerialised)
decryptedEvent := result.Get("decrypted")
if !decryptedEvent.Exists() {
decryptedEvent = result
}
encryptedEvent := result.Get("encrypted")
//fmt.Printf("DECRYPTED: %s\nENCRYPTED: %s\n\n", decryptedEvent.Raw, encryptedEvent.Raw)
ev := Event{
ID: decryptedEvent.Get("event_id").Str,
Text: decryptedEvent.Get("content.body").Str,
Sender: decryptedEvent.Get("sender").Str,
}
if decryptedEvent.Get("type").Str == "m.room.member" {
ev.Membership = decryptedEvent.Get("content.membership").Str
ev.Target = decryptedEvent.Get("state_key").Str
}
if encryptedEvent.Exists() && decryptedEvent.Get("content.msgtype").Str == "m.bad.encrypted" {
ev.FailedToDecrypt = true
}

return ev
}

// StartSyncing to begin syncing from sync v2 / sliding sync.
Expand Down
95 changes: 80 additions & 15 deletions internal/api/rust.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package api

import (
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -32,9 +33,11 @@ type RustRoomInfo struct {

type RustClient struct {
FFIClient *matrix_sdk_ffi.Client
rooms map[string]*RustRoomInfo
listeners map[int32]func(roomID string)
listenerID atomic.Int32
allRooms *matrix_sdk_ffi.RoomList
rooms map[string]*RustRoomInfo
roomsMu *sync.RWMutex
userID string
}

Expand All @@ -58,6 +61,7 @@ func NewRustClient(t *testing.T, opts ClientCreationOpts, ssURL string) (Client,
FFIClient: client,
rooms: make(map[string]*RustRoomInfo),
listeners: make(map[int32]func(roomID string)),
roomsMu: &sync.RWMutex{},
}
c.Logf(t, "NewRustClient[%s] created client", opts.UserID)
return &LoggedClient{Client: c}, nil
Expand All @@ -69,7 +73,16 @@ func (c *RustClient) Close(t *testing.T) {
}

func (c *RustClient) MustGetEvent(t *testing.T, roomID, eventID string) Event {
return Event{}
room := c.findRoom(t, roomID)
timelineItem, err := room.GetEventTimelineItemByEventId(eventID)
if err != nil {
fatalf(t, "MustGetEvent(%s, %s): %s", roomID, eventID, err)
}
ev := eventTimelineItemToEvent(timelineItem)
if ev == nil {
fatalf(t, "MustGetEvent(%s, %s): found timeline item but failed to convert it to an Event", roomID, eventID)
}
return *ev
}

// StartSyncing to begin syncing from sync v2 / sliding sync.
Expand All @@ -86,6 +99,7 @@ func (c *RustClient) StartSyncing(t *testing.T) (stopSyncing func()) {
})
must.NotError(t, "failed to call RoomList.LoadingState", err)
go syncService.Start()
c.allRooms = roomList

isSyncing := false

Expand Down Expand Up @@ -116,7 +130,7 @@ func (c *RustClient) StartSyncing(t *testing.T) (stopSyncing func()) {
// provide a bogus room ID.
func (c *RustClient) IsRoomEncrypted(t *testing.T, roomID string) (bool, error) {
t.Helper()
r := c.findRoom(roomID)
r := c.findRoom(t, roomID)
if r == nil {
rooms := c.FFIClient.Rooms()
return false, fmt.Errorf("failed to find room %s, got %d rooms", roomID, len(rooms))
Expand Down Expand Up @@ -174,7 +188,7 @@ func (c *RustClient) SendMessage(t *testing.T, roomID, text string) (eventID str

func (c *RustClient) MustBackpaginate(t *testing.T, roomID string, count int) {
t.Helper()
r := c.findRoom(roomID)
r := c.findRoom(t, roomID)
must.NotEqual(t, r, nil, "unknown room")
must.NotError(t, "failed to backpaginate", r.PaginateBackwards(matrix_sdk_ffi.PaginationOptionsSingleRequest{
EventLimit: uint16(count),
Expand All @@ -185,16 +199,50 @@ func (c *RustClient) UserID() string {
return c.userID
}

func (c *RustClient) findRoom(roomID string) *matrix_sdk_ffi.Room {
func (c *RustClient) findRoomInMap(roomID string) *matrix_sdk_ffi.Room {
c.roomsMu.RLock()
defer c.roomsMu.RUnlock()
// do we have a reference to it already?
roomInfo := c.rooms[roomID]
if roomInfo != nil {
return roomInfo.room
}
return nil
}

// findRoom returns the room, waiting up to 5s for it to appear
func (c *RustClient) findRoom(t *testing.T, roomID string) *matrix_sdk_ffi.Room {
room := c.findRoomInMap(roomID)
if room != nil {
return room
}
// try to find it in all_rooms
if c.allRooms != nil {
roomListItem, err := c.allRooms.Room(roomID)
if err != nil {
c.Logf(t, "allRooms.Room(%s) err: %s", roomID, err)
} else if roomListItem != nil {
room := roomListItem.FullRoom()
c.roomsMu.Lock()
c.rooms[roomID] = &RustRoomInfo{
room: room,
}
c.roomsMu.Unlock()
return room
}
}
// try to find it from cache?
rooms := c.FFIClient.Rooms()
for i, r := range rooms {
rid := r.Id()
// ensure we only store rooms once
_, exists := c.rooms[rid]
if !exists {
c.roomsMu.Lock()
c.rooms[rid] = &RustRoomInfo{
room: rooms[i],
}
c.roomsMu.Unlock()
}
if r.Id() == roomID {
return c.rooms[rid].room
Expand All @@ -210,18 +258,19 @@ func (c *RustClient) Logf(t *testing.T, format string, args ...interface{}) {
}

func (c *RustClient) ensureListening(t *testing.T, roomID string) *matrix_sdk_ffi.Room {
r := c.findRoom(roomID)
r := c.findRoom(t, roomID)
must.NotEqual(t, r, nil, fmt.Sprintf("room %s does not exist", roomID))

info := c.rooms[roomID]
if info.attachedListener {
return r
}

t.Logf("[%s]AddTimelineListener[%s]", c.userID, roomID)
c.Logf(t, "[%s]AddTimelineListener[%s]", c.userID, roomID)
// we need a timeline listener before we can send messages
r.AddTimelineListener(&timelineListener{fn: func(diff []*matrix_sdk_ffi.TimelineDiff) {
result := r.AddTimelineListener(&timelineListener{fn: func(diff []*matrix_sdk_ffi.TimelineDiff) {
timeline := c.rooms[roomID].timeline
c.Logf(t, "[%s]AddTimelineListener[%s] TimelineDiff len=%d", c.userID, roomID, len(diff))
for _, d := range diff {
switch d.Change() {
case matrix_sdk_ffi.TimelineChangeInsert:
Expand Down Expand Up @@ -275,6 +324,17 @@ func (c *RustClient) ensureListening(t *testing.T, roomID string) *matrix_sdk_ff
l(roomID)
}
}})
events := make([]*Event, len(result.Items))
for i := range result.Items {
events[i] = timelineItemToEvent(result.Items[i])
}
c.rooms[roomID].timeline = events
c.Logf(t, "[%s]AddTimelineListener[%s] result.Items len=%d", c.userID, roomID, len(result.Items))
if len(events) > 0 {
for _, l := range c.listeners {
l(roomID)
}
}
info.attachedListener = true
return r
}
Expand Down Expand Up @@ -361,19 +421,22 @@ func timelineItemToEvent(item *matrix_sdk_ffi.TimelineItem) *Event {
if ev == nil { // e.g day divider
return nil
}
evv := *ev
if evv == nil {
return eventTimelineItemToEvent(*ev)
}

func eventTimelineItemToEvent(item *matrix_sdk_ffi.EventTimelineItem) *Event {
if item == nil {
return nil
}
eventID := ""
if evv.EventId() != nil {
eventID = *evv.EventId()
if item.EventId() != nil {
eventID = *item.EventId()
}
complementEvent := Event{
ID: eventID,
Sender: evv.Sender(),
Sender: item.Sender(),
}
switch k := evv.Content().Kind().(type) {
switch k := item.Content().Kind().(type) {
case matrix_sdk_ffi.TimelineItemContentKindRoomMembership:
complementEvent.Target = k.UserId
change := *k.Change
Expand Down Expand Up @@ -401,9 +464,11 @@ func timelineItemToEvent(item *matrix_sdk_ffi.TimelineItem) *Event {
default:
fmt.Printf("%s unhandled membership %d\n", k.UserId, change)
}
case matrix_sdk_ffi.TimelineItemContentKindUnableToDecrypt:
complementEvent.FailedToDecrypt = true
}

content := evv.Content()
content := item.Content()
if content != nil {
msg := content.AsMessage()
if msg != nil {
Expand Down
3 changes: 2 additions & 1 deletion tests/membership_acls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ func TestBobCanSeeButNotDecryptHistoryInPublicRoom(t *testing.T) {

// bob hits scrollback and should see but not be able to decrypt the message
bob.MustBackpaginate(t, roomID, 5)
ev := bob.MustGetEvent(t, roomID, evID) // TODO
ev := bob.MustGetEvent(t, roomID, evID)
must.NotEqual(t, ev.Text, beforeJoinBody, "bob was able to decrypt a message from before he was joined")
must.Equal(t, ev.FailedToDecrypt, true, "message not marked as failed to decrypt")
})
}

0 comments on commit b8832d6

Please sign in to comment.