diff --git a/src/dfcx_scrapi/core/search.py b/src/dfcx_scrapi/core/search.py index 4d955c2a..a1941a8f 100644 --- a/src/dfcx_scrapi/core/search.py +++ b/src/dfcx_scrapi/core/search.py @@ -14,9 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Any, List, Union -from google.cloud.discoveryengine_v1beta import SearchServiceClient -from google.cloud.discoveryengine_v1beta import types +import re +from typing import Dict, Any, List, Union, Optional +from google.cloud.discoveryengine import ( + SearchServiceClient, + DocumentServiceClient, + SearchRequest, + UserInfo, + Interval, + Document, + ListDocumentsRequest + ) from dfcx_scrapi.core import scrapi_base @@ -40,11 +48,11 @@ def __init__( @staticmethod def build_image_query( search_request: Dict[str, Any] - ) -> Union[types.SearchRequest.ImageQuery, None]: + ) -> Union[SearchRequest.ImageQuery, None]: image_query = search_request.get("image_query", None) if image_query: image_bytes = image_query.get("image_bytes", None) - return types.SearchRequest.ImageQuery(image_bytes=image_bytes) + return SearchRequest.ImageQuery(image_bytes=image_bytes) else: return None @@ -52,41 +60,61 @@ def build_image_query( @staticmethod def build_user_info( search_request: Dict[str, Any] - ) -> Union[types.UserInfo, None]: + ) -> Union[UserInfo, None]: user_info = search_request.get("user_info", None) if user_info: user_id = user_info.get("user_id", None) user_agent = user_info.get("user_agent", None) - return types.UserInfo(user_id=user_id, user_agent=user_agent) + return UserInfo(user_id=user_id, user_agent=user_agent) else: return None @staticmethod - def build_interval(interval_dict: Dict[str, Any]) -> types.Interval: + def build_interval(interval_dict: Dict[str, Any]) -> Interval: for k, v in interval_dict.items(): if k == "minimum": - return types.Interval(minimum=v) + return Interval(minimum=v) elif k == "exclusive_minimum": - return types.Interval(exclusive_minimum=v) + return Interval(exclusive_minimum=v) elif k == "maximum": - return types.Interval(maximum=v) + return Interval(maximum=v) elif k == "exclusive_maximum": - return types.Interval(exclusive_maximum=v) + return Interval(exclusive_maximum=v) else: return None + + @staticmethod + def search_url(urls: List[str], url: str, regex: bool = False) -> List[str]: + """Searches a url in a list of urls.""" + matched_urls: List[str] = [] + + if regex: + pattern = re.compile(url) + for item in urls: + if pattern.search(item): + matched_urls.append(item) + print(item) + + else: + for item in urls: + if url in item: + matched_urls.append(item) + print(item) + + return matched_urls def build_facet_key( self, facet_key_dict: Dict[str, Any] - ) -> types.SearchRequest.FacetSpec.FacetKey: + ) -> SearchRequest.FacetSpec.FacetKey: intervals_list = facet_key_dict.get("intervals", None) if intervals_list: all_intervals = [] for interval in intervals_list: all_intervals.append(self.build_interval(interval)) - return types.SearchRequest.FacetSpec.FacetKey( + return SearchRequest.FacetSpec.FacetKey( key=facet_key_dict.get("key", None), intervals=all_intervals, restricted_values=facet_key_dict.get("restricted_values", None), @@ -98,7 +126,7 @@ def build_facet_key( def build_single_facet_spec( self, spec: Dict[str, Any] - ) -> types.SearchRequest.FacetSpec: + ) -> SearchRequest.FacetSpec: facet_key_dict = spec.get("facet_key", None) if not facet_key_dict: raise ValueError( @@ -106,7 +134,7 @@ def build_single_facet_spec( ) facet_key = self.build_facet_key(facet_key_dict) - return types.SearchRequest.FacetSpec( + return SearchRequest.FacetSpec( facet_key=facet_key, limit=spec.get("limit", None), excluded_filter_keys=spec.get("excluded_filter_keys", None), @@ -115,7 +143,7 @@ def build_single_facet_spec( def build_facet_specs( self, search_request: Dict[str, Any] - ) -> Union[List[types.SearchRequest.FacetSpec], None]: + ) -> Union[List[SearchRequest.FacetSpec], None]: facet_specs = search_request.get("facet_specs", None) if facet_specs: all_specs = [] @@ -129,14 +157,14 @@ def build_facet_specs( def build_condition_boost_spec( self, spec: Dict[str, Any] - ) -> types.SearchRequest.BoostSpec.ConditionBoostSpec: - return types.SearchRequest.BoostSpec.ConditionBoostSpec( + ) -> SearchRequest.BoostSpec.ConditionBoostSpec: + return SearchRequest.BoostSpec.ConditionBoostSpec( condition=spec.get("condition", None), boost=spec.get("boost", None) ) def build_boost_spec( self, search_request: Dict[str, Any] - ) -> Union[types.SearchRequest.BoostSpec, None]: + ) -> Union[SearchRequest.BoostSpec, None]: boost_spec_dict = search_request.get("boost_spec", None) if boost_spec_dict: condition_boost_specs = boost_spec_dict.get( @@ -145,7 +173,7 @@ def build_boost_spec( all_boost_specs = [] for spec in condition_boost_specs: all_boost_specs.append(self.build_condition_boost_spec(spec)) - return types.SearchRequest.BoostSpec( + return SearchRequest.BoostSpec( condition_boost_specs=all_boost_specs ) @@ -154,10 +182,10 @@ def build_boost_spec( def get_condition_from_map( self, exp_spec_dict: Dict[str, Any] - ) -> types.SearchRequest.QueryExpansionSpec.Condition: + ) -> SearchRequest.QueryExpansionSpec.Condition: condition_map = { - "DISABLED": types.SearchRequest.QueryExpansionSpec.Condition.DISABLED, # pylint: disable=C0301 - "AUTO": types.SearchRequest.QueryExpansionSpec.Condition.AUTO, + "DISABLED": SearchRequest.QueryExpansionSpec.Condition.DISABLED, # pylint: disable=C0301 + "AUTO": SearchRequest.QueryExpansionSpec.Condition.AUTO, } condition_value = exp_spec_dict.get("condition", "DISABLED") @@ -166,7 +194,7 @@ def get_condition_from_map( def build_query_expansion_spec( self, search_request: Dict[str, Any] - ) -> Union[types.SearchRequest.QueryExpansionSpec, None]: + ) -> Union[SearchRequest.QueryExpansionSpec, None]: exp_spec_dict = search_request.get("query_expansion_spec", None) if exp_spec_dict: condition = self.get_condition_from_map(exp_spec_dict) @@ -174,7 +202,7 @@ def build_query_expansion_spec( "pin_unexpanded_results", False ) - return types.SearchRequest.QueryExpansionSpec( + return SearchRequest.QueryExpansionSpec( condition=condition, pin_unexpanded_results=pin_unexpanded_results, ) @@ -184,10 +212,10 @@ def build_query_expansion_spec( def get_spell_correct_mode_from_map( self, spell_spec_dict: Dict[str, Any] - ) -> types.SearchRequest.SpellCorrectionSpec.Mode: + ) -> SearchRequest.SpellCorrectionSpec.Mode: mode_map = { - "SUGGESTION_ONLY": types.SearchRequest.SpellCorrectionSpec.Mode.SUGGESTION_ONLY, # pylint: disable=C0301 - "AUTO": types.SearchRequest.SpellCorrectionSpec.Mode.AUTO, + "SUGGESTION_ONLY": SearchRequest.SpellCorrectionSpec.Mode.SUGGESTION_ONLY, # pylint: disable=C0301 + "AUTO": SearchRequest.SpellCorrectionSpec.Mode.AUTO, } mode_value = spell_spec_dict.get("mode", "AUTO") @@ -196,23 +224,23 @@ def get_spell_correct_mode_from_map( def build_spell_correction_spec( self, search_request: Dict[str, Any] - ) -> Union[types.SearchRequest.SpellCorrectionSpec, None]: + ) -> Union[SearchRequest.SpellCorrectionSpec, None]: spell_spec_dict = search_request.get("spell_correction_spec", None) if spell_spec_dict: mode = self.get_spell_correct_mode_from_map(spell_spec_dict) - return types.SearchRequest.SpellCorrectionSpec(mode=mode) + return SearchRequest.SpellCorrectionSpec(mode=mode) else: return None def build_model_prompt_spec( self, content_spec_dict: Dict[str, Any] - ) -> types.SearchRequest.ContentSearchSpec.SummarySpec.ModelPromptSpec: + ) -> SearchRequest.ContentSearchSpec.SummarySpec.ModelPromptSpec: model_prompt_spec_dict = content_spec_dict.get( "model_prompt_spec", None ) if model_prompt_spec_dict: - return types.SearchRequest.ContentSearchSpec.SummarySpec.ModelPromptSpec( # pylint: disable=C0301 + return SearchRequest.ContentSearchSpec.SummarySpec.ModelPromptSpec( # pylint: disable=C0301 preamble=model_prompt_spec_dict.get("preamble", None) ) @@ -221,10 +249,10 @@ def build_model_prompt_spec( def build_model_spec( self, content_spec_dict: Dict[str, Any] - ) -> types.SearchRequest.ContentSearchSpec.SummarySpec.ModelSpec: + ) -> SearchRequest.ContentSearchSpec.SummarySpec.ModelSpec: model_spec_dict = content_spec_dict.get("model_spec", None) if model_spec_dict: - return types.SearchRequest.ContentSearchSpec.SummarySpec.ModelSpec( + return SearchRequest.ContentSearchSpec.SummarySpec.ModelSpec( version=model_spec_dict.get("version", "stable") ) @@ -232,19 +260,19 @@ def build_model_spec( return None def build_snippet_spec( - self) -> types.SearchRequest.ContentSearchSpec.SnippetSpec: - return types.SearchRequest.ContentSearchSpec.SnippetSpec( + self) -> SearchRequest.ContentSearchSpec.SnippetSpec: + return SearchRequest.ContentSearchSpec.SnippetSpec( return_snippet=True ) def build_summary_spec( self, content_spec_dict: Dict[str, Any] - ) -> types.SearchRequest.ContentSearchSpec.SummarySpec: + ) -> SearchRequest.ContentSearchSpec.SummarySpec: model_prompt_spec = self.build_model_prompt_spec(content_spec_dict) model_spec = self.build_model_spec(content_spec_dict) - return types.SearchRequest.ContentSearchSpec.SummarySpec( + return SearchRequest.ContentSearchSpec.SummarySpec( summary_result_count=content_spec_dict.get( "summary_result_count", 10 ), @@ -263,12 +291,12 @@ def build_summary_spec( def build_extractive_content_spec( self, content_spec_dict: Dict[str, Any] ) -> Union[ - types.SearchRequest.ContentSearchSpec.ExtractiveContentSpec, None + SearchRequest.ContentSearchSpec.ExtractiveContentSpec, None ]: ext_spec_dict = content_spec_dict.get("extractive_content_spec", None) if ext_spec_dict: - return types.SearchRequest.ContentSearchSpec.ExtractiveContentSpec( + return SearchRequest.ContentSearchSpec.ExtractiveContentSpec( max_extractive_answer_count=ext_spec_dict.get( "max_extractive_answer_count", 5 ), @@ -288,7 +316,7 @@ def build_extractive_content_spec( def build_content_search_spec( self, search_request: Dict[str, Any] - ) -> Union[types.SearchRequest.ContentSearchSpec, None]: + ) -> Union[SearchRequest.ContentSearchSpec, None]: content_spec_dict = search_request.get("content_search_spec", None) if content_spec_dict: snippet_spec = self.build_snippet_spec() @@ -297,7 +325,7 @@ def build_content_search_spec( content_spec_dict ) - return types.SearchRequest.ContentSearchSpec( + return SearchRequest.ContentSearchSpec( snippet_spec=snippet_spec, summary_spect=summary_spec, extractive_content_spec=extractive_content_spec, @@ -308,27 +336,105 @@ def build_content_search_spec( def build_embedding_vector( self, vector_dict: Dict[str, Any] - ) -> types.SearchRequest.EmbeddingSpec.EmbeddingVector: - return types.SearchRequest.EmbeddingSpec.EmbeddingVector( + ) -> SearchRequest.EmbeddingSpec.EmbeddingVector: + return SearchRequest.EmbeddingSpec.EmbeddingVector( field_path=vector_dict.get("field_path", None), vector=vector_dict.get("vector", None), ) def build_embedding_spec( self, search_request: Dict[str, Any] - ) -> Union[types.SearchRequest.EmbeddingSpec, None]: + ) -> Union[SearchRequest.EmbeddingSpec, None]: embedding_vectors_dict = search_request.get("embedding_vectors", None) if embedding_vectors_dict: vector_list = embedding_vectors_dict.get("embedding_vectors", None) all_vectors = [] for vector_dict in vector_list: all_vectors.append(self.build_embedding_vector(vector_dict)) - return types.SearchRequest.EmbeddingSpec( + return SearchRequest.EmbeddingSpec( embedding_vectors=all_vectors ) else: return None + + def list_documents( + self, datastore_id: str, page_size: int = 1000) -> List[Document]: + """List all documents in the provided datastore.""" + client_options = self._client_options_discovery_engine(datastore_id) + client = DocumentServiceClient( + credentials=self.creds, + client_options=client_options + ) + + request = ListDocumentsRequest( + parent=f"{datastore_id}/branches/default_branch", + page_size=page_size + ) + + response = client.list_documents(request) + + all_docs: List[Document] = [] + for page in response.pages: + for doc in page.documents: + all_docs.append(doc) + + return all_docs + + def list_indexed_urls( + self, datastore_id: str, docs: Optional[List[Document]] = None + ) -> List[str]: + """List all indexed URLs from the provided datastore.""" + if not docs: + docs = self.list_documents(datastore_id) + + urls: List[str] = [doc.content.uri for doc in docs] + + return urls + + def search_doc_id( + self, + document_id: str, + datastore_id: str = None, + docs: Optional[List[Document]] = None + ) -> List[str]: + if not docs and not datastore_id: + raise ValueError("Must provide either `docs` or `datastore_id`") + + elif not docs and datastore_id: + docs = self.list_documents(datastore_id) + + doc_found = False + for doc in docs: + if doc.parent_document_id == document_id: + doc_found = True + print(doc) + break + + if not doc_found: + print(f"Document not found for Doc ID: `{document_id}`") + + def check_datastore_index_status(self, datastore_id: str): + """Checks the current indexing status of your datastore.""" + + PENDING_MESSAGE = "No docs found.\n" \ + "It\'s likely one of two issues:\n" \ + "\t[1] Your data store is not finished indexing.\n" \ + "\t[2] Your data store failed indexing.\n\n" \ + "If you just added your data store, it can take up to 4 hours" \ + " before it will become available." + + SUCCESS_MESSAGE = "Success! 🎉\n" \ + "Your indexing is complete.\n" \ + "Your index contains {DOCS} documents." + + docs = self.list_documents(datastore_id) + + if len(docs) == 0: + print(PENDING_MESSAGE) + else: + print(SUCCESS_MESSAGE.replace("{DOCS}", str(len(docs)))) + # pylint: disable=C0301 def search(self, search_config: Dict[str, Any], total_results: int = 10): @@ -337,7 +443,7 @@ def search(self, search_config: Dict[str, Any], total_results: int = 10): Args: search_config: A dictionary containing keys that correspond to the SearchRequest attributes as defined in: - https://cloud.google.com/python/docs/reference/discoveryengine/latest/google.cloud.discoveryengine_v1beta.types.SearchRequest + https://cloud.google.com/python/docs/reference/discoveryengine/latest/google.cloud.discoveryengine.SearchRequest For complex attributes that require nested fields, you can pass in another Dictionary as the value. @@ -381,7 +487,7 @@ def search(self, search_config: Dict[str, Any], total_results: int = 10): branch_stub = "/".join(serving_config.split("/")[0:8]) branch = branch_stub + "/branches/0" - request = types.SearchRequest( + request = SearchRequest( serving_config=serving_config, branch=branch, query=search_config.get("query", None),