Skip to content

Commit

Permalink
api: Add authorization to viewership API (#60)
Browse files Browse the repository at this point in the history
* api: Return a 404 on asset not found

* api: Allow asset views to be fetched only by ID

* api: Create consts for the param names

* api/auth: Use path param on authorization instead

* api: Add auth middleware to views API

* api: Forward Origin header on auth request

* api/auth: Forward CORS response headers back

* api/auth: Fix authorization error

We were sending the auth req url not the original url

* api: Add Allow-Credentials to cors headers

* api/auth: Rename auth request vars for clarity

* api/auth: Add some logs to auth middleware

Log origiunal request

* api/auth: Grab original URI correctly

* Revert "api/auth: Add some logs to auth middleware"

This reverts commit 06d7a32.
  • Loading branch information
victorges authored Sep 17, 2022
1 parent ebaae8c commit 0b71230
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 26 deletions.
58 changes: 44 additions & 14 deletions api/authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,17 @@ import (
)

var (
authorizationHeaders = []string{"Authorization", "Cookie"}
authTimeout = 3 * time.Second
authorizationHeaders = []string{"Authorization", "Cookie", "Origin"}
// the response headers proxied from the auth request are basically cors headers
proxiedResponseHeaders = []string{
"Access-Control-Allow-Origin",
"Access-Control-Allow-Credentials",
"Access-Control-Allow-Methods",
"Access-Control-Allow-Headers",
"Access-Control-Expose-Headers",
"Access-Control-Max-Age",
}
authTimeout = 3 * time.Second

authRequestDuration = metrics.Factory.NewSummaryVec(
prometheus.SummaryOpts{
Expand All @@ -34,33 +43,54 @@ func authorization(authUrl string) middleware {
ctx, cancel := context.WithTimeout(r.Context(), authTimeout)
defer cancel()

status := getStreamStatus(r)
req, err := http.NewRequestWithContext(ctx, r.Method, authUrl, nil)
authReq, err := http.NewRequestWithContext(ctx, r.Method, authUrl, nil)
if err != nil {
respondError(rw, http.StatusInternalServerError, err)
return
}
req.Header.Set("X-Original-Uri", req.URL.String())
req.Header.Set("X-Livepeer-Stream-Id", status.ID)
for _, header := range authorizationHeaders {
req.Header[header] = r.Header[header]
authReq.Header.Set("X-Original-Uri", originalReqUri(r))
if streamID := apiParam(r, streamIDParam); streamID != "" {
authReq.Header.Set("X-Livepeer-Stream-Id", streamID)
} else if assetID := apiParam(r, assetIDParam); assetID != "" {
authReq.Header.Set("X-Livepeer-Asset-Id", assetID)
}
res, err := httpClient.Do(req)
copyHeaders(authorizationHeaders, r.Header, authReq.Header)
authRes, err := httpClient.Do(authReq)
if err != nil {
respondError(rw, http.StatusInternalServerError, fmt.Errorf("error authorizing request: %w", err))
return
}
copyHeaders(proxiedResponseHeaders, authRes.Header, rw.Header())

if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusNoContent {
if contentType := res.Header.Get("Content-Type"); contentType != "" {
if authRes.StatusCode != http.StatusOK && authRes.StatusCode != http.StatusNoContent {
if contentType := authRes.Header.Get("Content-Type"); contentType != "" {
rw.Header().Set("Content-Type", contentType)
}
rw.WriteHeader(res.StatusCode)
if _, err := io.Copy(rw, res.Body); err != nil {
glog.Errorf("Error writing auth error response. err=%q, status=%d, headers=%+v", err, res.StatusCode, res.Header)
rw.WriteHeader(authRes.StatusCode)
if _, err := io.Copy(rw, authRes.Body); err != nil {
glog.Errorf("Error writing auth error response. err=%q, status=%d, headers=%+v", err, authRes.StatusCode, authRes.Header)
}
return
}
next.ServeHTTP(rw, r)
})
}

func originalReqUri(r *http.Request) string {
proto := "http"
if r.TLS != nil {
proto = "https"
}
if fwdProto := r.Header.Get("X-Forwarded-Proto"); fwdProto != "" {
proto = fwdProto
}
return fmt.Sprintf("%s://%s%s", proto, r.Host, r.URL.RequestURI())
}

func copyHeaders(headers []string, src, dest http.Header) {
for _, header := range headers {
if vals := src[header]; len(vals) > 0 {
dest[header] = vals
}
}
}
5 changes: 4 additions & 1 deletion api/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/golang/glog"
"github.com/livepeer/livepeer-data/health"
"github.com/livepeer/livepeer-data/views"
)

type errorResponse struct {
Expand All @@ -18,7 +19,9 @@ func respondError(rw http.ResponseWriter, defaultStatus int, errs ...error) {
response := errorResponse{}
for _, err := range errs {
response.Errors = append(response.Errors, err.Error())
if errors.Is(err, health.ErrStreamNotFound) || errors.Is(err, health.ErrEventNotFound) {
if errors.Is(err, health.ErrStreamNotFound) ||
errors.Is(err, health.ErrEventNotFound) ||
errors.Is(err, views.ErrAssetNotFound) {
status = http.StatusNotFound
}
}
Expand Down
22 changes: 16 additions & 6 deletions api/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ const (
sseRetryBackoff = 10 * time.Second
ssePingDelay = 20 * time.Second
sseBufferSize = 128

streamIDParam = "streamId"
assetIDParam = "assetId"
)

type APIHandlerOptions struct {
Expand Down Expand Up @@ -54,14 +57,14 @@ func NewHandler(serverCtx context.Context, opts APIHandlerOptions, healthcore *h
func addStreamHealthHandlers(router *httprouter.Router, handler *apiHandler) {
healthcore, opts := handler.core, handler.opts
middlewares := []middleware{
streamStatus(healthcore, "streamId"),
streamStatus(healthcore),
regionProxy(opts.RegionalHostFormat, opts.OwnRegion),
}
if opts.AuthURL != "" {
middlewares = append(middlewares, authorization(opts.AuthURL))
}
addApiHandler := func(apiPath, name string, handler http.HandlerFunc) {
fullPath := path.Join(opts.APIRoot, "/stream/:streamId", apiPath)
fullPath := path.Join(opts.APIRoot, "/stream/:"+streamIDParam, apiPath)
fullHandler := prepareHandlerFunc(name, opts.Prometheus, handler, middlewares...)
router.Handler("GET", fullPath, fullHandler)
}
Expand All @@ -71,10 +74,13 @@ func addStreamHealthHandlers(router *httprouter.Router, handler *apiHandler) {

func addViewershipHandlers(router *httprouter.Router, handler *apiHandler) {
opts := handler.opts
// TODO: Add authorization to views API
middlewares := []middleware{}
if opts.AuthURL != "" {
middlewares = append(middlewares, authorization(opts.AuthURL))
}
addApiHandler := func(apiPath, name string, handler http.HandlerFunc) {
fullPath := path.Join(opts.APIRoot, "/views/:assetId", apiPath)
fullHandler := prepareHandlerFunc(name, opts.Prometheus, handler)
fullPath := path.Join(opts.APIRoot, "/views/:"+assetIDParam, apiPath)
fullHandler := prepareHandlerFunc(name, opts.Prometheus, handler, middlewares...)
router.Handler("GET", fullPath, fullHandler)
}
addApiHandler("/total", "get_total_views", handler.getTotalViews)
Expand All @@ -87,6 +93,10 @@ func (h *apiHandler) cors() middleware {
}
rw.Header().Set("Access-Control-Allow-Origin", "*")
rw.Header().Set("Access-Control-Allow-Headers", "*")
if origin := r.Header.Get("Origin"); origin != "" {
rw.Header().Set("Access-Control-Allow-Origin", origin)
rw.Header().Set("Access-Control-Allow-Credentials", "true")
}
next.ServeHTTP(rw, r)
})
}
Expand All @@ -100,7 +110,7 @@ func (h *apiHandler) healthcheck(rw http.ResponseWriter, r *http.Request) {
}

func (h *apiHandler) getTotalViews(rw http.ResponseWriter, r *http.Request) {
views, err := h.views.GetTotalViews(r.Context(), apiParam(r, "assetId"))
views, err := h.views.GetTotalViews(r.Context(), apiParam(r, assetIDParam))
if err != nil {
respondError(rw, http.StatusInternalServerError, err)
return
Expand Down
2 changes: 1 addition & 1 deletion api/streamStatus.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const (
streamStatusKey contextKey = iota
)

func streamStatus(healthcore *health.Core, streamIDParam string) middleware {
func streamStatus(healthcore *health.Core) middleware {
return inlineMiddleware(func(rw http.ResponseWriter, r *http.Request, next http.Handler) {
streamID := apiParam(r, streamIDParam)
if streamID == "" {
Expand Down
7 changes: 3 additions & 4 deletions views/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
"github.com/prometheus/common/model"
)

var ErrAssetNotFound = errors.New("asset not found")

type TotalViews struct {
ID string `json:"id"`
StartViews int64 `json:"startViews"`
Expand Down Expand Up @@ -41,10 +43,7 @@ func NewClient(opts ClientOptions) (*Client, error) {
func (c *Client) GetTotalViews(ctx context.Context, id string) ([]TotalViews, error) {
asset, err := c.lp.GetAsset(id)
if errors.Is(err, livepeer.ErrNotExists) {
asset, err = c.lp.GetAssetByPlaybackID(id, false)
}
if errors.Is(err, livepeer.ErrNotExists) {
return nil, errors.New("asset not found")
return nil, ErrAssetNotFound
} else if err != nil {
return nil, fmt.Errorf("error getting asset: %w", err)
}
Expand Down

0 comments on commit 0b71230

Please sign in to comment.