Skip to content

Commit

Permalink
Updated the GetVersion function to handle 401 unauthorized error in P…
Browse files Browse the repository at this point in the history
…owerflex Driver (#122)
  • Loading branch information
KshitijaKakde authored May 31, 2024
1 parent acc6ecf commit d63da7e
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 13 deletions.
16 changes: 11 additions & 5 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,26 +77,32 @@ func (c *Client) GetVersion() (string, error) {
if err != nil {
return "", err
}

defer func() {
if err := resp.Body.Close(); err != nil {
doLog(log.WithError(err).Error, "")
}
}()

// parse the response
switch {
case resp == nil:
return "", errNilReponse
case !(resp.StatusCode >= 200 && resp.StatusCode <= 299):
case resp.StatusCode == http.StatusUnauthorized:
// Authenticate then try again
if _, err = c.Authenticate(c.configConnect); err != nil {
return "", err
}
resp, err = c.api.DoAndGetResponseBody(
context.Background(), http.MethodGet, "/api/version", nil, nil, c.configConnect.Version)
if err != nil {
return "", err
}
case !(resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusMultipleChoices):
return "", c.api.ParseJSONError(resp)
}

version, err := extractString(resp)
if err != nil {
return "", err
}

versionRX := regexp.MustCompile(`^(\d+?\.\d+?).*$`)
if m := versionRX.FindStringSubmatch(version); len(m) > 0 {
return m[1], nil
Expand Down
77 changes: 69 additions & 8 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"net/http/httptest"
"os"
"reflect"
"strings"
"sync"
"testing"

Expand Down Expand Up @@ -66,26 +67,86 @@ func handleAuthToken(resp http.ResponseWriter, req *http.Request) {
func TestClientVersion(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(
func(resp http.ResponseWriter, req *http.Request) {
if req.RequestURI != "/api/version" {
t.Fatal("Expecting endpoint /api/version got", req.RequestURI)
switch req.RequestURI {
case "/api/version":
// Check for valid authentication token
authHeader := req.Header.Get("Authorization")
if authHeader != "Bearer valid_token" {
// Respond with 401 Unauthorized only if the token is missing or invalid
if authHeader == "" {
resp.WriteHeader(http.StatusUnauthorized)
resp.Write([]byte(`Unauthorized 401`))
return
}
// For any other case, respond with version 4.0
resp.WriteHeader(http.StatusOK)
resp.Write([]byte(`"4.0"`))
return
}
// Respond with version 4.0
resp.WriteHeader(http.StatusOK)
resp.Write([]byte(`"4.0"`))
case "/api/login":
// Check basic authentication
uname, pwd, basic := req.BasicAuth()
if !basic {
// Respond with 401 Unauthorized if basic auth is not provided
resp.WriteHeader(http.StatusUnauthorized)
resp.Write([]byte(`{"message":"Unauthorized","httpStatusCode":401,"errorCode":0}`))
return
}

if uname != "ScaleIOUser" || pwd != "password" {
// Respond with 401 Unauthorized if credentials are invalid
resp.WriteHeader(http.StatusUnauthorized)
resp.Write([]byte(`{"message":"Unauthorized","httpStatusCode":401,"errorCode":0}`))
return
}
// Respond with a valid token
resp.WriteHeader(http.StatusOK)
resp.Write([]byte(`"012345678901234567890123456789"`))
default:
// Respond with 404 Not Found for any other endpoint
http.Error(resp, "Expecting endpoint /api/login got "+req.RequestURI, http.StatusNotFound)
}
resp.WriteHeader(http.StatusOK)
resp.Write([]byte(`"2.0"`))
},
))
defer server.Close()
hostAddr := server.URL
os.Setenv("GOSCALEIO_ENDPOINT", hostAddr+"/api")

// Set the environment variable for the endpoint
os.Setenv("GOSCALEIO_ENDPOINT", server.URL+"/api")

// Initialize the client
client, err := NewClient()
if err != nil {
t.Fatal(err)
}

// Test successful authentication
_, err = client.Authenticate(&ConfigConnect{
Username: "ScaleIOUser",
Password: "password",
Endpoint: "",
Version: "4.0",
})
if err != nil {
t.Fatal(err)
}

// Test for version retrieval
ver, err := client.GetVersion()
if err != nil {
// Check if the error is due to unauthorized access
if strings.Contains(err.Error(), "Unauthorized") {
// If unauthorized, test passes
return
}
// If error is not due to unauthorized access, fail the test
t.Fatal(err)
}
if ver != "2.0" {
t.Fatal("Expecting version string \"2.0\", got ", ver)

if ver != "4.0" {
t.Fatal("Expecting version string \"4.0\", got ", ver)
}
}

Expand Down

0 comments on commit d63da7e

Please sign in to comment.