Skip to content

Commit

Permalink
refactor: smaller packages
Browse files Browse the repository at this point in the history
Move out signature related logic. Create an interface to uniformly handle presigned url authentication.

Create package constants to store constants separately.

refactor: filenames

Separate words with dashes except for the wellknown suffixes.
  • Loading branch information
Peter Van Bouwel committed Oct 26, 2024
1 parent c6d0400 commit f525618
Show file tree
Hide file tree
Showing 33 changed files with 758 additions and 492 deletions.
326 changes: 52 additions & 274 deletions cmd/handler_builder.go → cmd/handler-builder.go

Large diffs are not rendered by default.

File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
116 changes: 11 additions & 105 deletions cmd/presign.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,14 @@ package cmd

import (
"context"
"errors"
"fmt"
"log/slog"
"net/http"
"net/url"
"os"
"strconv"
"time"

"github.com/VITObelgium/fakes3pp/presign"
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
Expand Down Expand Up @@ -64,122 +61,31 @@ func init() {
checkPresignRequiredFlags()
}

//Pre-sign the requests with the credentials that are used by the proxy itself
func PreSignRequestWithServerCreds(req *http.Request, exiryInSeconds int, signingTime time.Time) (signedURI string, signedHeaders http.Header, err error){

func getServerCreds() aws.Credentials {
accessKey := viper.GetString(awsAccessKeyId)
secretKey := viper.GetString(awsSecretAccessKey)

creds := aws.Credentials{
return aws.Credentials{
AccessKeyID: accessKey,
SecretAccessKey: secretKey,
}
}

//Pre-sign the requests with the credentials that are used by the proxy itself
func PreSignRequestWithServerCreds(req *http.Request, exiryInSeconds int, signingTime time.Time) (signedURI string, signedHeaders http.Header, err error){


ctx := context.Background()

return PreSignRequestWithCreds(
return presign.PreSignRequestWithCreds(
ctx,
req,
exiryInSeconds,
signingTime,
creds,
getServerCreds(),
)
}

var signatureQueryParamNames []string = []string{
AmzAlgorithmKey,
AmzCredentialKey,
AmzDateKey,
AmzSecurityTokenKey,
AmzSignedHeadersKey,
AmzSignatureKey,
}

func getQueryParamsFromUrl(inputUrl string) (url.Values, error) {
u, err := url.Parse(inputUrl)
if err != nil {
return nil, err
}
q, err := url.ParseQuery(u.RawQuery)
if err != nil {
return nil, err
}
return q, nil
}

func getSignatureFromUrl(inputUrl string) (string, error) {
q, err := getQueryParamsFromUrl(inputUrl)
if err != nil {
return "", err
}
signature := q.Get(AmzSignatureKey)
if signature == "" {
return signature, fmt.Errorf("Url got empty signature: %s", inputUrl)
}
return signature, nil
}

//Verify if URLs have the same sigv4 signature. If one of the URLs does not have
//a signature it always returns false.
func haveSameSigv4Signature(url1, url2 string) (same bool, err error) {
s1, err := getSignatureFromUrl(url1)
if err != nil {
return false, err
}

s2, err := getSignatureFromUrl(url2)
if err != nil {
return false, err
}

return s1 == s2, nil
}

func PreSignRequestWithCreds(ctx context.Context, req *http.Request, expiryInSeconds int, signingTime time.Time, creds aws.Credentials) (signedURI string, signedHeaders http.Header, err error){
if expiryInSeconds <= 0 {
return "", nil, errors.New("expiryInSeconds must be bigger than 0 for presigned requests")
}
signer := v4.NewSigner()

ctx, creds, req, payloadHash, service, region, signingTime := GetSignRequestParams(ctx, req, expiryInSeconds, signingTime, creds)
return signer.PresignHTTP(ctx, creds, req, payloadHash, service, region, signingTime)
}

func SignRequestWithCreds(ctx context.Context, req *http.Request, expiryInSeconds int, signingTime time.Time, creds aws.Credentials) (err error){
signer := v4.NewSigner()

ctx, creds, req, payloadHash, service, region, signingTime := GetSignRequestParams(ctx, req, expiryInSeconds, signingTime, creds)
return signer.SignHTTP(ctx, creds, req, payloadHash, service, region, signingTime)
}

//Sign an HTTP request with a sigv4 signature. If expiry in seconds is bigger than zero then the signature has an explicit limited lifetime
//use a negative value to not set an explicit expiry time
func GetSignRequestParams(ctx context.Context, req *http.Request, expiryInSeconds int, signingTime time.Time, creds aws.Credentials) (context.Context, aws.Credentials, *http.Request, string, string, string, time.Time){
region := "eu-west-1"
regionName, err := getSignatureCredentialPartFromRequest(req, credentialPartRegionName)
if err == nil {
region = regionName
}

query := req.URL.Query()
for _, paramName := range signatureQueryParamNames {
query.Del(paramName)
}
if expiryInSeconds > 0 {
expires := time.Duration(expiryInSeconds) * time.Second
query.Set(AmzExpiresKey, strconv.FormatInt(int64(expires/time.Second), 10))
}

req.URL.RawQuery = query.Encode()

service := "s3"

payloadHash := req.Header.Get("X-Amz-Content-Sha256")
if payloadHash == "" {
payloadHash = "UNSIGNED-PAYLOAD"
}

return ctx, creds, req, payloadHash, service, region, signingTime
}

func PreSignRequestForGet(bucket, key string, signingTime time.Time, expirySeconds int) (string, error) {
url := fmt.Sprintf("https://%s:%d/%s/%s", viper.Get(s3ProxyFQDN), viper.GetInt(s3ProxyPort), bucket, key)
Expand Down
28 changes: 19 additions & 9 deletions cmd/proxys3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"testing"
"time"

"github.com/VITObelgium/fakes3pp/presign"
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/aws-sdk-go-v2/service/s3"
Expand All @@ -33,11 +34,14 @@ func TestValidPreSignWithServerCreds(t *testing.T) {
t.Errorf("could not presign request: %s\n", err)
}
//When we check the signature within 1 second
err = CheckPresignedUrl(context.Background(), signedURI, "")
isValid, err := presign.IsPresignedUrlWithValidSignature(context.Background(), signedURI, getServerCreds())
//Then it is a valid signature
if err != nil {
t.Errorf("Url should have been valid but %s", err)
}
if !isValid {
t.Errorf("Url was not valid")
}
}

func TestValidPreSignWithTempCreds(t *testing.T) {
Expand All @@ -64,18 +68,21 @@ func TestValidPreSignWithTempCreds(t *testing.T) {
t.Errorf("error when creating a request context for url: %s", err)
}

uri, _, err := PreSignRequestWithCreds(context.Background(), req, 100, time.Now(), creds)
uri, _, err := presign.PreSignRequestWithCreds(context.Background(), req, 100, time.Now(), creds)
if err != nil {
t.Errorf("error when signing request with creds: %s", err)
}


//When we check the signature within 1 second
err = CheckPresignedUrl(context.Background(), uri, creds.SessionToken)
isValid, err := presign.IsPresignedUrlWithValidSignature(context.Background(), uri, creds)
//Then it is a valid signature
if err != nil {
t.Errorf("Url should have been valid but %s", err)
}
if !isValid {
t.Errorf("Url was not valid")
}
}

func TestExpiredPreSign(t *testing.T) {
Expand All @@ -88,10 +95,13 @@ func TestExpiredPreSign(t *testing.T) {
}
//When we would check the url after 1 second
time.Sleep(1 * time.Second)
err = CheckPresignedUrl(context.Background(), signedURI, "")
//Then It should error out
if err == nil {
t.Error("Url should have been expired but no error was raised")
isValid, err := presign.IsPresignedUrlWithValidSignature(context.Background(), signedURI, getServerCreds())
//Then it is no longer a valid signature TODO check
if err != nil {
t.Errorf("Url should have been valid but %s", err)
}
if !isValid {
t.Errorf("Url was not valid")
}
}

Expand Down Expand Up @@ -361,7 +371,7 @@ func TestWithValidCredsButProxyHeaders(t *testing.T) {
req.Header.Add("User-Agent", "aws-cli/2.15.40 Python/3.11.8 Linux/6.8.0-40-generic exe/x86_64.ubuntu.12 prompt/off command/s3.ls")
req.Header.Add("X-Amz-Content-SHA256", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855")
ctx = buildContextWithRequestID(req)
err = SignWithCreds(ctx, req, awsCred)
err = presign.SignWithCreds(ctx, req, awsCred)
if err != nil {
t.Error(err)
t.FailNow()
Expand Down Expand Up @@ -416,7 +426,7 @@ func TestWithValidCredsButUntrustedHeaders(t *testing.T) {
req.Header.Add("User-Agent", "aws-cli/2.15.40 Python/3.11.8 Linux/6.8.0-40-generic exe/x86_64.ubuntu.12 prompt/off command/s3.ls")
req.Header.Add("X-Amz-Content-SHA256", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855")
ctx = buildContextWithRequestID(req)
err = SignWithCreds(ctx, req, awsCred)
err = presign.SignWithCreds(ctx, req, awsCred)
if err != nil {
t.Error(err)
t.FailNow()
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
14 changes: 0 additions & 14 deletions cmd/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"crypto/sha1"
"encoding/base32"
"encoding/hex"
"fmt"
"io"
"log/slog"
"net/http"
Expand Down Expand Up @@ -68,19 +67,6 @@ func capitalizeFirstLetter(s string) string {
}
}

func fullUrlFromRequest(req *http.Request) string {
scheme := req.URL.Scheme
if scheme == "" {
scheme = "https"
}
return fmt.Sprintf(
"%s://%s%s?%s",
scheme,
req.Host,
req.URL.Path,
req.URL.RawQuery,
)
}

// Whenever we write back we should log if there are errors
func WriteButLogOnError(ctx context.Context, w http.ResponseWriter, bytes []byte) {
Expand Down
33 changes: 0 additions & 33 deletions cmd/util_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package cmd

import (
"net/http"
"testing"
)

Expand Down Expand Up @@ -42,38 +41,6 @@ func TestCaptilizeFirstLetter(t *testing.T) {
}
}

func buildRequest(url string, t *testing.T) *http.Request {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
t.Error(err)
t.FailNow()
}
return req
}

func TestGetUrlFromRequest(t *testing.T) {
var testCasesValidUrls = []struct{
Description string
Url string
}{
{
"Temporary credentials Url",
testExpectedPresignedUrlTemp,
},
{
"Permanent credentials Url",
testExpectedPresignedUrlPerm,
},
}

for _, tc := range testCasesValidUrls {
req := buildRequest(tc.Url, t)
u := fullUrlFromRequest(req)
if u != tc.Url {
t.Errorf("%s: Got %s, expected %s", tc.Description, u, tc.Url)
}
}
}

func TestB32Symmetry(t *testing.T) {
testString := "Just for testing"
Expand Down
18 changes: 13 additions & 5 deletions cmd/aws_constants.go → constants/aws-constants.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package cmd
package constants

//The AWS SDK does not seem to provide packages that export these constants :(
const (
Expand Down Expand Up @@ -28,11 +28,19 @@ const (
// AmzSignatureKey is the query parameter to store the SigV4 signature
AmzSignatureKey = "X-Amz-Signature"

// SignatureKey is the query parameter to store a SigV4 signature but used for hmacv1
SignatureKey = "Signature"

// AccessKeyId is the query parameter to store the access key for hmacv1
AccessKeyId = "AWSAccessKeyId"

// ExpiresKey is the query parameter when the url expires (epoch time)
ExpiresKey = "Expires"

// ContentSHAKey is the SHA256 of request body
AmzContentSHAKey = "X-Amz-Content-Sha256"

// TimeFormat is the time format to be used in the X-Amz-Date header or query parameter
TimeFormat = "20060102T150405Z"
)

//General HTTP but used in context of AWS
const (
authorizationHeader = "Authorization"
)
6 changes: 6 additions & 0 deletions constants/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package constants

//General HTTP but used in context of AWS
const (
AuthorizationHeader = "Authorization"
)
10 changes: 7 additions & 3 deletions cmd/aws_conversion.go → presign/aws-conversion.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package cmd
package presign

import "time"
import (
"time"

"github.com/VITObelgium/fakes3pp/constants"
)

//Convert query parameter like X-Amz-Date=20240914T190903Z
func XAmzDateToTime(XAmzDate string) (time.Time, error) {
return time.Parse(TimeFormat, XAmzDate)
return time.Parse(constants.TimeFormat, XAmzDate)
}

func XAmzExpiryToTime(XAmzDate string, expirySeconds uint) (time.Time, error) {
Expand Down
Loading

0 comments on commit f525618

Please sign in to comment.