Skip to content

Commit

Permalink
Merge pull request #145 from tinh-tinh/feat/ren/144-add-middleware-co…
Browse files Browse the repository at this point in the history
…nsumer

feat: add middleware consumer to include/exclude routes
  • Loading branch information
Ren0503 authored Dec 4, 2024
2 parents 7629e51 + b8d410d commit b4738d3
Show file tree
Hide file tree
Showing 5 changed files with 361 additions and 2 deletions.
21 changes: 21 additions & 0 deletions common/slice.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package common

func Filter[T any](slice []T, f func(T) bool) []T {
var result []T
for _, v := range slice {
if f(v) {
result = append(result, v)
}
}
return result
}

func Remove[T any](slice []T, f func(T) bool) []T {
var result []T
for _, v := range slice {
if !f(v) {
result = append(result, v)
}
}
return result
}
28 changes: 28 additions & 0 deletions common/slice_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package common_test

import (
"testing"

"github.com/stretchr/testify/require"
"github.com/tinh-tinh/tinhtinh/common"
)

func Test_Filter(t *testing.T) {
data := []int{1, 2, 3, 4, 5}

res := common.Filter(data, func(item int) bool {
return item%2 == 0
})

require.Equal(t, []int{2, 4}, res)
}

func Test_Remove(t *testing.T) {
data := []int{1, 2, 3, 4, 5}

res := common.Remove(data, func(item int) bool {
return item%2 == 0
})

require.Equal(t, []int{1, 3, 5}, res)
}
95 changes: 95 additions & 0 deletions core/consumer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package core

import (
"github.com/tinh-tinh/tinhtinh/common"
)

const MethodAll = "ALL"

type RoutesPath struct {
Path string
Method string
}

func (c *Consumer) Include(includes ...RoutesPath) *Consumer {
c.includes = append(c.includes, includes...)
return c
}

func (c *Consumer) Exclude(excludes ...RoutesPath) *Consumer {
c.excludes = append(c.excludes, excludes...)
return c
}

type Consumer struct {
middlewares []Middleware
includes []RoutesPath
excludes []RoutesPath
}

func NewConsumer() *Consumer {
return &Consumer{}
}

func (c *Consumer) Apply(middlewares ...Middleware) *Consumer {
c.middlewares = append(c.middlewares, middlewares...)
return c
}

func (m *DynamicModule) Consumer(consumer *Consumer) *DynamicModule {
effectRoutes := []*Router{}
for _, i := range consumer.includes {
if i.Path == "*" && i.Method == MethodAll {
effectRoutes = m.Routers
} else if i.Path == "*" {
effectRoutes = common.Filter(m.Routers, func(r *Router) bool {
return r.Method == i.Method
})
} else if i.Method == MethodAll {
effectRoutes = common.Filter(m.Routers, func(r *Router) bool {
route := ParseRoute(" " + r.Path)
route.SetPrefix(r.Name)
return route.Path == i.Path
})
} else {
effectRoutes = common.Filter(m.Routers, func(r *Router) bool {
route := ParseRoute(" " + r.Path)
route.SetPrefix(r.Name)
return r.Method == i.Method && route.Path == i.Path
})
}
}
if len(consumer.includes) == 0 {
effectRoutes = m.Routers
}

for _, e := range consumer.excludes {
if e.Path == "*" && e.Method == MethodAll {
effectRoutes = common.Remove(effectRoutes, func(r *Router) bool {
return true
})
} else if e.Path == "*" {
effectRoutes = common.Remove(effectRoutes, func(r *Router) bool {
return r.Method == e.Method
})
} else if e.Method == MethodAll {
effectRoutes = common.Remove(effectRoutes, func(r *Router) bool {
route := ParseRoute(" " + r.Path)
route.SetPrefix(r.Name)
return route.Path == e.Path
})
} else {
effectRoutes = common.Remove(effectRoutes, func(r *Router) bool {
route := ParseRoute(" " + r.Path)
route.SetPrefix(r.Name)
return r.Method == e.Method && route.Path == e.Path
})
}
}

for _, r := range effectRoutes {
r.Middlewares = append(consumer.middlewares, r.Middlewares...)
}

return m
}
214 changes: 214 additions & 0 deletions core/consumer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
package core_test

import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/require"
"github.com/tinh-tinh/tinhtinh/core"
)

