diff --git a/pkg/openSearch/openSearchClient/client.go b/pkg/openSearch/openSearchClient/client.go index 67a6ec1..d9fe629 100644 --- a/pkg/openSearch/openSearchClient/client.go +++ b/pkg/openSearch/openSearchClient/client.go @@ -73,6 +73,127 @@ func (c *Client) Search(indexName string, requestBody []byte) (responseBody []by return result, nil } +func (c *Client) SearchStream(indexName string, requestBody []byte, scrollTimeout time.Duration, ctx context.Context) (io.Reader, error) { + reader, writer := io.Pipe() + startSignal := make(chan error, 1) + + go func() { + var scrollID string + // Initialize query with scroll + searchResponse, err := c.openSearchProjectClient.Search( + ctx, + &opensearchapi.SearchReq{ + Indices: []string{indexName}, + Body: bytes.NewReader(requestBody), + Params: opensearchapi.SearchParams{ + Scroll: scrollTimeout, + }, + }, + ) + if err != nil { + writer.Close() + startSignal <- err + return + } + if searchResponse.Errors || searchResponse.Inspect().Response.IsError() { + writer.Close() + startSignal <- fmt.Errorf("search failed") + log.Error().Msgf("search response: %s: %s", searchResponse.Inspect().Response.Status(), searchResponse.Inspect().Response.String()) + return + } + + if searchResponse.ScrollID == nil { + writer.Close() + startSignal <- fmt.Errorf("search response contained no scroll ID") + return + } + + startSignal <- nil + + scrollID = *searchResponse.ScrollID + body := searchResponse.Inspect().Response.Body + defer body.Close() + + _, err = io.Copy(writer, body) + if err != nil { + writer.CloseWithError(err) + return + } + + // Continue scrolling thru + scrolled := 0 + for { + scrolled++ + log.Debug().Msgf("Scrolling %d", scrolled) + scrollReq := opensearchapi.ScrollGetReq{ + ScrollID: scrollID, + Params: opensearchapi.ScrollGetParams{ + Scroll: scrollTimeout, + }, + } + + scrollResult, err := c.openSearchProjectClient.Scroll.Get(ctx, scrollReq) + if err != nil { + writer.CloseWithError(err) + log.Err(err).Msgf("scroll-request failed: %v", scrollReq) + return + } + + if scrollResult.Inspect().Response.IsError() { + writer.CloseWithError(fmt.Errorf("scroll-result error")) + log.Error().Msgf("search response: %s: %s", scrollResult.Inspect().Response.Status(), scrollResult.Inspect().Response.String()) + return + } + + noMoreHits, err := processResponse(scrollResult, writer) + if err != nil { + writer.CloseWithError(err) + log.Err(err).Msgf("process response failed: %v", scrollResult) + return + } + if noMoreHits { + break + } + + // update the scrollId from the last result + if scrollResult != nil && scrollResult.ScrollID != nil { + scrollID = *scrollResult.ScrollID + } else { + log.Warn().Msg("No scroll ID found in response") + } + } + + writer.Close() + // Delete Scroll Context manually + clearScrollReq := opensearchapi.ScrollDeleteReq{ + ScrollIDs: []string{scrollID}, + } + _, err = c.openSearchProjectClient.Scroll.Delete(context.Background(), clearScrollReq) + if err != nil { + log.Warn().Err(err).Msgf("failed to delete scroll context") + } + }() + err := <-startSignal + if err != nil { + return nil, err + } + return reader, nil +} + +// processResponse reads the response, checks for hits, and writes them to the writer +func processResponse(response *opensearchapi.ScrollGetResp, writer *io.PipeWriter) (noMoreHits bool, err error) { + if len(response.Hits.Hits) <= 0 { + return true, nil + } + body := response.Inspect().Response.Body + _, err = io.Copy(writer, body) + body.Close() + if err != nil { + return false, err + } + return false, nil +} + // Update updates documents in the given index using UpdateQueue (which is also part of this package). // It does not wait for the update to finish before returning. // It returns the response body as or an error in case something went wrong. @@ -160,7 +281,7 @@ func (c *Client) BulkUpdate(indexName string, requestBody []byte) error { Index: indexName, Body: bytes.NewReader(requestBody), Params: opensearchapi.BulkParams{ - Refresh: "true", + Refresh: "false", }, }, ) diff --git a/pkg/openSearch/openSearchClient/client_test.go b/pkg/openSearch/openSearchClient/client_test.go index df7000d..e89ea64 100644 --- a/pkg/openSearch/openSearchClient/client_test.go +++ b/pkg/openSearch/openSearchClient/client_test.go @@ -6,6 +6,8 @@ package openSearchClient import ( "context" + "encoding/json" + "io" "testing" "time" @@ -43,10 +45,10 @@ func (v *Vulnerability) SetId(id string) { func TestClient(t *testing.T) { type testCase struct { - testFunc func(t *testing.T, client *Client) + testFunc func(t *testing.T, client *Client, iFunc *IndexFunction) } tcs := map[string]testCase{ - "TestBulkUpdate": {func(t *testing.T, client *Client) { + "TestBulkUpdate": {func(t *testing.T, client *Client, iFunc *IndexFunction) { // given bulkRequest, err := SerializeDocumentsForBulkUpdate(indexName, []*Vulnerability{&aVulnerability}) require.NoError(t, err) @@ -56,6 +58,10 @@ func TestClient(t *testing.T) { // then require.NoError(t, err) + + err = iFunc.RefreshIndex(indexName) + require.NoError(t, err) + require.EventuallyWithT(t, func(c *assert.CollectT) { searchResponse := searchAllVulnerabilities(c, client) assert.Equal(c, uint(1), searchResponse.Hits.Total.Value) @@ -63,7 +69,7 @@ func TestClient(t *testing.T) { assert.Equal(c, aVulnerability, *searchResponse.GetResults()[0]) }, 10*time.Second, 500*time.Millisecond) }}, - "TestUpdate": {func(t *testing.T, client *Client) { + "TestUpdate": {func(t *testing.T, client *Client, _ *IndexFunction) { // given createDataInIndex(t, client, []*Vulnerability{&aVulnerability}, 1) updateRequest := `{ @@ -93,7 +99,7 @@ func TestClient(t *testing.T) { assert.Equal(c, aVulnerability.Oid, searchResponse.GetResults()[0].Oid) }, 10*time.Second, 500*time.Millisecond) }}, - "TestAsyncDeleteByQuery": {func(t *testing.T, client *Client) { + "TestAsyncDeleteByQuery": {func(t *testing.T, client *Client, _ *IndexFunction) { // given createDataInIndex(t, client, []*Vulnerability{&aVulnerability}, 1) @@ -109,7 +115,7 @@ func TestClient(t *testing.T) { assert.Equal(c, 0, len(searchResponse.GetResults())) }, 10*time.Second, 500*time.Millisecond) }}, - "TestDeleteByQuery": {func(t *testing.T, client *Client) { + "TestDeleteByQuery": {func(t *testing.T, client *Client, _ *IndexFunction) { // given createDataInIndex(t, client, []*Vulnerability{&aVulnerability}, 1) @@ -125,7 +131,7 @@ func TestClient(t *testing.T) { assert.Equal(c, 0, len(searchResponse.GetResults())) }, 10*time.Second, 500*time.Millisecond) }}, - "TestSearch": {func(t *testing.T, client *Client) { + "TestSearch": {func(t *testing.T, client *Client, _ *IndexFunction) { // given createDataInIndex(t, client, []*Vulnerability{&aVulnerability}, 1) @@ -152,6 +158,48 @@ func TestClient(t *testing.T) { assert.Equal(t, uint(0), searchResponse.Hits.Total.Value) assert.Equal(t, 0, len(searchResponse.GetResults())) }}, + "TestSearchStream": {func(t *testing.T, client *Client, _ *IndexFunction) { + var searchResponse SearchResponse[*Vulnerability] + + // given + createDataInIndex(t, client, []*Vulnerability{&aVulnerability}, 1) + + // when + query := `{"query":{"bool":{"filter":[{"term":{"oid":{"value":"1.3.6.1.4.1.25623.1.0.117842"}}}]}}}` + responseReader, err := client.SearchStream(indexName, []byte(query), time.Millisecond, context.Background()) + + // then + require.NoError(t, err) + + decoder := json.NewDecoder(responseReader) + + // first read + err = decoder.Decode(&searchResponse) + require.NoError(t, err) + + // second read + err = decoder.Decode(&searchResponse) + require.Equal(t, io.EOF, err) + + assert.Equal(t, uint(1), searchResponse.Hits.Total.Value) + assert.Equal(t, 1, len(searchResponse.GetResults())) + assert.Equal(t, aVulnerability, *searchResponse.GetResults()[0]) + + // when + query = `{"query":{"bool":{"filter":[{"term":{"oid":{"value":"doesNotExist"}}}]}}}` + responseReader, err = client.SearchStream(indexName, []byte(query), time.Millisecond, context.Background()) + + // then + require.NoError(t, err) + + decoder = json.NewDecoder(responseReader) + err = decoder.Decode(&searchResponse) + require.NoError(t, err) + + assert.Empty(t, searchResponse.Hits.SearchHits) + assert.Equal(t, uint(0), searchResponse.Hits.Total.Value) + assert.Equal(t, 0, len(searchResponse.GetResults())) + }}, } ctx := context.Background() @@ -178,7 +226,7 @@ func TestClient(t *testing.T) { schema := folder.GetContent(t, "testdata/testSchema.json") err = iFunc.CreateIndex(indexName, []byte(schema)) require.NoError(t, err) - testCase.testFunc(t, client) + testCase.testFunc(t, client, iFunc) }) } } diff --git a/pkg/openSearch/openSearchClient/opensearchTestContainer.go b/pkg/openSearch/openSearchClient/opensearchTestContainer.go index dbc02b4..fbf169f 100644 --- a/pkg/openSearch/openSearchClient/opensearchTestContainer.go +++ b/pkg/openSearch/openSearchClient/opensearchTestContainer.go @@ -30,7 +30,7 @@ const openSearchTestDefaultHttpPort = "9200/tcp" // ctx is the context to use for the container. func StartOpensearchTestContainer(ctx context.Context) (testcontainers.Container, config.OpensearchClientConfig, error) { req := testcontainers.ContainerRequest{ - Image: "opensearchproject/opensearch:2.11.0", + Image: "opensearchproject/opensearch:2.18.0", ExposedPorts: []string{openSearchTestDefaultHttpPort, "9300/tcp"}, WaitingFor: createWaitStrategyFor(), Env: map[string]string{