diff --git a/topdown/http.go b/topdown/http.go index 0b5c16870a7..aa2f01e8f73 100644 --- a/topdown/http.go +++ b/topdown/http.go @@ -839,10 +839,7 @@ func (c *interQueryCache) checkHTTPSendInterQueryCache() (ast.Value, error) { return nil, handleHTTPSendErr(c.bctx, err) } - headers, err := parseResponseHeaders(cachedRespData.Headers) - if err != nil { - return nil, err - } + headers := parseResponseHeaders(cachedRespData.Headers) // check with the server if the stale response is still up-to-date. // If server returns a new response (ie. status_code=200), update the cache with the new response @@ -864,11 +861,16 @@ func (c *interQueryCache) checkHTTPSendInterQueryCache() (ast.Value, error) { } } - expiresAt, err := expiryFromHeaders(result.Header) - if err != nil { - return nil, err + if forceCaching(c.forceCacheParams) { + createdAt := getCurrentTime(c.bctx) + cachedRespData.ExpiresAt = createdAt.Add(time.Second * time.Duration(c.forceCacheParams.forceCacheDurationSeconds)) + } else { + expiresAt, err := expiryFromHeaders(result.Header) + if err != nil { + return nil, err + } + cachedRespData.ExpiresAt = expiresAt } - cachedRespData.ExpiresAt = expiresAt cachingMode, err := getCachingMode(c.key) if err != nil { @@ -1155,28 +1157,14 @@ type responseHeaders struct { // time in seconds: http://tools.ietf.org/html/rfc7234#section-1.2.1 type deltaSeconds int32 -func parseResponseHeaders(headers http.Header) (*responseHeaders, error) { - var err error +func parseResponseHeaders(headers http.Header) (*responseHeaders) { result := responseHeaders{} - result.date, err = getResponseHeaderDate(headers) - if err != nil { - return nil, err - } - - result.cacheControl = parseCacheControlHeader(headers) - result.maxAge, err = parseMaxAgeCacheDirective(result.cacheControl) - if err != nil { - return nil, err - } - - result.expires = getResponseHeaderExpires(headers) - result.etag = headers.Get("etag") result.lastModified = headers.Get("last-modified") - return &result, nil + return &result } func revalidateCachedResponse(req *http.Request, client *http.Client, inputReqObj ast.Object, headers *responseHeaders) (*http.Response, bool, error) { diff --git a/topdown/http_test.go b/topdown/http_test.go index d1e080808b2..95208c6e91d 100644 --- a/topdown/http_test.go +++ b/topdown/http_test.go @@ -1464,6 +1464,141 @@ func TestHTTPSendInterQueryForceCaching(t *testing.T) { } } +func TestHTTPSendInterQueryForceCachingRefresh(t *testing.T) { + cacheTime := 300 + tests := []struct { + note string + request string + headers map[string][]string + skipDate bool + response string + expectedReqCount int + }{ + { + note: "http.send GET cache expired, reloads normally", + request: `{"method": "get", "url": "%URL%", "force_json_decode": true, "force_cache": true, "force_cache_duration_seconds": %CACHE%}`, + headers: map[string][]string{}, + expectedReqCount: 2, + response: `{"x": 1}`, + }, + { + note: "http.send GET cache expired, no date, reloads normally", + request: `{"method": "get", "url": "%URL%", "force_json_decode": true, "force_cache": true, "force_cache_duration_seconds": %CACHE%}`, + headers: map[string][]string{}, + expectedReqCount: 2, + skipDate: true, + response: `{"x": 1}`, + }, + { + note: "http.send GET cache expired, returns not modified", + request: `{"method": "get", "url": "%URL%", "force_json_decode": true, "force_cache": true, "force_cache_duration_seconds": %CACHE%}`, + headers: map[string][]string{"Etag": {"1234"}}, + expectedReqCount: 2, + response: `{"x": 1}`, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + t0 := time.Now().UTC() + + var requests []*http.Request + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requests = append(requests, r) + headers := w.Header() + + for k, v := range tc.headers { + headers[k] = v + } + + if tc.skipDate { + headers["Date"] = nil + } else { + headers.Set("Date", t0.Format(http.TimeFormat)) + } + + etag := w.Header().Get("etag") + + if r.Header.Get("if-none-match") != "" { + if r.Header.Get("if-none-match") == etag { + // add new headers and update existing header value + headers["Cache-Control"] = []string{"max-age=200, public"} + w.WriteHeader(http.StatusNotModified) + } + } else { + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte(tc.response)) + if err != nil { + t.Fatal(err) + } + } + })) + defer ts.Close() + + request := strings.ReplaceAll(tc.request, "%URL%", ts.URL) + request = strings.ReplaceAll(request, "%CACHE%", strconv.Itoa(cacheTime)) + full := fmt.Sprintf("http.send(%s, x)", request) + config, _ := iCache.ParseCachingConfig(nil) + interQueryCache := iCache.NewInterQueryCache(config) + q := NewQuery(ast.MustParseBody(full)). + WithInterQueryBuiltinCache(interQueryCache). + WithTime(t0) + + for i := 0; i < 2; i++ { + resp, err := q.Run(context.Background()) + if err != nil { + t.Fatal(err) + } + + if len(resp) < 1 { + t.Fatalf("missing response on query %d: %v", i, resp) + } + + resResponse := resp[0]["x"].Value.(ast.Object).Get(ast.StringTerm("raw_body")) + if ast.String(tc.response).Compare(resResponse.Value) != 0 { + t.Fatalf("Expected response on query %d to be %v, got %v", i, tc.response, resResponse.String()) + } + + var x interface{} + if err := util.UnmarshalJSON([]byte(request), &x); err != nil { + t.Fatalf("failed to unmarshal request on query %d: %v", i, err) + } + cacheKey, err := ast.InterfaceToValue(x) + if err != nil { + t.Fatalf("failed create request object on query %d: %v", i, err) + } + + val, found := interQueryCache.Get(cacheKey) + if !found { + t.Fatalf("Expected inter-query cache hit on query %d", i) + } + + m, err := val.(*interQueryCacheValue).copyCacheData() + if err != nil { + t.Fatal(err) + } + expectedExpiry := t0.Add(time.Second * time.Duration(cacheTime)) + if expectedExpiry.Sub(m.ExpiresAt).Abs() > time.Second*1 { + t.Fatalf("Expected cache to expire on query %d in %v secs got %s", i, cacheTime, t0.Sub(m.ExpiresAt).Abs()) + } + + m.ExpiresAt = t0.Add(-time.Hour * 1) + v, err := m.toCacheValue() + if err != nil { + t.Fatal(err) + } + + interQueryCache.Insert(cacheKey, v) + } + + actualCount := len(requests) + if actualCount != tc.expectedReqCount { + t.Errorf("Expected to get %d requests, got %d", tc.expectedReqCount, actualCount) + } + }) + } +} + func TestHTTPSendInterQueryCachingModifiedResp(t *testing.T) { tests := []struct { note string