Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance handler customization. #50

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 145 additions & 14 deletions gzip.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gziphandler
import (
"bufio"
"compress/gzip"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -67,13 +68,29 @@ func addLevelPool(level int) {
}
}

// GzipWriter provides a Gzip response writer interface.
type GzipWriter interface {
Header() http.Header
Write([]byte) (int, error)
WriteHeader(int)
Hijack() (net.Conn, *bufio.ReadWriter, error)
Close() error
SetResponseWriter(http.ResponseWriter)
setLevel(int)
setIndex(int)
getLevel() int
setMinSize(int)
getMinSize() int
}

// GzipResponseWriter provides an http.ResponseWriter interface, which gzips
// bytes before writing them to the underlying response. This doesn't close the
// writers, so don't forget to do that.
// It can be configured to skip response smaller than minSize.
type GzipResponseWriter struct {
http.ResponseWriter
index int // Index for gzipWriterPools.
level int
gw *gzip.Writer

code int // Saves the WriteHeader value.
Expand All @@ -82,6 +99,38 @@ type GzipResponseWriter struct {
buf []byte // Holds the first part of the write before reaching the minSize or the end of the write.
}

// SetResponseWriter sets the initial ResponseWriter.
func (w *GzipResponseWriter) SetResponseWriter(rw http.ResponseWriter) {
w.ResponseWriter = rw
}

// setLevel sets gzip compression level.
func (w *GzipResponseWriter) setLevel(level int) {
w.level = level
}

// getLevel gets gzip compression level
func (w *GzipResponseWriter) getLevel() int {
return w.level
}

// setIndex sets index into gzipWriterPools
func (w *GzipResponseWriter) setIndex(index int) {
w.index = index
}

// setMinSize specified the minimum response size to gzip.
// If the response length is bigger than this value, it is compressed.
func (w *GzipResponseWriter) setMinSize(minSize int) {
w.minSize = minSize
}

// getMinSize specified the minimum response size to gzip.
// If the response length is bigger than this value, it is compressed.
func (w *GzipResponseWriter) getMinSize() int {
return w.minSize
}

// Write appends data to the gzip writer.
func (w *GzipResponseWriter) Write(b []byte) (int, error) {
// If content type is not set.
Expand Down Expand Up @@ -111,7 +160,7 @@ func (w *GzipResponseWriter) Write(b []byte) (int, error) {
return len(b), nil
}

// startGzip initialize any GZIP specific informations.
// startGzip initialize any GZIP specific information.
func (w *GzipResponseWriter) startGzip() error {

// Set the GZIP header.
Expand Down Expand Up @@ -207,47 +256,85 @@ func (w *GzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
// verify Hijacker interface implementation
var _ http.Hijacker = &GzipResponseWriter{}

// Deprecated use MustNewGzipHandler instead
// MustNewGzipLevelHandler behaves just like NewGzipLevelHandler except that in
// an error case it panics rather than returning an error.
func MustNewGzipLevelHandler(level int) func(http.Handler) http.Handler {
wrap, err := NewGzipLevelHandler(level)
wrap, err := NewGzipHandler(CompressionLevel(level), MinSizeDefault)
if err != nil {
panic(err)
}
return wrap
}

// MustNewGzipHandler behaves just like NewGzipHandler except that in
// an error case it panics rather than returning an error.
func MustNewGzipHandler(options ...func(GzipWriter) error) func(http.Handler) http.Handler {
wrap, err := NewGzipHandler(options...)
if err != nil {
panic(err)
}
return wrap
}

// Deprecated use NewGzipHandler instead
// NewGzipLevelHandler returns a wrapper function (often known as middleware)
// which can be used to wrap an HTTP handler to transparently gzip the response
// body if the client supports it (via the Accept-Encoding header). Responses will
// be encoded at the given gzip compression level. An error will be returned only
// if an invalid gzip compression level is given, so if one can ensure the level
// is valid, the returned error can be safely ignored.
func NewGzipLevelHandler(level int) (func(http.Handler) http.Handler, error) {
return NewGzipLevelAndMinSize(level, DefaultMinSize)
return NewGzipHandler(CompressionLevel(level), MinSizeDefault)
}

// Deprecated use NewGzipHandler instead
// NewGzipLevelAndMinSize behave as NewGzipLevelHandler except it let the caller
// specify the minimum size before compression.
func NewGzipLevelAndMinSize(level, minSize int) (func(http.Handler) http.Handler, error) {
if level != gzip.DefaultCompression && (level < gzip.BestSpeed || level > gzip.BestCompression) {
return nil, fmt.Errorf("invalid compression level requested: %d", level)
return NewGzipHandler(CompressionLevel(level), MinSize(minSize))
}

// NewGzipHandler returns a wrapper function (often known as middleware)
// which can be used to wrap an HTTP handler to transparently gzip the response
// body if the client supports it (via the Accept-Encoding header). Responses will
// be encoded at the given gzip compression level. An error will be returned only
// if an invalid options is given.
func NewGzipHandler(options ...func(GzipWriter) error) (func(http.Handler) http.Handler, error) {
return NewHandler(&GzipResponseWriter{}, options...)
}

// NewHandler
// NewHandler behave as NewGzipHandler except it let the caller
// specify a GzipWriter.
func NewHandler(gw GzipWriter, options ...func(GzipWriter) error) (func(http.Handler) http.Handler, error) {
if gw == nil {
return nil, errors.New("the GzipWriter must be defined")
}

for _, opt := range options {
err := opt(gw)
if err != nil {
return nil, err
}
}

if err := checkMinSize(gw); err != nil {
return nil, err
}
if minSize < 0 {
return nil, fmt.Errorf("minimum size must be more than zero")
if err := checkLevel(gw); err != nil {
return nil, err
}

return func(h http.Handler) http.Handler {
index := poolIndex(level)
index := poolIndex(gw.getLevel())

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add(vary, acceptEncoding)

if acceptsGzip(r) {
gw := &GzipResponseWriter{
ResponseWriter: w,
index: index,
minSize: minSize,
}
gw.setIndex(index)
gw.SetResponseWriter(w)
defer gw.Close()

h.ServeHTTP(gw, r)
Expand All @@ -258,11 +345,55 @@ func NewGzipLevelAndMinSize(level, minSize int) (func(http.Handler) http.Handler
}, nil
}

// MinSize specified the minimum response size to gzip.
// If the response length is bigger than this value, it is compressed.
func MinSize(minSize int) func(GzipWriter) error {
return func(gw GzipWriter) error {
gw.setMinSize(minSize)
return nil
}
}

// MinSizeDefault specified the default minimum response size to gzip. (DefaultMinSize: 512)
func MinSizeDefault(gw GzipWriter) error {
gw.setMinSize(DefaultMinSize)
return nil
}

// CompressionLevel specified the compression level
func CompressionLevel(level int) func(GzipWriter) error {
return func(gw GzipWriter) error {
gw.setLevel(level)
return nil
}
}

// CompressionLevelDefault specified the default compression level. (gzip.DefaultCompression)
func CompressionLevelDefault(gw GzipWriter) error {
gw.setLevel(gzip.DefaultCompression)
return nil
}

func checkLevel(gw GzipWriter) error {
level := gw.getLevel()
if level != gzip.DefaultCompression && (level < gzip.BestSpeed || level > gzip.BestCompression) {
return fmt.Errorf("invalid compression level requested: %d", level)
}
return nil
}

func checkMinSize(gw GzipWriter) error {
if gw.getMinSize() < 0 {
return errors.New("minimum size must be more than zero")
}
return nil
}

// GzipHandler wraps an HTTP handler, to transparently gzip the response body if
// the client supports it (via the Accept-Encoding header). This will compress at
// the default compression level.
func GzipHandler(h http.Handler) http.Handler {
wrapper, _ := NewGzipLevelHandler(gzip.DefaultCompression)
wrapper, _ := NewGzipHandler(CompressionLevelDefault, MinSizeDefault)
return wrapper(h)
}

Expand Down
86 changes: 86 additions & 0 deletions gzip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,31 @@ func TestNewGzipLevelHandler(t *testing.T) {
}
}

func TestNewGzipHandler(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
io.WriteString(w, testBody)
})

for lvl := gzip.BestSpeed; lvl <= gzip.BestCompression; lvl++ {
wrapper, err := NewGzipHandler(CompressionLevel(lvl), MinSizeDefault)
if !assert.Nil(t, err, "NewGzipLevleHandler returned error for level:", lvl) {
continue
}

req, _ := http.NewRequest("GET", "/whatever", nil)
req.Header.Set("Accept-Encoding", "gzip")
resp := httptest.NewRecorder()
wrapper(handler).ServeHTTP(resp, req)
res := resp.Result()

assert.Equal(t, 200, res.StatusCode)
assert.Equal(t, "gzip", res.Header.Get("Content-Encoding"))
assert.Equal(t, "Accept-Encoding", res.Header.Get("Vary"))
assert.Equal(t, gzipStrLevel(testBody, lvl), resp.Body.Bytes())
}
}

func TestNewGzipLevelHandlerReturnsErrorForInvalidLevels(t *testing.T) {
var err error
_, err = NewGzipLevelHandler(-42)
Expand All @@ -115,6 +140,15 @@ func TestNewGzipLevelHandlerReturnsErrorForInvalidLevels(t *testing.T) {
assert.NotNil(t, err)
}

func TestNewGzipHandlerReturnsErrorForInvalidLevels(t *testing.T) {
var err error
_, err = NewGzipHandler(MinSizeDefault, CompressionLevel(-42))
assert.NotNil(t, err)

_, err = NewGzipHandler(MinSizeDefault, CompressionLevel(42))
assert.NotNil(t, err)
}

func TestMustNewGzipLevelHandlerWillPanic(t *testing.T) {
defer func() {
if r := recover(); r == nil {
Expand All @@ -125,6 +159,16 @@ func TestMustNewGzipLevelHandlerWillPanic(t *testing.T) {
_ = MustNewGzipLevelHandler(-42)
}

func TestMustNewGzipHandlerWillPanic(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Error("panic was not called")
}
}()

_ = MustNewGzipHandler(MinSizeDefault, CompressionLevel(-42))
}

func TestGzipHandlerNoBody(t *testing.T) {
tests := []struct {
statusCode int
Expand Down Expand Up @@ -220,6 +264,11 @@ func TestGzipHandlerMinSizeMustBePositive(t *testing.T) {
assert.Error(t, err)
}

func TestNewGzipHandlerMinSizeMustBePositive(t *testing.T) {
_, err := NewGzipHandler(CompressionLevelDefault, MinSize(-1))
assert.Error(t, err)
}

func TestGzipHandlerMinSize(t *testing.T) {
responseLength := 0
b := []byte{'x'}
Expand Down Expand Up @@ -257,6 +306,43 @@ func TestGzipHandlerMinSize(t *testing.T) {
}
}

func TestGzipHandlerWithMinSize(t *testing.T) {
responseLength := 0
b := []byte{'x'}

wrapper, _ := NewGzipHandler(CompressionLevelDefault, MinSize(128))
handler := wrapper(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
// Write responses one byte at a time to ensure that the flush
// mechanism, if used, is working properly.
for i := 0; i < responseLength; i++ {
n, err := w.Write(b)
assert.Equal(t, 1, n)
assert.Nil(t, err)
}
},
))

r, _ := http.NewRequest("GET", "/whatever", &bytes.Buffer{})
r.Header.Add("Accept-Encoding", "gzip")

// Short response is not compressed
responseLength = 127
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
if w.Result().Header.Get(contentEncoding) == "gzip" {
t.Error("Expected uncompressed response, got compressed")
}

// Long response is not compressed
responseLength = 128
w = httptest.NewRecorder()
handler.ServeHTTP(w, r)
if w.Result().Header.Get(contentEncoding) != "gzip" {
t.Error("Expected compressed response, got uncompressed")
}
}

func TestGzipDoubleClose(t *testing.T) {
// reset the pool for the default compression so we can make sure duplicates
// aren't added back by double close
Expand Down