Skip to content

Commit

Permalink
Add Blocker package (#9)
Browse files Browse the repository at this point in the history
* add blocker package

* fix test descriptions

* update changelog

* update readme

* make check for method and scheme case insensitive

* rename test comments

* remove extra space

* blocker instead of block

* changed words

* wording
  • Loading branch information
psampaz authored Feb 25, 2020
1 parent d8325c2 commit 65081cc
Show file tree
Hide file tree
Showing 8 changed files with 420 additions and 37 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)

## [v1.1.0] 2020-02-25
### Added
- Blocker package with built-in functionality to block by scheme, method and query params

## [v1.0.0] 2020-02-23
### Added
- Shield middleware
69 changes: 32 additions & 37 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,52 +115,47 @@ type Options struct {
Body []byte
}
```
## List of predefined block methods

## Examples of block functions
### Block based on a list of query param regexes

### Block requests based on a HTTP header
```go
func(r *http.Request) bool {
if r.Header.Get("X-Custom") != "" {
return false
}

return true
}
queryBlock := blocker.NewQuery(map[string]string{
"op": "search",
"page": `\d+`,
"limit": `\d+`,
})
shieldMiddleware := shield.New(shield.Options{
Block: queryBlock.Block,
Code: http.StatusBadRequest,
Headers: http.Header{"Content-Type": {"text/plain"}},
Body: []byte(http.StatusText(http.StatusBadRequest)),
})
```
### Block requests based on HTTP method
```go
func(r *http.Request) bool {
if r.Method == "GET" {
return false
}

return true
}
```
### Block requests based on HTTP scheme
```go
func(r *http.Request) bool {
if r.URL.Sheme == "https" {
return false
}
### Block based on a list of HTTP Method

return true
}
```
### Block requests based on query parameters
matched, err := regexp.MatchString(v, r.URL.Query().Get(k))
```go
func(r *http.Request) bool {
// allow only request that have a query param named page which is a number
matched, _ := regexp.MatchString(`\d+`, r.URL.Query().Get("page"))
if matched {
return false
}
return true
}
methodBlock := blocker.NewMethod([]string{http.MethodGet, http.MethodPost})
shieldMiddleware := shield.New(shield.Options{
Block: methodBlock.Block,
Code: http.StatusBadRequest,
Headers: http.Header{"Content-Type": {"text/plain"}},
Body: []byte(http.StatusText(http.StatusBadRequest)),
})
```

### Block based on a list of HTTP Scheme

```go
schemeBlock := blocker.NewScheme([]string{"https"})
shieldMiddleware := shield.New(shield.Options{
Block: schemeBlock.Block,
Code: http.StatusBadRequest,
Headers: http.Header{"Content-Type": {"text/plain"}},
Body: []byte(http.StatusText(http.StatusBadRequest)),
})
```
# Integration with popular routers

## Gorilla Mux
Expand Down
30 changes: 30 additions & 0 deletions blocker/method.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Package blocker provide a list of predefined predicate methods
package blocker

import (
"net/http"
"strings"
)

// Method provides functionality to block a request based on a list of allowed HTTP methods.
// Empty allowedMethods will block
type Method struct {
allowedMethods []string
}

// NewMethod is a constructor function for Method struct
func NewMethod(allowedMethods []string) *Method {
return &Method{allowedMethods}
}

// Block is a predicate for method based request blocking.
func (m *Method) Block(r *http.Request) bool {
block := true
for _, v := range m.allowedMethods {
if r.Method == strings.ToUpper(v) {
block = false
break
}
}
return block
}
77 changes: 77 additions & 0 deletions blocker/method_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package blocker

import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
)

func TestMethod(t *testing.T) {
t.Parallel()
tests := []struct {
name string
allowMethods []string
requestMethod string
shouldBlock bool
}{
{
name: "no methods to match - nil",
allowMethods: nil,
requestMethod: http.MethodGet,
shouldBlock: true,
},
{
name: "no methods to match - empty",
allowMethods: []string{},
requestMethod: http.MethodGet,
shouldBlock: true,
},
{
name: "single method match",
allowMethods: []string{http.MethodGet},
requestMethod: http.MethodGet,
shouldBlock: false,
},
{
name: "single method mismatch",
allowMethods: []string{http.MethodPost},
requestMethod: http.MethodGet,
shouldBlock: true,
},
{
name: "multiple method match",
allowMethods: []string{http.MethodGet, http.MethodPost},
requestMethod: http.MethodGet,
shouldBlock: false,
},
{
name: "multiple method mismatch",
allowMethods: []string{http.MethodGet, http.MethodPost},
requestMethod: http.MethodPut,
shouldBlock: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
blocker := NewMethod(tt.allowMethods)

req := httptest.NewRequest(tt.requestMethod, "http://localhost/", nil)
blocked := blocker.Block(req)

if blocked != tt.shouldBlock {
t.Errorf("Got status code %v, wanted %v", blocked, tt.shouldBlock)
}
})
}
}

func ExampleMethod_Block() {
m := NewMethod([]string{http.MethodGet, http.MethodPost})
r := httptest.NewRequest(http.MethodPut, "http://localhost:8080/", nil)
blocked := m.Block(r)
fmt.Printf("%t", blocked)
// Output: true
}
34 changes: 34 additions & 0 deletions blocker/query.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package blocker

import (
"net/http"
"regexp"
)

// Query provides functionality to block a request based on a query parameters.
// Empty queryParams will block
type Query struct {
queryParams map[string]string
}

// NewQuery is a constructor function for Query struct
func NewQuery(queryParams map[string]string) *Query {
return &Query{queryParams}
}

// Block is a predicate for query based request blocking.
func (m *Query) Block(r *http.Request) bool {
if len(m.queryParams) == 0 {
return true
}

block := false
for k, v := range m.queryParams {
matched, err := regexp.MatchString(v, r.URL.Query().Get(k))
if err != nil || !matched {
block = true
break
}
}
return block
}
119 changes: 119 additions & 0 deletions blocker/query_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package blocker

import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
)

func TestQuery(t *testing.T) {
t.Parallel()
tests := []struct {
name string
qparams map[string]string
url string
shouldBlock bool
}{
{
name: "no query params - nil",
qparams: nil,
url: "http://localhost/foo?q=1",
shouldBlock: true,
},
{
name: "no query params - empty",
qparams: map[string]string{},
url: "http://localhost/foo?q=1",
shouldBlock: true,
},
{
name: "single query param match - regex",
qparams: map[string]string{"q": `\d`},
url: "http://localhost/foo?q=1",
shouldBlock: false,
},
{
name: "simple query param match - value",
qparams: map[string]string{"q": "value"},
url: "http://localhost/foo?q=value",
shouldBlock: false,
},
{
name: "multiple query params match - values",
qparams: map[string]string{"q1": "value1", "q2": "value2"},
url: "http://localhost/foo?q1=value1&q2=value2",
shouldBlock: false,
},
{
name: "single query param not match",
qparams: map[string]string{"q": `\d`},
url: "http://localhost/foo?q=a",
shouldBlock: true,
},
{
name: "single query param regex - no actual query param",
qparams: map[string]string{"q": `\d`},
url: "http://localhost/foo",
shouldBlock: true,
},
{
name: "optional query param regex - query param does not exist",
qparams: map[string]string{"q": `^$|\d`},
url: "http://localhost/foo",
shouldBlock: false,
},
{
name: "optional query param regex - query param exists but empty",
qparams: map[string]string{"q": `^$|\d`},
url: "http://localhost/foo?q=",
shouldBlock: false,
},
{
name: "optional query param regex - query param exists",
qparams: map[string]string{"q": `^$|\d`},
url: "http://localhost/foo?q=1",
shouldBlock: false,
},
{
name: "multiple params regex - all exist",
qparams: map[string]string{"q1": `\d`, "q2": `\d`},
url: "http://localhost/foo?q1=1&q2=2",
shouldBlock: false,
},
{
name: "multiple params regex, one optional. optional query param does not exist",
qparams: map[string]string{"q1": `\d`, "q2": `^$|\d`},
url: "http://localhost/foo?q1=1",
shouldBlock: false,
},
{
name: "multiple params regex, one optional. optional query param exists",
qparams: map[string]string{"q1": `\d`, "q2": `^$|\d`},
url: "http://localhost/foo?q2=1",
shouldBlock: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
blocker := NewQuery(tt.qparams)

req := httptest.NewRequest(http.MethodGet, tt.url, nil)
blocked := blocker.Block(req)

if blocked != tt.shouldBlock {
t.Errorf("Got status code %v, wanted %v", blocked, tt.shouldBlock)
}
})
}
}

func ExampleQuery_Block() {
m := NewQuery(map[string]string{"page": `\d+`})
r := httptest.NewRequest(http.MethodPut, "http://localhost:8080?page=a", nil)
blocked := m.Block(r)
fmt.Printf("%t", blocked)
// Output: true
}
29 changes: 29 additions & 0 deletions blocker/scheme.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package blocker

import (
"net/http"
"strings"
)

// Scheme provides functionality to block a request based on a list of allowed HTTP schemes (http/https).
// Empty allowedSchemes will block
type Scheme struct {
allowedSchemes []string
}

// NewScheme is a constructor function for Scheme struct
func NewScheme(allowedSchemes []string) *Scheme {
return &Scheme{allowedSchemes}
}

// Block is a predicate for scheme based request blocking.
func (s *Scheme) Block(r *http.Request) bool {
block := true
for _, v := range s.allowedSchemes {
if r.URL.Scheme == strings.ToLower(v) {
block = false
break
}
}
return block
}
Loading

0 comments on commit 65081cc

Please sign in to comment.