Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
darkweak committed Aug 8, 2024
1 parent af6ec14 commit a966ba2
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 26 deletions.
1 change: 1 addition & 0 deletions pkg/api/souin.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ func (s *SouinAPI) BulkDelete(key string, purge bool) {
// Delete will delete a record into the provider cache system and will update the Souin API if enabled
// The key can be a regexp to delete multiple items
func (s *SouinAPI) Delete(key string) {
fmt.Printf("Delete the key => %#v\n", key)
for _, current := range s.storers {
current.Delete(key)
}
Expand Down
33 changes: 19 additions & 14 deletions pkg/middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ func (s *SouinBaseHandler) Store(
rq *http.Request,
requestCc *cacheobject.RequestCacheDirectives,
cachedKey string,
uri string,
) error {
statusCode := customWriter.GetStatusCode()
if !isCacheableCode(statusCode) {
Expand Down Expand Up @@ -341,6 +342,7 @@ func (s *SouinBaseHandler) Store(
variedKey,
) == nil {
s.Configuration.GetLogger().Sugar().Debugf("Stored the key %s in the %s provider", variedKey, currentStorer.Name())
res.Request = rq
} else {
mu.Lock()
fails = append(fails, fmt.Sprintf("; detail=%s-INSERTION-ERROR", currentStorer.Name()))
Expand All @@ -351,9 +353,9 @@ func (s *SouinBaseHandler) Store(

wg.Wait()
if len(fails) < s.storersLen {
go func(rs http.Response, key string) {
_ = s.SurrogateKeyStorer.Store(&rs, key)
}(res, variedKey)
go func(rs http.Response, key string, basekey string) {
_ = s.SurrogateKeyStorer.Store(&rs, key, uri, basekey)
}(res, variedKey, cachedKey)
status += "; stored"
}

Expand Down Expand Up @@ -387,6 +389,7 @@ func (s *SouinBaseHandler) Upstream(
next handlerFunc,
requestCc *cacheobject.RequestCacheDirectives,
cachedKey string,
uri string,
) error {
s.Configuration.GetLogger().Sugar().Debug("Request the upstream server")
prometheus.Increment(prometheus.RequestCounter)
Expand Down Expand Up @@ -434,7 +437,7 @@ func (s *SouinBaseHandler) Upstream(
customWriter.Header().Set(headerName, s.DefaultMatchedUrl.DefaultCacheControl)
}

err := s.Store(customWriter, rq, requestCc, cachedKey)
err := s.Store(customWriter, rq, requestCc, cachedKey, uri)
defer customWriter.Buf.Reset()

return singleflightValue{
Expand All @@ -458,7 +461,7 @@ func (s *SouinBaseHandler) Upstream(
for _, vh := range variedHeaders {
if rq.Header.Get(vh) != sfWriter.requestHeaders.Get(vh) {
// cachedKey += rfc.GetVariedCacheKey(rq, variedHeaders)
return s.Upstream(customWriter, rq, next, requestCc, cachedKey)
return s.Upstream(customWriter, rq, next, requestCc, cachedKey, uri)
}
}
}
Expand All @@ -474,7 +477,7 @@ func (s *SouinBaseHandler) Upstream(
return nil
}

func (s *SouinBaseHandler) Revalidate(validator *core.Revalidator, next handlerFunc, customWriter *CustomWriter, rq *http.Request, requestCc *cacheobject.RequestCacheDirectives, cachedKey string) error {
func (s *SouinBaseHandler) Revalidate(validator *core.Revalidator, next handlerFunc, customWriter *CustomWriter, rq *http.Request, requestCc *cacheobject.RequestCacheDirectives, cachedKey string, uri string) error {
s.Configuration.GetLogger().Sugar().Debug("Revalidate the request with the upstream server")
prometheus.Increment(prometheus.RequestRevalidationCounter)

Expand All @@ -496,7 +499,7 @@ func (s *SouinBaseHandler) Revalidate(validator *core.Revalidator, next handlerF
}

if statusCode != http.StatusNotModified {
err = s.Store(customWriter, rq, requestCc, cachedKey)
err = s.Store(customWriter, rq, requestCc, cachedKey, uri)
}
}

Expand Down Expand Up @@ -616,6 +619,8 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
}
cachedKey := req.Context().Value(context.Key).(string)

// Need to copy URL path before calling next because it can alter the URI
uri := req.URL.Path
bufPool := s.bufPool.Get().(*bytes.Buffer)
bufPool.Reset()
defer s.bufPool.Put(bufPool)
Expand Down Expand Up @@ -669,14 +674,14 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
}

if validator.NeedRevalidation {
err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey)
err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey, uri)
_, _ = customWriter.Send()

return err
}
if resCc, _ := cacheobject.ParseResponseCacheControl(rfc.HeaderAllCommaSepValuesString(response.Header, headerName)); resCc.NoCachePresent {
prometheus.Increment(prometheus.NoCachedResponseCounter)
err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey)
err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey, uri)
_, _ = customWriter.Send()

return err
Expand Down Expand Up @@ -711,9 +716,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
_, _ = io.Copy(customWriter.Buf, response.Body)
_, err := customWriter.Send()
customWriter = NewCustomWriter(req, rw, bufPool)
go func(v *core.Revalidator, goCw *CustomWriter, goRq *http.Request, goNext func(http.ResponseWriter, *http.Request) error, goCc *cacheobject.RequestCacheDirectives, goCk string) {
_ = s.Revalidate(v, goNext, goCw, goRq, goCc, goCk)
}(validator, customWriter, req, next, requestCc, cachedKey)
go func(v *core.Revalidator, goCw *CustomWriter, goRq *http.Request, goNext func(http.ResponseWriter, *http.Request) error, goCc *cacheobject.RequestCacheDirectives, goCk string, goUri string) {
_ = s.Revalidate(v, goNext, goCw, goRq, goCc, goCk, goUri)
}(validator, customWriter, req, next, requestCc, cachedKey, uri)
buf := s.bufPool.Get().(*bytes.Buffer)
buf.Reset()
defer s.bufPool.Put(buf)
Expand All @@ -723,7 +728,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n

if responseCc.MustRevalidate || responseCc.NoCachePresent || validator.NeedRevalidation {
req.Header["If-None-Match"] = append(req.Header["If-None-Match"], validator.ResponseETag)
err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey)
err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey, uri)
statusCode := customWriter.GetStatusCode()
if err != nil {
if responseCc.StaleIfError > -1 || requestCc.StaleIfError > 0 {
Expand Down Expand Up @@ -785,7 +790,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
errorCacheCh := make(chan error)
go func(vr *http.Request, cw *CustomWriter) {
prometheus.Increment(prometheus.NoCachedResponseCounter)
errorCacheCh <- s.Upstream(cw, vr, next, requestCc, cachedKey)
errorCacheCh <- s.Upstream(cw, vr, next, requestCc, cachedKey, uri)
}(req, customWriter)

select {
Expand Down
4 changes: 2 additions & 2 deletions pkg/surrogate/providers/akamai.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ func (*AkamaiSurrogateStorage) getHeaderSeparator() string {
}

// Store stores the response tags located in the first non empty supported header
func (a *AkamaiSurrogateStorage) Store(response *http.Response, cacheKey string) error {
func (a *AkamaiSurrogateStorage) Store(response *http.Response, cacheKey, uri, basekey string) error {
defer func() {
response.Header.Del(surrogateKey)
response.Header.Del(surrogateControl)
}()
e := a.baseStorage.Store(response, cacheKey)
e := a.baseStorage.Store(response, cacheKey, uri, basekey)
response.Header.Set(edgeCacheTag, response.Header.Get(surrogateKey))

return e
Expand Down
4 changes: 2 additions & 2 deletions pkg/surrogate/providers/cloudflare.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ func (*CloudflareSurrogateStorage) getHeaderSeparator() string {
}

// Store stores the response tags located in the first non empty supported header
func (c *CloudflareSurrogateStorage) Store(response *http.Response, cacheKey string) error {
func (c *CloudflareSurrogateStorage) Store(response *http.Response, cacheKey, uri, basekey string) error {
defer func() {
response.Header.Del(surrogateKey)
response.Header.Del(surrogateControl)
}()
e := c.baseStorage.Store(response, cacheKey)
e := c.baseStorage.Store(response, cacheKey, uri, basekey)
response.Header.Set(cacheTag, strings.Join(c.ParseHeaders(response.Header.Get(surrogateKey)), c.getHeaderSeparator()))

return e
Expand Down
6 changes: 5 additions & 1 deletion pkg/surrogate/providers/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ func (s *baseStorage) purgeTag(tag string) []string {
}

// Store will take the lead to store the cache key for each provided Surrogate-key
func (s *baseStorage) Store(response *http.Response, cacheKey string) error {
func (s *baseStorage) Store(response *http.Response, cacheKey, uri, basekey string) error {
h := response.Header

cacheKey = url.QueryEscape(cacheKey)
Expand All @@ -223,13 +223,17 @@ func (s *baseStorage) Store(response *http.Response, cacheKey string) error {
for _, control := range controls {
if s.parent.candidateStore(control) {
s.storeTag(key, cacheKey, urlRegexp)

break
}
}
} else {
s.storeTag(key, cacheKey, urlRegexp)
}
}

s.storeTag(uri, cacheKey, urlRegexp)

return nil
}

Expand Down
12 changes: 6 additions & 6 deletions pkg/surrogate/providers/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func TestBaseStorage_Store(t *testing.T) {

bs := mockCommonProvider()

e := bs.Store(&res, "((((invalid_key_but_escaped")
e := bs.Store(&res, "((((invalid_key_but_escaped", "", "")
if e != nil {
t.Error("It shouldn't throw an error with a valid key.")
}
Expand All @@ -116,7 +116,7 @@ func TestBaseStorage_Store(t *testing.T) {
_ = bs.Storage.Set("test5", []byte("first,second,fifth"), storageToInfiniteTTLMap[bs.Storage.Name()])
_ = bs.Storage.Set("testInvalid", []byte("invalid"), storageToInfiniteTTLMap[bs.Storage.Name()])

if e = bs.Store(&res, "stored"); e != nil {
if e = bs.Store(&res, "stored", "", ""); e != nil {
t.Error("It shouldn't throw an error with a valid key.")
}

Expand All @@ -133,10 +133,10 @@ func TestBaseStorage_Store(t *testing.T) {
}

res.Header.Set(surrogateKey, "something")
_ = bs.Store(&res, "/something")
_ = bs.Store(&res, "/something")
_ = bs.Store(&res, "/something", "", "")
_ = bs.Store(&res, "/something", "", "")
res.Header.Set(surrogateKey, "something")
_ = bs.Store(&res, "/some")
_ = bs.Store(&res, "/some", "", "")

storageSize := len(bs.Storage.MapKeys(surrogatePrefix))
if storageSize != 6 {
Expand All @@ -161,7 +161,7 @@ func TestBaseStorage_Store_Load(t *testing.T) {
wg.Add(1)
go func(r http.Response, iteration int, group *sync.WaitGroup) {
defer wg.Done()
_ = bs.Store(&r, fmt.Sprintf("my_dynamic_cache_key_%d", iteration))
_ = bs.Store(&r, fmt.Sprintf("my_dynamic_cache_key_%d", iteration), "", "")
}(res, i, &wg)
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/surrogate/providers/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type SurrogateInterface interface {
Purge(http.Header) (cacheKeys []string, surrogateKeys []string)
Invalidate(method string, h http.Header)
purgeTag(string) []string
Store(*http.Response, string) error
Store(*http.Response, string, string, string) error
storeTag(string, string, *regexp.Regexp)
ParseHeaders(string) []string
List() map[string]string
Expand Down

0 comments on commit a966ba2

Please sign in to comment.