Skip to content

Commit

Permalink
fix(internal/server): return 404 not found on unknown component queries
Browse files Browse the repository at this point in the history
Signed-off-by: Gyuho Lee <[email protected]>
  • Loading branch information
gyuho committed Oct 21, 2024
1 parent 871533e commit 4c903be
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 4 deletions.
5 changes: 5 additions & 0 deletions client/v1/v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"net/url"

v1 "github.com/leptonai/gpud/api/v1"
"github.com/leptonai/gpud/errdefs"
"github.com/leptonai/gpud/internal/server"
"sigs.k8s.io/yaml"
)
Expand Down Expand Up @@ -213,7 +214,11 @@ func GetStates(ctx context.Context, addr string, opts ...OpOption) (v1.LeptonSta
return nil, fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
if resp.StatusCode == http.StatusNotFound {
return nil, errdefs.ErrNotFound
}
return nil, errors.New("server not ready, response not 200")
}

Expand Down
12 changes: 9 additions & 3 deletions components/components.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,19 @@ func GetComponent(name string) (Component, error) {
defaultSetMu.RLock()
defer defaultSetMu.RUnlock()

if defaultSet == nil {
return getComponent(defaultSet, name)
}

func getComponent(set map[string]Component, name string) (Component, error) {
if set == nil {
return nil, fmt.Errorf("component set not initialized: %w", errdefs.ErrUnavailable)
}
if _, ok := defaultSet[name]; !ok {

v, ok := set[name]
if !ok {
return nil, fmt.Errorf("component %s not found: %w", name, errdefs.ErrNotFound)
}
return defaultSet[name], nil
return v, nil
}

func GetAllComponents() map[string]Component {
Expand Down
17 changes: 17 additions & 0 deletions components/components_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package components

import (
"errors"
"testing"

"github.com/leptonai/gpud/errdefs"
)

func TestGetComponentErrors(t *testing.T) {
if _, err := getComponent(nil, "nvidia"); !errors.Is(err, errdefs.ErrUnavailable) {
t.Errorf("expected ErrUnavailable, got %v", err)
}
if _, err := getComponent(map[string]Component{}, "nvidia"); !errors.Is(err, errdefs.ErrNotFound) {
t.Errorf("expected ErrNotFound, got %v", err)
}
}
6 changes: 6 additions & 0 deletions e2e/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"crypto/tls"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net"
Expand All @@ -18,6 +19,7 @@ import (

v1 "github.com/leptonai/gpud/api/v1"
client_v1 "github.com/leptonai/gpud/client/v1"
"github.com/leptonai/gpud/errdefs"
"github.com/leptonai/gpud/internal/server"
)

Expand Down Expand Up @@ -314,6 +316,10 @@ func TestGpudHealthzInfo(t *testing.T) {
}
}

if _, err := client_v1.GetStates(ctx, "https://"+ep, append(opts, client_v1.WithComponent("unknown!!!"))...); !errors.Is(err, errdefs.ErrNotFound) {
t.Errorf("expected ErrNotFound, got %v", err)
}

states, err := client_v1.GetStates(ctx, "https://"+ep, opts...)
if err != nil {
t.Errorf("failed to get states: %v", err)
Expand Down
3 changes: 2 additions & 1 deletion internal/server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package server

import (
"encoding/json"
"errors"
"fmt"
"net/http"
"sort"
Expand Down Expand Up @@ -72,7 +73,7 @@ func (g *globalHandler) getReqComponents(c *gin.Context) ([]string, error) {
var ret []string
for _, component := range strings.Split(components, ",") {
if _, err := lep_components.GetComponent(component); err != nil {
return nil, fmt.Errorf("failed to get component: %v", err)
return nil, fmt.Errorf("failed to get component: %v (%w)", err, errors.Unwrap(err))
}
ret = append(ret, component)
}
Expand Down
21 changes: 21 additions & 0 deletions internal/server/handlers_components.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"errors"
"net/http"
"sort"
"time"
Expand Down Expand Up @@ -122,6 +123,11 @@ func (g *globalHandler) getStates(c *gin.Context) {
var states v1.LeptonStates
components, err := g.getReqComponents(c)
if err != nil {
if errors.Is(err, errdefs.ErrNotFound) {
c.JSON(http.StatusNotFound, gin.H{"code": errdefs.ErrNotFound, "message": "component not found: " + err.Error()})
return
}

c.JSON(http.StatusBadRequest, gin.H{"code": errdefs.ErrInvalidArgument, "message": "failed to parse components: " + err.Error()})
return
}
Expand Down Expand Up @@ -193,6 +199,11 @@ func (g *globalHandler) getEvents(c *gin.Context) {
var events v1.LeptonEvents
components, err := g.getReqComponents(c)
if err != nil {
if errors.Is(err, errdefs.ErrNotFound) {
c.JSON(http.StatusNotFound, gin.H{"code": errdefs.ErrNotFound, "message": "component not found: " + err.Error()})
return
}

c.JSON(http.StatusBadRequest, gin.H{"code": errdefs.ErrInvalidArgument, "message": "failed to parse components: " + err.Error()})
return
}
Expand Down Expand Up @@ -270,6 +281,11 @@ func (g *globalHandler) getInfo(c *gin.Context) {
var infos v1.LeptonInfo
components, err := g.getReqComponents(c)
if err != nil {
if errors.Is(err, errdefs.ErrNotFound) {
c.JSON(http.StatusNotFound, gin.H{"code": errdefs.ErrNotFound, "message": "component not found: " + err.Error()})
return
}

c.JSON(http.StatusBadRequest, gin.H{"code": errdefs.ErrInvalidArgument, "message": "failed to parse components: " + err.Error()})
return
}
Expand Down Expand Up @@ -377,6 +393,11 @@ const (
func (g *globalHandler) getMetrics(c *gin.Context) {
components, err := g.getReqComponents(c)
if err != nil {
if errors.Is(err, errdefs.ErrNotFound) {
c.JSON(http.StatusNotFound, gin.H{"code": errdefs.ErrNotFound, "message": "component not found: " + err.Error()})
return
}

c.JSON(http.StatusBadRequest, gin.H{"code": errdefs.ErrInvalidArgument, "message": "failed to parse components: " + err.Error()})
return
}
Expand Down

0 comments on commit 4c903be

Please sign in to comment.