Skip to content

Commit

Permalink
make better use of Gin middleware system (#3849)
Browse files Browse the repository at this point in the history
  • Loading branch information
aler9 authored Oct 8, 2024
1 parent c8cdb77 commit d13dc10
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 135 deletions.
96 changes: 49 additions & 47 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ type API struct {
SRTServer SRTServer
Parent apiParent

httpServer *httpp.WrappedServer
httpServer *httpp.Server
mutex sync.RWMutex
}

Expand All @@ -158,77 +158,79 @@ func (a *API) Initialize() error {
router := gin.New()
router.SetTrustedProxies(a.TrustedProxies.ToTrustedProxies()) //nolint:errcheck

router.NoRoute(a.middlewareOrigin, a.middlewareAuth)
group := router.Group("/", a.middlewareOrigin, a.middlewareAuth)
router.Use(a.middlewareOrigin)
router.Use(a.middlewareAuth)

group.GET("/v3/config/global/get", a.onConfigGlobalGet)
group.PATCH("/v3/config/global/patch", a.onConfigGlobalPatch)
group := router.Group("/v3")

group.GET("/v3/config/pathdefaults/get", a.onConfigPathDefaultsGet)
group.PATCH("/v3/config/pathdefaults/patch", a.onConfigPathDefaultsPatch)
group.GET("/config/global/get", a.onConfigGlobalGet)
group.PATCH("/config/global/patch", a.onConfigGlobalPatch)

group.GET("/v3/config/paths/list", a.onConfigPathsList)
group.GET("/v3/config/paths/get/*name", a.onConfigPathsGet)
group.POST("/v3/config/paths/add/*name", a.onConfigPathsAdd)
group.PATCH("/v3/config/paths/patch/*name", a.onConfigPathsPatch)
group.POST("/v3/config/paths/replace/*name", a.onConfigPathsReplace)
group.DELETE("/v3/config/paths/delete/*name", a.onConfigPathsDelete)
group.GET("/config/pathdefaults/get", a.onConfigPathDefaultsGet)
group.PATCH("/config/pathdefaults/patch", a.onConfigPathDefaultsPatch)

group.GET("/v3/paths/list", a.onPathsList)
group.GET("/v3/paths/get/*name", a.onPathsGet)
group.GET("/config/paths/list", a.onConfigPathsList)
group.GET("/config/paths/get/*name", a.onConfigPathsGet)
group.POST("/config/paths/add/*name", a.onConfigPathsAdd)
group.PATCH("/config/paths/patch/*name", a.onConfigPathsPatch)
group.POST("/config/paths/replace/*name", a.onConfigPathsReplace)
group.DELETE("/config/paths/delete/*name", a.onConfigPathsDelete)

group.GET("/paths/list", a.onPathsList)
group.GET("/paths/get/*name", a.onPathsGet)

if !interfaceIsEmpty(a.HLSServer) {
group.GET("/v3/hlsmuxers/list", a.onHLSMuxersList)
group.GET("/v3/hlsmuxers/get/*name", a.onHLSMuxersGet)
group.GET("/hlsmuxers/list", a.onHLSMuxersList)
group.GET("/hlsmuxers/get/*name", a.onHLSMuxersGet)
}

if !interfaceIsEmpty(a.RTSPServer) {
group.GET("/v3/rtspconns/list", a.onRTSPConnsList)
group.GET("/v3/rtspconns/get/:id", a.onRTSPConnsGet)
group.GET("/v3/rtspsessions/list", a.onRTSPSessionsList)
group.GET("/v3/rtspsessions/get/:id", a.onRTSPSessionsGet)
group.POST("/v3/rtspsessions/kick/:id", a.onRTSPSessionsKick)
group.GET("/rtspconns/list", a.onRTSPConnsList)
group.GET("/rtspconns/get/:id", a.onRTSPConnsGet)
group.GET("/rtspsessions/list", a.onRTSPSessionsList)
group.GET("/rtspsessions/get/:id", a.onRTSPSessionsGet)
group.POST("/rtspsessions/kick/:id", a.onRTSPSessionsKick)
}

if !interfaceIsEmpty(a.RTSPSServer) {
group.GET("/v3/rtspsconns/list", a.onRTSPSConnsList)
group.GET("/v3/rtspsconns/get/:id", a.onRTSPSConnsGet)
group.GET("/v3/rtspssessions/list", a.onRTSPSSessionsList)
group.GET("/v3/rtspssessions/get/:id", a.onRTSPSSessionsGet)
group.POST("/v3/rtspssessions/kick/:id", a.onRTSPSSessionsKick)
group.GET("/rtspsconns/list", a.onRTSPSConnsList)
group.GET("/rtspsconns/get/:id", a.onRTSPSConnsGet)
group.GET("/rtspssessions/list", a.onRTSPSSessionsList)
group.GET("/rtspssessions/get/:id", a.onRTSPSSessionsGet)
group.POST("/rtspssessions/kick/:id", a.onRTSPSSessionsKick)
}

if !interfaceIsEmpty(a.RTMPServer) {
group.GET("/v3/rtmpconns/list", a.onRTMPConnsList)
group.GET("/v3/rtmpconns/get/:id", a.onRTMPConnsGet)
group.POST("/v3/rtmpconns/kick/:id", a.onRTMPConnsKick)
group.GET("/rtmpconns/list", a.onRTMPConnsList)
group.GET("/rtmpconns/get/:id", a.onRTMPConnsGet)
group.POST("/rtmpconns/kick/:id", a.onRTMPConnsKick)
}

if !interfaceIsEmpty(a.RTMPSServer) {
group.GET("/v3/rtmpsconns/list", a.onRTMPSConnsList)
group.GET("/v3/rtmpsconns/get/:id", a.onRTMPSConnsGet)
group.POST("/v3/rtmpsconns/kick/:id", a.onRTMPSConnsKick)
group.GET("/rtmpsconns/list", a.onRTMPSConnsList)
group.GET("/rtmpsconns/get/:id", a.onRTMPSConnsGet)
group.POST("/rtmpsconns/kick/:id", a.onRTMPSConnsKick)
}

if !interfaceIsEmpty(a.WebRTCServer) {
group.GET("/v3/webrtcsessions/list", a.onWebRTCSessionsList)
group.GET("/v3/webrtcsessions/get/:id", a.onWebRTCSessionsGet)
group.POST("/v3/webrtcsessions/kick/:id", a.onWebRTCSessionsKick)
group.GET("/webrtcsessions/list", a.onWebRTCSessionsList)
group.GET("/webrtcsessions/get/:id", a.onWebRTCSessionsGet)
group.POST("/webrtcsessions/kick/:id", a.onWebRTCSessionsKick)
}

if !interfaceIsEmpty(a.SRTServer) {
group.GET("/v3/srtconns/list", a.onSRTConnsList)
group.GET("/v3/srtconns/get/:id", a.onSRTConnsGet)
group.POST("/v3/srtconns/kick/:id", a.onSRTConnsKick)
group.GET("/srtconns/list", a.onSRTConnsList)
group.GET("/srtconns/get/:id", a.onSRTConnsGet)
group.POST("/srtconns/kick/:id", a.onSRTConnsKick)
}

group.GET("/v3/recordings/list", a.onRecordingsList)
group.GET("/v3/recordings/get/*name", a.onRecordingsGet)
group.DELETE("/v3/recordings/deletesegment", a.onRecordingDeleteSegment)
group.GET("/recordings/list", a.onRecordingsList)
group.GET("/recordings/get/*name", a.onRecordingsGet)
group.DELETE("/recordings/deletesegment", a.onRecordingDeleteSegment)

network, address := restrictnetwork.Restrict("tcp", a.Address)

a.httpServer = &httpp.WrappedServer{
a.httpServer = &httpp.Server{
Network: network,
Address: address,
ReadTimeout: time.Duration(a.ReadTimeout),
Expand Down Expand Up @@ -270,14 +272,14 @@ func (a *API) writeError(ctx *gin.Context, status int, err error) {
}

func (a *API) middlewareOrigin(ctx *gin.Context) {
ctx.Writer.Header().Set("Access-Control-Allow-Origin", a.AllowOrigin)
ctx.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
ctx.Header("Access-Control-Allow-Origin", a.AllowOrigin)
ctx.Header("Access-Control-Allow-Credentials", "true")

// preflight requests
if ctx.Request.Method == http.MethodOptions &&
ctx.Request.Header.Get("Access-Control-Request-Method") != "" {
ctx.Writer.Header().Set("Access-Control-Allow-Methods", "OPTIONS, GET, POST, PATCH, DELETE")
ctx.Writer.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
ctx.Header("Access-Control-Allow-Methods", "OPTIONS, GET, POST, PATCH, DELETE")
ctx.Header("Access-Control-Allow-Headers", "Authorization, Content-Type")
ctx.AbortWithStatus(http.StatusNoContent)
return
}
Expand Down
32 changes: 18 additions & 14 deletions internal/metrics/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ type Metrics struct {
AuthManager metricsAuthManager
Parent metricsParent

httpServer *httpp.WrappedServer
httpServer *httpp.Server
mutex sync.Mutex
pathManager api.PathManager
rtspServer api.RTSPServer
Expand All @@ -68,11 +68,15 @@ type Metrics struct {
func (m *Metrics) Initialize() error {
router := gin.New()
router.SetTrustedProxies(m.TrustedProxies.ToTrustedProxies()) //nolint:errcheck
router.NoRoute(m.onRequest)

router.Use(m.middlewareOrigin)
router.Use(m.middlewareAuth)

router.GET("/metrics", m.onMetrics)

network, address := restrictnetwork.Restrict("tcp", m.Address)

m.httpServer = &httpp.WrappedServer{
m.httpServer = &httpp.Server{
Network: network,
Address: address,
ReadTimeout: time.Duration(m.ReadTimeout),
Expand Down Expand Up @@ -103,23 +107,21 @@ func (m *Metrics) Log(level logger.Level, format string, args ...interface{}) {
m.Parent.Log(level, "[metrics] "+format, args...)
}

func (m *Metrics) onRequest(ctx *gin.Context) {
ctx.Writer.Header().Set("Access-Control-Allow-Origin", m.AllowOrigin)
ctx.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
func (m *Metrics) middlewareOrigin(ctx *gin.Context) {
ctx.Header("Access-Control-Allow-Origin", m.AllowOrigin)
ctx.Header("Access-Control-Allow-Credentials", "true")

// preflight requests
if ctx.Request.Method == http.MethodOptions &&
ctx.Request.Header.Get("Access-Control-Request-Method") != "" {
ctx.Writer.Header().Set("Access-Control-Allow-Methods", "OPTIONS, GET")
ctx.Writer.Header().Set("Access-Control-Allow-Headers", "Authorization")
ctx.Writer.WriteHeader(http.StatusNoContent)
return
}

if ctx.Request.URL.Path != "/metrics" || ctx.Request.Method != http.MethodGet {
ctx.Header("Access-Control-Allow-Methods", "OPTIONS, GET")
ctx.Header("Access-Control-Allow-Headers", "Authorization")
ctx.AbortWithStatus(http.StatusNoContent)
return
}
}

func (m *Metrics) middlewareAuth(ctx *gin.Context) {
err := m.AuthManager.Authenticate(&auth.Request{
IP: net.ParseIP(ctx.ClientIP()),
Action: conf.AuthActionMetrics,
Expand All @@ -135,10 +137,12 @@ func (m *Metrics) onRequest(ctx *gin.Context) {
// wait some seconds to mitigate brute force attacks
<-time.After(auth.PauseAfterError)

ctx.Writer.WriteHeader(http.StatusUnauthorized)
ctx.AbortWithStatus(http.StatusUnauthorized)
return
}
}

func (m *Metrics) onMetrics(ctx *gin.Context) {
out := ""

data, err := m.pathManager.APIPathsList()
Expand Down
19 changes: 9 additions & 10 deletions internal/playback/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type Server struct {
AuthManager serverAuthManager
Parent logger.Writer

httpServer *httpp.WrappedServer
httpServer *httpp.Server
mutex sync.RWMutex
}

Expand All @@ -41,15 +41,14 @@ func (s *Server) Initialize() error {
router := gin.New()
router.SetTrustedProxies(s.TrustedProxies.ToTrustedProxies()) //nolint:errcheck

router.NoRoute(s.middlewareOrigin)
group := router.Group("/", s.middlewareOrigin)
router.Use(s.middlewareOrigin)

group.GET("/list", s.onList)
group.GET("/get", s.onGet)
router.GET("/list", s.onList)
router.GET("/get", s.onGet)

network, address := restrictnetwork.Restrict("tcp", s.Address)

s.httpServer = &httpp.WrappedServer{
s.httpServer = &httpp.Server{
Network: network,
Address: address,
ReadTimeout: time.Duration(s.ReadTimeout),
Expand Down Expand Up @@ -104,14 +103,14 @@ func (s *Server) safeFindPathConf(name string) (*conf.Path, error) {
}

func (s *Server) middlewareOrigin(ctx *gin.Context) {
ctx.Writer.Header().Set("Access-Control-Allow-Origin", s.AllowOrigin)
ctx.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
ctx.Header("Access-Control-Allow-Origin", s.AllowOrigin)
ctx.Header("Access-Control-Allow-Credentials", "true")

// preflight requests
if ctx.Request.Method == http.MethodOptions &&
ctx.Request.Header.Get("Access-Control-Request-Method") != "" {
ctx.Writer.Header().Set("Access-Control-Allow-Methods", "OPTIONS, GET")
ctx.Writer.Header().Set("Access-Control-Allow-Headers", "Authorization")
ctx.Header("Access-Control-Allow-Methods", "OPTIONS, GET")
ctx.Header("Access-Control-Allow-Headers", "Authorization")
ctx.AbortWithStatus(http.StatusNoContent)
return
}
Expand Down
32 changes: 20 additions & 12 deletions internal/pprof/pprof.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,22 @@ type PPROF struct {
AuthManager pprofAuthManager
Parent pprofParent

httpServer *httpp.WrappedServer
httpServer *httpp.Server
}

// Initialize initializes PPROF.
func (pp *PPROF) Initialize() error {
router := gin.New()
router.SetTrustedProxies(pp.TrustedProxies.ToTrustedProxies()) //nolint:errcheck
router.NoRoute(pp.onRequest)

router.Use(pp.middlewareOrigin)
router.Use(pp.middlewareAuth)

router.Use(pp.onRequest)

network, address := restrictnetwork.Restrict("tcp", pp.Address)

pp.httpServer = &httpp.WrappedServer{
pp.httpServer = &httpp.Server{
Network: network,
Address: address,
ReadTimeout: time.Duration(pp.ReadTimeout),
Expand Down Expand Up @@ -79,37 +83,41 @@ func (pp *PPROF) Log(level logger.Level, format string, args ...interface{}) {
pp.Parent.Log(level, "[pprof] "+format, args...)
}

func (pp *PPROF) onRequest(ctx *gin.Context) {
ctx.Writer.Header().Set("Access-Control-Allow-Origin", pp.AllowOrigin)
ctx.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
func (pp *PPROF) middlewareOrigin(ctx *gin.Context) {
ctx.Header("Access-Control-Allow-Origin", pp.AllowOrigin)
ctx.Header("Access-Control-Allow-Credentials", "true")

// preflight requests
if ctx.Request.Method == http.MethodOptions &&
ctx.Request.Header.Get("Access-Control-Request-Method") != "" {
ctx.Writer.Header().Set("Access-Control-Allow-Methods", "OPTIONS, GET")
ctx.Writer.Header().Set("Access-Control-Allow-Headers", "Authorization")
ctx.Writer.WriteHeader(http.StatusNoContent)
ctx.Header("Access-Control-Allow-Methods", "OPTIONS, GET")
ctx.Header("Access-Control-Allow-Headers", "Authorization")
ctx.AbortWithStatus(http.StatusNoContent)
return
}
}

func (pp *PPROF) middlewareAuth(ctx *gin.Context) {
err := pp.AuthManager.Authenticate(&auth.Request{
IP: net.ParseIP(ctx.ClientIP()),
Action: conf.AuthActionMetrics,
HTTPRequest: ctx.Request,
})
if err != nil {
if err.(*auth.Error).AskCredentials { //nolint:errorlint
ctx.Writer.Header().Set("WWW-Authenticate", `Basic realm="mediamtx"`)
ctx.Writer.WriteHeader(http.StatusUnauthorized)
ctx.Header("WWW-Authenticate", `Basic realm="mediamtx"`)
ctx.AbortWithStatus(http.StatusUnauthorized)
return
}

// wait some seconds to mitigate brute force attacks
<-time.After(auth.PauseAfterError)

ctx.Writer.WriteHeader(http.StatusUnauthorized)
ctx.AbortWithStatus(http.StatusUnauthorized)
return
}
}

func (pp *PPROF) onRequest(ctx *gin.Context) {
http.DefaultServeMux.ServeHTTP(ctx.Writer, ctx.Request)
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ func (nilWriter) Write(p []byte) (int, error) {
return len(p), nil
}

// WrappedServer is a wrapper around http.Server that provides:
// Server is a wrapper around http.Server that provides:
// - net.Listener allocation and closure
// - TLS allocation
// - exit on panic
// - logging
// - server header
// - filtering of invalid requests
type WrappedServer struct {
type Server struct {
Network string
Address string
ReadTimeout time.Duration
Expand All @@ -42,8 +42,8 @@ type WrappedServer struct {
loader *certloader.CertLoader
}

// Initialize initializes a WrappedServer.
func (s *WrappedServer) Initialize() error {
// Initialize initializes a Server.
func (s *Server) Initialize() error {
var tlsConfig *tls.Config
if s.Encryption {
if s.ServerCert == "" {
Expand Down Expand Up @@ -91,7 +91,7 @@ func (s *WrappedServer) Initialize() error {
}

// Close closes all resources and waits for all routines to return.
func (s *WrappedServer) Close() {
func (s *Server) Close() {
ctx, ctxCancel := context.WithCancel(context.Background())
ctxCancel()
s.inner.Shutdown(ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
)

func TestFilterEmptyPath(t *testing.T) {
s := &WrappedServer{
s := &Server{
Network: "tcp",
Address: "localhost:4555",
ReadTimeout: 10 * time.Second,
Expand Down
Loading

0 comments on commit d13dc10

Please sign in to comment.