Skip to content

Commit

Permalink
Merge pull request #401 from matrix-org/kegan/msc4102
Browse files Browse the repository at this point in the history
Implement MSC4102
  • Loading branch information
kegsay authored Feb 16, 2024
2 parents bbb886e + 782703c commit aa3ea8f
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 16 deletions.
40 changes: 24 additions & 16 deletions state/receipt_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ import (
)

type receiptEDU struct {
Type string `json:"type"`
Content map[string]struct {
Read map[string]receiptInfo `json:"m.read,omitempty"`
ReadPrivate map[string]receiptInfo `json:"m.read.private,omitempty"`
} `json:"content"`
Type string `json:"type"`
Content map[string]receiptContent `json:"content"`
}

type receiptContent struct {
Read map[string]receiptInfo `json:"m.read,omitempty"`
ReadPrivate map[string]receiptInfo `json:"m.read.private,omitempty"`
}

type receiptInfo struct {
Expand Down Expand Up @@ -164,29 +166,35 @@ func (t *ReceiptTable) bulkInsert(tableName string, txn *sqlx.Tx, receipts []int
// client connections.
func PackReceiptsIntoEDU(receipts []internal.Receipt) (json.RawMessage, error) {
newReceiptEDU := receiptEDU{
Type: "m.receipt",
Content: make(map[string]struct {
Read map[string]receiptInfo `json:"m.read,omitempty"`
ReadPrivate map[string]receiptInfo `json:"m.read.private,omitempty"`
}),
Type: "m.receipt",
Content: make(map[string]receiptContent),
}
for _, r := range receipts {
thisReceiptIsUnthreaded := r.ThreadID == ""
receiptsForEvent := newReceiptEDU.Content[r.EventID]
if r.IsPrivate {
if receiptsForEvent.ReadPrivate == nil {
receiptsForEvent.ReadPrivate = make(map[string]receiptInfo)
}
receiptsForEvent.ReadPrivate[r.UserID] = receiptInfo{
TS: r.TS,
ThreadID: r.ThreadID,
// MSC4102: always replace threaded receipts with unthreaded ones if there is a clash
_, receiptAlreadyExists := receiptsForEvent.ReadPrivate[r.UserID]
if !receiptAlreadyExists || (receiptAlreadyExists && thisReceiptIsUnthreaded) {
receiptsForEvent.ReadPrivate[r.UserID] = receiptInfo{
TS: r.TS,
ThreadID: r.ThreadID,
}
}
} else {
if receiptsForEvent.Read == nil {
receiptsForEvent.Read = make(map[string]receiptInfo)
}
receiptsForEvent.Read[r.UserID] = receiptInfo{
TS: r.TS,
ThreadID: r.ThreadID,
// MSC4102: always replace threaded receipts with unthreaded ones if there is a clash
_, receiptAlreadyExists := receiptsForEvent.Read[r.UserID]
if !receiptAlreadyExists || (receiptAlreadyExists && thisReceiptIsUnthreaded) {
receiptsForEvent.Read[r.UserID] = receiptInfo{
TS: r.TS,
ThreadID: r.ThreadID,
}
}
}
newReceiptEDU.Content[r.EventID] = receiptsForEvent
Expand Down
176 changes: 176 additions & 0 deletions state/receipt_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,182 @@ func parsedReceiptsEqual(t *testing.T, got, want []internal.Receipt) {
}
}

func TestReceiptPacking(t *testing.T) {
testCases := []struct {
receipts []internal.Receipt
wantEDU receiptEDU
name string
}{
{
name: "single receipt",
receipts: []internal.Receipt{
{
RoomID: "!foo",
EventID: "$bar",
UserID: "@baz",
TS: 42,
},
},
wantEDU: receiptEDU{
Type: "m.receipt",
Content: map[string]receiptContent{
"$bar": {
Read: map[string]receiptInfo{
"@baz": {
TS: 42,
},
},
},
},
},
},
{
name: "two distinct receipt",
receipts: []internal.Receipt{
{
RoomID: "!foo",
EventID: "$bar",
UserID: "@baz",
TS: 42,
},
{
RoomID: "!foo2",
EventID: "$bar2",
UserID: "@baz2",
TS: 422,
},
},
wantEDU: receiptEDU{
Type: "m.receipt",
Content: map[string]receiptContent{
"$bar": {
Read: map[string]receiptInfo{
"@baz": {
TS: 42,
},
},
},
"$bar2": {
Read: map[string]receiptInfo{
"@baz2": {
TS: 422,
},
},
},
},
},
},
{
name: "MSC4102: unthreaded wins when threaded first",
receipts: []internal.Receipt{
{
RoomID: "!foo",
EventID: "$bar",
UserID: "@baz",
TS: 42,
ThreadID: "thread_id",
},
{
RoomID: "!foo",
EventID: "$bar",
UserID: "@baz",
TS: 420,
},
},
wantEDU: receiptEDU{
Type: "m.receipt",
Content: map[string]receiptContent{
"$bar": {
Read: map[string]receiptInfo{
"@baz": {
TS: 420,
},
},
},
},
},
},
{
name: "MSC4102: unthreaded wins when unthreaded first",
receipts: []internal.Receipt{
{
RoomID: "!foo",
EventID: "$bar",
UserID: "@baz",
TS: 420,
},
{
RoomID: "!foo",
EventID: "$bar",
UserID: "@baz",
TS: 42,
ThreadID: "thread_id",
},
},
wantEDU: receiptEDU{
Type: "m.receipt",
Content: map[string]receiptContent{
"$bar": {
Read: map[string]receiptInfo{
"@baz": {
TS: 420,
},
},
},
},
},
},
{
name: "MSC4102: unthreaded wins in private receipts when unthreaded first",
receipts: []internal.Receipt{
{
RoomID: "!foo",
EventID: "$bar",
UserID: "@baz",
TS: 420,
IsPrivate: true,
},
{
RoomID: "!foo",
EventID: "$bar",
UserID: "@baz",
TS: 42,
ThreadID: "thread_id",
IsPrivate: true,
},
},
wantEDU: receiptEDU{
Type: "m.receipt",
Content: map[string]receiptContent{
"$bar": {
ReadPrivate: map[string]receiptInfo{
"@baz": {
TS: 420,
},
},
},
},
},
},
}
for _, tc := range testCases {
edu, err := PackReceiptsIntoEDU(tc.receipts)
if err != nil {
t.Fatalf("%s: PackReceiptsIntoEDU: %s", tc.name, err)
}
gotEDU := receiptEDU{
Type: "m.receipt",
Content: make(map[string]receiptContent),
}
if err := json.Unmarshal(edu, &gotEDU); err != nil {
t.Fatalf("%s: json.Unmarshal: %s", tc.name, err)
}
if !reflect.DeepEqual(gotEDU, tc.wantEDU) {
t.Errorf("%s: EDU mismatch, got %+v\n want %+v", tc.name, gotEDU, tc.wantEDU)
}
}
}

func TestReceiptTable(t *testing.T) {
db, close := connectToDB(t)
defer close()
Expand Down

0 comments on commit aa3ea8f

Please sign in to comment.