From 1e9141a89bd0c037143417d1646c7152b3b73222 Mon Sep 17 00:00:00 2001 From: Tom Date: Mon, 23 Sep 2024 11:31:22 +0100 Subject: [PATCH] Add public accessors for request pattern and method These are very useful values to be able to access easily while processing requests. Let's make them public and reachable via the context. --- router.go | 46 +++++++++++++++++++++++++++++++++++++--------- router_test.go | 22 ++++++++++++++++++++++ 2 files changed, 59 insertions(+), 9 deletions(-) diff --git a/router.go b/router.go index 4a576e35..f6331261 100644 --- a/router.go +++ b/router.go @@ -14,10 +14,12 @@ import ( // https://play.golang.org/p/MxhRiL37R-9 type routerContextKeyType struct{} type routerRequestPatternContextKeyType struct{} +type routerRequestMethodContextKeyType struct{} var ( routerContextKey = routerContextKeyType{} routerRequestPatternContextKey = routerRequestPatternContextKeyType{} + routerRequestMethodContextKey = routerRequestMethodContextKeyType{} routerComponentsRe = regexp.MustCompile(`(?:^|/)(\*\w*|:\w+)`) ) @@ -53,6 +55,22 @@ func routerPathPatternForRequest(r Request) string { return "" } +// RequestPatternFromContext returns the pattern that was matched for the request, if available. +func RequestPatternFromContext(ctx context.Context) (string, bool) { + if v := ctx.Value(routerRequestPatternContextKey); v != nil { + return v.(string), true + } + return "", false +} + +// RequestMethodFromContext returns the method of the request, if available. +func RequestMethodFromContext(ctx context.Context) (string, bool) { + if v := ctx.Value(routerRequestMethodContextKey); v != nil { + return v.(string), true + } + return "", false +} + func (r *Router) compile(pattern string) *regexp.Regexp { re, pos := ``, 0 for _, m := range routerComponentsRe.FindAllStringSubmatchIndex(pattern, -1) { @@ -134,6 +152,7 @@ func (r Router) Serve() Service { } req.Context = context.WithValue(req.Context, routerContextKey, &r) req.Context = context.WithValue(req.Context, routerRequestPatternContextKey, pathPattern) + req.Context = context.WithValue(req.Context, routerRequestMethodContextKey, req.Method) rsp := svc(req) if rsp.Request == nil { rsp.Request = &req @@ -157,37 +176,46 @@ func (r Router) Params(req Request) map[string]string { // Sugar // GET is shorthand for: -// r.Register("GET", pattern, svc) +// +// r.Register("GET", pattern, svc) func (r *Router) GET(pattern string, svc Service) { r.Register("GET", pattern, svc) } // CONNECT is shorthand for: -// r.Register("CONNECT", pattern, svc) +// +// r.Register("CONNECT", pattern, svc) func (r *Router) CONNECT(pattern string, svc Service) { r.Register("CONNECT", pattern, svc) } // DELETE is shorthand for: -// r.Register("DELETE", pattern, svc) +// +// r.Register("DELETE", pattern, svc) func (r *Router) DELETE(pattern string, svc Service) { r.Register("DELETE", pattern, svc) } // HEAD is shorthand for: -// r.Register("HEAD", pattern, svc) +// +// r.Register("HEAD", pattern, svc) func (r *Router) HEAD(pattern string, svc Service) { r.Register("HEAD", pattern, svc) } // OPTIONS is shorthand for: -// r.Register("OPTIONS", pattern, svc) +// +// r.Register("OPTIONS", pattern, svc) func (r *Router) OPTIONS(pattern string, svc Service) { r.Register("OPTIONS", pattern, svc) } // PATCH is shorthand for: -// r.Register("PATCH", pattern, svc) +// +// r.Register("PATCH", pattern, svc) func (r *Router) PATCH(pattern string, svc Service) { r.Register("PATCH", pattern, svc) } // POST is shorthand for: -// r.Register("POST", pattern, svc) +// +// r.Register("POST", pattern, svc) func (r *Router) POST(pattern string, svc Service) { r.Register("POST", pattern, svc) } // PUT is shorthand for: -// r.Register("PUT", pattern, svc) +// +// r.Register("PUT", pattern, svc) func (r *Router) PUT(pattern string, svc Service) { r.Register("PUT", pattern, svc) } // TRACE is shorthand for: -// r.Register("TRACE", pattern, svc) +// +// r.Register("TRACE", pattern, svc) func (r *Router) TRACE(pattern string, svc Service) { r.Register("TRACE", pattern, svc) } diff --git a/router_test.go b/router_test.go index a3ec7735..5d884fd6 100644 --- a/router_test.go +++ b/router_test.go @@ -126,3 +126,25 @@ func TestRouterSetsRequest(t *testing.T) { req.Context = rsp.Request.Context assert.Equal(t, req, *rsp.Request) } + +func TestRouterSetsContextValues(t *testing.T) { + t.Parallel() + + router := Router{} + router.GET("/", func(req Request) Response { + return Response{} + }) + + ctx := context.Background() + req := NewRequest(ctx, "GET", "/", map[string]string{"r": "foo"}) + rsp := router.Serve()(req) + require.NotNil(t, rsp.Request) + + ctxPattern, ok := RequestPatternFromContext(rsp.Request.Context) + assert.True(t, ok) + assert.Equal(t, "/", ctxPattern) + + ctxMethod, ok := RequestMethodFromContext(rsp.Request.Context) + assert.True(t, ok) + assert.Equal(t, "GET", ctxMethod) +}