diff --git a/server/jetstream_test.go b/server/jetstream_test.go index 36c478a7b6f..f0e1cd57fec 100644 --- a/server/jetstream_test.go +++ b/server/jetstream_test.go @@ -24250,3 +24250,79 @@ func TestJetStreamSourceRemovalAndReAdd(t *testing.T) { require_Equal(t, m.Subject, fmt.Sprintf("foo.%d", i)) } } + +func TestJetStreamRateLimitHighStreamIngest(t *testing.T) { + cfgFmt := []byte(fmt.Sprintf(` + jetstream: { + enabled: true + store_dir: %s + max_buffered_size: 1kb + max_buffered_msgs: 1 + } + `, t.TempDir())) + + conf := createConfFile(t, cfgFmt) + s, opts := RunServerWithConfig(conf) + defer s.Shutdown() + + require_Equal(t, opts.StreamMaxBufferedSize, 1024) + require_Equal(t, opts.StreamMaxBufferedMsgs, 1) + + nc, js := jsClientConnect(t, s) + defer nc.Close() + + _, err := js.AddStream(&nats.StreamConfig{ + Name: "TEST", + Subjects: []string{"test"}, + }) + require_NoError(t, err) + + // Create a reply inbox that we can await API requests on. + // This is instead of using nc.Request(). + inbox := nc.NewRespInbox() + resp := make(chan *nats.Msg, 1000) + _, err = nc.ChanSubscribe(inbox, resp) + require_NoError(t, err) + + // Publish a large number of messages using Core NATS withou + // waiting for the responses from the API. + msg := &nats.Msg{ + Subject: "test", + Reply: inbox, + } + for i := 0; i < 1000; i++ { + require_NoError(t, nc.PublishMsg(msg)) + } + + // Now sort through the API responses. We're looking for one + // that tells us that we were rate-limited. If we don't find + // one then we fail the test. + var rateLimited bool + for i, msg := 0, <-resp; i < 1000; i, msg = i+1, <-resp { + if msg.Header.Get("Status") == "429" { + rateLimited = true + break + } + } + require_True(t, rateLimited) +} + +func TestJetStreamRateLimitHighStreamIngestDefaults(t *testing.T) { + s := RunBasicJetStreamServer(t) + defer s.Shutdown() + + nc, js := jsClientConnect(t, s) + defer nc.Close() + + _, err := js.AddStream(&nats.StreamConfig{ + Name: "TEST", + Subjects: []string{"test"}, + }) + require_NoError(t, err) + + stream, err := s.globalAccount().lookupStream("TEST") + require_NoError(t, err) + + require_Equal(t, stream.msgs.mlen, streamDefaultMaxQueueMsgs) + require_Equal(t, stream.msgs.msz, streamDefaultMaxQueueBytes) +} diff --git a/server/opts.go b/server/opts.go index 38f0d181d1a..dcd029492c0 100644 --- a/server/opts.go +++ b/server/opts.go @@ -331,6 +331,8 @@ type Options struct { JetStreamLimits JSLimitOpts JetStreamTpm JSTpmOpts JetStreamMaxCatchup int64 + StreamMaxBufferedMsgs int + StreamMaxBufferedSize int64 StoreDir string `json:"-"` SyncInterval time.Duration `json:"-"` SyncAlways bool `json:"-"` @@ -2373,6 +2375,18 @@ func parseJetStream(v any, opts *Options, errors *[]error, warnings *[]error) er return &configErr{tk, fmt.Sprintf("%s %s", strings.ToLower(mk), err)} } opts.JetStreamMaxCatchup = s + case "max_buffered_size": + s, err := getStorageSize(mv) + if err != nil { + return &configErr{tk, fmt.Sprintf("%s %s", strings.ToLower(mk), err)} + } + opts.StreamMaxBufferedSize = s + case "max_buffered_msgs": + mlen, ok := mv.(int64) + if !ok { + return &configErr{tk, fmt.Sprintf("Expected a parseable size for %q, got %v", mk, mv)} + } + opts.StreamMaxBufferedMsgs = int(mlen) default: if !tk.IsUsedVariable() { err := &unknownConfigFieldErr{ diff --git a/server/stream.go b/server/stream.go index 1230203aec0..85c550f620e 100644 --- a/server/stream.go +++ b/server/stream.go @@ -218,6 +218,12 @@ type ExternalStream struct { DeliverPrefix string `json:"deliver"` } +// For managing stream ingest. +const ( + streamDefaultMaxQueueMsgs = 10_000 + streamDefaultMaxQueueBytes = 1024 * 1024 * 128 +) + // Stream is a jetstream stream of messages. When we receive a message internally destined // for a Stream we will direct link from the client to this structure. type stream struct { @@ -576,6 +582,16 @@ func (a *Account) addStreamWithAssignment(config *StreamConfig, fsConfig *FileSt c := s.createInternalJetStreamClient() ic := s.createInternalJetStreamClient() + // Work out the stream ingest limits. + mlen := s.opts.StreamMaxBufferedMsgs + msz := uint64(s.opts.StreamMaxBufferedSize) + if mlen == 0 { + mlen = streamDefaultMaxQueueMsgs + } + if msz == 0 { + msz = streamDefaultMaxQueueBytes + } + qpfx := fmt.Sprintf("[ACC:%s] stream '%s' ", a.Name, config.Name) mset := &stream{ acc: a, @@ -588,12 +604,18 @@ func (a *Account) addStreamWithAssignment(config *StreamConfig, fsConfig *FileSt tier: tier, stype: cfg.Storage, consumers: make(map[string]*consumer), - msgs: newIPQueue[*inMsg](s, qpfx+"messages"), - gets: newIPQueue[*directGetReq](s, qpfx+"direct gets"), - qch: make(chan struct{}), - mqch: make(chan struct{}), - uch: make(chan struct{}, 4), - sch: make(chan struct{}, 1), + msgs: newIPQueue[*inMsg](s, qpfx+"messages", + ipqSizeCalculation(func(msg *inMsg) uint64 { + return uint64(len(msg.hdr) + len(msg.msg) + len(msg.rply) + len(msg.subj)) + }), + ipqLimitByLen[*inMsg](mlen), + ipqLimitBySize[*inMsg](msz), + ), + gets: newIPQueue[*directGetReq](s, qpfx+"direct gets"), + qch: make(chan struct{}), + mqch: make(chan struct{}), + uch: make(chan struct{}, 4), + sch: make(chan struct{}, 1), } // Start our signaling routine to process consumers. @@ -4156,7 +4178,13 @@ func (im *inMsg) returnToPool() { func (mset *stream) queueInbound(ib *ipQueue[*inMsg], subj, rply string, hdr, msg []byte, si *sourceInfo, mt *msgTrace) { im := inMsgPool.Get().(*inMsg) im.subj, im.rply, im.hdr, im.msg, im.si, im.mt = subj, rply, hdr, msg, si, mt - ib.push(im) + if _, err := ib.push(im); err != nil { + mset.srv.RateLimitWarnf("Dropping messages due to excessive stream ingest rate on '%s' > '%s': %s", mset.acc.Name, mset.name(), err) + if rply != _EMPTY_ { + hdr := []byte("NATS/1.0 429 Too Many Requests\r\n\r\n") + mset.outq.send(newJSPubMsg(rply, _EMPTY_, _EMPTY_, hdr, nil, nil, 0)) + } + } } var dgPool = sync.Pool{