diff --git a/cmd/apiserver/main.go b/cmd/apiserver/main.go index 23792c40..03079c5c 100644 --- a/cmd/apiserver/main.go +++ b/cmd/apiserver/main.go @@ -69,6 +69,10 @@ func main() { } return options }). + WithServerFns(func(server *builder.GenericAPIServer) *builder.GenericAPIServer { + server.Handler.FullHandlerChain = clusterv1alpha1.NewClusterGatewayProxyRequestEscaper(server.Handler.FullHandlerChain) + return server + }). WithPostStartHook("init-master-loopback-client", singleton.InitLoopbackClient). Build() if err != nil { diff --git a/pkg/apis/cluster/v1alpha1/clustergateway_proxy.go b/pkg/apis/cluster/v1alpha1/clustergateway_proxy.go index 6803a6aa..5f6076cd 100644 --- a/pkg/apis/cluster/v1alpha1/clustergateway_proxy.go +++ b/pkg/apis/cluster/v1alpha1/clustergateway_proxy.go @@ -22,11 +22,14 @@ import ( "net/url" "os" gopath "path" + "regexp" "strings" "time" + "k8s.io/apiserver/pkg/server" utilfeature "k8s.io/apiserver/pkg/util/feature" "k8s.io/klog/v2" + "k8s.io/utils/strings/slices" "github.com/oam-dev/cluster-gateway/pkg/config" "github.com/oam-dev/cluster-gateway/pkg/featuregates" @@ -216,7 +219,7 @@ func (p *proxyHandler) ServeHTTP(writer http.ResponseWriter, request *http.Reque path := strings.TrimPrefix(request.URL.Path, apiPrefix+p.parentName+apiSuffix) newReq.Host = host newReq.URL.Path = gopath.Join(urlAddr.Path, path) - newReq.URL.RawQuery = request.URL.RawQuery + newReq.URL.RawQuery = unescapeQueryValues(request.URL.Query()).Encode() newReq.RequestURI = newReq.URL.RequestURI() cfg, err := NewConfigFromCluster(request.Context(), cluster) @@ -332,3 +335,55 @@ func (p *proxyHandler) getImpersonationConfig(req *http.Request) restclient.Impe Extra: user.GetExtra(), } } + +// NewClusterGatewayProxyRequestEscaper wrap the base http.Handler and escape +// the dryRun parameter. Otherwise, the dryRun request will be blocked by +// apiserver middlewares +func NewClusterGatewayProxyRequestEscaper(delegate http.Handler) http.Handler { + return &clusterGatewayProxyRequestEscaper{delegate: delegate} +} + +type clusterGatewayProxyRequestEscaper struct { + delegate http.Handler +} + +var ( + clusterGatewayProxyPathPattern = regexp.MustCompile(strings.Join([]string{ + server.APIGroupPrefix, + config.MetaApiGroupName, + config.MetaApiVersionName, + "clustergateways", + "[a-z0-9]([-a-z0-9]*[a-z0-9])?", + "proxy"}, "/")) + clusterGatewayProxyQueryKeysToEscape = []string{"dryRun"} + clusterGatewayProxyEscaperPrefix = "__" +) + +func (in *clusterGatewayProxyRequestEscaper) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if clusterGatewayProxyPathPattern.MatchString(req.URL.Path) { + newReq := req.Clone(req.Context()) + q := req.URL.Query() + for _, k := range clusterGatewayProxyQueryKeysToEscape { + if q.Has(k) { + q.Set(clusterGatewayProxyEscaperPrefix+k, q.Get(k)) + q.Del(k) + } + } + newReq.URL.RawQuery = q.Encode() + req = newReq + } + in.delegate.ServeHTTP(w, req) +} + +func unescapeQueryValues(values url.Values) url.Values { + unescaped := url.Values{} + for k, vs := range values { + if strings.HasPrefix(k, clusterGatewayProxyEscaperPrefix) && + slices.Contains(clusterGatewayProxyQueryKeysToEscape, + strings.TrimPrefix(k, clusterGatewayProxyEscaperPrefix)) { + k = strings.TrimPrefix(k, clusterGatewayProxyEscaperPrefix) + } + unescaped[k] = vs + } + return unescaped +} diff --git a/pkg/apis/cluster/v1alpha1/clustergateway_proxy_test.go b/pkg/apis/cluster/v1alpha1/clustergateway_proxy_test.go index 0e3c9eb7..124e9c1a 100644 --- a/pkg/apis/cluster/v1alpha1/clustergateway_proxy_test.go +++ b/pkg/apis/cluster/v1alpha1/clustergateway_proxy_test.go @@ -34,6 +34,8 @@ func TestProxyHandler(t *testing.T) { objName string inputOption *ClusterGatewayProxyOptions reqInfo request.RequestInfo + query string + expectedQuery string endpointPath string expectedFailure bool errorAssertFunc func(t *testing.T, err error) @@ -106,6 +108,33 @@ func TestProxyHandler(t *testing.T) { Verb: "get", }, }, + { + name: "normal proxy with query in endpoint should work", + parent: &fakeParentStorage{ + obj: &ClusterGateway{ + ObjectMeta: metav1.ObjectMeta{ + Name: "myName", + }, + Spec: ClusterGatewaySpec{ + Access: ClusterAccess{ + Credential: &ClusterAccessCredential{ + Type: CredentialTypeServiceAccountToken, + ServiceAccountToken: "myToken", + }, + }, + }, + }, + }, + objName: "myName", + inputOption: &ClusterGatewayProxyOptions{ + Path: "/abc", + }, + query: "__dryRun=All&fieldValidation=Strict", + expectedQuery: "dryRun=All&fieldValidation=Strict", + reqInfo: request.RequestInfo{ + Verb: "get", + }, + }, } for _, c := range cases { @@ -148,13 +177,14 @@ func TestProxyHandler(t *testing.T) { defer svr.Close() path := "/foo" targetPath := apiPrefix + c.objName + apiSuffix + path - resp, err := svr.Client().Get(svr.URL + targetPath) + resp, err := svr.Client().Get(svr.URL + targetPath + "?" + c.query) assert.NoError(t, err) data, err := io.ReadAll(resp.Body) require.NoError(t, err) assert.Equal(t, text, string(data)) assert.Equal(t, 200, resp.StatusCode) assert.Equal(t, gopath.Join(c.endpointPath, path), receivingReq.URL.Path) + assert.Equal(t, c.expectedQuery, receivingReq.URL.RawQuery) }) } }