Skip to content

Commit

Permalink
Don't update capabilities concurrently from same host.
Browse files Browse the repository at this point in the history
If capabilities are expired and requested from multiple clients concurrently,
this could cause concurrent (duplicate) requests to the same Nextcloud host.
With this change, only a single request is sent to Nextcloud in such cases.
  • Loading branch information
fancycode committed Oct 9, 2024
1 parent d692a3b commit 5067fb6
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 59 deletions.
126 changes: 67 additions & 59 deletions capabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,28 +60,32 @@ var (
)

type capabilitiesEntry struct {
c *Capabilities
mu sync.RWMutex
nextUpdate time.Time
etag string
mustRevalidate bool
capabilities map[string]interface{}
}

func newCapabilitiesEntry() *capabilitiesEntry {
return &capabilitiesEntry{}
func newCapabilitiesEntry(c *Capabilities) *capabilitiesEntry {
return &capabilitiesEntry{
c: c,
}
}

func (e *capabilitiesEntry) valid(now time.Time) bool {
e.mu.RLock()
defer e.mu.RUnlock()

return e.validLocked(now)
}

func (e *capabilitiesEntry) validLocked(now time.Time) bool {
return e.nextUpdate.After(now)
}

func (e *capabilitiesEntry) updateRequest(r *http.Request) {
e.mu.RLock()
defer e.mu.RUnlock()

if e.etag != "" {
r.Header.Set("If-None-Match", e.etag)
}
Expand All @@ -94,19 +98,59 @@ func (e *capabilitiesEntry) invalidate() {
e.nextUpdate = time.Now()
}

func (e *capabilitiesEntry) errorIfMustRevalidate(err error) error {
func (e *capabilitiesEntry) errorIfMustRevalidate(err error) (bool, error) {
if !e.mustRevalidate {
return nil
return false, nil
}

e.capabilities = nil
return err
return false, err
}

func (e *capabilitiesEntry) update(response *http.Response, now time.Time) error {
func (e *capabilitiesEntry) update(ctx context.Context, u *url.URL, now time.Time) (bool, error) {
e.mu.Lock()
defer e.mu.Unlock()

if e.validLocked(now) {
// Capabilities were updated while waiting for the lock.
return false, nil
}

capUrl := *u
if !strings.Contains(capUrl.Path, "ocs/v2.php") {
if !strings.HasSuffix(capUrl.Path, "/") {
capUrl.Path += "/"
}
capUrl.Path = capUrl.Path + "ocs/v2.php/cloud/capabilities"
} else if pos := strings.Index(capUrl.Path, "/ocs/v2.php/"); pos >= 0 {
capUrl.Path = capUrl.Path[:pos+11] + "/cloud/capabilities"
}

log.Printf("Capabilities expired for %s, updating", capUrl.String())

client, pool, err := e.c.pool.Get(ctx, &capUrl)
if err != nil {
log.Printf("Could not get client for host %s: %s", capUrl.Host, err)
return false, err
}
defer pool.Put(client)

req, err := http.NewRequestWithContext(ctx, "GET", capUrl.String(), nil)
if err != nil {
log.Printf("Could not create request to %s: %s", &capUrl, err)
return false, err
}
req.Header.Set("Accept", "application/json")
req.Header.Set("OCS-APIRequest", "true")
req.Header.Set("User-Agent", "nextcloud-spreed-signaling/"+e.c.version)
e.updateRequest(req)

response, err := client.Do(req)
if err != nil {
return false, err
}
defer response.Body.Close()

url := response.Request.URL
e.etag = response.Header.Get("ETag")

Expand All @@ -127,7 +171,7 @@ func (e *capabilitiesEntry) update(response *http.Response, now time.Time) error

if response.StatusCode == http.StatusNotModified {
log.Printf("Capabilities %+v from %s have not changed", e.capabilities, url)
return nil
return false, nil
} else if response.StatusCode != http.StatusOK {
log.Printf("Received unexpected HTTP status from %s: %s", url, response.Status)
return e.errorIfMustRevalidate(ErrUnexpectedHttpStatus)
Expand Down Expand Up @@ -164,19 +208,19 @@ func (e *capabilitiesEntry) update(response *http.Response, now time.Time) error
if !found || len(capaObj) == 0 {
log.Printf("No capabilities received for app spreed from %s: %+v", url, capaResponse)
e.capabilities = nil
return nil
return false, nil
}

var capa map[string]interface{}
if err := json.Unmarshal(capaObj, &capa); err != nil {
log.Printf("Unsupported capabilities received for app spreed from %s: %+v", url, capaResponse)
e.capabilities = nil
return nil
return false, nil
}

log.Printf("Received capabilities %+v from %s", capa, url)
e.capabilities = capa
return nil
return true, nil
}

func (e *capabilitiesEntry) GetCapabilities() map[string]interface{} {
Expand Down Expand Up @@ -231,11 +275,15 @@ func (c *Capabilities) getCapabilities(key string) (*capabilitiesEntry, bool) {

now := c.getNow()
entry, found := c.entries[key]
if found && entry.valid(now) {
return entry, true
if !found {
// Upgrade to write-lock
c.mu.RUnlock()
defer c.mu.RLock()

entry = c.newCapabilitiesEntry(key)
}

return entry, false
return entry, entry.valid(now)
}

func (c *Capabilities) invalidateCapabilities(key string) {
Expand All @@ -260,7 +308,7 @@ func (c *Capabilities) newCapabilitiesEntry(key string) *capabilitiesEntry {

entry, found := c.entries[key]
if !found {
entry = newCapabilitiesEntry()
entry = newCapabilitiesEntry(c)
c.entries[key] = entry
}

Expand All @@ -279,52 +327,12 @@ func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[st
return entry.GetCapabilities(), true, nil
}

capUrl := *u
if !strings.Contains(capUrl.Path, "ocs/v2.php") {
if !strings.HasSuffix(capUrl.Path, "/") {
capUrl.Path += "/"
}
capUrl.Path = capUrl.Path + "ocs/v2.php/cloud/capabilities"
} else if pos := strings.Index(capUrl.Path, "/ocs/v2.php/"); pos >= 0 {
capUrl.Path = capUrl.Path[:pos+11] + "/cloud/capabilities"
}

log.Printf("Capabilities expired for %s, updating", capUrl.String())

client, pool, err := c.pool.Get(ctx, &capUrl)
updated, err := entry.update(ctx, u, c.getNow())
if err != nil {
log.Printf("Could not get client for host %s: %s", capUrl.Host, err)
return nil, false, err
}
defer pool.Put(client)

req, err := http.NewRequestWithContext(ctx, "GET", capUrl.String(), nil)
if err != nil {
log.Printf("Could not create request to %s: %s", &capUrl, err)
return nil, false, err
}
req.Header.Set("Accept", "application/json")
req.Header.Set("OCS-APIRequest", "true")
req.Header.Set("User-Agent", "nextcloud-spreed-signaling/"+c.version)
if entry != nil {
entry.updateRequest(req)
}

resp, err := client.Do(req)
if err != nil {
return nil, false, err
}
defer resp.Body.Close()

if entry == nil {
entry = c.newCapabilitiesEntry(key)
}

if err := entry.update(resp, c.getNow()); err != nil {
return nil, false, err
}

return entry.GetCapabilities(), false, nil
return entry.GetCapabilities(), !updated, nil
}

func (c *Capabilities) HasCapabilityFeature(ctx context.Context, u *url.URL, feature string) bool {
Expand Down
54 changes: 54 additions & 0 deletions capabilities_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"net/http/httptest"
"net/url"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -528,3 +529,56 @@ func TestCapabilitiesNoCacheMustRevalidate(t *testing.T) {
value = called.Load()
assert.EqualValues(2, value)
}

func TestConcurrentExpired(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
assert := assert.New(t)
var called atomic.Uint32
url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error {
called.Add(1)
return nil
})

ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()

expectedString := "bar"
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
assert.False(cached)
}

count := 100
start := make(chan struct{})
var numCached atomic.Uint32
var numFetched atomic.Uint32
var finished sync.WaitGroup
for i := 0; i < count; i++ {
finished.Add(1)
go func() {
defer finished.Done()
<-start
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
if cached {
numCached.Add(1)
} else {
numFetched.Add(1)
}
}
}()
}

SetCapabilitiesGetNow(t, capabilities, func() time.Time {
return time.Now().Add(minCapabilitiesCacheDuration)
})

close(start)
finished.Wait()

assert.EqualValues(2, called.Load())
assert.EqualValues(count, numFetched.Load()+numCached.Load())
assert.EqualValues(1, numFetched.Load())
assert.EqualValues(count-1, numCached.Load())
}

0 comments on commit 5067fb6

Please sign in to comment.