diff --git a/.github/workflows/MQTT_test.yaml b/.github/workflows/MQTT_test.yaml deleted file mode 100644 index 2bfd0bed2bb..00000000000 --- a/.github/workflows/MQTT_test.yaml +++ /dev/null @@ -1,67 +0,0 @@ -name: MQTTEx -on: [push, pull_request] - -permissions: - pull-requests: write # to comment on PRs - contents: write # to comment on commits (to upload artifacts?) - -jobs: - test: - env: - GOPATH: /home/runner/work/nats-server - GO111MODULE: "on" - runs-on: ubuntu-latest - steps: - - name: Checkout code - uses: actions/checkout@v4 - with: - path: src/github.com/nats-io/nats-server - - - name: Setup Go - uses: actions/setup-go@v5 - with: - go-version-file: src/github.com/nats-io/nats-server/go.mod - cache-dependency-path: src/github.com/nats-io/nats-server/go.sum - - - name: Set up testing tools and environment - shell: bash --noprofile --norc -eo pipefail {0} - id: setup - run: | - wget https://github.com/hivemq/mqtt-cli/releases/download/v4.20.0/mqtt-cli-4.20.0.deb - sudo apt install ./mqtt-cli-4.20.0.deb - go install github.com/ConnectEverything/mqtt-test@4dd571c31318dcfebe5443242f53f262403ceafb - - # - name: Download benchmark result for ${{ github.base_ref || github.ref_name }} - # uses: dawidd6/action-download-artifact@v2 - # continue-on-error: true - # with: - # path: src/github.com/nats-io/nats-server/bench - # name: bench-output-${{ runner.os }} - # branch: ${{ github.base_ref || github.ref }} - - - name: Run tests and benchmarks - shell: bash --noprofile --norc -eo pipefail {0} - run: | - cd src/github.com/nats-io/nats-server - go test -v --run='MQTTEx' ./server - # go test --run='-' --count=10 --bench 'MQTT_' ./server | tee output.txt - # go test --run='-' --count=10 --bench 'MQTTEx' --timeout=20m --benchtime=100x ./server | tee -a output.txt - go test --run='-' --count=3 --bench 'MQTTEx' --benchtime=100x ./server - - # - name: Compare benchmarks - # uses: benchmark-action/github-action-benchmark@v1 - # with: - # tool: go - # output-file-path: src/github.com/nats-io/nats-server/output.txt - # github-token: ${{ secrets.GITHUB_TOKEN }} - # alert-threshold: 140% - # comment-on-alert: true - # # fail-on-alert: true - # external-data-json-path: src/github.com/nats-io/nats-server/bench/benchmark-data.json - - # - name: Store benchmark result for ${{ github.ref_name }} - # if: always() - # uses: actions/upload-artifact@v3 - # with: - # path: src/github.com/nats-io/nats-server/bench - # name: bench-output-${{ runner.os }} diff --git a/.github/workflows/mqtt-test.yaml b/.github/workflows/mqtt-test.yaml new file mode 100644 index 00000000000..893837e6b60 --- /dev/null +++ b/.github/workflows/mqtt-test.yaml @@ -0,0 +1,48 @@ +name: MQTT external test +on: [pull_request] + +jobs: + test: + env: + GOPATH: /home/runner/work/nats-server + GO111MODULE: "on" + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + path: src/github.com/nats-io/nats-server + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: src/github.com/nats-io/nats-server/go.mod + cache-dependency-path: src/github.com/nats-io/nats-server/go.sum + + - name: Set up testing tools and environment + shell: bash --noprofile --norc -eo pipefail {0} + id: setup + run: | + wget https://github.com/hivemq/mqtt-cli/releases/download/v4.20.0/mqtt-cli-4.20.0.deb + sudo apt install ./mqtt-cli-4.20.0.deb + go install github.com/ConnectEverything/mqtt-test@v0.1.0 + + - name: Run tests (3 times to detect flappers) + shell: bash --noprofile --norc -eo pipefail {0} + run: | + cd src/github.com/nats-io/nats-server + go test -v --count=3 --run='TestXMQTT' ./server + + - name: Run tests with --race + shell: bash --noprofile --norc -eo pipefail {0} + run: | + cd src/github.com/nats-io/nats-server + go test -v --race --failfast --run='TestXMQTT' ./server + + - name: Run benchmarks + shell: bash --noprofile --norc -eo pipefail {0} + run: | + cd src/github.com/nats-io/nats-server + go test --run='-' --count=3 --bench 'BenchmarkXMQTT' --benchtime=100x ./server + + # TODO: compare benchmarks diff --git a/server/filestore.go b/server/filestore.go index df58c981e80..689f0163e30 100644 --- a/server/filestore.go +++ b/server/filestore.go @@ -214,7 +214,7 @@ type msgBlock struct { bytes uint64 // User visible bytes count. rbytes uint64 // Total bytes (raw) including deleted. Used for rolling to new blk. msgs uint64 // User visible message count. - fss map[string]*SimpleState + fss *stree.SubjectTree[SimpleState] kfn string lwts int64 llts int64 @@ -2063,11 +2063,13 @@ func (fs *fileStore) expireMsgsOnRecover() { } // Make sure we do subject cleanup as well. mb.ensurePerSubjectInfoLoaded() - for subj, ss := range mb.fss { + mb.fss.Iter(func(bsubj []byte, ss *SimpleState) bool { + subj := bytesToString(bsubj) for i := uint64(0); i < ss.Msgs; i++ { fs.removePerSubject(subj) } - } + return true + }) mb.dirtyCloseWithRemove(true) deleted++ } @@ -2314,9 +2316,13 @@ func (mb *msgBlock) firstMatching(filter string, wc bool, start uint64, sm *Stor // Mark fss activity. mb.lsts = time.Now().UnixNano() + if filter == _EMPTY_ { + filter = fwcs + } + // If we only have 1 subject currently and it matches our filter we can also set isAll. - if !isAll && len(mb.fss) == 1 { - _, isAll = mb.fss[filter] + if !isAll && mb.fss.Size() == 1 { + _, isAll = mb.fss.Find(stringToBytes(filter)) } // Make sure to start at mb.first.seq if fseq < mb.first.seq if seq := atomic.LoadUint64(&mb.first.seq); seq > fseq { @@ -2325,16 +2331,15 @@ func (mb *msgBlock) firstMatching(filter string, wc bool, start uint64, sm *Stor lseq := atomic.LoadUint64(&mb.last.seq) // Optionally build the isMatch for wildcard filters. - tsa := [32]string{} - fsa := [32]string{} - var fts []string + _tsa, _fsa := [32]string{}, [32]string{} + tsa, fsa := _tsa[:0], _fsa[:0] var isMatch func(subj string) bool // Decide to build. if wc { - fts = tokenizeSubjectIntoSlice(fsa[:0], filter) + fsa = tokenizeSubjectIntoSlice(fsa[:0], filter) isMatch = func(subj string) bool { - tts := tokenizeSubjectIntoSlice(tsa[:0], subj) - return isSubsetMatchTokenized(tts, fts) + tsa = tokenizeSubjectIntoSlice(tsa[:0], subj) + return isSubsetMatchTokenized(tsa, fsa) } } @@ -2346,18 +2351,16 @@ func (mb *msgBlock) firstMatching(filter string, wc bool, start uint64, sm *Stor // 25th quantile of a match in a linear walk. Filter should be a wildcard. // We should consult fss if our cache is not loaded and we only have fss loaded. if !doLinearScan && wc && mb.cacheAlreadyLoaded() { - doLinearScan = len(mb.fss)*4 > int(lseq-fseq) + doLinearScan = mb.fss.Size()*4 > int(lseq-fseq) } if !doLinearScan { // If we have a wildcard match against all tracked subjects we know about. if wc { subs = subs[:0] - for subj := range mb.fss { - if isMatch(subj) { - subs = append(subs, subj) - } - } + mb.fss.Match(stringToBytes(filter), func(bsubj []byte, _ *SimpleState) { + subs = append(subs, string(bsubj)) + }) // Check if we matched anything if len(subs) == 0 { return nil, didLoad, ErrStoreMsgNotFound @@ -2365,7 +2368,7 @@ func (mb *msgBlock) firstMatching(filter string, wc bool, start uint64, sm *Stor } fseq = lseq + 1 for _, subj := range subs { - ss := mb.fss[subj] + ss, _ := mb.fss.Find(stringToBytes(subj)) if ss != nil && ss.firstNeedsUpdate { mb.recalculateFirstForSubj(subj, ss.First, ss) } @@ -2456,6 +2459,10 @@ func (mb *msgBlock) filteredPendingLocked(filter string, wc bool, sseq uint64) ( } } + if filter == _EMPTY_ { + filter = fwcs + } + update := func(ss *SimpleState) { total += ss.Msgs if first == 0 || ss.First < first { @@ -2469,9 +2476,9 @@ func (mb *msgBlock) filteredPendingLocked(filter string, wc bool, sseq uint64) ( // Make sure we have fss loaded. mb.ensurePerSubjectInfoLoaded() - tsa := [32]string{} - fsa := [32]string{} - fts := tokenizeSubjectIntoSlice(fsa[:0], filter) + _tsa, _fsa := [32]string{}, [32]string{} + tsa, fsa := _tsa[:0], _fsa[:0] + fsa = tokenizeSubjectIntoSlice(fsa[:0], filter) // 1. See if we match any subs from fss. // 2. If we match and the sseq is past ss.Last then we can use meta only. @@ -2481,25 +2488,22 @@ func (mb *msgBlock) filteredPendingLocked(filter string, wc bool, sseq uint64) ( if !wc { return subj == filter } - tts := tokenizeSubjectIntoSlice(tsa[:0], subj) - return isSubsetMatchTokenized(tts, fts) + tsa = tokenizeSubjectIntoSlice(tsa[:0], subj) + return isSubsetMatchTokenized(tsa, fsa) } var havePartial bool - for subj, ss := range mb.fss { - if isAll || isMatch(subj) { - if ss.firstNeedsUpdate { - mb.recalculateFirstForSubj(subj, ss.First, ss) - } - if sseq <= ss.First { - update(ss) - } else if sseq <= ss.Last { - // We matched but its a partial. - havePartial = true - break - } + mb.fss.Match(stringToBytes(filter), func(bsubj []byte, ss *SimpleState) { + if ss.firstNeedsUpdate { + mb.recalculateFirstForSubj(bytesToString(bsubj), ss.First, ss) } - } + if sseq <= ss.First { + update(ss) + } else if sseq <= ss.Last { + // We matched but its a partial. + havePartial = true + } + }) // If we did not encounter any partials we can return here. if !havePartial { @@ -2590,9 +2594,48 @@ func (fs *fileStore) FilteredState(sseq uint64, subj string) SimpleState { return ss } +// This is used to see if we can selectively jump start blocks based on filter subject and a floor block index. +// Will return -1 if no matches at all. +func (fs *fileStore) checkSkipFirstBlock(filter string, wc bool) int { + start := uint32(math.MaxUint32) + if wc { + fs.psim.Match(stringToBytes(filter), func(_ []byte, psi *psi) { + if psi.fblk < start { + start = psi.fblk + } + }) + } else if psi, ok := fs.psim.Find(stringToBytes(filter)); ok { + start = psi.fblk + } + // Nothing found. + if start == uint32(math.MaxUint32) { + return -1 + } + // Here we need to translate this to index into fs.blks. + mb := fs.bim[start] + if mb == nil { + return -1 + } + bi, _ := fs.selectMsgBlockWithIndex(atomic.LoadUint64(&mb.last.seq)) + return bi +} + // Optimized way for getting all num pending matching a filter subject. // Lock should be held. func (fs *fileStore) numFilteredPending(filter string, ss *SimpleState) { + fs.numFilteredPendingWithLast(filter, true, ss) +} + +// Optimized way for getting all num pending matching a filter subject and first sequence only. +// Lock should be held. +func (fs *fileStore) numFilteredPendingNoLast(filter string, ss *SimpleState) { + fs.numFilteredPendingWithLast(filter, false, ss) +} + +// Optimized way for getting all num pending matching a filter subject. +// Optionally look up last sequence. Sometimes do not need last and this avoids cost. +// Lock should be held. +func (fs *fileStore) numFilteredPendingWithLast(filter string, last bool, ss *SimpleState) { isAll := filter == _EMPTY_ || filter == fwcs // If isAll we do not need to do anything special to calculate the first and last and total. @@ -2602,6 +2645,12 @@ func (fs *fileStore) numFilteredPending(filter string, ss *SimpleState) { ss.Msgs = fs.state.Msgs return } + // Always reset. + ss.First, ss.Last, ss.Msgs = 0, 0, 0 + + if filter == _EMPTY_ { + filter = fwcs + } // We do need to figure out the first and last sequences. wc := subjectHasWildcard(filter) @@ -2625,7 +2674,6 @@ func (fs *fileStore) numFilteredPending(filter string, ss *SimpleState) { // Did not find anything. if stop == 0 { - ss.First, ss.Last, ss.Msgs = 0, 0, 0 return } @@ -2636,10 +2684,12 @@ func (fs *fileStore) numFilteredPending(filter string, ss *SimpleState) { ss.First = f } - // Hold this outside loop for psim fblk updates on misses. - i := start + 1 if ss.First == 0 { - // This is a miss. This can happen since psi.fblk is lazy, but should be very rare. + // This is a miss. This can happen since psi.fblk is lazy. + // We will make sure to update fblk. + + // Hold this outside loop for psim fblk updates when done. + i := start + 1 for ; i <= stop; i++ { mb := fs.bim[i] if mb == nil { @@ -2650,25 +2700,25 @@ func (fs *fileStore) numFilteredPending(filter string, ss *SimpleState) { break } } - } - // Update fblk if we missed matching some blocks, meaning fblk was outdated. - if i > start+1 { + // Update fblk since fblk was outdated. if !wc { if info, ok := fs.psim.Find(stringToBytes(filter)); ok { info.fblk = i } } else { - fs.psim.Match(stringToBytes(filter), func(_ []byte, psi *psi) { + fs.psim.Match(stringToBytes(filter), func(subj []byte, psi *psi) { if i > psi.fblk { psi.fblk = i } }) } } - // Now last - if mb = fs.bim[stop]; mb != nil { - _, _, l := mb.filteredPending(filter, wc, 0) - ss.Last = l + // Now gather last sequence if asked to do so. + if last { + if mb = fs.bim[stop]; mb != nil { + _, _, l := mb.filteredPending(filter, wc, 0) + ss.Last = l + } } } @@ -2681,6 +2731,10 @@ func (fs *fileStore) SubjectsState(subject string) map[string]SimpleState { return nil } + if subject == _EMPTY_ { + subject = fwcs + } + start, stop := fs.blks[0], fs.lmb // We can short circuit if not a wildcard using psim for start and stop. if !subjectHasWildcard(subject) { @@ -2712,21 +2766,20 @@ func (fs *fileStore) SubjectsState(subject string) map[string]SimpleState { } // Mark fss activity. mb.lsts = time.Now().UnixNano() - for subj, ss := range mb.fss { - if subject == _EMPTY_ || subject == fwcs || subjectIsSubsetMatch(subj, subject) { - if ss.firstNeedsUpdate { - mb.recalculateFirstForSubj(subj, ss.First, ss) - } - oss := fss[subj] - if oss.First == 0 { // New - fss[subj] = *ss - } else { - // Merge here. - oss.Last, oss.Msgs = ss.Last, oss.Msgs+ss.Msgs - fss[subj] = oss - } + mb.fss.Match(stringToBytes(subject), func(bsubj []byte, ss *SimpleState) { + subj := string(bsubj) + if ss.firstNeedsUpdate { + mb.recalculateFirstForSubj(subj, ss.First, ss) } - } + oss := fss[subj] + if oss.First == 0 { // New + fss[subj] = *ss + } else { + // Merge here. + oss.Last, oss.Msgs = ss.Last, oss.Msgs+ss.Msgs + fss[subj] = oss + } + }) if shouldExpire { // Expire this cache before moving on. mb.tryForceExpireCacheLocked() @@ -2784,8 +2837,9 @@ func (fs *fileStore) NumPending(sseq uint64, filter string, lastPerSubject bool) return fs.state.LastSeq - sseq + 1, validThrough } - var tsa, fsa [32]string - fts := tokenizeSubjectIntoSlice(fsa[:0], filter) + _tsa, _fsa := [32]string{}, [32]string{} + tsa, fsa := _tsa[:0], _fsa[:0] + fsa = tokenizeSubjectIntoSlice(fsa[:0], filter) isMatch := func(subj string) bool { if isAll { @@ -2794,8 +2848,8 @@ func (fs *fileStore) NumPending(sseq uint64, filter string, lastPerSubject bool) if !wc { return subj == filter } - tts := tokenizeSubjectIntoSlice(tsa[:0], subj) - return isSubsetMatchTokenized(tts, fts) + tsa = tokenizeSubjectIntoSlice(tsa[:0], subj) + return isSubsetMatchTokenized(tsa, fsa) } // Handle last by subject a bit differently. @@ -2895,20 +2949,18 @@ func (fs *fileStore) NumPending(sseq uint64, filter string, lastPerSubject bool) mb.lsts = time.Now().UnixNano() var havePartial bool - for subj, ss := range mb.fss { - if isMatch(subj) { - if ss.firstNeedsUpdate { - mb.recalculateFirstForSubj(subj, ss.First, ss) - } - if sseq <= ss.First { - t += ss.Msgs - } else if sseq <= ss.Last { - // We matched but its a partial. - havePartial = true - break - } + mb.fss.Match(stringToBytes(filter), func(bsubj []byte, ss *SimpleState) { + subj := bytesToString(bsubj) + if ss.firstNeedsUpdate { + mb.recalculateFirstForSubj(subj, ss.First, ss) } - } + if sseq <= ss.First { + t += ss.Msgs + } else if sseq <= ss.Last { + // We matched but its a partial. + havePartial = true + } + }) // See if we need to scan msgs here. if havePartial { @@ -2986,11 +3038,9 @@ func (fs *fileStore) NumPending(sseq uint64, filter string, lastPerSubject bool) // Mark fss activity. mb.lsts = time.Now().UnixNano() - for subj, ss := range mb.fss { - if isMatch(subj) { - adjust += ss.Msgs - } - } + mb.fss.Match(stringToBytes(filter), func(bsubj []byte, ss *SimpleState) { + adjust += ss.Msgs + }) } } else { // This is the last block. We need to scan per message here. @@ -3111,7 +3161,7 @@ func (fs *fileStore) newMsgBlockForWrite() (*msgBlock, error) { // Lock should be held to quiet race detector. mb.mu.Lock() mb.setupWriteCache(rbuf) - mb.fss = make(map[string]*SimpleState) + mb.fss = stree.NewSubjectTree[SimpleState]() // Set cache time to creation time to start. ts := time.Now().UnixNano() @@ -3563,10 +3613,11 @@ func (fs *fileStore) firstSeqForSubj(subj string) (uint64, error) { // Mark fss activity. mb.lsts = time.Now().UnixNano() - if ss := mb.fss[subj]; ss != nil { + bsubj := stringToBytes(subj) + if ss, ok := mb.fss.Find(bsubj); ok && ss != nil { // Adjust first if it was not where we thought it should be. if i != start { - if info, ok := fs.psim.Find(stringToBytes(subj)); ok { + if info, ok := fs.psim.Find(bsubj); ok { info.fblk = i } } @@ -3699,8 +3750,8 @@ func (fs *fileStore) enforceMsgPerSubjectLimit(fireCallback bool) { // Grab the ss entry for this subject in case sparse. mb.mu.Lock() mb.ensurePerSubjectInfoLoaded() - ss := mb.fss[subj] - if ss != nil && ss.firstNeedsUpdate { + ss, ok := mb.fss.Find(stringToBytes(subj)) + if ok && ss != nil && ss.firstNeedsUpdate { mb.recalculateFirstForSubj(subj, ss.First, ss) } mb.mu.Unlock() @@ -4795,11 +4846,11 @@ func (mb *msgBlock) writeMsgRecord(rl, seq uint64, subj string, mhdr, msg []byte } // Mark fss activity. mb.lsts = time.Now().UnixNano() - if ss := mb.fss[subj]; ss != nil { + if ss, ok := mb.fss.Find(stringToBytes(subj)); ok && ss != nil { ss.Msgs++ ss.Last = seq } else { - mb.fss[subj] = &SimpleState{Msgs: 1, First: seq, Last: seq} + mb.fss.Insert(stringToBytes(subj), SimpleState{Msgs: 1, First: seq, Last: seq}) } } @@ -5400,7 +5451,7 @@ func (mb *msgBlock) indexCacheBuf(buf []byte) error { // Create FSS if we should track. var popFss bool if mb.fssNotLoaded() { - mb.fss = make(map[string]*SimpleState) + mb.fss = stree.NewSubjectTree[SimpleState]() popFss = true } // Mark fss activity. @@ -5467,15 +5518,15 @@ func (mb *msgBlock) indexCacheBuf(buf []byte) error { // Handle FSS inline here. if popFss && slen > 0 && !mb.noTrack && !erased && !mb.dmap.Exists(seq) { bsubj := buf[index+msgHdrSize : index+msgHdrSize+uint32(slen)] - if ss := mb.fss[string(bsubj)]; ss != nil { + if ss, ok := mb.fss.Find(bsubj); ok && ss != nil { ss.Msgs++ ss.Last = seq } else { - mb.fss[string(bsubj)] = &SimpleState{ + mb.fss.Insert(bsubj, SimpleState{ Msgs: 1, First: seq, Last: seq, - } + }) } } } @@ -6166,6 +6217,8 @@ func (fs *fileStore) loadLast(subj string, sm *StoreMsg) (lsm *StoreMsg, err err if stop == 0 { return nil, ErrStoreMsgNotFound } + // These need to be swapped. + start, stop = stop, start } else if info, ok := fs.psim.Find(stringToBytes(subj)); ok { start, stop = info.lblk, info.fblk } else { @@ -6189,7 +6242,7 @@ func (fs *fileStore) loadLast(subj string, sm *StoreMsg) (lsm *StoreMsg, err err var l uint64 // Optimize if subject is not a wildcard. if !wc { - if ss := mb.fss[subj]; ss != nil { + if ss, ok := mb.fss.Find(stringToBytes(subj)); ok && ss != nil { l = ss.Last } } @@ -6283,7 +6336,7 @@ func (fs *fileStore) LoadNextMsg(filter string, wc bool, start uint64, sm *Store // let's check the psim to see if we can skip ahead. if start <= fs.state.FirstSeq { var ss SimpleState - fs.numFilteredPending(filter, &ss) + fs.numFilteredPendingNoLast(filter, &ss) // Nothing available. if ss.Msgs == 0 { return nil, fs.state.LastSeq, ErrStoreEOF @@ -6309,16 +6362,15 @@ func (fs *fileStore) LoadNextMsg(filter string, wc bool, start uint64, sm *Store // Similar to above if start <= first seq. // TODO(dlc) - For v2 track these by filter subject since they will represent filtered consumers. if i == bi { - var ss SimpleState - fs.numFilteredPending(filter, &ss) + nbi := fs.checkSkipFirstBlock(filter, wc) // Nothing available. - if ss.Msgs == 0 { + if nbi < 0 { return nil, fs.state.LastSeq, ErrStoreEOF } // See if we can jump ahead here. // Right now we can only spin on first, so if we have interior sparseness need to favor checking per block fss if loaded. // For v2 will track all blocks that have matches for psim. - if nbi, _ := fs.selectMsgBlockWithIndex(ss.First); nbi > i { + if nbi > i { i = nbi - 1 // For the iterator condition i++ } } @@ -6905,11 +6957,13 @@ func (fs *fileStore) Compact(seq uint64) (uint64, error) { bytes += mb.bytes // Make sure we do subject cleanup as well. mb.ensurePerSubjectInfoLoaded() - for subj, ss := range mb.fss { + mb.fss.Iter(func(bsubj []byte, ss *SimpleState) bool { + subj := bytesToString(bsubj) for i := uint64(0); i < ss.Msgs; i++ { fs.removePerSubject(subj) } - } + return true + }) // Now close. mb.dirtyCloseWithRemove(true) mb.mu.Unlock() @@ -7310,13 +7364,17 @@ func (mb *msgBlock) dirtyCloseWithRemove(remove bool) { // Lock should be held. func (mb *msgBlock) removeSeqPerSubject(subj string, seq uint64) { mb.ensurePerSubjectInfoLoaded() - ss := mb.fss[subj] - if ss == nil { + if mb.fss == nil { + return + } + bsubj := stringToBytes(subj) + ss, ok := mb.fss.Find(bsubj) + if !ok || ss == nil { return } if ss.Msgs == 1 { - delete(mb.fss, subj) + mb.fss.Delete(bsubj) return } @@ -7418,7 +7476,7 @@ func (mb *msgBlock) generatePerSubjectInfo() error { } // Create new one regardless. - mb.fss = make(map[string]*SimpleState) + mb.fss = stree.NewSubjectTree[SimpleState]() var smv StoreMsg fseq, lseq := atomic.LoadUint64(&mb.first.seq), atomic.LoadUint64(&mb.last.seq) @@ -7435,16 +7493,16 @@ func (mb *msgBlock) generatePerSubjectInfo() error { return err } if sm != nil && len(sm.subj) > 0 { - if ss := mb.fss[sm.subj]; ss != nil { + if ss, ok := mb.fss.Find(stringToBytes(sm.subj)); ok && ss != nil { ss.Msgs++ ss.Last = seq } else { - mb.fss[sm.subj] = &SimpleState{Msgs: 1, First: seq, Last: seq} + mb.fss.Insert(stringToBytes(sm.subj), SimpleState{Msgs: 1, First: seq, Last: seq}) } } } - if len(mb.fss) > 0 { + if mb.fss.Size() > 0 { // Make sure we run the cache expire timer. mb.llts = time.Now().UnixNano() // Mark fss activity. @@ -7465,7 +7523,7 @@ func (mb *msgBlock) ensurePerSubjectInfoLoaded() error { return nil } if mb.msgs == 0 { - mb.fss = make(map[string]*SimpleState) + mb.fss = stree.NewSubjectTree[SimpleState]() return nil } return mb.generatePerSubjectInfo() @@ -7482,9 +7540,8 @@ func (fs *fileStore) populateGlobalPerSubjectInfo(mb *msgBlock) { } // Now populate psim. - for subj, ss := range mb.fss { - if len(subj) > 0 { - bsubj := stringToBytes(subj) + mb.fss.Iter(func(bsubj []byte, ss *SimpleState) bool { + if len(bsubj) > 0 { if info, ok := fs.psim.Find(bsubj); ok { info.total += ss.Msgs if mb.index > info.lblk { @@ -7492,10 +7549,11 @@ func (fs *fileStore) populateGlobalPerSubjectInfo(mb *msgBlock) { } } else { fs.psim.Insert(bsubj, psi{total: ss.Msgs, fblk: mb.index, lblk: mb.index}) - fs.tsl += len(subj) + fs.tsl += len(bsubj) } } - } + return true + }) } // Close the message block. diff --git a/server/filestore_test.go b/server/filestore_test.go index 34d712d6b84..84e9e979583 100644 --- a/server/filestore_test.go +++ b/server/filestore_test.go @@ -4098,10 +4098,10 @@ func TestFileStoreNoFSSBugAfterRemoveFirst(t *testing.T) { mb := fs.blks[0] fs.mu.Unlock() mb.mu.RLock() - ss := mb.fss["foo.bar.0"] + ss, ok := mb.fss.Find([]byte("foo.bar.0")) mb.mu.RUnlock() - if ss != nil { + if ok && ss != nil { t.Fatalf("Expected no state for %q, but got %+v\n", "foo.bar.0", ss) } }) @@ -6782,7 +6782,7 @@ func TestFileStoreFSSExpireNumPending(t *testing.T) { require_True(t, elapsed > time.Since(start)) // Sleep enough so that all mb.fss should expire, which is 2s above. - time.Sleep(3 * time.Second) + time.Sleep(4 * time.Second) fs.mu.RLock() for i, mb := range fs.blks { mb.mu.RLock() @@ -6790,7 +6790,7 @@ func TestFileStoreFSSExpireNumPending(t *testing.T) { mb.mu.RUnlock() if fss != nil { fs.mu.RUnlock() - t.Fatalf("Detected loaded fss for mb %d", i) + t.Fatalf("Detected loaded fss for mb %d (size %d)", i, fss.Size()) } } fs.mu.RUnlock() @@ -6938,6 +6938,25 @@ func TestFileStoreLoadLastWildcard(t *testing.T) { require_Equal(t, cloads, 1) } +func TestFileStoreLoadLastWildcardWithPresenceMultipleBlocks(t *testing.T) { + sd := t.TempDir() + fs, err := newFileStore( + FileStoreConfig{StoreDir: sd, BlockSize: 64}, + StreamConfig{Name: "zzz", Subjects: []string{"foo.*.*"}, Storage: FileStorage}) + require_NoError(t, err) + defer fs.Stop() + + // Make sure we have "foo.222.bar" in multiple blocks to show bug. + fs.StoreMsg("foo.22.bar", nil, []byte("hello")) + fs.StoreMsg("foo.22.baz", nil, []byte("ok")) + fs.StoreMsg("foo.22.baz", nil, []byte("ok")) + fs.StoreMsg("foo.22.bar", nil, []byte("hello22")) + require_True(t, fs.numMsgBlocks() > 1) + sm, err := fs.LoadLastMsg("foo.*.bar", nil) + require_NoError(t, err) + require_Equal(t, "hello22", string(sm.msg)) +} + // We want to make sure that we update psim correctly on a miss. func TestFileStoreFilteredPendingPSIMFirstBlockUpdate(t *testing.T) { sd := t.TempDir() @@ -7006,7 +7025,7 @@ func TestFileStoreWildcardFilteredPendingPSIMFirstBlockUpdate(t *testing.T) { for i := 0; i < 1000; i++ { fs.StoreMsg("foo.1.foo", nil, msg) } - // Bookend with 3 more,twoe foo.baz and two foo.bar. + // Bookend with 3 more, two foo.baz and two foo.bar. fs.StoreMsg("foo.22.baz", nil, msg) fs.StoreMsg("foo.22.baz", nil, msg) fs.StoreMsg("foo.22.bar", nil, msg) @@ -7065,6 +7084,84 @@ func TestFileStoreWildcardFilteredPendingPSIMFirstBlockUpdate(t *testing.T) { require_Equal(t, psi.lblk, 92) } +// Make sure if we only miss by one for fblk that we still update it. +func TestFileStoreFilteredPendingPSIMFirstBlockUpdateNextBlock(t *testing.T) { + sd := t.TempDir() + fs, err := newFileStore( + FileStoreConfig{StoreDir: sd, BlockSize: 128}, + StreamConfig{Name: "zzz", Subjects: []string{"foo.*.*"}, Storage: FileStorage}) + require_NoError(t, err) + defer fs.Stop() + + msg := []byte("hello") + // Create 4 blocks, each block holds 2 msgs + for i := 0; i < 4; i++ { + fs.StoreMsg("foo.22.bar", nil, msg) + fs.StoreMsg("foo.22.baz", nil, msg) + } + require_Equal(t, fs.numMsgBlocks(), 4) + + fetch := func(subj string) *psi { + t.Helper() + fs.mu.RLock() + psi, ok := fs.psim.Find([]byte(subj)) + fs.mu.RUnlock() + require_True(t, ok) + return psi + } + + psi := fetch("foo.22.bar") + require_Equal(t, psi.total, 4) + require_Equal(t, psi.fblk, 1) + require_Equal(t, psi.lblk, 4) + + // Now remove first instance of "foo.22.bar" + removed, err := fs.RemoveMsg(1) + require_NoError(t, err) + require_True(t, removed) + + // Call into numFilterePending(), we want to make sure it updates fblk. + var ss SimpleState + fs.mu.Lock() + fs.numFilteredPending("foo.22.bar", &ss) + fs.mu.Unlock() + require_Equal(t, ss.Msgs, 3) + require_Equal(t, ss.First, 3) + require_Equal(t, ss.Last, 7) + + // Now make sure that we properly updated the psim entry. + psi = fetch("foo.22.bar") + require_Equal(t, psi.total, 3) + require_Equal(t, psi.fblk, 2) + require_Equal(t, psi.lblk, 4) + + // Now make sure wildcard calls into also update blks. + // First remove first "foo.22.baz" which will remove first block. + removed, err = fs.RemoveMsg(2) + require_NoError(t, err) + require_True(t, removed) + // Make sure 3 blks left + require_Equal(t, fs.numMsgBlocks(), 3) + + psi = fetch("foo.22.baz") + require_Equal(t, psi.total, 3) + require_Equal(t, psi.fblk, 1) + require_Equal(t, psi.lblk, 4) + + // Now call wildcard version of numFilteredPending to make sure it clears. + fs.mu.Lock() + fs.numFilteredPending("foo.*.baz", &ss) + fs.mu.Unlock() + require_Equal(t, ss.Msgs, 3) + require_Equal(t, ss.First, 4) + require_Equal(t, ss.Last, 8) + + psi = fetch("foo.22.baz") + require_Equal(t, psi.total, 3) + require_Equal(t, psi.fblk, 2) + require_Equal(t, psi.lblk, 4) +} + /////////////////////////////////////////////////////////////////////////// // Benchmarks /////////////////////////////////////////////////////////////////////////// diff --git a/server/jetstream_test.go b/server/jetstream_test.go index d02ff5fdf7c..858d11f0120 100644 --- a/server/jetstream_test.go +++ b/server/jetstream_test.go @@ -22728,3 +22728,33 @@ func TestJetStreamAuditStreams(t *testing.T) { }) require_NoError(t, err) } + +// https://github.com/nats-io/nats-server/issues/5570 +func TestJetStreamBadSubjectMappingStream(t *testing.T) { + s := RunBasicJetStreamServer(t) + defer s.Shutdown() + + // Client for API requests. + nc, js := jsClientConnect(t, s) + defer nc.Close() + + _, err := js.AddStream(&nats.StreamConfig{Name: "test"}) + require_NoError(t, err) + + _, err = js.UpdateStream(&nats.StreamConfig{ + Name: "test", + Sources: []*nats.StreamSource{ + { + Name: "mapping", + SubjectTransforms: []nats.SubjectTransformConfig{ + { + Source: "events.*", + Destination: "events.{{wildcard(1)}}{{split(3,1)}}", + }, + }, + }, + }, + }) + + require_Error(t, err, NewJSStreamUpdateError(errors.New("unable to get subject transform for source: invalid mapping destination: too many arguments passed to the function in {{wildcard(1)}}{{split(3,1)}}"))) +} diff --git a/server/memstore.go b/server/memstore.go index 6024df793f9..339005dc1ac 100644 --- a/server/memstore.go +++ b/server/memstore.go @@ -389,9 +389,9 @@ func (ms *memStore) filteredStateLocked(sseq uint64, filter string, lastPerSubje } } - tsa := [32]string{} - fsa := [32]string{} - fts := tokenizeSubjectIntoSlice(fsa[:0], filter) + _tsa, _fsa := [32]string{}, [32]string{} + tsa, fsa := _tsa[:0], _fsa[:0] + fsa = tokenizeSubjectIntoSlice(fsa[:0], filter) wc := subjectHasWildcard(filter) // 1. See if we match any subs from fss. @@ -405,8 +405,8 @@ func (ms *memStore) filteredStateLocked(sseq uint64, filter string, lastPerSubje if !wc { return subj == filter } - tts := tokenizeSubjectIntoSlice(tsa[:0], subj) - return isSubsetMatchTokenized(tts, fts) + tsa = tokenizeSubjectIntoSlice(tsa[:0], subj) + return isSubsetMatchTokenized(tsa, fsa) } update := func(fss *SimpleState) { @@ -602,9 +602,9 @@ func (ms *memStore) SubjectsTotals(filterSubject string) map[string]uint64 { return nil } - tsa := [32]string{} - fsa := [32]string{} - fts := tokenizeSubjectIntoSlice(fsa[:0], filterSubject) + _tsa, _fsa := [32]string{}, [32]string{} + tsa, fsa := _tsa[:0], _fsa[:0] + fsa = tokenizeSubjectIntoSlice(fsa[:0], filterSubject) isAll := filterSubject == _EMPTY_ || filterSubject == fwcs fst := make(map[string]uint64) @@ -613,7 +613,7 @@ func (ms *memStore) SubjectsTotals(filterSubject string) map[string]uint64 { if isAll { fst[subjs] = ss.Msgs } else { - if tts := tokenizeSubjectIntoSlice(tsa[:0], subjs); isSubsetMatchTokenized(tts, fts) { + if tsa = tokenizeSubjectIntoSlice(tsa[:0], subjs); isSubsetMatchTokenized(tsa, fsa) { fst[subjs] = ss.Msgs } } diff --git a/server/mqtt.go b/server/mqtt.go index 7ca49081914..33a00109929 100644 --- a/server/mqtt.go +++ b/server/mqtt.go @@ -974,7 +974,7 @@ func (s *Server) mqttHandleClosedClient(c *client) { // This needs to be done outside of any lock. if doClean { - if err := sess.clear(); err != nil { + if err := sess.clear(true); err != nil { c.Errorf(err.Error()) } } @@ -1449,7 +1449,7 @@ func (s *Server) mqttCreateAccountSessionManager(acc *Account, quitCh chan struc // Opportunistically delete the old (legacy) consumer, from v2.10.10 and // before. Ignore any errors that might arise. rmLegacyDurName := mqttRetainedMsgsStreamName + "_" + jsa.id - jsa.deleteConsumer(mqttRetainedMsgsStreamName, rmLegacyDurName) + jsa.deleteConsumer(mqttRetainedMsgsStreamName, rmLegacyDurName, true) // Create a new, uniquely names consumer for retained messages for this // server. The prior one will expire eventually. @@ -1672,8 +1672,21 @@ func (jsa *mqttJSA) createDurableConsumer(cfg *CreateConsumerRequest) (*JSApiCon return ccr, ccr.ToError() } -func (jsa *mqttJSA) deleteConsumer(streamName, consName string) (*JSApiConsumerDeleteResponse, error) { +func (jsa *mqttJSA) sendMsg(subj string, msg []byte) { + if subj == _EMPTY_ { + return + } + jsa.sendq.push(&mqttJSPubMsg{subj: subj, msg: msg, hdr: -1}) +} + +// if noWait is specified, does not wait for the JS response, returns nil +func (jsa *mqttJSA) deleteConsumer(streamName, consName string, noWait bool) (*JSApiConsumerDeleteResponse, error) { subj := fmt.Sprintf(JSApiConsumerDeleteT, streamName, consName) + if noWait { + jsa.sendMsg(subj, nil) + return nil, nil + } + cdri, err := jsa.newRequest(mqttJSAConsumerDel, subj, 0, nil) if err != nil { return nil, err @@ -1950,9 +1963,13 @@ func (as *mqttAccountSessionManager) processRetainedMsg(_ *subscription, c *clie } // If lastSeq is 0 (nothing to recover, or done doing it) and this is // from our own server, ignore. + as.mu.RLock() if as.rrmLastSeq == 0 && rm.Origin == as.jsa.id { + as.mu.RUnlock() return } + as.mu.RUnlock() + // At this point we either recover from our own server, or process a remote retained message. seq, _, _ := ackReplyInfo(reply) @@ -1960,11 +1977,13 @@ func (as *mqttAccountSessionManager) processRetainedMsg(_ *subscription, c *clie as.handleRetainedMsg(rm.Subject, &mqttRetainedMsgRef{sseq: seq}, rm, false) // If we were recovering (lastSeq > 0), then check if we are done. + as.mu.Lock() if as.rrmLastSeq > 0 && seq >= as.rrmLastSeq { as.rrmLastSeq = 0 close(as.rrmDoneCh) as.rrmDoneCh = nil } + as.mu.Unlock() } func (as *mqttAccountSessionManager) processRetainedMsgDel(_ *subscription, c *client, _ *Account, subject, reply string, rmsg []byte) { @@ -3072,7 +3091,7 @@ func (sess *mqttSession) save() error { // // Runs from the client's readLoop. // Lock not held on entry, but session is in the locked map. -func (sess *mqttSession) clear() error { +func (sess *mqttSession) clear(noWait bool) error { var durs []string var pubRelDur string @@ -3100,19 +3119,19 @@ func (sess *mqttSession) clear() error { sess.mu.Unlock() for _, dur := range durs { - if _, err := sess.jsa.deleteConsumer(mqttStreamName, dur); isErrorOtherThan(err, JSConsumerNotFoundErr) { + if _, err := sess.jsa.deleteConsumer(mqttStreamName, dur, noWait); isErrorOtherThan(err, JSConsumerNotFoundErr) { return fmt.Errorf("unable to delete consumer %q for session %q: %v", dur, sess.id, err) } } - if pubRelDur != "" { - _, err := sess.jsa.deleteConsumer(mqttOutStreamName, pubRelDur) + if pubRelDur != _EMPTY_ { + _, err := sess.jsa.deleteConsumer(mqttOutStreamName, pubRelDur, noWait) if isErrorOtherThan(err, JSConsumerNotFoundErr) { return fmt.Errorf("unable to delete consumer %q for session %q: %v", pubRelDur, sess.id, err) } } if seq > 0 { - err := sess.jsa.deleteMsg(mqttSessStreamName, seq, true) + err := sess.jsa.deleteMsg(mqttSessStreamName, seq, !noWait) // Ignore the various errors indicating that the message (or sequence) // is already deleted, can happen in a cluster. if isErrorOtherThan(err, JSSequenceNotFoundErrF) { @@ -3378,7 +3397,7 @@ func (sess *mqttSession) untrackPubRel(pi uint16) (jsAckSubject string) { func (sess *mqttSession) deleteConsumer(cc *ConsumerConfig) { sess.mu.Lock() sess.tmaxack -= cc.MaxAckPending - sess.jsa.sendq.push(&mqttJSPubMsg{subj: sess.jsa.prefixDomain(fmt.Sprintf(JSApiConsumerDeleteT, mqttStreamName, cc.Durable))}) + sess.jsa.deleteConsumer(mqttStreamName, cc.Durable, true) sess.mu.Unlock() } @@ -3717,7 +3736,7 @@ CHECK: // This Session lasts as long as the Network Connection. State data // associated with this Session MUST NOT be reused in any subsequent // Session. - if err := es.clear(); err != nil { + if err := es.clear(false); err != nil { asm.removeSession(es, true) return err } diff --git a/server/mqtt_ex_test.go b/server/mqtt_ex_bench_test.go similarity index 70% rename from server/mqtt_ex_test.go rename to server/mqtt_ex_bench_test.go index 44025bef806..efe2056f473 100644 --- a/server/mqtt_ex_test.go +++ b/server/mqtt_ex_bench_test.go @@ -17,51 +17,12 @@ package server import ( - "bytes" - "encoding/json" "fmt" - "os" - "os/exec" "strconv" "testing" "time" ) -func TestMQTTExCompliance(t *testing.T) { - mqttPath := os.Getenv("MQTT_CLI") - if mqttPath == "" { - if p, err := exec.LookPath("mqtt"); err == nil { - mqttPath = p - } - } - if mqttPath == "" { - t.Skip(`"mqtt" command is not found in $PATH nor $MQTT_CLI. See https://hivemq.github.io/mqtt-cli/docs/installation/#debian-package for installation instructions`) - } - - conf := createConfFile(t, []byte(fmt.Sprintf(` - listen: 127.0.0.1:-1 - server_name: mqtt - jetstream { - store_dir = %q - } - mqtt { - listen: 127.0.0.1:-1 - } - `, t.TempDir()))) - s, o := RunServerWithConfig(conf) - defer testMQTTShutdownServer(s) - - cmd := exec.Command(mqttPath, "test", "-V", "3", "-p", strconv.Itoa(o.MQTT.Port)) - - output, err := cmd.CombinedOutput() - t.Log(string(output)) - if err != nil { - if exitError, ok := err.(*exec.ExitError); ok { - t.Fatalf("mqtt cli exited with error: %v", exitError) - } - } -} - const ( KB = 1024 ) @@ -83,9 +44,6 @@ type mqttBenchContext struct { Host string Port int - - // full path to mqtt-test command - testCmdPath string } var mqttBenchDefaultMatrix = mqttBenchMatrix{ @@ -102,8 +60,12 @@ type MQTTBenchmarkResult struct { Bytes int64 `json:"bytes"` } -func BenchmarkMQTTEx(b *testing.B) { - bc := mqttNewBenchEx(b) +func BenchmarkXMQTT(b *testing.B) { + if mqttTestCommandPath == "" { + b.Skip(`"mqtt-test" command is not found in $PATH.`) + } + + bc := mqttBenchContext{} b.Run("Server", func(b *testing.B) { b.Cleanup(bc.startServer(b, false)) bc.runAll(b) @@ -142,11 +104,11 @@ func (bc mqttBenchContext) benchmarkPub(b *testing.B) { b.Run("PUB", func(b *testing.B) { m.runMatrix(b, bc, func(b *testing.B, bc *mqttBenchContext) { - bc.runCommand(b, "pub", + bc.runAndReport(b, "pub", "--qos", strconv.Itoa(bc.QOS), - "--n", strconv.Itoa(b.N), + "--messages", strconv.Itoa(b.N), "--size", strconv.Itoa(bc.MessageSize), - "--num-publishers", strconv.Itoa(bc.Publishers), + "--publishers", strconv.Itoa(bc.Publishers), ) }) }) @@ -165,11 +127,11 @@ func (bc mqttBenchContext) benchmarkPubRetained(b *testing.B) { b.Run("PUBRET", func(b *testing.B) { m.runMatrix(b, bc, func(b *testing.B, bc *mqttBenchContext) { - bc.runCommand(b, "pub", "--retain", + bc.runAndReport(b, "pub", "--retain", "--qos", strconv.Itoa(bc.QOS), - "--n", strconv.Itoa(b.N), + "--messages", strconv.Itoa(b.N), "--size", strconv.Itoa(bc.MessageSize), - "--num-publishers", strconv.Itoa(bc.Publishers), + "--publishers", strconv.Itoa(bc.Publishers), ) }) }) @@ -185,11 +147,11 @@ func (bc mqttBenchContext) benchmarkPubSub(b *testing.B) { b.Run("PUBSUB", func(b *testing.B) { m.runMatrix(b, bc, func(b *testing.B, bc *mqttBenchContext) { - bc.runCommand(b, "pubsub", + bc.runAndReport(b, "pubsub", "--qos", strconv.Itoa(bc.QOS), - "--n", strconv.Itoa(b.N), + "--messages", strconv.Itoa(b.N), "--size", strconv.Itoa(bc.MessageSize), - "--num-subscribers", strconv.Itoa(bc.Subscribers), + "--subscribers", strconv.Itoa(bc.Subscribers), ) }) }) @@ -206,67 +168,23 @@ func (bc mqttBenchContext) benchmarkSubRet(b *testing.B) { b.Run("SUBRET", func(b *testing.B) { m.runMatrix(b, bc, func(b *testing.B, bc *mqttBenchContext) { - bc.runCommand(b, "subret", + bc.runAndReport(b, "subret", "--qos", strconv.Itoa(bc.QOS), - "--n", strconv.Itoa(b.N), // number of subscribe requests - "--num-subscribers", strconv.Itoa(bc.Subscribers), - "--num-topics", strconv.Itoa(bc.Topics), + "--topics", strconv.Itoa(bc.Topics), // number of retained messages + "--subscribers", strconv.Itoa(bc.Subscribers), "--size", strconv.Itoa(bc.MessageSize), + "--repeat", strconv.Itoa(b.N), // number of subscribe requests ) }) }) } -func mqttBenchLookupCommand(b *testing.B, name string) string { +func (bc mqttBenchContext) runAndReport(b *testing.B, name string, extraArgs ...string) { b.Helper() - cmd, err := exec.LookPath(name) - if err != nil { - b.Skipf("%q command is not found in $PATH. Please `go install github.com/nats-io/meta-nats/apps/go/mqtt/...@latest` and try again.", name) - } - return cmd -} - -func (bc mqttBenchContext) runCommand(b *testing.B, name string, extraArgs ...string) { - b.Helper() - - args := append([]string{ - name, - "-q", - "--servers", fmt.Sprintf("%s:%d", bc.Host, bc.Port), - }, extraArgs...) - - cmd := exec.Command(bc.testCmdPath, args...) - stdout, err := cmd.StdoutPipe() - if err != nil { - b.Fatalf("Error executing %q: %v", cmd.String(), err) - } - defer stdout.Close() - errbuf := bytes.Buffer{} - cmd.Stderr = &errbuf - if err = cmd.Start(); err != nil { - b.Fatalf("Error executing %q: %v", cmd.String(), err) - } - r := &MQTTBenchmarkResult{} - if err = json.NewDecoder(stdout).Decode(r); err != nil { - b.Fatalf("failed to decode output of %q: %v\n\n%s", cmd.String(), err, errbuf.String()) - } - if err = cmd.Wait(); err != nil { - b.Fatalf("Error executing %q: %v\n\n%s", cmd.String(), err, errbuf.String()) - } - + r := mqttRunExCommandTest(b, name, mqttNewDial("", "", bc.Host, bc.Port, ""), extraArgs...) r.report(b) } -func (bc mqttBenchContext) initServer(b *testing.B) { - b.Helper() - bc.runCommand(b, "pubsub", - "--id", "__init__", - "--qos", "0", - "--n", "1", - "--size", "100", - "--num-subscribers", "1") -} - func (bc *mqttBenchContext) startServer(b *testing.B, disableRMSCache bool) func() { b.Helper() b.StopTimer() @@ -278,7 +196,7 @@ func (bc *mqttBenchContext) startServer(b *testing.B, disableRMSCache bool) func o = s.getOpts() bc.Host = o.MQTT.Host bc.Port = o.MQTT.Port - bc.initServer(b) + mqttInitTestServer(b, mqttNewDial("", "", bc.Host, bc.Port, "")) return func() { testMQTTShutdownServer(s) testDisableRMSCache = prevDisableRMSCache @@ -314,7 +232,7 @@ func (bc *mqttBenchContext) startCluster(b *testing.B, disableRMSCache bool) fun o := cl.randomNonLeader().getOpts() bc.Host = o.MQTT.Host bc.Port = o.MQTT.Port - bc.initServer(b) + mqttInitTestServer(b, mqttNewDial("", "", bc.Host, bc.Port, "")) return func() { cl.shutdown() testDisableRMSCache = prevDisableRMSCache @@ -410,15 +328,7 @@ func (r MQTTBenchmarkResult) report(b *testing.B) { nsOp := float64(ns) / float64(r.Ops) b.ReportMetric(nsOp/1000000, unit+"_ms/op") } - // Diable ReportAllocs() since it confuses the github benchmarking action // with the noise. // b.ReportAllocs() } - -func mqttNewBenchEx(b *testing.B) *mqttBenchContext { - cmd := mqttBenchLookupCommand(b, "mqtt-test") - return &mqttBenchContext{ - testCmdPath: cmd, - } -} diff --git a/server/mqtt_ex_test_test.go b/server/mqtt_ex_test_test.go new file mode 100644 index 00000000000..9acad558779 --- /dev/null +++ b/server/mqtt_ex_test_test.go @@ -0,0 +1,383 @@ +// Copyright 2024 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !skip_mqtt_tests +// +build !skip_mqtt_tests + +package server + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "os" + "os/exec" + "strconv" + "strings" + "testing" + + "github.com/nats-io/nuid" +) + +type mqttDial string + +type mqttTarget struct { + singleServers []*Server + clusters []*cluster + configs []mqttTestConfig + all []mqttDial +} + +type mqttTestConfig struct { + name string + pub []mqttDial + sub []mqttDial +} + +func TestXMQTTCompliance(t *testing.T) { + if mqttCLICommandPath == _EMPTY_ { + t.Skip(`"mqtt" command is not found in $PATH nor $MQTT_CLI. See https://hivemq.github.io/mqtt-cli/docs/installation/#debian-package for installation instructions`) + } + + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + o = s.getOpts() + defer testMQTTShutdownServer(s) + + cmd := exec.Command(mqttCLICommandPath, "test", "-V", "3", "-p", strconv.Itoa(o.MQTT.Port)) + + output, err := cmd.CombinedOutput() + t.Log(string(output)) + if err != nil { + if exitError, ok := err.(*exec.ExitError); ok { + t.Fatalf("mqtt cli exited with error: %v", exitError) + } + } +} + +func TestXMQTTRetainedMessages(t *testing.T) { + if mqttTestCommandPath == _EMPTY_ { + t.Skip(`"mqtt-test" command is not found in $PATH.`) + } + + for _, topo := range []struct { + name string + makef func(testing.TB) *mqttTarget + }{ + { + name: "single server", + makef: mqttMakeTestServer, + }, + { + name: "cluster", + makef: mqttMakeTestCluster(4, ""), + }, + } { + t.Run(topo.name, func(t *testing.T) { + target := topo.makef(t) + t.Cleanup(target.Shutdown) + + // initialize the MQTT assets by "touching" all nodes in the + // cluster, but then reload to start with fresh server state. + for _, dial := range target.all { + mqttInitTestServer(t, dial) + } + + numRMS := 100 + strNumRMS := strconv.Itoa(numRMS) + topics := make([]string, len(target.configs)) + + for i, tc := range target.configs { + // Publish numRMS retained messages one at a time, + // round-robin across pub nodes. Remember the topic for each + // test config to check the subs after reload. + topic := "subret_" + nuid.Next() + topics[i] = topic + iNode := 0 + for i := 0; i < numRMS; i++ { + pubTopic := fmt.Sprintf("%s/%d", topic, i) + dial := tc.pub[iNode%len(tc.pub)] + mqttRunExCommandTest(t, "pub", dial, + "--retain", + "--topic", pubTopic, + "--qos", "0", + "--size", "128", // message size 128 bytes + ) + iNode++ + } + } + + // Check all sub nodes for retained messages + for i, tc := range target.configs { + for _, dial := range tc.sub { + mqttRunExCommandTest(t, "sub", dial, + "--retained", strNumRMS, + "--qos", "0", + "--topic", topics[i], + ) + } + } + + // Reload the target + target.Reload(t) + + // Now check again + for i, tc := range target.configs { + for _, dial := range tc.sub { + mqttRunExCommandTestRetry(t, 1, "sub", dial, + "--retained", strNumRMS, + "--qos", "0", + "--topic", topics[i], + ) + } + } + }) + } +} + +func mqttInitTestServer(tb testing.TB, dial mqttDial) { + tb.Helper() + mqttRunExCommandTestRetry(tb, 5, "pub", dial) +} + +func mqttNewDialForServer(s *Server, username, password string) mqttDial { + o := s.getOpts().MQTT + return mqttNewDial(username, password, o.Host, o.Port, s.Name()) +} + +func mqttNewDial(username, password, host string, port int, comment string) mqttDial { + d := "" + switch { + case username != "" && password != "": + d = fmt.Sprintf("%s:%s@%s:%d", username, password, host, port) + case username != "": + d = fmt.Sprintf("%s@%s:%d", username, host, port) + default: + d = fmt.Sprintf("%s:%d", host, port) + } + if comment != "" { + d += "#" + comment + } + return mqttDial(d) +} + +func (d mqttDial) Get() (u, p, s, c string) { + if d == "" { + return "", "", "127.0.0.1:1883", "" + } + in := string(d) + if i := strings.LastIndex(in, "#"); i != -1 { + c = in[i+1:] + in = in[:i] + } + if i := strings.LastIndex(in, "@"); i != -1 { + up := in[:i] + in = in[i+1:] + u = up + if i := strings.Index(up, ":"); i != -1 { + u = up[:i] + p = up[i+1:] + } + } + s = in + return u, p, s, c +} + +func (d mqttDial) Name() string { + _, _, _, c := d.Get() + return c +} + +func (t *mqttTarget) Reload(tb testing.TB) { + tb.Helper() + + for _, c := range t.clusters { + c.stopAll() + c.restartAllSamePorts() + } + + for i, s := range t.singleServers { + o := s.getOpts() + s.Shutdown() + t.singleServers[i] = testMQTTRunServer(tb, o) + } + + for _, dial := range t.all { + mqttInitTestServer(tb, dial) + } +} + +func (t *mqttTarget) Shutdown() { + for _, c := range t.clusters { + c.shutdown() + } + for _, s := range t.singleServers { + testMQTTShutdownServer(s) + } +} + +func mqttMakeTestServer(tb testing.TB) *mqttTarget { + tb.Helper() + o := testMQTTDefaultOptions() + s := testMQTTRunServer(tb, o) + all := []mqttDial{mqttNewDialForServer(s, "", "")} + return &mqttTarget{ + singleServers: []*Server{s}, + all: all, + configs: []mqttTestConfig{ + { + name: "single server", + pub: all, + sub: all, + }, + }, + } +} + +func mqttMakeTestCluster(size int, domain string) func(tb testing.TB) *mqttTarget { + return func(tb testing.TB) *mqttTarget { + tb.Helper() + if size < 3 { + tb.Fatal("cluster size must be at least 3") + } + + if domain != "" { + domain = "domain: " + domain + ", " + } + clusterConf := ` + listen: 127.0.0.1:-1 + + server_name: %s + jetstream: {max_mem_store: 256MB, max_file_store: 2GB, ` + domain + `store_dir: '%s'} + + leafnodes { + listen: 127.0.0.1:-1 + } + + cluster { + name: %s + listen: 127.0.0.1:%d + routes = [%s] + } + + mqtt { + listen: 127.0.0.1:-1 + stream_replicas: 3 + } + + accounts { + ONE { users = [ { user: "one", pass: "p" } ]; jetstream: enabled } + $SYS { users = [ { user: "admin", pass: "s3cr3t!" } ] } + } +` + cl := createJetStreamClusterWithTemplate(tb, clusterConf, "MQTT", size) + cl.waitOnLeader() + + all := []mqttDial{} + for _, s := range cl.servers { + all = append(all, mqttNewDialForServer(s, "one", "p")) + } + + return &mqttTarget{ + clusters: []*cluster{cl}, + all: all, + configs: []mqttTestConfig{ + { + name: "publish to one", + pub: []mqttDial{ + mqttNewDialForServer(cl.randomServer(), "one", "p"), + }, + sub: all, + }, + { + name: "publish to all", + pub: all, + sub: all, + }, + }, + } + } +} + +var mqttCLICommandPath = func() string { + p := os.Getenv("MQTT_CLI") + if p == "" { + p, _ = exec.LookPath("mqtt") + } + return p +}() + +var mqttTestCommandPath = func() string { + p, _ := exec.LookPath("mqtt-test") + return p +}() + +func mqttRunExCommandTest(tb testing.TB, subCommand string, dial mqttDial, extraArgs ...string) *MQTTBenchmarkResult { + tb.Helper() + return mqttRunExCommandTestRetry(tb, 1, subCommand, dial, extraArgs...) +} + +func mqttRunExCommandTestRetry(tb testing.TB, n int, subCommand string, dial mqttDial, extraArgs ...string) (r *MQTTBenchmarkResult) { + tb.Helper() + var err error + for i := 0; i < n; i++ { + if r, err = mqttTryExCommandTest(tb, subCommand, dial, extraArgs...); err == nil { + return r + } + + if i < (n - 1) { + tb.Logf("failed to %q %s to %q on attempt %v, will retry.", subCommand, extraArgs, dial.Name(), i) + } else { + tb.Fatal(err) + } + } + return nil +} + +func mqttTryExCommandTest(tb testing.TB, subCommand string, dial mqttDial, extraArgs ...string) (r *MQTTBenchmarkResult, err error) { + tb.Helper() + if mqttTestCommandPath == "" { + tb.Skip(`"mqtt-test" command is not found in $PATH.`) + } + + args := []string{subCommand, // "-q", + "-s", string(dial), + } + args = append(args, extraArgs...) + cmd := exec.Command(mqttTestCommandPath, args...) + + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("error executing %q: %v", cmd.String(), err) + } + defer stdout.Close() + errbuf := bytes.Buffer{} + cmd.Stderr = &errbuf + if err = cmd.Start(); err != nil { + return nil, fmt.Errorf("error executing %q: %v", cmd.String(), err) + } + out, err := io.ReadAll(stdout) + if err != nil { + return nil, fmt.Errorf("error executing %q: failed to read output: %v", cmd.String(), err) + } + if err = cmd.Wait(); err != nil { + return nil, fmt.Errorf("error executing %q: %v\n\n%s\n\n%s", cmd.String(), err, string(out), errbuf.String()) + } + + r = &MQTTBenchmarkResult{} + if err := json.Unmarshal(out, r); err != nil { + tb.Fatalf("error executing %q: failed to decode output: %v\n\n%s\n\n%s", cmd.String(), err, string(out), errbuf.String()) + } + return r, nil +} diff --git a/server/mqtt_test.go b/server/mqtt_test.go index 613e21d41c8..09b372b910a 100644 --- a/server/mqtt_test.go +++ b/server/mqtt_test.go @@ -6780,65 +6780,6 @@ func TestMQTTConsumerMemStorageReload(t *testing.T) { } } -type unableToDeleteConsLogger struct { - DummyLogger - errCh chan string -} - -func (l *unableToDeleteConsLogger) Errorf(format string, args ...any) { - msg := fmt.Sprintf(format, args...) - if strings.Contains(msg, "unable to delete consumer") { - l.errCh <- msg - } -} - -func TestMQTTSessionNotDeletedOnDeleteConsumerError(t *testing.T) { - org := mqttJSAPITimeout - mqttJSAPITimeout = 1000 * time.Millisecond - defer func() { mqttJSAPITimeout = org }() - - cl := createJetStreamClusterWithTemplate(t, testMQTTGetClusterTemplaceNoLeaf(), "MQTT", 2) - defer cl.shutdown() - - o := cl.opts[0] - s1 := cl.servers[0] - // Plug error logger to s1 - l := &unableToDeleteConsLogger{errCh: make(chan string, 10)} - s1.SetLogger(l, false, false) - - nc, js := jsClientConnect(t, s1) - defer nc.Close() - - mc, r := testMQTTConnectRetry(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port, 5) - defer mc.Close() - testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) - - testMQTTSub(t, 1, mc, r, []*mqttFilter{{filter: "foo", qos: 1}}, []byte{1}) - testMQTTFlush(t, mc, nil, r) - - // Now shutdown server 2, we should lose quorum - cl.servers[1].Shutdown() - - // Close the MQTT client: - testMQTTDisconnect(t, mc, nil) - - // We should have reported that there was an error deleting the consumer - select { - case <-l.errCh: - // OK - case <-time.After(time.Second): - t.Fatal("Server did not report any error") - } - - // Now restart the server 2 so that we can check that the session is still persisted. - cl.restartAllSamePorts() - cl.waitOnStreamLeader(globalAccountName, mqttSessStreamName) - - si, err := js.StreamInfo(mqttSessStreamName) - require_NoError(t, err) - require_True(t, si.State.Msgs == 1) -} - // Test for auto-cleanup of consumers. func TestMQTTConsumerInactiveThreshold(t *testing.T) { tdir := t.TempDir() diff --git a/server/stream.go b/server/stream.go index 716409bb1f9..a09afdbf323 100644 --- a/server/stream.go +++ b/server/stream.go @@ -1841,7 +1841,7 @@ func (mset *stream) updateWithAdvisory(config *StreamConfig, sendAdvisory bool) si.trs[i], err = NewSubjectTransform(s.SubjectTransforms[i].Source, s.SubjectTransforms[i].Destination) if err != nil { mset.mu.Unlock() - mset.srv.Errorf("Unable to get subject transform for source: %v", err) + return fmt.Errorf("unable to get subject transform for source: %v", err) } } }