From 6854067179940f094df10be40f036b3d1efc35af Mon Sep 17 00:00:00 2001 From: vkrd <49703203+vkrd@users.noreply.github.com> Date: Mon, 29 Jul 2024 18:29:08 -0700 Subject: [PATCH] Add initial support for ingesting visual content (#1026) Co-authored-by: Vikram Duvvur --- requirements-dev.txt | 3 +- scripts/data_preparation.py | 23 +++- scripts/data_utils.py | 248 ++++++++++++++++++++++++++++++------ 3 files changed, 230 insertions(+), 44 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 5c6d8330aa..d2ed32f7dc 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,5 @@ -r requirements.txt -azure-ai-formrecognizer==3.2.1 +azure-ai-documentintelligence==1.0.0b2 Markdown==3.4.4 requests==2.31.0 tqdm==4.66.1 @@ -9,6 +9,7 @@ bs4==0.0.1 urllib3==2.1.0 pytest==7.4.0 pytest-asyncio==0.23.2 +PyMuPDF==1.24.5 azure-storage-blob chardet azure-keyvault-secrets diff --git a/scripts/data_preparation.py b/scripts/data_preparation.py index a4d2e224a8..a5974f69f3 100644 --- a/scripts/data_preparation.py +++ b/scripts/data_preparation.py @@ -7,7 +7,7 @@ import time import requests -from azure.ai.formrecognizer import DocumentAnalysisClient +from azure.ai.documentintelligence import DocumentIntelligenceClient from azure.core.credentials import AzureKeyCredential from azure.identity import AzureCliCredential from azure.search.documents import SearchClient @@ -209,6 +209,14 @@ def create_or_update_search_index( "type": "Edm.String", "searchable": True, }, + { + "name": "image_mapping", + "type": "Edm.String", + "searchable": False, + "sortable": False, + "facetable": False, + "filterable": False + } ], "suggesters": [], "scoringProfiles": [], @@ -356,7 +364,7 @@ def validate_index(service_name, subscription_id, resource_group, index_name): print(f"Request failed. Please investigate. Status code: {response.status_code}") break -def create_index(config, credential, form_recognizer_client=None, embedding_model_endpoint=None, use_layout=False, njobs=4): +def create_index(config, credential, form_recognizer_client=None, embedding_model_endpoint=None, use_layout=False, njobs=4, captioning_model_endpoint=None, captioning_model_key=None): service_name = config["search_service_name"] subscription_id = config["subscription_id"] resource_group = config["resource_group"] @@ -410,7 +418,8 @@ def create_index(config, credential, form_recognizer_client=None, embedding_mode elif os.path.exists(data_config["path"]): result = chunk_directory(data_config["path"], num_tokens=config["chunk_size"], token_overlap=config.get("token_overlap",0), azure_credential=credential, form_recognizer_client=form_recognizer_client, use_layout=use_layout, njobs=njobs, - add_embeddings=add_embeddings, embedding_endpoint=embedding_model_endpoint, url_prefix=data_config["url_prefix"]) + add_embeddings=add_embeddings, embedding_endpoint=embedding_model_endpoint, url_prefix=data_config["url_prefix"], + captioning_model_endpoint=captioning_model_endpoint, captioning_model_key=captioning_model_key) else: raise Exception(f"Path {data_config['path']} does not exist and is not a blob URL. Please check the path and try again.") @@ -443,11 +452,13 @@ def valid_range(n): parser.add_argument("--config", type=str, help="Path to config file containing settings for data preparation") parser.add_argument("--form-rec-resource", type=str, help="Name of your Form Recognizer resource to use for PDF cracking.") parser.add_argument("--form-rec-key", type=str, help="Key for your Form Recognizer resource to use for PDF cracking.") - parser.add_argument("--form-rec-use-layout", default=False, action='store_true', help="Whether to use Layout model for PDF cracking, if False will use Read model.") + parser.add_argument("--form-rec-use-layout", default=True, action='store_true', help="Whether to use Layout model for PDF cracking, if False will use Read model.") parser.add_argument("--njobs", type=valid_range, default=4, help="Number of jobs to run (between 1 and 32). Default=4") parser.add_argument("--embedding-model-endpoint", type=str, help="Endpoint for the embedding model to use for vector search. Format: 'https://.openai.azure.com/openai/deployments//embeddings?api-version=2024-03-01-Preview'") parser.add_argument("--embedding-model-key", type=str, help="Key for the embedding model to use for vector search.") parser.add_argument("--search-admin-key", type=str, help="Admin key for the search service. If not provided, will use Azure CLI to get the key.") + parser.add_argument("--azure-openai-endpoint", type=str, help="Endpoint for the (Azure) OpenAI API. Format: 'https://.openai.azure.com/openai/deployments//chat/completions?api-version=2024-04-01-preview'") + parser.add_argument("--azure-openai-key", type=str, help="Key for the (Azure) OpenAI API.") args = parser.parse_args() with open(args.config) as f: @@ -464,7 +475,7 @@ def valid_range(n): os.environ["FORM_RECOGNIZER_ENDPOINT"] = f"https://{args.form_rec_resource}.cognitiveservices.azure.com/" os.environ["FORM_RECOGNIZER_KEY"] = args.form_rec_key if args.njobs==1: - form_recognizer_client = DocumentAnalysisClient(endpoint=f"https://{args.form_rec_resource}.cognitiveservices.azure.com/", credential=AzureKeyCredential(args.form_rec_key)) + form_recognizer_client = DocumentIntelligenceClient(endpoint=f"https://{args.form_rec_resource}.cognitiveservices.azure.com/", credential=AzureKeyCredential(args.form_rec_key)) print(f"Using Form Recognizer resource {args.form_rec_resource} for PDF cracking, with the {'Layout' if args.form_rec_use_layout else 'Read'} model.") for index_config in config: @@ -472,7 +483,7 @@ def valid_range(n): if index_config.get("vector_config_name") and not args.embedding_model_endpoint: raise Exception("ERROR: Vector search is enabled in the config, but no embedding model endpoint and key were provided. Please provide these values or disable vector search.") - create_index(index_config, credential, form_recognizer_client, embedding_model_endpoint=args.embedding_model_endpoint, use_layout=args.form_rec_use_layout, njobs=args.njobs) + create_index(index_config, credential, form_recognizer_client, embedding_model_endpoint=args.embedding_model_endpoint, use_layout=args.form_rec_use_layout, njobs=args.njobs, captioning_model_endpoint=args.azure_openai_endpoint, captioning_model_key=args.azure_openai_key) print("Data preparation for index", index_config["index_name"], "completed") print(f"Data preparation script completed. {len(config)} indexes updated.") \ No newline at end of file diff --git a/scripts/data_utils.py b/scripts/data_utils.py index 550f6019d8..dde9b6ece8 100644 --- a/scripts/data_utils.py +++ b/scripts/data_utils.py @@ -14,11 +14,15 @@ from dataclasses import dataclass from functools import partial from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union +from azure.ai.documentintelligence.models import AnalyzeDocumentRequest +import fitz +import requests +import base64 import markdown import requests import tiktoken -from azure.ai.formrecognizer import DocumentAnalysisClient +from azure.ai.documentintelligence import DocumentIntelligenceClient from azure.core.credentials import AzureKeyCredential from azure.identity import DefaultAzureCredential from azure.storage.blob import ContainerClient @@ -40,7 +44,12 @@ "py": "python", "pdf": "pdf", "docx": "docx", - "pptx": "pptx" + "pptx": "pptx", + "png": "png", + "jpg": "jpg", + "jpeg": "jpeg", + "gif": "gif", + "webp": "webp" } RETRY_COUNT = 5 @@ -108,23 +117,35 @@ def extract_caption(self, text): return caption - def mask_urls(self, text) -> Tuple[Dict[str, str], str]: + def mask_urls_and_imgs(self, text) -> Tuple[Dict[str, str], str]: def find_urls(string): regex = r"(?i)\b((?:https?://|www\d{0,3}[.]|[a-z0-9.\-]+[.][a-z]{2,4}/)(?:[^()\s<>]+|\(([^()\s<>]+|(\([^()\s<>]+\)))*\))+(?:\(([^()\s<>]+|(\([^()\s<>]+\)))*\)|[^()\s`!()\[\]{};:'\".,<>?«»“”‘’]))" urls = re.findall(regex, string) return [x[0] for x in urls] - url_dict = {} + + def find_imgs(string): + regex = r'(]*>.*?)' + imgs = re.findall(regex, string, re.DOTALL) + return imgs + + content_dict = {} masked_text = text urls = set(find_urls(text)) for i, url in enumerate(urls): masked_text = masked_text.replace(url, f"##URL{i}##") - url_dict[f"##URL{i}##"] = url - return url_dict, masked_text + content_dict[f"##URL{i}##"] = url + + imgs = set(find_imgs(text)) + for i, img in enumerate(imgs): + masked_text = masked_text.replace(img, f"##IMG{i}##") + content_dict[f"##IMG{i}##"] = img + + return content_dict, masked_text def split_text(self, text: str) -> List[str]: - url_dict, masked_text = self.mask_urls(text) + content_dict, masked_text = self.mask_urls_and_imgs(text) start_tag = self._table_tags["table_open"] end_tag = self._table_tags["table_close"] splits = masked_text.split(start_tag) @@ -148,7 +169,7 @@ def split_text(self, text: str) -> List[str]: table_caption_prefix = "" - final_final_chunks = [chunk for chunk, chunk_size in merge_chunks_serially(final_chunks, self._chunk_size, url_dict)] + final_final_chunks = [chunk for chunk, chunk_size in merge_chunks_serially(final_chunks, self._chunk_size, content_dict)] return final_final_chunks @@ -244,6 +265,7 @@ class Document(object): url: Optional[str] = None metadata: Optional[Dict] = None contentVector: Optional[List[float]] = None + image_mapping: Optional[Dict] = None def cleanup_content(content: str) -> str: """Cleans up the given content using regexes @@ -429,13 +451,22 @@ def parse(self, content: str, file_name: Optional[str] = None) -> Document: def __init__(self) -> None: super().__init__() +class ImageParser(BaseParser): + def parse(self, content: str, file_name: Optional[str] = None) -> Document: + return Document(content=content, title=file_name) + class ParserFactory: def __init__(self): self._parsers = { "html": HTMLParser(), "text": TextParser(), "markdown": MarkdownParser(), - "python": PythonParser() + "python": PythonParser(), + "png": ImageParser(), + "jpg": ImageParser(), + "jpeg": ImageParser(), + "gif": ImageParser(), + "webp": ImageParser() } @property @@ -545,19 +576,27 @@ def table_to_html(table): for cell in row_cells: tag = "th" if (cell.kind == "columnHeader" or cell.kind == "rowHeader") else "td" cell_spans = "" - if cell.column_span > 1: cell_spans += f" colSpan={cell.column_span}" - if cell.row_span > 1: cell_spans += f" rowSpan={cell.row_span}" + if cell.column_span and cell.column_span > 1: cell_spans += f" colSpan={cell.column_span}" + if cell.row_span and cell.row_span > 1: cell_spans += f" rowSpan={cell.row_span}" table_html += f"<{tag}{cell_spans}>{html.escape(cell.content)}" table_html +="" table_html += "" return table_html +def polygon_to_bbox(polygon, dpi=72): + x_coords = polygon[0::2] + y_coords = polygon[1::2] + x0, y0 = min(x_coords)*dpi, min(y_coords)*dpi + x1, y1 = max(x_coords)*dpi, max(y_coords)*dpi + return x0, y0, x1, y1 + def extract_pdf_content(file_path, form_recognizer_client, use_layout=False): offset = 0 page_map = [] model = "prebuilt-layout" if use_layout else "prebuilt-read" - with open(file_path, "rb") as f: - poller = form_recognizer_client.begin_analyze_document(model, document = f) + + base64file = base64.b64encode(open(file_path, "rb").read()).decode() + poller = form_recognizer_client.begin_analyze_document(model, AnalyzeDocumentRequest(bytes_source=base64file)) form_recognizer_results = poller.result() # (if using layout) mark all the positions of headers @@ -571,11 +610,20 @@ def extract_pdf_content(file_path, form_recognizer_client, use_layout=False): roles_end[para_end] = paragraph.role for page_num, page in enumerate(form_recognizer_results.pages): - tables_on_page = [table for table in form_recognizer_results.tables if table.bounding_regions[0].page_number == page_num + 1] - - # (if using layout) mark all positions of the table spans in the page page_offset = page.spans[0].offset page_length = page.spans[0].length + + if use_layout: + tables_on_page = [] + for table in form_recognizer_results.tables: + table_offset = table.spans[0].offset + table_length = table.spans[0].length + if page_offset <= table_offset and table_offset + table_length < page_offset + page_length: + tables_on_page.append(table) + else: + tables_on_page = [] + + # (if using layout) mark all positions of the table spans in the page table_chars = [-1]*page_length for table_id, table in enumerate(tables_on_page): for span in table.spans: @@ -611,19 +659,58 @@ def extract_pdf_content(file_path, form_recognizer_client, use_layout=False): offset += len(page_text) full_text = "".join([page_text for _, _, page_text in page_map]) - return full_text -def merge_chunks_serially(chunked_content_list: List[str], num_tokens: int, url_dict: Dict[str, str]={}) -> Generator[Tuple[str, int], None, None]: - def unmask_urls(text, url_dict={}): - if "##URL" in text: - for key, value in url_dict.items(): + # Extract any images + image_mapping = {} + + if "figures" in form_recognizer_results.keys() and file_path.endswith(".pdf"): + document = fitz.open(file_path) + + for figure in form_recognizer_results["figures"]: + bounding_box = figure.bounding_regions[0] + + page_number = bounding_box['pageNumber'] - 1 # Page numbers in PyMuPDF start from 0 + x0, y0, x1, y1 = polygon_to_bbox(bounding_box['polygon']) + + # Select the figure and upscale it by 200% for higher resolution + page = document.load_page(page_number) + bbox = fitz.Rect(x0, y0, x1, y1) + + zoom = 2.0 + mat = fitz.Matrix(zoom, zoom) + image = page.get_pixmap(matrix=mat, clip=bbox) + + # Save the extracted image to a base64 string + image_data = image.tobytes(output='jpg') + image_base64 = base64.b64encode(image_data).decode("utf-8") + image_base64 = f"data:image/jpg;base64,{image_base64}" + + # Add the image tag to the full text + replace_start = figure["spans"][0]["offset"] + replace_end = figure["spans"][0]["offset"] + figure["spans"][0]["length"] + original_text = form_recognizer_results.content[replace_start:replace_end] + + if original_text not in full_text: + continue + + img_tag = image_content_to_tag(original_text) + + full_text = full_text.replace(original_text, img_tag) + image_mapping[img_tag] = image_base64 + + return full_text, image_mapping + +def merge_chunks_serially(chunked_content_list: List[str], num_tokens: int, content_dict: Dict[str, str]={}) -> Generator[Tuple[str, int], None, None]: + def unmask_urls_and_imgs(text, content_dict={}): + if "##URL" in text or "##IMG" in text: + for key, value in content_dict.items(): text = text.replace(key, value) return text # TODO: solve for token overlap current_chunk = "" total_size = 0 for chunked_content in chunked_content_list: - chunked_content = unmask_urls(chunked_content, url_dict) + chunked_content = unmask_urls_and_imgs(chunked_content, content_dict) chunk_size = TOKEN_ESTIMATOR.estimate_tokens(chunked_content) if total_size > 0: new_size = total_size + chunk_size @@ -675,7 +762,7 @@ def get_embedding(text, embedding_model_endpoint=None, embedding_model_key=None, input=text, dimensions=int(os.getenv("VECTOR_DIMENSION", 1536))) - return embeddings.dict()['data'][0]['embedding'] + return embeddings.model_dump()['data'][0]['embedding'] if FLAG_EMBEDDING_MODEL == "COHERE": if FLAG_COHERE == "MULTILINGUAL": @@ -709,7 +796,7 @@ def chunk_content_helper( doc = parser.parse(content, file_name=file_name) # if the original doc after parsing is < num_tokens return as it is doc_content_size = TOKEN_ESTIMATOR.estimate_tokens(doc.content) - if doc_content_size < num_tokens: + if doc_content_size < num_tokens or file_format in ["png", "jpg", "jpeg", "gif", "webp"]: yield doc.content, doc_content_size, doc else: if file_format == "markdown": @@ -750,7 +837,8 @@ def chunk_content( use_layout = False, add_embeddings = False, azure_credential = None, - embedding_endpoint = None + embedding_endpoint = None, + image_mapping = {} ) -> ChunkingResult: """Chunks the given content. If ignore_errors is true, returns None in case of an error @@ -799,13 +887,18 @@ def chunk_content( if doc.contentVector is None: raise Exception(f"Error getting embedding for chunk={chunk}") - + doc.image_mapping = {} + for key, value in image_mapping.items(): + if key in chunk: + doc.image_mapping[key] = value chunks.append( Document( content=chunk, title=doc.title, url=url, - contentVector=doc.contentVector + contentVector=doc.contentVector, + metadata=doc.metadata, + image_mapping=doc.image_mapping ) ) else: @@ -829,6 +922,69 @@ def chunk_content( skipped_chunks=skipped_chunks, ) +def image_content_to_tag(image_content: str) -> str: + # We encode the images in an XML-like format to make the replacement very unlikely to conflict with other text + # This also lets us preserve the content with minimal escaping, just escaping the tags + random_id = str(time.time()).replace(".", "")[-4:] + img_tag = f'{image_content.replace("", "<img>").replace("", "</img>")}' + return img_tag + +def get_caption(image_path, captioning_model_endpoint, captioning_model_key): + encoded_image = base64.b64encode(open(image_path, 'rb').read()).decode('ascii') + file_ext = image_path.split(".")[-1] + headers = { + "Content-Type": "application/json", + "api-key": captioning_model_key, + } + + payload = { + "messages": [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "You are a captioning model that helps uses find descriptive captions." + } + ] + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Describe this image as if you were describing it to someone who can't see it. " + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/{file_ext};base64,{encoded_image}" + } + } + ] + } + ], + "temperature": 0 + } + + for i in range(RETRY_COUNT): + try: + response = requests.post(captioning_model_endpoint, headers=headers, json=payload) + response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code + break + except Exception as e: + print(f"Error getting caption with error={e}, retrying, current at {i + 1} retry, {RETRY_COUNT - (i + 1)} retries left") + time.sleep(15) + + if response.status_code != 200: + raise Exception(f"Error getting caption with status_code={response.status_code}") + + caption = response.json()["choices"][0]["message"]["content"] + img_tag = image_content_to_tag(caption) + mapping = {img_tag: f"data:image/{file_ext};base64,{encoded_image}"} + + return img_tag, mapping + def chunk_file( file_path: str, ignore_errors: bool = True, @@ -841,7 +997,9 @@ def chunk_file( use_layout = False, add_embeddings=False, azure_credential = None, - embedding_endpoint = None + embedding_endpoint = None, + captioning_model_endpoint = None, + captioning_model_key = None ) -> ChunkingResult: """Chunks the given file. Args: @@ -851,6 +1009,7 @@ def chunk_file( """ file_name = os.path.basename(file_path) file_format = _get_file_format(file_name, extensions_to_process) + image_mapping = {} if not file_format: if ignore_errors: return ChunkingResult( @@ -863,8 +1022,13 @@ def chunk_file( if file_format in ["pdf", "docx", "pptx"]: if form_recognizer_client is None: raise UnsupportedFormatError("form_recognizer_client is required for pdf files") - content = extract_pdf_content(file_path, form_recognizer_client, use_layout=use_layout) + content, image_mapping = extract_pdf_content(file_path, form_recognizer_client, use_layout=use_layout) cracked_pdf = True + elif file_format in ["png", "jpg", "jpeg", "webp"]: + # Make call to LLM for a descriptive caption + if captioning_model_endpoint is None or captioning_model_key is None: + raise Exception("CAPTIONING_MODEL_ENDPOINT and CAPTIONING_MODEL_KEY are required for images") + content, image_mapping = get_caption(file_path, captioning_model_endpoint, captioning_model_key) else: try: with open(file_path, "r", encoding="utf8") as f: @@ -889,7 +1053,8 @@ def chunk_file( use_layout=use_layout, add_embeddings=add_embeddings, azure_credential=azure_credential, - embedding_endpoint=embedding_endpoint + embedding_endpoint=embedding_endpoint, + image_mapping=image_mapping ) @@ -906,7 +1071,9 @@ def process_file( use_layout = False, add_embeddings = False, azure_credential = None, - embedding_endpoint = None + embedding_endpoint = None, + captioning_model_endpoint = None, + captioning_model_key = None ): if not form_recognizer_client: @@ -932,11 +1099,14 @@ def process_file( use_layout=use_layout, add_embeddings=add_embeddings, azure_credential=azure_credential, - embedding_endpoint=embedding_endpoint + embedding_endpoint=embedding_endpoint, + captioning_model_endpoint=captioning_model_endpoint, + captioning_model_key=captioning_model_key ) for chunk_idx, chunk_doc in enumerate(result.chunks): chunk_doc.filepath = rel_file_path chunk_doc.metadata = json.dumps({"chunk_id": str(chunk_idx)}) + chunk_doc.image_mapping = json.dumps(chunk_doc.image_mapping) if chunk_doc.image_mapping else None except Exception as e: print(e) if not ignore_errors: @@ -999,7 +1169,9 @@ def chunk_directory( njobs=4, add_embeddings = False, azure_credential = None, - embedding_endpoint = None + embedding_endpoint = None, + captioning_model_endpoint = None, + captioning_model_key = None ): """ Chunks the given directory recursively @@ -1041,7 +1213,8 @@ def chunk_directory( token_overlap=token_overlap, extensions_to_process=extensions_to_process, form_recognizer_client=form_recognizer_client, use_layout=use_layout, add_embeddings=add_embeddings, - azure_credential=azure_credential, embedding_endpoint=embedding_endpoint) + azure_credential=azure_credential, embedding_endpoint=embedding_endpoint, + captioning_model_endpoint=captioning_model_endpoint, captioning_model_key=captioning_model_key) if is_error: num_files_with_errors += 1 continue @@ -1057,7 +1230,8 @@ def chunk_directory( token_overlap=token_overlap, extensions_to_process=extensions_to_process, form_recognizer_client=None, use_layout=use_layout, add_embeddings=add_embeddings, - azure_credential=azure_credential, embedding_endpoint=embedding_endpoint) + azure_credential=azure_credential, embedding_endpoint=embedding_endpoint, + captioning_model_endpoint=captioning_model_endpoint, captioning_model_key=captioning_model_key) with ProcessPoolExecutor(max_workers=njobs) as executor: futures = list(tqdm(executor.map(process_file_partial, files_to_process), total=len(files_to_process))) for result, is_error in futures: @@ -1087,7 +1261,7 @@ def __new__(cls, *args, **kwargs): url = os.getenv("FORM_RECOGNIZER_ENDPOINT") key = os.getenv("FORM_RECOGNIZER_KEY") if url and key: - cls.instance = DocumentAnalysisClient( + cls.instance = DocumentIntelligenceClient( endpoint=url, credential=AzureKeyCredential(key), headers={"x-ms-useragent": "sample-app-aoai-chatgpt/1.0.0"}) else: print("SingletonFormRecognizerClient: Skipping since credentials not provided. Assuming NO form recognizer extensions(like .pdf) in directory") @@ -1099,4 +1273,4 @@ def __getstate__(self): def __setstate__(self, state): url, key = state - self.instance = DocumentAnalysisClient(endpoint=url, credential=AzureKeyCredential(key), headers={"x-ms-useragent": "sample-app-aoai-chatgpt/1.0.0"}) + self.instance = DocumentIntelligenceClient(endpoint=url, credential=AzureKeyCredential(key), headers={"x-ms-useragent": "sample-app-aoai-chatgpt/1.0.0"})