From 37294a2de465b100fe3e8c40eb4fc46316ed5bd5 Mon Sep 17 00:00:00 2001 From: Alexey Palazhchenko Date: Wed, 17 Jul 2024 12:14:14 +0400 Subject: [PATCH] Simplify --- wireclient/wiredriver.go | 52 +++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 28 deletions(-) diff --git a/wireclient/wiredriver.go b/wireclient/wiredriver.go index 051c74d..d346fd9 100644 --- a/wireclient/wiredriver.go +++ b/wireclient/wiredriver.go @@ -28,8 +28,8 @@ import ( "github.com/FerretDB/wire/internal/util/lazyerrors" ) -// lastRequestID stores last generated request ID. -var lastRequestID atomic.Int32 +// nextRequestID stores the last generated request ID. +var nextRequestID atomic.Int32 // Conn represents a single client connection. // @@ -38,12 +38,12 @@ type Conn struct { c net.Conn r *bufio.Reader w *bufio.Writer - l *slog.Logger // debug level only + l *slog.Logger // debug-level only } // New wraps the given connection. // -// The passed logger will be used only for debug level messages. +// The passed logger will be used only for debug-level messages. func New(c net.Conn, l *slog.Logger) *Conn { return &Conn{ c: c, @@ -57,6 +57,8 @@ func New(c net.Conn, l *slog.Logger) *Conn { // // Context can be used to cancel the connection attempt. // Canceling the context after the connection is established has no effect. +// +// The passed logger will be used only for debug-level messages. func Connect(ctx context.Context, uri string, l *slog.Logger) (*Conn, error) { u, err := url.Parse(uri) if err != nil { @@ -181,41 +183,41 @@ func (c *Conn) WriteRaw(ctx context.Context, b []byte) error { } // Request sends the given request to the connection and returns the response. -// If header MessageLength or RequestID is not specified, it assigns the proper values. -// For header.OpCode the wire.OpCodeMsg is used as default. +// If header's MessageLength or RequestID are not specified, it assigns the proper values. +// For header's OpCode the [wire.OpCodeMsg] is used as default. +// +// Passed context's deadline is honored if set. // -// It returns errors only for request/response parsing issues, or connection issues. -// All of the driver level errors are stored inside response. +// It returns errors only for request/response parsing or connection issues. +// All protocol-level errors are stored inside response. func (c *Conn) Request(ctx context.Context, header *wire.MsgHeader, body wire.MsgBody) (*wire.MsgHeader, wire.MsgBody, error) { if header == nil { header = new(wire.MsgHeader) } if header.MessageLength == 0 { - msgBin, err := body.MarshalBinary() + b, err := body.MarshalBinary() if err != nil { return nil, nil, lazyerrors.Error(err) } - header.MessageLength = int32(len(msgBin) + wire.MsgHeaderLen) - } - - if header.OpCode == 0 { - header.OpCode = wire.OpCodeMsg + header.MessageLength = int32(len(b) + wire.MsgHeaderLen) } if header.RequestID == 0 { - header.RequestID = lastRequestID.Add(1) + header.RequestID = nextRequestID.Add(1) } if header.ResponseTo != 0 { return nil, nil, lazyerrors.Errorf("setting response_to is not allowed") } - if m, ok := body.(*wire.OpMsg); ok { - if m.Flags != 0 { - return nil, nil, lazyerrors.Errorf("unsupported request flags %s", m.Flags) - } + if header.OpCode == 0 { + header.OpCode = wire.OpCodeMsg + } + + if body == nil { + return nil, nil, lazyerrors.Errorf("body can't be nil") } if err := c.Write(ctx, header, body); err != nil { @@ -228,18 +230,12 @@ func (c *Conn) Request(ctx context.Context, header *wire.MsgHeader, body wire.Ms } if resHeader.ResponseTo != header.RequestID { - return nil, nil, lazyerrors.Errorf( - "response_to is not equal to request_id (response_to=%d; expected=%d)", + err = fmt.Errorf( + "response's response_to=%d is not equal to request's request_id=%d", resHeader.ResponseTo, header.RequestID, ) } - if m, ok := resBody.(*wire.OpMsg); ok { - if m.Flags != 0 { - return nil, nil, lazyerrors.Errorf("unsupported response flags %s", m.Flags) - } - } - - return resHeader, resBody, nil + return resHeader, resBody, err }