diff --git a/router.go b/router.go index 4a576e35..f6331261 100644 --- a/router.go +++ b/router.go @@ -14,10 +14,12 @@ import ( // https://play.golang.org/p/MxhRiL37R-9 type routerContextKeyType struct{} type routerRequestPatternContextKeyType struct{} +type routerRequestMethodContextKeyType struct{} var ( routerContextKey = routerContextKeyType{} routerRequestPatternContextKey = routerRequestPatternContextKeyType{} + routerRequestMethodContextKey = routerRequestMethodContextKeyType{} routerComponentsRe = regexp.MustCompile(`(?:^|/)(\*\w*|:\w+)`) ) @@ -53,6 +55,22 @@ func routerPathPatternForRequest(r Request) string { return "" } +// RequestPatternFromContext returns the pattern that was matched for the request, if available. +func RequestPatternFromContext(ctx context.Context) (string, bool) { + if v := ctx.Value(routerRequestPatternContextKey); v != nil { + return v.(string), true + } + return "", false +} + +// RequestMethodFromContext returns the method of the request, if available. +func RequestMethodFromContext(ctx context.Context) (string, bool) { + if v := ctx.Value(routerRequestMethodContextKey); v != nil { + return v.(string), true + } + return "", false +} + func (r *Router) compile(pattern string) *regexp.Regexp { re, pos := ``, 0 for _, m := range routerComponentsRe.FindAllStringSubmatchIndex(pattern, -1) { @@ -134,6 +152,7 @@ func (r Router) Serve() Service { } req.Context = context.WithValue(req.Context, routerContextKey, &r) req.Context = context.WithValue(req.Context, routerRequestPatternContextKey, pathPattern) + req.Context = context.WithValue(req.Context, routerRequestMethodContextKey, req.Method) rsp := svc(req) if rsp.Request == nil { rsp.Request = &req @@ -157,37 +176,46 @@ func (r Router) Params(req Request) map[string]string { // Sugar // GET is shorthand for: -// r.Register("GET", pattern, svc) +// +// r.Register("GET", pattern, svc) func (r *Router) GET(pattern string, svc Service) { r.Register("GET", pattern, svc) } // CONNECT is shorthand for: -// r.Register("CONNECT", pattern, svc) +// +// r.Register("CONNECT", pattern, svc) func (r *Router) CONNECT(pattern string, svc Service) { r.Register("CONNECT", pattern, svc) } // DELETE is shorthand for: -// r.Register("DELETE", pattern, svc) +// +// r.Register("DELETE", pattern, svc) func (r *Router) DELETE(pattern string, svc Service) { r.Register("DELETE", pattern, svc) } // HEAD is shorthand for: -// r.Register("HEAD", pattern, svc) +// +// r.Register("HEAD", pattern, svc) func (r *Router) HEAD(pattern string, svc Service) { r.Register("HEAD", pattern, svc) } // OPTIONS is shorthand for: -// r.Register("OPTIONS", pattern, svc) +// +// r.Register("OPTIONS", pattern, svc) func (r *Router) OPTIONS(pattern string, svc Service) { r.Register("OPTIONS", pattern, svc) } // PATCH is shorthand for: -// r.Register("PATCH", pattern, svc) +// +// r.Register("PATCH", pattern, svc) func (r *Router) PATCH(pattern string, svc Service) { r.Register("PATCH", pattern, svc) } // POST is shorthand for: -// r.Register("POST", pattern, svc) +// +// r.Register("POST", pattern, svc) func (r *Router) POST(pattern string, svc Service) { r.Register("POST", pattern, svc) } // PUT is shorthand for: -// r.Register("PUT", pattern, svc) +// +// r.Register("PUT", pattern, svc) func (r *Router) PUT(pattern string, svc Service) { r.Register("PUT", pattern, svc) } // TRACE is shorthand for: -// r.Register("TRACE", pattern, svc) +// +// r.Register("TRACE", pattern, svc) func (r *Router) TRACE(pattern string, svc Service) { r.Register("TRACE", pattern, svc) } diff --git a/router_test.go b/router_test.go index a3ec7735..5d884fd6 100644 --- a/router_test.go +++ b/router_test.go @@ -126,3 +126,25 @@ func TestRouterSetsRequest(t *testing.T) { req.Context = rsp.Request.Context assert.Equal(t, req, *rsp.Request) } + +func TestRouterSetsContextValues(t *testing.T) { + t.Parallel() + + router := Router{} + router.GET("/", func(req Request) Response { + return Response{} + }) + + ctx := context.Background() + req := NewRequest(ctx, "GET", "/", map[string]string{"r": "foo"}) + rsp := router.Serve()(req) + require.NotNil(t, rsp.Request) + + ctxPattern, ok := RequestPatternFromContext(rsp.Request.Context) + assert.True(t, ok) + assert.Equal(t, "/", ctxPattern) + + ctxMethod, ok := RequestMethodFromContext(rsp.Request.Context) + assert.True(t, ok) + assert.Equal(t, "GET", ctxMethod) +}