Skip to content

Commit

Permalink
Simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
AlekSi committed Jul 17, 2024
1 parent f391018 commit 37294a2
Showing 1 changed file with 24 additions and 28 deletions.
52 changes: 24 additions & 28 deletions wireclient/wiredriver.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ import (
"github.com/FerretDB/wire/internal/util/lazyerrors"
)

// lastRequestID stores last generated request ID.
var lastRequestID atomic.Int32
// nextRequestID stores the last generated request ID.
var nextRequestID atomic.Int32

// Conn represents a single client connection.
//
Expand All @@ -38,12 +38,12 @@ type Conn struct {
c net.Conn
r *bufio.Reader
w *bufio.Writer
l *slog.Logger // debug level only
l *slog.Logger // debug-level only
}

// New wraps the given connection.
//
// The passed logger will be used only for debug level messages.
// The passed logger will be used only for debug-level messages.
func New(c net.Conn, l *slog.Logger) *Conn {
return &Conn{
c: c,
Expand All @@ -57,6 +57,8 @@ func New(c net.Conn, l *slog.Logger) *Conn {
//
// Context can be used to cancel the connection attempt.
// Canceling the context after the connection is established has no effect.
//
// The passed logger will be used only for debug-level messages.
func Connect(ctx context.Context, uri string, l *slog.Logger) (*Conn, error) {
u, err := url.Parse(uri)
if err != nil {
Expand Down Expand Up @@ -181,41 +183,41 @@ func (c *Conn) WriteRaw(ctx context.Context, b []byte) error {
}

// Request sends the given request to the connection and returns the response.
// If header MessageLength or RequestID is not specified, it assigns the proper values.
// For header.OpCode the wire.OpCodeMsg is used as default.
// If header's MessageLength or RequestID are not specified, it assigns the proper values.
// For header's OpCode the [wire.OpCodeMsg] is used as default.
//
// Passed context's deadline is honored if set.
//
// It returns errors only for request/response parsing issues, or connection issues.
// All of the driver level errors are stored inside response.
// It returns errors only for request/response parsing or connection issues.
// All protocol-level errors are stored inside response.
func (c *Conn) Request(ctx context.Context, header *wire.MsgHeader, body wire.MsgBody) (*wire.MsgHeader, wire.MsgBody, error) {
if header == nil {
header = new(wire.MsgHeader)
}

if header.MessageLength == 0 {
msgBin, err := body.MarshalBinary()
b, err := body.MarshalBinary()
if err != nil {
return nil, nil, lazyerrors.Error(err)
}

header.MessageLength = int32(len(msgBin) + wire.MsgHeaderLen)
}

if header.OpCode == 0 {
header.OpCode = wire.OpCodeMsg
header.MessageLength = int32(len(b) + wire.MsgHeaderLen)
}

if header.RequestID == 0 {
header.RequestID = lastRequestID.Add(1)
header.RequestID = nextRequestID.Add(1)
}

if header.ResponseTo != 0 {
return nil, nil, lazyerrors.Errorf("setting response_to is not allowed")
}

if m, ok := body.(*wire.OpMsg); ok {
if m.Flags != 0 {
return nil, nil, lazyerrors.Errorf("unsupported request flags %s", m.Flags)
}
if header.OpCode == 0 {
header.OpCode = wire.OpCodeMsg
}

if body == nil {
return nil, nil, lazyerrors.Errorf("body can't be nil")
}

if err := c.Write(ctx, header, body); err != nil {
Expand All @@ -228,18 +230,12 @@ func (c *Conn) Request(ctx context.Context, header *wire.MsgHeader, body wire.Ms
}

if resHeader.ResponseTo != header.RequestID {
return nil, nil, lazyerrors.Errorf(
"response_to is not equal to request_id (response_to=%d; expected=%d)",
err = fmt.Errorf(
"response's response_to=%d is not equal to request's request_id=%d",
resHeader.ResponseTo,
header.RequestID,
)
}

if m, ok := resBody.(*wire.OpMsg); ok {
if m.Flags != 0 {
return nil, nil, lazyerrors.Errorf("unsupported response flags %s", m.Flags)
}
}

return resHeader, resBody, nil
return resHeader, resBody, err
}

0 comments on commit 37294a2

Please sign in to comment.