diff --git a/connection_impl.go b/connection_impl.go index db2f8d69..aa15e1e9 100644 --- a/connection_impl.go +++ b/connection_impl.go @@ -24,8 +24,14 @@ import ( "time" ) +type connState = int32 + const ( defaultZeroCopyTimeoutSec = 60 + + connStateNone = 0 + connStateConnected = 1 + connStateDisconnected = 2 ) // connection is the implement of Connection @@ -45,9 +51,9 @@ type connection struct { outputBuffer *LinkBuffer outputBarrier *barrier supportZeroCopy bool - maxSize int // The maximum size of data between two Release(). - bookSize int // The size of data that can be read at once. - state int32 // 0: not connected, 1: connected, 2: disconnected. Connection state should be changed sequentially. + maxSize int // The maximum size of data between two Release(). + bookSize int // The size of data that can be read at once. + state connState // Connection state should be changed sequentially. } var ( @@ -333,7 +339,7 @@ func (c *connection) init(conn Conn, opts *options) (err error) { c.bookSize, c.maxSize = defaultLinkBufferSize, defaultLinkBufferSize c.inputBuffer, c.outputBuffer = NewLinkBuffer(defaultLinkBufferSize), NewLinkBuffer() c.outputBarrier = barrierPool.Get().(*barrier) - c.state = 0 + c.state = connStateNone c.initNetFD(conn) // conn must be *netFD{} c.initFDOperator() @@ -536,3 +542,15 @@ func (c *connection) waitFlush() (err error) { return Exception(ErrWriteTimeout, c.remoteAddr.String()) } } + +func (c *connection) getState() connState { + return atomic.LoadInt32(&c.state) +} + +func (c *connection) setState(newState connState) { + atomic.StoreInt32(&c.state, newState) +} + +func (c *connection) changeState(from, to connState) bool { + return atomic.CompareAndSwapInt32(&c.state, from, to) +} diff --git a/connection_onevent.go b/connection_onevent.go index 35b7c001..5dc986da 100644 --- a/connection_onevent.go +++ b/connection_onevent.go @@ -134,7 +134,7 @@ func (c *connection) onPrepare(opts *options) (err error) { func (c *connection) onConnect() { var onConnect, _ = c.onConnectCallback.Load().(OnConnect) if onConnect == nil { - atomic.StoreInt32(&c.state, 1) + c.changeState(connStateNone, connStateConnected) return } if !c.lock(connecting) { @@ -142,35 +142,7 @@ func (c *connection) onConnect() { return } var onRequest, _ = c.onRequestCallback.Load().(OnRequest) - c.onProcess( - // only process when conn active and have unread data - func(c *connection) bool { - // if onConnect not called - if atomic.LoadInt32(&c.state) == 0 { - return true - } - // check for onRequest - return onRequest != nil && c.Reader().Len() > 0 - }, - func(c *connection) { - if atomic.CompareAndSwapInt32(&c.state, 0, 1) { - c.ctx = onConnect(c.ctx, c) - - if !c.IsActive() && atomic.CompareAndSwapInt32(&c.state, 1, 2) { - // since we hold connecting lock, so we should help to call onDisconnect here - var onDisconnect, _ = c.onDisconnectCallback.Load().(OnDisconnect) - if onDisconnect != nil { - onDisconnect(c.ctx, c) - } - } - c.unlock(connecting) - return - } - if onRequest != nil { - _ = onRequest(c.ctx, c) - } - }, - ) + c.onProcess(onConnect, onRequest) } // when onDisconnect called, c.IsActive() must return false @@ -182,15 +154,16 @@ func (c *connection) onDisconnect() { var onConnect, _ = c.onConnectCallback.Load().(OnConnect) if onConnect == nil { // no need lock if onConnect is nil - atomic.StoreInt32(&c.state, 2) + // it's ok to force set state to disconnected since onConnect is nil + c.setState(connStateDisconnected) onDisconnect(c.ctx, c) return } // check if OnConnect finished when onConnect != nil && onDisconnect != nil - if atomic.LoadInt32(&c.state) > 0 && c.lock(connecting) { // means OnConnect already finished + if c.getState() != connStateNone && c.lock(connecting) { // means OnConnect already finished // protect onDisconnect run once // if CAS return false, means OnConnect already helps to run onDisconnect - if atomic.CompareAndSwapInt32(&c.state, 1, 2) { + if c.changeState(connStateConnected, connStateDisconnected) { onDisconnect(c.ctx, c) } c.unlock(connecting) @@ -207,63 +180,66 @@ func (c *connection) onRequest() (needTrigger bool) { return true } // wait onConnect finished first - if atomic.LoadInt32(&c.state) == 0 && c.onConnectCallback.Load() != nil { + if c.getState() == connStateNone && c.onConnectCallback.Load() != nil { // let onConnect to call onRequest return } - processed := c.onProcess( - // only process when conn active and have unread data - func(c *connection) bool { - return c.Reader().Len() > 0 - }, - func(c *connection) { - _ = onRequest(c.ctx, c) - }, - ) + processed := c.onProcess(nil, onRequest) // if not processed, should trigger read return !processed } -// onProcess is responsible for executing the process function serially, -// and make sure the connection has been closed correctly if user call c.Close() in process function. -func (c *connection) onProcess(isProcessable func(c *connection) bool, process func(c *connection)) (processed bool) { - if process == nil { - return false - } +// onProcess is responsible for executing the onConnect/onRequest function serially, +// and make sure the connection has been closed correctly if user call c.Close() in onConnect/onRequest function. +func (c *connection) onProcess(onConnect OnConnect, onRequest OnRequest) (processed bool) { // task already exists if !c.lock(processing) { return false } - // add new task - var task = func() { + + task := func() { panicked := true defer func() { + if !panicked { + return + } // cannot use recover() here, since we don't want to break the panic stack - if panicked { - c.unlock(processing) - if c.IsActive() { - c.Close() - } else { - c.closeCallback(false, false) - } + c.unlock(processing) + if c.IsActive() { + c.Close() + } else { + c.closeCallback(false, false) } }() + // trigger onConnect first + if onConnect != nil && c.changeState(connStateNone, connStateConnected) { + c.ctx = onConnect(c.ctx, c) + if !c.IsActive() && c.changeState(connStateConnected, connStateDisconnected) { + // since we hold connecting lock, so we should help to call onDisconnect here + onDisconnect, _ := c.onDisconnectCallback.Load().(OnDisconnect) + if onDisconnect != nil { + onDisconnect(c.ctx, c) + } + } + c.unlock(connecting) + } START: - // `process` must be executed at least once if `isProcessable` in order to cover the `send & close by peer` case. - // Then the loop processing must ensure that the connection `IsActive`. - if isProcessable(c) { - process(c) + // The `onRequest` must be executed at least once if conn have any readable data, + // which is in order to cover the `send & close by peer` case. + if onRequest != nil && c.Reader().Len() > 0 { + _ = onRequest(c.ctx, c) } - // `process` must either eventually read all the input data or actively Close the connection, + // The processing loop must ensure that the connection meets `IsActive`. + // `onRequest` must either eventually read all the input data or actively Close the connection, // otherwise the goroutine will fall into a dead loop. var closedBy who for { closedBy = c.status(closing) - // close by user or no processable - if closedBy == user || !isProcessable(c) { + // close by user or not processable + if closedBy == user || onRequest == nil || c.Reader().Len() == 0 { break } - process(c) + _ = onRequest(c.ctx, c) } // handling callback if connection has been closed. if closedBy != none { @@ -288,14 +264,15 @@ func (c *connection) onProcess(isProcessable func(c *connection) bool, process f panicked = false return } - // double check isProcessable - if isProcessable(c) && c.lock(processing) { + // double check is processable + if onRequest != nil && c.Reader().Len() > 0 && c.lock(processing) { goto START } // task exits panicked = false return } + // add new task runTask(c.ctx, task) return true } diff --git a/connection_test.go b/connection_test.go index 163645ab..65dd69ff 100644 --- a/connection_test.go +++ b/connection_test.go @@ -32,6 +32,30 @@ import ( "time" ) +func BenchmarkConnectionIO(b *testing.B) { + var dataSize = 1024 * 16 + var writeBuffer = make([]byte, dataSize) + var rfd, wfd = GetSysFdPairs() + var rconn, wconn = new(connection), new(connection) + rconn.init(&netFD{fd: rfd}, &options{onRequest: func(ctx context.Context, connection Connection) error { + read, _ := connection.Reader().Next(dataSize) + _ = wconn.Reader().Release() + _, _ = connection.Writer().WriteBinary(read) + _ = connection.Writer().Flush() + return nil + }}) + wconn.init(&netFD{fd: wfd}, new(options)) + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _ = wconn.WriteBinary(writeBuffer) + _ = wconn.Flush() + _, _ = wconn.Reader().Next(dataSize) + _ = wconn.Reader().Release() + } +} + func TestConnectionWrite(t *testing.T) { var cycle, caps = 10000, 256 var msg, buf = make([]byte, caps), make([]byte, caps) diff --git a/go.mod b/go.mod index b9a2a4bf..5ff5a988 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,6 @@ module github.com/cloudwego/netpoll go 1.15 require ( - github.com/bytedance/gopkg v0.0.0-20240507064146-197ded923ae3 + github.com/bytedance/gopkg v0.1.0 golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 ) diff --git a/go.sum b/go.sum index 49445cde..36af6c14 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/bytedance/gopkg v0.0.0-20240507064146-197ded923ae3 h1:ZKUHguI38SRQJkq7hhmwn8lAv3xM6B5qkj1IneS15YY= -github.com/bytedance/gopkg v0.0.0-20240507064146-197ded923ae3/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= +github.com/bytedance/gopkg v0.1.0 h1:aAxB7mm1qms4Wz4sp8e1AtKDOeFLtdqvGiUe7aonRJs= +github.com/bytedance/gopkg v0.1.0/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/netpoll_config.go b/netpoll_config.go index dda65a35..85c05925 100644 --- a/netpoll_config.go +++ b/netpoll_config.go @@ -17,14 +17,12 @@ package netpoll import ( "context" "io" - "time" ) // global config var ( - defaultLinkBufferSize = pagesize - defaultGracefulShutdownCheckInterval = time.Second - featureAlwaysNoCopyRead = false + defaultLinkBufferSize = pagesize + featureAlwaysNoCopyRead = false ) // Config expose some tuning parameters to control the internal behaviors of netpoll. diff --git a/netpoll_server.go b/netpoll_server.go index ace92ba6..248ae909 100644 --- a/netpoll_server.go +++ b/netpoll_server.go @@ -63,27 +63,33 @@ func (s *server) Close(ctx context.Context) error { s.operator.Control(PollDetach) s.ln.Close() - var ticker = time.NewTicker(defaultGracefulShutdownCheckInterval) - defer ticker.Stop() - var hasConn bool for { - hasConn = false + activeConn := 0 s.connections.Range(func(key, value interface{}) bool { var conn, ok = value.(gracefulExit) if !ok || conn.isIdle() { value.(Connection).Close() + } else { + activeConn++ } - hasConn = true return true }) - if !hasConn { // all connections have been closed + if activeConn == 0 { // all connections have been closed return nil } + // smart control graceful shutdown check internal + // we should wait for more time if there are more active connections + waitTime := time.Millisecond * time.Duration(activeConn) + if waitTime > time.Second { // max wait time is 1000 ms + waitTime = time.Millisecond * 1000 + } else if waitTime < time.Millisecond*50 { // min wait time is 50 ms + waitTime = time.Millisecond * 50 + } select { case <-ctx.Done(): return ctx.Err() - case <-ticker.C: + case <-time.After(waitTime): continue } } diff --git a/netpoll_unix_test.go b/netpoll_unix_test.go index 70638c57..24689fd0 100644 --- a/netpoll_unix_test.go +++ b/netpoll_unix_test.go @@ -64,15 +64,6 @@ func Assert(t *testing.T, cond bool, val ...interface{}) { } } -func TestMain(m *testing.M) { - // defaultGracefulShutdownCheckInterval will affect shutdown function running time, - // so for speed up tests, we change it to 10ms here - oldGracefulShutdownCheckInterval := defaultGracefulShutdownCheckInterval - defaultGracefulShutdownCheckInterval = time.Millisecond * 10 - m.Run() - defaultGracefulShutdownCheckInterval = oldGracefulShutdownCheckInterval -} - var testPort int32 = 10000 // getTestAddress return a unique port for every tests, so all tests will not share a same listerner diff --git a/poll_manager_test.go b/poll_manager_test.go index c5648a76..f61c5282 100644 --- a/poll_manager_test.go +++ b/poll_manager_test.go @@ -61,7 +61,7 @@ func TestPollManagerSetNumLoops(t *testing.T) { poll := pm.Pick() newGs := runtime.NumGoroutine() Assert(t, poll != nil) - Assert(t, newGs-startGs == 1, newGs, startGs) + Assert(t, newGs-startGs >= 1, newGs, startGs) t.Logf("old=%d, new=%d", startGs, newGs) // change pollers diff --git a/poll_test.go b/poll_test.go index 5980dde7..b3c7f2e8 100644 --- a/poll_test.go +++ b/poll_test.go @@ -90,7 +90,7 @@ func TestPollMod(t *testing.T) { runtime.Gosched() } r, w, h = atomic.LoadInt32(&rn), atomic.LoadInt32(&wn), atomic.LoadInt32(&hn) - Assert(t, r == 0 && w == 1 && h == 0, r, w, h) + Assert(t, r == 0 && w >= 1 && h == 0, r, w, h) err = p.Control(rop, PollR2RW) // trigger write MustNil(t, err)