From a784ede8d4427915762bc3ce5e4df0048315c175 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Wed, 24 Jul 2024 12:08:39 +0200 Subject: [PATCH 1/3] Improve caching when fetching capabilities. - Honor "max-age" in "Cache-Control" header. - Use "ETag" to avoid fetching unmodified data. --- capabilities.go | 187 ++++++++++++++++++++++++++++++------------- capabilities_test.go | 115 ++++++++++++++++++++++++-- go.mod | 1 + go.sum | 2 + 4 files changed, 242 insertions(+), 63 deletions(-) diff --git a/capabilities.go b/capabilities.go index 2472f5c2..4db47793 100644 --- a/capabilities.go +++ b/capabilities.go @@ -32,6 +32,8 @@ import ( "strings" "sync" "time" + + "github.com/marcw/cachecontrol" ) const ( @@ -41,18 +43,114 @@ const ( // Name of capability to enable the "v3" API for the signaling endpoint. FeatureSignalingV3Api = "signaling-v3" - // Cache received capabilities for one hour. - CapabilitiesCacheDuration = time.Hour - // Don't invalidate more than once per minute. maxInvalidateInterval = time.Minute ) type capabilitiesEntry struct { + mu sync.RWMutex nextUpdate time.Time + etag string capabilities map[string]interface{} } +func newCapabilitiesEntry() *capabilitiesEntry { + return &capabilitiesEntry{} +} + +func (e *capabilitiesEntry) valid(now time.Time) bool { + e.mu.RLock() + defer e.mu.RUnlock() + + return e.nextUpdate.After(now) +} + +func (e *capabilitiesEntry) updateRequest(r *http.Request) { + e.mu.RLock() + defer e.mu.RUnlock() + + if e.etag != "" { + r.Header.Set("If-None-Match", e.etag) + } +} + +func (e *capabilitiesEntry) invalidate() { + e.mu.Lock() + defer e.mu.Unlock() + + e.nextUpdate = time.Now() +} + +func (e *capabilitiesEntry) update(response *http.Response, now time.Time) error { + e.mu.Lock() + defer e.mu.Unlock() + + url := response.Request.URL + e.etag = response.Header.Get("ETag") + + var maxAge time.Duration + if cacheControl := response.Header.Get("Cache-Control"); cacheControl != "" { + cc := cachecontrol.Parse(cacheControl) + maxAge = cc.MaxAge() + } + e.nextUpdate = now.Add(maxAge) + + if response.StatusCode == http.StatusNotModified { + log.Printf("Capabilities %+v from %s have not changed", e.capabilities, url) + return nil + } + + ct := response.Header.Get("Content-Type") + if !strings.HasPrefix(ct, "application/json") { + log.Printf("Received unsupported content-type from %s: %s (%s)", url, ct, response.Status) + return ErrUnsupportedContentType + } + + body, err := io.ReadAll(response.Body) + if err != nil { + log.Printf("Could not read response body from %s: %s", url, err) + return err + } + + var ocs OcsResponse + if err := json.Unmarshal(body, &ocs); err != nil { + log.Printf("Could not decode OCS response %s from %s: %s", string(body), url, err) + return err + } else if ocs.Ocs == nil || len(ocs.Ocs.Data) == 0 { + log.Printf("Incomplete OCS response %s from %s", string(body), url) + return fmt.Errorf("incomplete OCS response") + } + + var capaResponse CapabilitiesResponse + if err := json.Unmarshal(ocs.Ocs.Data, &capaResponse); err != nil { + log.Printf("Could not decode OCS response body %s from %s: %s", string(ocs.Ocs.Data), url, err) + return err + } + + capaObj, found := capaResponse.Capabilities[AppNameSpreed] + if !found || len(capaObj) == 0 { + log.Printf("No capabilities received for app spreed from %s: %+v", url, capaResponse) + return nil + } + + var capa map[string]interface{} + if err := json.Unmarshal(capaObj, &capa); err != nil { + log.Printf("Unsupported capabilities received for app spreed from %s: %+v", url, capaResponse) + return nil + } + + log.Printf("Received capabilities %+v from %s", capa, url) + e.capabilities = capa + return nil +} + +func (e *capabilitiesEntry) GetCapabilities() map[string]interface{} { + e.mu.RLock() + defer e.mu.RUnlock() + + return e.capabilities +} + type Capabilities struct { mu sync.RWMutex @@ -92,42 +190,46 @@ type CapabilitiesResponse struct { Capabilities map[string]json.RawMessage `json:"capabilities"` } -func (c *Capabilities) getCapabilities(key string) (map[string]interface{}, bool) { +func (c *Capabilities) getCapabilities(key string) (*capabilitiesEntry, bool) { c.mu.RLock() defer c.mu.RUnlock() now := c.getNow() - if entry, found := c.entries[key]; found && entry.nextUpdate.After(now) { - return entry.capabilities, true + entry, found := c.entries[key] + if found && entry.valid(now) { + return entry, true } - return nil, false + return entry, false } -func (c *Capabilities) setCapabilities(key string, capabilities map[string]interface{}) { +func (c *Capabilities) invalidateCapabilities(key string) { c.mu.Lock() defer c.mu.Unlock() now := c.getNow() - entry := &capabilitiesEntry{ - nextUpdate: now.Add(CapabilitiesCacheDuration), - capabilities: capabilities, + if entry, found := c.nextInvalidate[key]; found && entry.After(now) { + return } - c.entries[key] = entry + if entry, found := c.entries[key]; found { + entry.invalidate() + } + + c.nextInvalidate[key] = now.Add(maxInvalidateInterval) } -func (c *Capabilities) invalidateCapabilities(key string) { +func (c *Capabilities) newCapabilitiesEntry(key string) *capabilitiesEntry { c.mu.Lock() defer c.mu.Unlock() - now := c.getNow() - if entry, found := c.nextInvalidate[key]; found && entry.After(now) { - return + entry, found := c.entries[key] + if !found { + entry = newCapabilitiesEntry() + c.entries[key] = entry } - delete(c.entries, key) - c.nextInvalidate[key] = now.Add(maxInvalidateInterval) + return entry } func (c *Capabilities) getKeyForUrl(u *url.URL) string { @@ -137,8 +239,9 @@ func (c *Capabilities) getKeyForUrl(u *url.URL) string { func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[string]interface{}, bool, error) { key := c.getKeyForUrl(u) - if caps, found := c.getCapabilities(key); found { - return caps, true, nil + entry, valid := c.getCapabilities(key) + if valid { + return entry.GetCapabilities(), true, nil } capUrl := *u @@ -168,6 +271,9 @@ func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[st req.Header.Set("Accept", "application/json") req.Header.Set("OCS-APIRequest", "true") req.Header.Set("User-Agent", "nextcloud-spreed-signaling/"+c.version) + if entry != nil { + entry.updateRequest(req) + } resp, err := client.Do(req) if err != nil { @@ -175,48 +281,15 @@ func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[st } defer resp.Body.Close() - ct := resp.Header.Get("Content-Type") - if !strings.HasPrefix(ct, "application/json") { - log.Printf("Received unsupported content-type from %s: %s (%s)", capUrl.String(), ct, resp.Status) - return nil, false, ErrUnsupportedContentType - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - log.Printf("Could not read response body from %s: %s", capUrl.String(), err) - return nil, false, err + if entry == nil { + entry = c.newCapabilitiesEntry(key) } - var ocs OcsResponse - if err := json.Unmarshal(body, &ocs); err != nil { - log.Printf("Could not decode OCS response %s from %s: %s", string(body), capUrl.String(), err) + if err := entry.update(resp, c.getNow()); err != nil { return nil, false, err - } else if ocs.Ocs == nil || len(ocs.Ocs.Data) == 0 { - log.Printf("Incomplete OCS response %s from %s", string(body), u) - return nil, false, fmt.Errorf("incomplete OCS response") - } - - var response CapabilitiesResponse - if err := json.Unmarshal(ocs.Ocs.Data, &response); err != nil { - log.Printf("Could not decode OCS response body %s from %s: %s", string(ocs.Ocs.Data), capUrl.String(), err) - return nil, false, err - } - - capaObj, found := response.Capabilities[AppNameSpreed] - if !found || len(capaObj) == 0 { - log.Printf("No capabilities received for app spreed from %s: %+v", capUrl.String(), response) - return nil, false, nil - } - - var capa map[string]interface{} - if err := json.Unmarshal(capaObj, &capa); err != nil { - log.Printf("Unsupported capabilities received for app spreed from %s: %+v", capUrl.String(), response) - return nil, false, nil } - log.Printf("Received capabilities %+v from %s", capa, capUrl.String()) - c.setCapabilities(key, capa) - return capa, false, nil + return entry.GetCapabilities(), false, nil } func (c *Capabilities) HasCapabilityFeature(ctx context.Context, u *url.URL, feature string) bool { diff --git a/capabilities_test.go b/capabilities_test.go index 48742fb4..f3cef273 100644 --- a/capabilities_test.go +++ b/capabilities_test.go @@ -23,7 +23,10 @@ package signaling import ( "context" + "crypto/sha256" + "encoding/base64" "encoding/json" + "fmt" "net/http" "net/http/httptest" "net/url" @@ -35,7 +38,7 @@ import ( "github.com/gorilla/mux" ) -func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*CapabilitiesResponse)) (*url.URL, *Capabilities) { +func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*CapabilitiesResponse, http.ResponseWriter)) (*url.URL, *Capabilities) { pool, err := NewHttpClientPool(1, false) if err != nil { t.Fatal(err) @@ -86,10 +89,6 @@ func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*Capabilitie }, } - if callback != nil { - callback(response) - } - data, err := json.Marshal(response) if err != nil { t.Errorf("Could not marshal %+v: %s", response, err) @@ -107,9 +106,29 @@ func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*Capabilitie if data, err = json.Marshal(ocs); err != nil { t.Fatal(err) } + if !strings.Contains(t.Name(), "NoCache") { + w.Header().Add("Cache-Control", "max-age=60") + } + if strings.Contains(t.Name(), "ETag") { + h := sha256.New() + h.Write(data) // nolint + etag := fmt.Sprintf("\"%s\"", base64.StdEncoding.EncodeToString(h.Sum(nil))) + w.Header().Add("ETag", etag) + if inm := r.Header.Get("If-None-Match"); inm == etag { + w.WriteHeader(http.StatusNotModified) + if callback != nil { + callback(response, w) + } + + return + } + } w.Header().Add("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(data) // nolint + if callback != nil { + callback(response, w) + } } r.HandleFunc("/ocs/v2.php/cloud/capabilities", handleCapabilitiesFunc) @@ -204,7 +223,7 @@ func TestInvalidateCapabilities(t *testing.T) { t.Parallel() CatchLogForTest(t) var called atomic.Uint32 - url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse) { + url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) { called.Add(1) }) @@ -273,3 +292,87 @@ func TestInvalidateCapabilities(t *testing.T) { t.Errorf("expected called %d, got %d", 3, value) } } + +func TestCapabilitiesNoCache(t *testing.T) { + t.Parallel() + CatchLogForTest(t) + var called atomic.Uint32 + url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) { + called.Add(1) + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + expectedString := "bar" + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { + t.Error("could not find value for \"foo\"") + } else if value != expectedString { + t.Errorf("expected value %s, got %s", expectedString, value) + } else if cached { + t.Errorf("expected direct response") + } + + if value := called.Load(); value != 1 { + t.Errorf("expected called %d, got %d", 1, value) + } + + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { + t.Error("could not find value for \"foo\"") + } else if value != expectedString { + t.Errorf("expected value %s, got %s", expectedString, value) + } else if cached { + t.Errorf("expected direct response") + } + + if value := called.Load(); value != 2 { + t.Errorf("expected called %d, got %d", 2, value) + } +} + +func TestCapabilitiesNoCacheETag(t *testing.T) { + t.Parallel() + CatchLogForTest(t) + var called atomic.Uint32 + url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) { + ct := w.Header().Get("Content-Type") + switch called.Add(1) { + case 1: + if ct == "" { + t.Error("expected content-type on first request") + } + case 2: + if ct != "" { + t.Errorf("expected no content-type on second request, got %s", ct) + } + } + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + expectedString := "bar" + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { + t.Error("could not find value for \"foo\"") + } else if value != expectedString { + t.Errorf("expected value %s, got %s", expectedString, value) + } else if cached { + t.Errorf("expected direct response") + } + + if value := called.Load(); value != 1 { + t.Errorf("expected called %d, got %d", 1, value) + } + + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { + t.Error("could not find value for \"foo\"") + } else if value != expectedString { + t.Errorf("expected value %s, got %s", expectedString, value) + } else if cached { + t.Errorf("expected direct response") + } + + if value := called.Load(); value != 2 { + t.Errorf("expected called %d, got %d", 2, value) + } +} diff --git a/go.mod b/go.mod index e27de6c3..89ee3a24 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/gorilla/securecookie v1.1.2 github.com/gorilla/websocket v1.5.3 github.com/mailru/easyjson v0.7.7 + github.com/marcw/cachecontrol v0.0.0-20140722115028-30341fe9a7d5 github.com/nats-io/nats-server/v2 v2.10.18 github.com/nats-io/nats.go v1.36.0 github.com/notedit/janus-go v0.0.0-20200517101215-10eb8b95d1a0 diff --git a/go.sum b/go.sum index e40b5053..50aacc2c 100644 --- a/go.sum +++ b/go.sum @@ -106,6 +106,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/marcw/cachecontrol v0.0.0-20140722115028-30341fe9a7d5 h1:Wnc+HxXmAhN6xRzhmPJTiip9/sVZzwa6XlWksxjObCA= +github.com/marcw/cachecontrol v0.0.0-20140722115028-30341fe9a7d5/go.mod h1:e4ZZwiqLDqvzKu9TVxuGnh2kXCWeU6PxLG2hw/+no7g= github.com/minio/highwayhash v1.0.3 h1:kbnuUMoHYyVl7szWjSxJnxw11k2U709jqFPPmIUyD6Q= github.com/minio/highwayhash v1.0.3/go.mod h1:GGYsuwP/fPD6Y9hMiXuapVvlIUEhFhMTh0rxU3ik1LQ= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= From 71ec43f0d2c923eca9e2f2ed0b2bb41fe3bf4acf Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Wed, 24 Jul 2024 13:43:06 +0200 Subject: [PATCH 2/3] Honor "no-cache" and "must-revalidate". --- capabilities.go | 44 ++++++++++++---- capabilities_test.go | 120 +++++++++++++++++++++++++++++++++++++++---- 2 files changed, 143 insertions(+), 21 deletions(-) diff --git a/capabilities.go b/capabilities.go index 4db47793..fe5d092f 100644 --- a/capabilities.go +++ b/capabilities.go @@ -24,7 +24,7 @@ package signaling import ( "context" "encoding/json" - "fmt" + "errors" "io" "log" "net/http" @@ -47,11 +47,16 @@ const ( maxInvalidateInterval = time.Minute ) +var ( + ErrUnexpectedHttpStatus = errors.New("unexpected_http_status") +) + type capabilitiesEntry struct { - mu sync.RWMutex - nextUpdate time.Time - etag string - capabilities map[string]interface{} + mu sync.RWMutex + nextUpdate time.Time + etag string + mustRevalidate bool + capabilities map[string]interface{} } func newCapabilitiesEntry() *capabilitiesEntry { @@ -81,6 +86,15 @@ func (e *capabilitiesEntry) invalidate() { e.nextUpdate = time.Now() } +func (e *capabilitiesEntry) errorIfMustRevalidate(err error) error { + if !e.mustRevalidate { + return nil + } + + e.capabilities = nil + return err +} + func (e *capabilitiesEntry) update(response *http.Response, now time.Time) error { e.mu.Lock() defer e.mu.Unlock() @@ -91,51 +105,59 @@ func (e *capabilitiesEntry) update(response *http.Response, now time.Time) error var maxAge time.Duration if cacheControl := response.Header.Get("Cache-Control"); cacheControl != "" { cc := cachecontrol.Parse(cacheControl) - maxAge = cc.MaxAge() + if nc, _ := cc.NoCache(); !nc { + maxAge = cc.MaxAge() + } + e.mustRevalidate = cc.MustRevalidate() } e.nextUpdate = now.Add(maxAge) if response.StatusCode == http.StatusNotModified { log.Printf("Capabilities %+v from %s have not changed", e.capabilities, url) return nil + } else if response.StatusCode != http.StatusOK { + log.Printf("Received unexpected HTTP status from %s: %s", url, response.Status) + return e.errorIfMustRevalidate(ErrUnexpectedHttpStatus) } ct := response.Header.Get("Content-Type") if !strings.HasPrefix(ct, "application/json") { log.Printf("Received unsupported content-type from %s: %s (%s)", url, ct, response.Status) - return ErrUnsupportedContentType + return e.errorIfMustRevalidate(ErrUnsupportedContentType) } body, err := io.ReadAll(response.Body) if err != nil { log.Printf("Could not read response body from %s: %s", url, err) - return err + return e.errorIfMustRevalidate(err) } var ocs OcsResponse if err := json.Unmarshal(body, &ocs); err != nil { log.Printf("Could not decode OCS response %s from %s: %s", string(body), url, err) - return err + return e.errorIfMustRevalidate(err) } else if ocs.Ocs == nil || len(ocs.Ocs.Data) == 0 { log.Printf("Incomplete OCS response %s from %s", string(body), url) - return fmt.Errorf("incomplete OCS response") + return e.errorIfMustRevalidate(ErrIncompleteResponse) } var capaResponse CapabilitiesResponse if err := json.Unmarshal(ocs.Ocs.Data, &capaResponse); err != nil { log.Printf("Could not decode OCS response body %s from %s: %s", string(ocs.Ocs.Data), url, err) - return err + return e.errorIfMustRevalidate(err) } capaObj, found := capaResponse.Capabilities[AppNameSpreed] if !found || len(capaObj) == 0 { log.Printf("No capabilities received for app spreed from %s: %+v", url, capaResponse) + e.capabilities = nil return nil } var capa map[string]interface{} if err := json.Unmarshal(capaObj, &capa); err != nil { log.Printf("Unsupported capabilities received for app spreed from %s: %+v", url, capaResponse) + e.capabilities = nil return nil } diff --git a/capabilities_test.go b/capabilities_test.go index f3cef273..a9bf8c4c 100644 --- a/capabilities_test.go +++ b/capabilities_test.go @@ -26,6 +26,7 @@ import ( "crypto/sha256" "encoding/base64" "encoding/json" + "errors" "fmt" "net/http" "net/http/httptest" @@ -38,7 +39,7 @@ import ( "github.com/gorilla/mux" ) -func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*CapabilitiesResponse, http.ResponseWriter)) (*url.URL, *Capabilities) { +func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*CapabilitiesResponse, http.ResponseWriter) error) (*url.URL, *Capabilities) { pool, err := NewHttpClientPool(1, false) if err != nil { t.Fatal(err) @@ -106,8 +107,15 @@ func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*Capabilitie if data, err = json.Marshal(ocs); err != nil { t.Fatal(err) } + var cc []string if !strings.Contains(t.Name(), "NoCache") { - w.Header().Add("Cache-Control", "max-age=60") + cc = append(cc, "max-age=60") + } + if strings.Contains(t.Name(), "MustRevalidate") && !strings.Contains(t.Name(), "NoMustRevalidate") { + cc = append(cc, "must-revalidate") + } + if len(cc) > 0 { + w.Header().Add("Cache-Control", strings.Join(cc, ", ")) } if strings.Contains(t.Name(), "ETag") { h := sha256.New() @@ -115,20 +123,27 @@ func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*Capabilitie etag := fmt.Sprintf("\"%s\"", base64.StdEncoding.EncodeToString(h.Sum(nil))) w.Header().Add("ETag", etag) if inm := r.Header.Get("If-None-Match"); inm == etag { - w.WriteHeader(http.StatusNotModified) if callback != nil { - callback(response, w) + if err := callback(response, w); err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } } + w.WriteHeader(http.StatusNotModified) return } } w.Header().Add("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write(data) // nolint if callback != nil { - callback(response, w) + if err := callback(response, w); err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } } + + w.WriteHeader(http.StatusOK) + w.Write(data) // nolint } r.HandleFunc("/ocs/v2.php/cloud/capabilities", handleCapabilitiesFunc) @@ -223,8 +238,9 @@ func TestInvalidateCapabilities(t *testing.T) { t.Parallel() CatchLogForTest(t) var called atomic.Uint32 - url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) { + url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error { called.Add(1) + return nil }) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) @@ -297,8 +313,9 @@ func TestCapabilitiesNoCache(t *testing.T) { t.Parallel() CatchLogForTest(t) var called atomic.Uint32 - url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) { + url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error { called.Add(1) + return nil }) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) @@ -334,7 +351,7 @@ func TestCapabilitiesNoCacheETag(t *testing.T) { t.Parallel() CatchLogForTest(t) var called atomic.Uint32 - url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) { + url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error { ct := w.Header().Get("Content-Type") switch called.Add(1) { case 1: @@ -346,6 +363,7 @@ func TestCapabilitiesNoCacheETag(t *testing.T) { t.Errorf("expected no content-type on second request, got %s", ct) } } + return nil }) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) @@ -376,3 +394,85 @@ func TestCapabilitiesNoCacheETag(t *testing.T) { t.Errorf("expected called %d, got %d", 2, value) } } + +func TestCapabilitiesNoCacheNoMustRevalidate(t *testing.T) { + t.Parallel() + CatchLogForTest(t) + var called atomic.Uint32 + url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error { + if called.Add(1) == 2 { + return errors.New("trigger error") + } + + return nil + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + expectedString := "bar" + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { + t.Error("could not find value for \"foo\"") + } else if value != expectedString { + t.Errorf("expected value %s, got %s", expectedString, value) + } else if cached { + t.Errorf("expected direct response") + } + + if value := called.Load(); value != 1 { + t.Errorf("expected called %d, got %d", 1, value) + } + + // Expired capabilities can still be used even in case of update errors if + // "must-revalidate" is not set. + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { + t.Error("could not find value for \"foo\"") + } else if value != expectedString { + t.Errorf("expected value %s, got %s", expectedString, value) + } else if cached { + t.Errorf("expected direct response") + } + + if value := called.Load(); value != 2 { + t.Errorf("expected called %d, got %d", 2, value) + } +} + +func TestCapabilitiesNoCacheMustRevalidate(t *testing.T) { + t.Parallel() + CatchLogForTest(t) + var called atomic.Uint32 + url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error { + if called.Add(1) == 2 { + return errors.New("trigger error") + } + + return nil + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + expectedString := "bar" + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { + t.Error("could not find value for \"foo\"") + } else if value != expectedString { + t.Errorf("expected value %s, got %s", expectedString, value) + } else if cached { + t.Errorf("expected direct response") + } + + if value := called.Load(); value != 1 { + t.Errorf("expected called %d, got %d", 1, value) + } + + // Capabilities will be cleared if "must-revalidate" is set and an error + // occurs while fetching the updated data. + if value, _, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); found { + t.Errorf("should not have found value for \"foo\", got %s", value) + } + + if value := called.Load(); value != 2 { + t.Errorf("expected called %d, got %d", 2, value) + } +} From 25cfad751918e41d80ab8372412321395ad97ca7 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Thu, 25 Jul 2024 07:53:47 +0200 Subject: [PATCH 3/3] Use default cache duration if no "Cache-Control" is included in the response. --- capabilities.go | 6 ++++ capabilities_test.go | 72 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+) diff --git a/capabilities.go b/capabilities.go index fe5d092f..8fd9e92f 100644 --- a/capabilities.go +++ b/capabilities.go @@ -43,6 +43,10 @@ const ( // Name of capability to enable the "v3" API for the signaling endpoint. FeatureSignalingV3Api = "signaling-v3" + // Cache capabilities for one minute if response does not contain a + // "Cache-Control" header. + defaultCapabilitiesCacheDuration = time.Minute + // Don't invalidate more than once per minute. maxInvalidateInterval = time.Minute ) @@ -109,6 +113,8 @@ func (e *capabilitiesEntry) update(response *http.Response, now time.Time) error maxAge = cc.MaxAge() } e.mustRevalidate = cc.MustRevalidate() + } else { + maxAge = defaultCapabilitiesCacheDuration } e.nextUpdate = now.Add(maxAge) diff --git a/capabilities_test.go b/capabilities_test.go index a9bf8c4c..a6665647 100644 --- a/capabilities_test.go +++ b/capabilities_test.go @@ -334,6 +334,23 @@ func TestCapabilitiesNoCache(t *testing.T) { t.Errorf("expected called %d, got %d", 1, value) } + // Capabilities are cached for some time if no "Cache-Control" header is set. + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { + t.Error("could not find value for \"foo\"") + } else if value != expectedString { + t.Errorf("expected value %s, got %s", expectedString, value) + } else if !cached { + t.Errorf("expected cached response") + } + + if value := called.Load(); value != 1 { + t.Errorf("expected called %d, got %d", 1, value) + } + + SetCapabilitiesGetNow(t, capabilities, func() time.Time { + return time.Now().Add(defaultCapabilitiesCacheDuration) + }) + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { t.Error("could not find value for \"foo\"") } else if value != expectedString { @@ -382,6 +399,57 @@ func TestCapabilitiesNoCacheETag(t *testing.T) { t.Errorf("expected called %d, got %d", 1, value) } + SetCapabilitiesGetNow(t, capabilities, func() time.Time { + return time.Now().Add(defaultCapabilitiesCacheDuration) + }) + + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { + t.Error("could not find value for \"foo\"") + } else if value != expectedString { + t.Errorf("expected value %s, got %s", expectedString, value) + } else if cached { + t.Errorf("expected direct response") + } + + if value := called.Load(); value != 2 { + t.Errorf("expected called %d, got %d", 2, value) + } +} + +func TestCapabilitiesCacheNoMustRevalidate(t *testing.T) { + t.Parallel() + CatchLogForTest(t) + var called atomic.Uint32 + url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error { + if called.Add(1) == 2 { + return errors.New("trigger error") + } + + return nil + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + expectedString := "bar" + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { + t.Error("could not find value for \"foo\"") + } else if value != expectedString { + t.Errorf("expected value %s, got %s", expectedString, value) + } else if cached { + t.Errorf("expected direct response") + } + + if value := called.Load(); value != 1 { + t.Errorf("expected called %d, got %d", 1, value) + } + + SetCapabilitiesGetNow(t, capabilities, func() time.Time { + return time.Now().Add(time.Minute) + }) + + // Expired capabilities can still be used even in case of update errors if + // "must-revalidate" is not set. if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { t.Error("could not find value for \"foo\"") } else if value != expectedString { @@ -423,6 +491,10 @@ func TestCapabilitiesNoCacheNoMustRevalidate(t *testing.T) { t.Errorf("expected called %d, got %d", 1, value) } + SetCapabilitiesGetNow(t, capabilities, func() time.Time { + return time.Now().Add(defaultCapabilitiesCacheDuration) + }) + // Expired capabilities can still be used even in case of update errors if // "must-revalidate" is not set. if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {