diff --git a/read.go b/read.go index f2c7a801..381cea3d 100644 --- a/read.go +++ b/read.go @@ -304,7 +304,9 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro defer c.readMu.unlock() if !c.msgReader.fin { - return 0, nil, errors.New("previous message not read to completion") + err = errors.New("previous message not read to completion") + c.close(fmt.Errorf("failed to get reader: %w", err)) + return 0, nil, err } h, err := c.readLoop(ctx) @@ -361,21 +363,9 @@ func (mr *msgReader) setFrame(h header) { } func (mr *msgReader) Read(p []byte) (n int, err error) { - defer func() { - if errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate { - err = io.EOF - } - if errors.Is(err, io.EOF) { - err = io.EOF - mr.putFlateReader() - return - } - errd.Wrap(&err, "failed to read") - }() - err = mr.c.readMu.lock(mr.ctx) if err != nil { - return 0, err + return 0, fmt.Errorf("failed to read: %w", err) } defer mr.c.readMu.unlock() @@ -384,6 +374,14 @@ func (mr *msgReader) Read(p []byte) (n int, err error) { p = p[:n] mr.dict.write(p) } + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate { + mr.putFlateReader() + return n, io.EOF + } + if err != nil { + err = fmt.Errorf("failed to read: %w", err) + mr.c.close(err) + } return n, err } diff --git a/write.go b/write.go index 2d20b292..baa5e6e2 100644 --- a/write.go +++ b/write.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "io" - "sync" "time" "github.com/klauspost/compress/flate" @@ -71,7 +70,7 @@ type msgWriterState struct { c *Conn mu *mu - writeMu sync.Mutex + writeMu *mu ctx context.Context opcode opcode @@ -83,8 +82,9 @@ type msgWriterState struct { func newMsgWriterState(c *Conn) *msgWriterState { mw := &msgWriterState{ - c: c, - mu: newMu(c), + c: c, + mu: newMu(c), + writeMu: newMu(c), } return mw } @@ -155,10 +155,18 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error { // Write writes the given bytes to the WebSocket connection. func (mw *msgWriterState) Write(p []byte) (_ int, err error) { - defer errd.Wrap(&err, "failed to write") + err = mw.writeMu.lock(mw.ctx) + if err != nil { + return 0, fmt.Errorf("failed to write: %w", err) + } + defer mw.writeMu.unlock() - mw.writeMu.Lock() - defer mw.writeMu.Unlock() + defer func() { + if err != nil { + err = fmt.Errorf("failed to write: %w", err) + mw.c.close(err) + } + }() if mw.c.flate() { // Only enables flate if the length crosses the @@ -193,8 +201,11 @@ func (mw *msgWriterState) write(p []byte) (int, error) { func (mw *msgWriterState) Close() (err error) { defer errd.Wrap(&err, "failed to close writer") - mw.writeMu.Lock() - defer mw.writeMu.Unlock() + err = mw.writeMu.lock(mw.ctx) + if err != nil { + return err + } + defer mw.writeMu.unlock() _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil) if err != nil { @@ -214,7 +225,7 @@ func (mw *msgWriterState) close() { putBufioWriter(mw.c.bw) } - mw.writeMu.Lock() + mw.writeMu.forceLock() mw.dict.close() } @@ -230,8 +241,8 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error } // frame handles all writes to the connection. -func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (int, error) { - err := c.writeFrameMu.lock(ctx) +func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) { + err = c.writeFrameMu.lock(ctx) if err != nil { return 0, err } @@ -243,6 +254,13 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco case c.writeTimeout <- ctx: } + defer func() { + if err != nil { + err = fmt.Errorf("failed to write frame: %w", err) + c.close(err) + } + }() + c.writeHeader.fin = fin c.writeHeader.opcode = opcode c.writeHeader.payloadLength = int64(len(p))