diff --git a/handlers/allowedactions.go b/handlers/allowedactions.go new file mode 100644 index 0000000..f837757 --- /dev/null +++ b/handlers/allowedactions.go @@ -0,0 +1,33 @@ +package handlers + +import ( + "net/http" + + "strings" + + "github.com/go-kit/kit/log" +) + +func NewAllowedActions(l log.Logger, allowedActions []string) func(h http.Handler) http.Handler { + var actions = make(map[string]bool, len(allowedActions)) + for _, p := range allowedActions { + if p == "" { + continue + } + + actions[p] = true + } + + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + action := r.URL.Path[1:] + if _, exists := actions[action]; !exists { + l.Log("error", "action is not white-listed", "action", action, "allowed", strings.Join(allowedActions, ",")) + http.Error(w, "Unregisterd action", http.StatusBadRequest) + return + } + + h.ServeHTTP(w, r) + }) + } +} diff --git a/main.go b/main.go index f4e17a9..f600c1e 100644 --- a/main.go +++ b/main.go @@ -17,12 +17,13 @@ import ( ) var ( - allowedHosts argumentList - allowedImaginaryParams argumentList - imaginaryURL string - listenPort int64 - bucketRate float64 - bucketSize int64 + allowedHosts argumentList + allowedImaginaryParams argumentList + allowedImaginaryActions argumentList + imaginaryURL string + listenPort int64 + bucketRate float64 + bucketSize int64 Version = "dev" logger = log.With( @@ -35,12 +36,12 @@ var ( func init() { flag.Var(&allowedHosts, "allow-hosts", "Repeatable flag (or a comma-separated list) for hosts to allow for the URL parameter (e.g. \"d2dktr6aauwgqs.cloudfront.net\")") flag.Var(&allowedImaginaryParams, "allowed-params", "A comma seperated list of parameters allows to be sent upstream. If empty, everything is allowed.") + flag.Var(&allowedImaginaryActions, "allowed-actions", "A comma seperated list of actions allows to be sent upstream. If empty, everything is allowed.") flag.StringVar(&imaginaryURL, "imaginary-url", "http://localhost:9000", "URL to imaginary (default: http://localhost:9000)") flag.Int64Var(&listenPort, "listen-port", 8080, "Port to listen on") flag.Float64Var(&bucketRate, "bucket-rate", 20, "Rate limiter bucket fill rate (req/s)") flag.Int64Var(&bucketSize, "bucket-size", 500, "Rate limiter bucket size (burst capacity)") - } func main() { @@ -51,6 +52,7 @@ func main() { "version", Version, "allowed_hosts", allowedHosts.String(), "allowed_params", allowedImaginaryParams.String(), + "allowed_actions", allowedImaginaryActions.String(), "imaginary_backend", imaginaryURL, ) @@ -100,6 +102,15 @@ func decorateHandler(h http.Handler, b *ratelimit.Bucket) http.Handler { )) } + if len(allowedImaginaryActions) > 0 { + decorators = append( + decorators, + handlers.NewAllowedActions( + logger, + allowedImaginaryActions, + )) + } + // Defining early needed handlers last decorators = append( decorators,