diff --git a/capabilities.go b/capabilities.go index 45b7b546..e606bf03 100644 --- a/capabilities.go +++ b/capabilities.go @@ -60,6 +60,7 @@ var ( ) type capabilitiesEntry struct { + c *Capabilities mu sync.RWMutex nextUpdate time.Time etag string @@ -67,21 +68,24 @@ type capabilitiesEntry struct { capabilities map[string]interface{} } -func newCapabilitiesEntry() *capabilitiesEntry { - return &capabilitiesEntry{} +func newCapabilitiesEntry(c *Capabilities) *capabilitiesEntry { + return &capabilitiesEntry{ + c: c, + } } func (e *capabilitiesEntry) valid(now time.Time) bool { e.mu.RLock() defer e.mu.RUnlock() + return e.validLocked(now) +} + +func (e *capabilitiesEntry) validLocked(now time.Time) bool { return e.nextUpdate.After(now) } func (e *capabilitiesEntry) updateRequest(r *http.Request) { - e.mu.RLock() - defer e.mu.RUnlock() - if e.etag != "" { r.Header.Set("If-None-Match", e.etag) } @@ -94,19 +98,59 @@ func (e *capabilitiesEntry) invalidate() { e.nextUpdate = time.Now() } -func (e *capabilitiesEntry) errorIfMustRevalidate(err error) error { +func (e *capabilitiesEntry) errorIfMustRevalidate(err error) (bool, error) { if !e.mustRevalidate { - return nil + return false, nil } e.capabilities = nil - return err + return false, err } -func (e *capabilitiesEntry) update(response *http.Response, now time.Time) error { +func (e *capabilitiesEntry) update(ctx context.Context, u *url.URL, now time.Time) (bool, error) { e.mu.Lock() defer e.mu.Unlock() + if e.validLocked(now) { + // Capabilities were updated while waiting for the lock. + return false, nil + } + + capUrl := *u + if !strings.Contains(capUrl.Path, "ocs/v2.php") { + if !strings.HasSuffix(capUrl.Path, "/") { + capUrl.Path += "/" + } + capUrl.Path = capUrl.Path + "ocs/v2.php/cloud/capabilities" + } else if pos := strings.Index(capUrl.Path, "/ocs/v2.php/"); pos >= 0 { + capUrl.Path = capUrl.Path[:pos+11] + "/cloud/capabilities" + } + + log.Printf("Capabilities expired for %s, updating", capUrl.String()) + + client, pool, err := e.c.pool.Get(ctx, &capUrl) + if err != nil { + log.Printf("Could not get client for host %s: %s", capUrl.Host, err) + return false, err + } + defer pool.Put(client) + + req, err := http.NewRequestWithContext(ctx, "GET", capUrl.String(), nil) + if err != nil { + log.Printf("Could not create request to %s: %s", &capUrl, err) + return false, err + } + req.Header.Set("Accept", "application/json") + req.Header.Set("OCS-APIRequest", "true") + req.Header.Set("User-Agent", "nextcloud-spreed-signaling/"+e.c.version) + e.updateRequest(req) + + response, err := client.Do(req) + if err != nil { + return false, err + } + defer response.Body.Close() + url := response.Request.URL e.etag = response.Header.Get("ETag") @@ -127,7 +171,7 @@ func (e *capabilitiesEntry) update(response *http.Response, now time.Time) error if response.StatusCode == http.StatusNotModified { log.Printf("Capabilities %+v from %s have not changed", e.capabilities, url) - return nil + return false, nil } else if response.StatusCode != http.StatusOK { log.Printf("Received unexpected HTTP status from %s: %s", url, response.Status) return e.errorIfMustRevalidate(ErrUnexpectedHttpStatus) @@ -164,19 +208,19 @@ func (e *capabilitiesEntry) update(response *http.Response, now time.Time) error if !found || len(capaObj) == 0 { log.Printf("No capabilities received for app spreed from %s: %+v", url, capaResponse) e.capabilities = nil - return nil + return false, nil } var capa map[string]interface{} if err := json.Unmarshal(capaObj, &capa); err != nil { log.Printf("Unsupported capabilities received for app spreed from %s: %+v", url, capaResponse) e.capabilities = nil - return nil + return false, nil } log.Printf("Received capabilities %+v from %s", capa, url) e.capabilities = capa - return nil + return true, nil } func (e *capabilitiesEntry) GetCapabilities() map[string]interface{} { @@ -231,11 +275,15 @@ func (c *Capabilities) getCapabilities(key string) (*capabilitiesEntry, bool) { now := c.getNow() entry, found := c.entries[key] - if found && entry.valid(now) { - return entry, true + if !found { + // Upgrade to write-lock + c.mu.RUnlock() + defer c.mu.RLock() + + entry = c.newCapabilitiesEntry(key) } - return entry, false + return entry, entry.valid(now) } func (c *Capabilities) invalidateCapabilities(key string) { @@ -260,7 +308,7 @@ func (c *Capabilities) newCapabilitiesEntry(key string) *capabilitiesEntry { entry, found := c.entries[key] if !found { - entry = newCapabilitiesEntry() + entry = newCapabilitiesEntry(c) c.entries[key] = entry } @@ -279,52 +327,12 @@ func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[st return entry.GetCapabilities(), true, nil } - capUrl := *u - if !strings.Contains(capUrl.Path, "ocs/v2.php") { - if !strings.HasSuffix(capUrl.Path, "/") { - capUrl.Path += "/" - } - capUrl.Path = capUrl.Path + "ocs/v2.php/cloud/capabilities" - } else if pos := strings.Index(capUrl.Path, "/ocs/v2.php/"); pos >= 0 { - capUrl.Path = capUrl.Path[:pos+11] + "/cloud/capabilities" - } - - log.Printf("Capabilities expired for %s, updating", capUrl.String()) - - client, pool, err := c.pool.Get(ctx, &capUrl) + updated, err := entry.update(ctx, u, c.getNow()) if err != nil { - log.Printf("Could not get client for host %s: %s", capUrl.Host, err) - return nil, false, err - } - defer pool.Put(client) - - req, err := http.NewRequestWithContext(ctx, "GET", capUrl.String(), nil) - if err != nil { - log.Printf("Could not create request to %s: %s", &capUrl, err) - return nil, false, err - } - req.Header.Set("Accept", "application/json") - req.Header.Set("OCS-APIRequest", "true") - req.Header.Set("User-Agent", "nextcloud-spreed-signaling/"+c.version) - if entry != nil { - entry.updateRequest(req) - } - - resp, err := client.Do(req) - if err != nil { - return nil, false, err - } - defer resp.Body.Close() - - if entry == nil { - entry = c.newCapabilitiesEntry(key) - } - - if err := entry.update(resp, c.getNow()); err != nil { return nil, false, err } - return entry.GetCapabilities(), false, nil + return entry.GetCapabilities(), !updated, nil } func (c *Capabilities) HasCapabilityFeature(ctx context.Context, u *url.URL, feature string) bool { diff --git a/capabilities_test.go b/capabilities_test.go index fddefd0a..37fd4721 100644 --- a/capabilities_test.go +++ b/capabilities_test.go @@ -32,6 +32,7 @@ import ( "net/http/httptest" "net/url" "strings" + "sync" "sync/atomic" "testing" "time" @@ -528,3 +529,56 @@ func TestCapabilitiesNoCacheMustRevalidate(t *testing.T) { value = called.Load() assert.EqualValues(2, value) } + +func TestConcurrentExpired(t *testing.T) { + t.Parallel() + CatchLogForTest(t) + assert := assert.New(t) + var called atomic.Uint32 + url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error { + called.Add(1) + return nil + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + expectedString := "bar" + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) { + assert.Equal(expectedString, value) + assert.False(cached) + } + + count := 100 + start := make(chan struct{}) + var numCached atomic.Uint32 + var numFetched atomic.Uint32 + var finished sync.WaitGroup + for i := 0; i < count; i++ { + finished.Add(1) + go func() { + defer finished.Done() + <-start + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) { + assert.Equal(expectedString, value) + if cached { + numCached.Add(1) + } else { + numFetched.Add(1) + } + } + }() + } + + SetCapabilitiesGetNow(t, capabilities, func() time.Time { + return time.Now().Add(minCapabilitiesCacheDuration) + }) + + close(start) + finished.Wait() + + assert.EqualValues(2, called.Load()) + assert.EqualValues(count, numFetched.Load()+numCached.Load()) + assert.EqualValues(1, numFetched.Load()) + assert.EqualValues(count-1, numCached.Load()) +}