func Test_Consumer(t *testing.T) {
const (
Tenant core.CtxKey = "tenant"
Location core.CtxKey = "location"
)

tenantMiddleware := func(ctx core.Ctx) error {
tenant := ctx.Headers("x-tenant-id")
if tenant != "" {
ctx.Set(Tenant, tenant)
}
return ctx.Next()
}

locationMiddleware := func(ctx core.Ctx) error {
location := ctx.Headers("x-location-id")
if location != "" {
ctx.Set(Location, location)
}
return ctx.Next()
}

userMiddleware := func(ctx core.Ctx) error {
user := ctx.Headers("x-user-id")
if user != "" {
ctx.Set(Tenant, user)
}
return ctx.Next()
}

userController := func(module *core.DynamicModule) *core.DynamicController {
ctrl := module.NewController("user")

ctrl.Get("", func(ctx core.Ctx) error {
return ctx.JSON(core.Map{
"data": ctx.Get(Tenant),
})
})

ctrl.Get("location", func(ctx core.Ctx) error {
return ctx.JSON(core.Map{
"data": ctx.Get(Location),
})
})

ctrl.Get("special", func(ctx core.Ctx) error {
return ctx.JSON(core.Map{
"data": "special",
})
})

return ctrl
}

userModule := func(module *core.DynamicModule) *core.DynamicModule {
user := module.New(core.NewModuleOptions{
Controllers: []core.Controller{userController},
})

user.Consumer(core.NewConsumer().Apply(userMiddleware).Include(core.RoutesPath{
Path: "/user", Method: http.MethodGet,
}, core.RoutesPath{
Path: "*", Method: http.MethodGet,
}, core.RoutesPath{
Path: "/user/location", Method: core.MethodAll,
}, core.RoutesPath{
Path: "/user/special", Method: http.MethodGet,
}))

return user
}

postController := func(module *core.DynamicModule) *core.DynamicController {
ctrl := module.NewController("post")

ctrl.Get("", func(ctx core.Ctx) error {
return ctx.JSON(core.Map{
"data": ctx.Get(Tenant),
})
})

ctrl.Get("exclude", func(ctx core.Ctx) error {
return ctx.JSON(core.Map{
"data": ctx.Get(Location),
})
})

ctrl.Post("special", func(ctx core.Ctx) error {
return ctx.JSON(core.Map{
"data": "special",
})
})

return ctrl
}

postModule := func(module *core.DynamicModule) *core.DynamicModule {
post := module.New(core.NewModuleOptions{
Controllers: []core.Controller{postController},
})

post.Consumer(core.NewConsumer().Apply(userMiddleware).Exclude(core.RoutesPath{
Path: "*", Method: http.MethodGet,
}, core.RoutesPath{
Path: "/post/exclude", Method: core.MethodAll,
}, core.RoutesPath{
Path: "/post/special", Method: http.MethodPost,
}))

return post
}

appModule := func() *core.DynamicModule {
app := core.NewModule(core.NewModuleOptions{
Imports: []core.Module{userModule, postModule},
})

app.Consumer(core.NewConsumer().Apply(tenantMiddleware).Include(core.RoutesPath{
Path: "*", Method: core.MethodAll,
}))

app.Consumer(core.NewConsumer().Apply(locationMiddleware).Exclude(core.RoutesPath{
Path: "/post/exclude", Method: core.MethodAll,
}))

return app
}

app := core.CreateFactory(appModule)
app.SetGlobalPrefix("/api")

testServer := httptest.NewServer(app.PrepareBeforeListen())
defer testServer.Close()

testClient := testServer.Client()

req, err := http.NewRequest("GET", testServer.URL+"/api/user", nil)
require.Nil(t, err)
req.Header.Set("x-tenant-id", "test")
req.Header.Set("x-location-id", "test2")

resp, err := testClient.Do(req)
require.Nil(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)

data, err := io.ReadAll(resp.Body)
require.Nil(t, err)

var res Response
err = json.Unmarshal(data, &res)
require.Nil(t, err)
require.Equal(t, "test", res.Data)

req, err = http.NewRequest("GET", testServer.URL+"/api/user/location", nil)
require.Nil(t, err)
req.Header.Set("x-tenant-id", "test")
req.Header.Set("x-location-id", "test2")

resp, err = testClient.Do(req)
require.Nil(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)

data, err = io.ReadAll(resp.Body)
require.Nil(t, err)

err = json.Unmarshal(data, &res)
require.Nil(t, err)
require.Equal(t, "test2", res.Data)

req, err = http.NewRequest("GET", testServer.URL+"/api/post", nil)
require.Nil(t, err)
req.Header.Set("x-tenant-id", "test")
req.Header.Set("x-location-id", "test2")

resp, err = testClient.Do(req)
require.Nil(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)

data, err = io.ReadAll(resp.Body)
require.Nil(t, err)

err = json.Unmarshal(data, &res)
require.Nil(t, err)
require.Equal(t, "test", res.Data)

req, err = http.NewRequest("GET", testServer.URL+"/api/post/exclude", nil)
require.Nil(t, err)
req.Header.Set("x-tenant-id", "test")
req.Header.Set("x-location-id", "test2")

resp, err = testClient.Do(req)
require.Nil(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)

data, err = io.ReadAll(resp.Body)
require.Nil(t, err)

err = json.Unmarshal(data, &res)
require.Nil(t, err)
require.Nil(t, res.Data)
}
5 changes: 3 additions & 2 deletions core/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ const (
)

type DynamicModule struct {
isRoot bool
Scope Scope
Routers []*Router
Middlewares []Middleware
Expand Down Expand Up @@ -53,7 +54,7 @@ func NewModule(opt NewModuleOptions) *DynamicModule {
if opt.Scope == "" {
opt.Scope = Global
}
module := &DynamicModule{}
module := &DynamicModule{isRoot: true}
initModule(module, opt)

return module
Expand All @@ -70,7 +71,7 @@ func (m *DynamicModule) New(opt NewModuleOptions) *DynamicModule {
if opt.Scope == "" {
opt.Scope = Global
}
newMod := &DynamicModule{}
newMod := &DynamicModule{isRoot: false}
newMod.DataProviders = append(newMod.DataProviders, m.getExports()...)
newMod.Middlewares = append(newMod.Middlewares, m.Middlewares...)

Expand Down

0 comments on commit b4738d3

Please sign in to comment.