Skip to content

Commit

Permalink
fix: save response original header (#14)
Browse files Browse the repository at this point in the history
* fix: save response original header

* fix: remove unsafe b2s

* fix: remove unsafe b2s

* fix: remove unsafe b2s

* feat: add SetSaveOriginResHeader option
  • Loading branch information
li-jin-gou authored Aug 17, 2023
1 parent 331781c commit 1bb9078
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 1 deletion.
31 changes: 30 additions & 1 deletion reverse_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,12 @@ type ReverseProxy struct {
// target is set as a reverse proxy address
Target string

// transforTrailer is whether to forward Trailer-related header
// transferTrailer is whether to forward Trailer-related header
transferTrailer bool

// saveOriginResponse is whether to save the original response header
saveOriginResHeader bool

// director must be a function which modifies the request
// into a new request. Its response is then redirected
// back to the original client unmodified.
Expand Down Expand Up @@ -214,6 +217,21 @@ func (r *ReverseProxy) ServeHTTP(c context.Context, ctx *app.RequestContext) {
req := &ctx.Request
resp := &ctx.Response

// save tmp resp header
respTmpHeader := map[string][]string{}
if r.saveOriginResHeader {
resp.Header.SetNoDefaultContentType(true)
resp.Header.VisitAll(func(key, value []byte) {
keyStr := string(key)
valueStr := string(value)
if _, ok := respTmpHeader[keyStr]; ok {
respTmpHeader[keyStr] = []string{valueStr}
} else {
respTmpHeader[keyStr] = append(respTmpHeader[keyStr], valueStr)
}
})
}

if r.director != nil {
r.director(&ctx.Request)
}
Expand Down Expand Up @@ -261,6 +279,13 @@ func (r *ReverseProxy) ServeHTTP(c context.Context, ctx *app.RequestContext) {
return
}

// add tmp resp header
for key, hs := range respTmpHeader {
for _, h := range hs {
resp.Header.Add(key, h)
}
}

removeResponseConnHeaders(ctx)

for _, h := range hopHeaders {
Expand Down Expand Up @@ -303,6 +328,10 @@ func (r *ReverseProxy) SetTransferTrailer(b bool) {
r.transferTrailer = b
}

func (r *ReverseProxy) SetSaveOriginResHeader(b bool) {
r.saveOriginResHeader = b
}

func (r *ReverseProxy) getErrorHandler() func(c *app.RequestContext, err error) {
if r.errorHandler != nil {
return r.errorHandler
Expand Down
34 changes: 34 additions & 0 deletions reverse_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ import (
"testing"
"time"

"github.com/cloudwego/hertz/pkg/common/test/assert"

"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/client"
"github.com/cloudwego/hertz/pkg/app/server"
Expand Down Expand Up @@ -556,3 +558,35 @@ func TestReverseProxyTransferTrailer(t *testing.T) {
t.Errorf("handler got X-Trailer Trailer value %q; want 'trailer_value'", c)
}
}

func TestReverseProxySaveRespHeader(t *testing.T) {
const backendResponse = "I am the backend"
const backendStatus = 404
r := server.New(server.WithHostPorts("127.0.0.1:9997"))

r.GET("/proxy/backend", func(cc context.Context, ctx *app.RequestContext) {
ctx.Data(backendStatus, "application/json", []byte(backendResponse))
})

proxy, err := NewSingleHostReverseProxy("http://127.0.0.1:9997/proxy")
proxy.SetSaveOriginResHeader(true)
if err != nil {
t.Errorf("proxy error: %v", err)
}

r.GET("/backend", func(c context.Context, ctx *app.RequestContext) {
ctx.Response.Header.Set("aaa", "bbb")
proxy.ServeHTTP(c, ctx)
})
go r.Spin()
time.Sleep(time.Second)
cli, _ := client.NewClient()
req := protocol.AcquireRequest()
res := protocol.AcquireResponse()
req.SetRequestURI("http://localhost:9997/backend")
err = cli.Do(context.Background(), req, res)
if err != nil {
t.Fatalf("Get: %v", err)
}
assert.DeepEqual(t, "bbb", res.Header.Get("aaa"))
}

0 comments on commit 1bb9078

Please sign in to comment.