Skip to content

Commit

Permalink
Support force cache even when server doesn't set the Date header
Browse files Browse the repository at this point in the history
This fixes a bug where when a upstream server doesn't set a Date
header the response will be cached but then once the cache is
expired it never gets refreshed because parseResponseHeaders
returns a no date header error

This also makes it so that the force_cache_duration_seconds
is respected when an upstream returns a 304 not modified

Signed-off-by: Peter <[email protected]>
  • Loading branch information
c2zwdjnlcg committed Aug 24, 2023
1 parent 50d8ae8 commit 7ac99e3
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 30 deletions.
44 changes: 14 additions & 30 deletions topdown/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -839,10 +839,7 @@ func (c *interQueryCache) checkHTTPSendInterQueryCache() (ast.Value, error) {
return nil, handleHTTPSendErr(c.bctx, err)
}

headers, err := parseResponseHeaders(cachedRespData.Headers)
if err != nil {
return nil, err
}
headers := parseResponseHeaders(cachedRespData.Headers)

// check with the server if the stale response is still up-to-date.
// If server returns a new response (ie. status_code=200), update the cache with the new response
Expand All @@ -864,11 +861,16 @@ func (c *interQueryCache) checkHTTPSendInterQueryCache() (ast.Value, error) {
}
}

expiresAt, err := expiryFromHeaders(result.Header)
if err != nil {
return nil, err
if forceCaching(c.forceCacheParams) {
createdAt := getCurrentTime(c.bctx)
cachedRespData.ExpiresAt = createdAt.Add(time.Second * time.Duration(c.forceCacheParams.forceCacheDurationSeconds))
} else {
expiresAt, err := expiryFromHeaders(result.Header)
if err != nil {
return nil, err
}
cachedRespData.ExpiresAt = expiresAt
}
cachedRespData.ExpiresAt = expiresAt

cachingMode, err := getCachingMode(c.key)
if err != nil {
Expand Down Expand Up @@ -1143,40 +1145,22 @@ func (c *interQueryCacheData) Clone() (cache.InterQueryCacheValue, error) {
}

type responseHeaders struct {
date time.Time // origination date and time of response
cacheControl map[string]string // response cache-control header
maxAge deltaSeconds // max-age cache control directive
expires time.Time // date/time after which the response is considered stale
etag string // identifier for a specific version of the response
lastModified string // date and time response was last modified as per origin server
etag string // identifier for a specific version of the response
lastModified string // date and time response was last modified as per origin server
}

// deltaSeconds specifies a non-negative integer, representing
// time in seconds: http://tools.ietf.org/html/rfc7234#section-1.2.1
type deltaSeconds int32

func parseResponseHeaders(headers http.Header) (*responseHeaders, error) {
var err error
func parseResponseHeaders(headers http.Header) *responseHeaders {
result := responseHeaders{}

result.date, err = getResponseHeaderDate(headers)
if err != nil {
return nil, err
}

result.cacheControl = parseCacheControlHeader(headers)
result.maxAge, err = parseMaxAgeCacheDirective(result.cacheControl)
if err != nil {
return nil, err
}

result.expires = getResponseHeaderExpires(headers)

result.etag = headers.Get("etag")

result.lastModified = headers.Get("last-modified")

return &result, nil
return &result
}

func revalidateCachedResponse(req *http.Request, client *http.Client, inputReqObj ast.Object, headers *responseHeaders) (*http.Response, bool, error) {
Expand Down
145 changes: 145 additions & 0 deletions topdown/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1464,6 +1464,151 @@ func TestHTTPSendInterQueryForceCaching(t *testing.T) {
}
}

func TestHTTPSendInterQueryForceCachingRefresh(t *testing.T) {
cacheTime := 300
tests := []struct {
note string
request string
headers map[string][]string
skipDate bool
response string
expectedReqCount int
}{
{
note: "http.send GET cache expired, reloads normally",
request: `{"method": "get", "url": "%URL%", "force_json_decode": true, "force_cache": true, "force_cache_duration_seconds": %CACHE%}`,
headers: map[string][]string{},
expectedReqCount: 2,
response: `{"x": 1}`,
},
{
note: "http.send GET cache expired, no date, reloads normally",
request: `{"method": "get", "url": "%URL%", "force_json_decode": true, "force_cache": true, "force_cache_duration_seconds": %CACHE%}`,
headers: map[string][]string{},
expectedReqCount: 2,
skipDate: true,
response: `{"x": 1}`,
},
{
note: "http.send GET cache expired, returns not modified",
request: `{"method": "get", "url": "%URL%", "force_json_decode": true, "force_cache": true, "force_cache_duration_seconds": %CACHE%}`,
headers: map[string][]string{"Etag": {"1234"}},
expectedReqCount: 2,
response: `{"x": 1}`,
},
}

for _, tc := range tests {
t.Run(tc.note, func(t *testing.T) {
t0 := time.Now().UTC()

var requests []*http.Request
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requests = append(requests, r)
headers := w.Header()

for k, v := range tc.headers {
headers[k] = v
}

if tc.skipDate {
headers["Date"] = nil
} else {
headers.Set("Date", t0.Format(http.TimeFormat))
}

etag := w.Header().Get("etag")

if r.Header.Get("if-none-match") != "" {
if r.Header.Get("if-none-match") == etag {
// add new headers and update existing header value
headers["Cache-Control"] = []string{"max-age=200, public"}
w.WriteHeader(http.StatusNotModified)
}
} else {
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte(tc.response))
if err != nil {
t.Fatal(err)
}
}
}))
defer ts.Close()

request := strings.ReplaceAll(tc.request, "%URL%", ts.URL)
request = strings.ReplaceAll(request, "%CACHE%", strconv.Itoa(cacheTime))
full := fmt.Sprintf("http.send(%s, x)", request)
config, _ := iCache.ParseCachingConfig(nil)
interQueryCache := iCache.NewInterQueryCache(config)
q := NewQuery(ast.MustParseBody(full)).
WithInterQueryBuiltinCache(interQueryCache).
WithTime(t0)

/* Run tests twice once to populate the cache
then expire it out and run again to simulate an
expired cache
*/
for i := 0; i < 2; i++ {
resp, err := q.Run(context.Background())
if err != nil {
t.Fatal(err)
}

// make sure we have a valid response
if len(resp) < 1 {
t.Fatalf("missing response on query %d: %v", i, resp)
}

// check the body is what we expect
resResponse := resp[0]["x"].Value.(ast.Object).Get(ast.StringTerm("raw_body"))
if ast.String(tc.response).Compare(resResponse.Value) != 0 {
t.Fatalf("Expected response on query %d to be %v, got %v", i, tc.response, resResponse.String())
}

// pull the result out of the cache
var x interface{}
if err := util.UnmarshalJSON([]byte(request), &x); err != nil {
t.Fatalf("failed to unmarshal request on query %d: %v", i, err)
}
cacheKey, err := ast.InterfaceToValue(x)
if err != nil {
t.Fatalf("failed create request object on query %d: %v", i, err)
}

val, found := interQueryCache.Get(cacheKey)
if !found {
t.Fatalf("Expected inter-query cache hit on query %d", i)
}

m, err := val.(*interQueryCacheValue).copyCacheData()
if err != nil {
t.Fatal(err)
}

// Make sure the cache expires based on the force cache time setting
expectedExpiry := t0.Add(time.Second * time.Duration(cacheTime))
if expectedExpiry.Sub(m.ExpiresAt).Abs() > time.Second*1 {
t.Fatalf("Expected cache to expire on query %d in %v secs got %s", i, cacheTime, t0.Sub(m.ExpiresAt).Abs())
}

// Push an expired entry back into the cache for the next run
m.ExpiresAt = t0.Add(-time.Hour * 1)
v, err := m.toCacheValue()
if err != nil {
t.Fatal(err)
}

interQueryCache.Insert(cacheKey, v)
}

actualCount := len(requests)
if actualCount != tc.expectedReqCount {
t.Errorf("Expected to get %d requests, got %d", tc.expectedReqCount, actualCount)
}
})
}
}

func TestHTTPSendInterQueryCachingModifiedResp(t *testing.T) {
tests := []struct {
note string
Expand Down

0 comments on commit 7ac99e3

Please sign in to comment.