Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add SetClientBehavior method to allow users can select proxy client's do behavior #19

Merged
merged 10 commits into from
May 31, 2024
58 changes: 58 additions & 0 deletions proxy_client_behavior.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright 2024 CloudWeGo Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package reverseproxy

import "time"

type clientBehaviorType int

const (
do clientBehaviorType = iota
doDeadline
doRedirects
doTimeout
)

type clientBehavior struct {
clientBehaviorType clientBehaviorType
param interface{}
}

func ClientDo() clientBehavior {
return clientBehavior{
clientBehaviorType: do,
}
}

func ClientDoRedirects(param int) clientBehavior {
return clientBehavior{
clientBehaviorType: doRedirects,
param: param,
}
}

func ClientDoDeadline(param time.Time) clientBehavior {
return clientBehavior{
clientBehaviorType: doDeadline,
param: param,
}
}

func ClientDoTimeout(param time.Duration) clientBehavior {
return clientBehavior{
clientBehaviorType: doTimeout,
param: param,
}
}
43 changes: 31 additions & 12 deletions reverse_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"reflect"
"strings"
"sync"
"time"
"unsafe"

"github.com/cloudwego/hertz/pkg/app"
Expand All @@ -45,6 +46,8 @@ import (
type ReverseProxy struct {
client *client.Client

clientBehavior clientBehavior

// target is set as a reverse proxy address
Target string

Expand Down Expand Up @@ -105,7 +108,6 @@ var hopHeaders = []string{
// To rewrite Host headers, use ReverseProxy directly with a custom
// director policy.
//
// Note: if no config.ClientOption is passed it will use the default global client.Client instance.
// When passing config.ClientOption it will initialize a local client.Client instance.
// Using ReverseProxy.SetClient if there is need for shared customized client.Client instance.
func NewSingleHostReverseProxy(target string, options ...config.ClientOption) (*ReverseProxy, error) {
Expand All @@ -116,13 +118,11 @@ func NewSingleHostReverseProxy(target string, options ...config.ClientOption) (*
req.Header.SetHostBytes(req.URI().Host())
},
}
if len(options) != 0 {
c, err := client.NewClient(options...)
if err != nil {
return nil, err
}
r.client = c
c, err := client.NewClient(options...)
if err != nil {
return nil, err
}
r.client = c
return r, nil
}

Expand Down Expand Up @@ -275,11 +275,8 @@ func (r *ReverseProxy) ServeHTTP(c context.Context, ctx *app.RequestContext) {
req.Header.Add("X-Forwarded-For", ip)
}
}
fn := client.Do
if r.client != nil {
fn = r.client.Do
}
err := fn(c, req, resp)

err := r.doClientBehavior(c, req, resp)
if err != nil {
hlog.CtxErrorf(c, "HERTZ: Client request error: %#v", err.Error())
r.getErrorHandler()(ctx, err)
Expand Down Expand Up @@ -345,13 +342,35 @@ func (r *ReverseProxy) SetSaveOriginResHeader(b bool) {
r.saveOriginResHeader = b
}

func (r *ReverseProxy) SetClientBehavior(cb clientBehavior) {
r.clientBehavior = cb
}

func (r *ReverseProxy) getErrorHandler() func(c *app.RequestContext, err error) {
if r.errorHandler != nil {
return r.errorHandler
}
return r.defaultErrorHandler
}

func (r *ReverseProxy) doClientBehavior(ctx context.Context, req *protocol.Request, resp *protocol.Response) error {
var err error
switch r.clientBehavior.clientBehaviorType {
case doDeadline:
deadline := r.clientBehavior.param.(time.Time)
err = r.client.DoDeadline(ctx, req, resp, deadline)
case doRedirects:
maxRedirectsCount := r.clientBehavior.param.(int)
err = r.client.DoRedirects(ctx, req, resp, maxRedirectsCount)
case doTimeout:
timeout := r.clientBehavior.param.(time.Duration)
err = r.client.DoTimeout(ctx, req, resp, timeout)
default:
err = r.client.Do(ctx, req, resp)
}
return err
}

// b2s converts byte slice to a string without memory allocation.
// See https://groups.google.com/forum/#!msg/Golang-Nuts/ENgbUzYvCuU/90yGx7GUAgAJ .
//
Expand Down
1 change: 1 addition & 0 deletions reverse_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ func TestReverseProxySaveRespHeader(t *testing.T) {

proxy, err := NewSingleHostReverseProxy("http://127.0.0.1:9997/proxy")
proxy.SetSaveOriginResHeader(true)
proxy.SetClientBehavior(ClientDoRedirects(2))
if err != nil {
t.Errorf("proxy error: %v", err)
}
Expand Down
Loading