Skip to content

Commit

Permalink
Improve saslContinue logic
Browse files Browse the repository at this point in the history
  • Loading branch information
AlekSi committed Oct 8, 2024
1 parent 9eeb1eb commit 988a730
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 33 deletions.
16 changes: 16 additions & 0 deletions op_msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,22 @@ func (msg *OpMsg) RawDocument() (wirebson.RawDocument, error) {
return s.documents[0], nil
}

// DecodeDeepDocument returns the value of msg as deeply-decoded [*wirebson.Document].
//
// The error is returned if msg contains anything other than a single section of kind 0
// with a single document.
//
// Most callers do not need deeply-decoded document and should use more effective combination of
// [OpMsg.RawDocument] and [wirebson.RawDocument.Decode] instead.
func (msg *OpMsg) DecodeDeepDocument() (*wirebson.Document, error) {
raw, err := msg.RawDocument()
if err != nil {
return nil, err
}

return raw.DecodeDeep()
}

func (msg *OpMsg) msgbody() {}

// check implements [MsgBody].
Expand Down
100 changes: 73 additions & 27 deletions wireclient/wireclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ import (
// nextRequestID stores the last generated request ID.
var nextRequestID atomic.Int32

// skipEmptyExchange is the value of `saslStart`'s options used by [*Conn.Login].
const skipEmptyExchange = false

// Conn represents a single client connection.
//
// It is not safe for concurrent use.
Expand Down Expand Up @@ -269,12 +272,7 @@ func (c *Conn) Ping(ctx context.Context) error {
return fmt.Errorf("wireclient.Conn.Ping: %w", err)
}

resRaw, err := resBody.(*wire.OpMsg).RawDocument()
if err != nil {
return fmt.Errorf("wireclient.Conn.Ping: %w", err)
}

res, err := resRaw.Decode()
res, err := resBody.(*wire.OpMsg).DecodeDeepDocument()
if err != nil {
return fmt.Errorf("wireclient.Conn.Ping: %w", err)
}
Expand All @@ -286,7 +284,8 @@ func (c *Conn) Ping(ctx context.Context) error {
return nil
}

// Login authenticates the connection with the given credentials.
// Login authenticates the connection with the given credentials
// using some unspecified sequences of commands.
//
// It should not be used to test various authentication scenarios.
func (c *Conn) Login(ctx context.Context, username, password, authDB string) error {
Expand All @@ -309,10 +308,23 @@ func (c *Conn) Login(ctx context.Context, username, password, authDB string) err
"$db", authDB,
)

for step := range 3 {
// one `saslStart`, two `saslContinue` requests
steps := 3

if skipEmptyExchange {
if err = cmd.Add("options", wirebson.MustDocument("skipEmptyExchange", true)); err != nil {
return fmt.Errorf("wireclient.Conn.Login: %w", err)
}

// only one `saslContinue`
steps = 2
}

for step := 1; step <= steps; step++ {
c.l.DebugContext(
ctx, "Login",
slog.Int("step", step), slog.Bool("done", conv.Done()), slog.Bool("valid", conv.Valid()),
ctx, "Login: client",
slog.Int("step", step), slog.String("payload", payload),
slog.Bool("done", conv.Done()), slog.Bool("valid", conv.Valid()),
)

var body *wire.OpMsg
Expand All @@ -325,36 +337,51 @@ func (c *Conn) Login(ctx context.Context, username, password, authDB string) err
return fmt.Errorf("wireclient.Conn.Login: %w", err)
}

var resRaw wirebson.RawDocument
if resRaw, err = resBody.(*wire.OpMsg).RawDocument(); err != nil {
return fmt.Errorf("wireclient.Conn.Login: %w", err)
}

var res *wirebson.Document
if res, err = resRaw.Decode(); err != nil {
if res, err = resBody.(*wire.OpMsg).DecodeDeepDocument(); err != nil {
return fmt.Errorf("wireclient.Conn.Login: %w", err)
}

if ok := res.Get("ok"); ok != 1.0 {
return fmt.Errorf("wireclient.Conn.Login: %s failed (ok was %v)", cmd.Command(), ok)
}

// when `saslContinue` is called twice the SCRAM client conversation is
// completed before the second `saslContinue` request,
// ensure to move only the incomplete conversation forward
if !conv.Done() {
payload, err = conv.Step(string(res.Get("payload").(wirebson.Binary).B))
if err != nil {
return fmt.Errorf("wireclient.Conn.Login: %w", err)
}
}
payload = string(res.Get("payload").(wirebson.Binary).B)

c.l.DebugContext(
ctx, "Login: server",
slog.Int("step", step), slog.String("payload", payload),
)

if res.Get("done").(bool) {
if skipEmptyExchange {
if step != 2 {
return fmt.Errorf("wireclient.Conn.Login: expected server conversation to be done at step 2")
}

if _, err = conv.Step(payload); err != nil {
return fmt.Errorf("wireclient.Conn.Login: %w", err)
}
} else {
if step != 3 {
return fmt.Errorf("wireclient.Conn.Login: expected server conversation to be done at step 3")
}
}

if !conv.Done() {
return fmt.Errorf("wireclient.Conn.Login: conversation is not done")
}

if !conv.Valid() {
return fmt.Errorf("wireclient.Conn.Login: conversation is not valid")
return fmt.Errorf("wireclient.Conn.Login: conversation is done, but not valid")
}

return nil
return c.checkAuth(ctx)
}

payload, err = conv.Step(payload)
if err != nil {
return fmt.Errorf("wireclient.Conn.Login: %w", err)
}

cmd = wirebson.MustDocument(
Expand All @@ -368,6 +395,25 @@ func (c *Conn) Login(ctx context.Context, username, password, authDB string) err
return fmt.Errorf("wireclient.Conn.Login: too many steps")
}

// checkAuth checks if the connection is authenticated.
func (c *Conn) checkAuth(ctx context.Context) error {
_, resBody, err := c.Request(ctx, wire.MustOpMsg("listDatabases", int32(1), "$db", "admin"))
if err != nil {
return fmt.Errorf("wireclient.Conn.checkAuth: %w", err)
}

res, err := resBody.(*wire.OpMsg).DecodeDeepDocument()
if err != nil {
return fmt.Errorf("wireclient.Conn.Ping: %w", err)
}

if ok := res.Get("ok"); ok != 1.0 {
return fmt.Errorf("wireclient.Conn.checkAuth: failed (ok was %v)", ok)
}

return nil
}

// sleep waits until the given duration is over or the context is canceled.
func sleep(ctx context.Context, d time.Duration) {
ctx, cancel := context.WithTimeout(ctx, d)
Expand Down
26 changes: 20 additions & 6 deletions wireclient/wireclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,22 +68,36 @@ func TestConn(t *testing.T) {
t.Cleanup(cancel)

t.Run("Login", func(t *testing.T) {
conn := ConnectPing(ctx, uri, logger(t))
require.NotNil(t, conn)
t.Run("InvalidUsername", func(t *testing.T) {
conn := ConnectPing(ctx, uri, logger(t))
require.NotNil(t, conn)

t.Cleanup(func() {
require.NoError(t, conn.Close())
})
t.Cleanup(func() {
require.NoError(t, conn.Close())
})

t.Run("InvalidUsername", func(t *testing.T) {
assert.Error(t, conn.Login(ctx, "invalid", "invalid", "admin"))
})

t.Run("InvalidDatabase", func(t *testing.T) {
conn := ConnectPing(ctx, uri, logger(t))
require.NotNil(t, conn)

t.Cleanup(func() {
require.NoError(t, conn.Close())
})

assert.Error(t, conn.Login(ctx, "username", "password", "invalid"))
})

t.Run("Valid", func(t *testing.T) {
conn := ConnectPing(ctx, uri, logger(t))
require.NotNil(t, conn)

t.Cleanup(func() {
require.NoError(t, conn.Close())
})

assert.NoError(t, conn.Login(ctx, "username", "password", "admin"))
})
})
Expand Down

0 comments on commit 988a730

Please sign in to comment.