Skip to content

Commit

Permalink
Merge pull request #47 from CambioML/refined-extract
Browse files Browse the repository at this point in the history
feat: add PRO and ULTRA model options to sync and async extract
  • Loading branch information
CambioML authored Oct 4, 2024
2 parents a6263ae + 0d3157a commit 58c96ac
Show file tree
Hide file tree
Showing 4 changed files with 638 additions and 22 deletions.
3 changes: 2 additions & 1 deletion any_parser/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""AnyParser module for parsing data."""

from any_parser.any_parser import ModelType # Import ModelType here
from any_parser.any_parser import AnyParser

__all__ = ["AnyParser"]
__all__ = ["AnyParser", "ModelType"] # Add ModelType to __all__

__version__ = "0.0.15"
58 changes: 54 additions & 4 deletions any_parser/any_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import base64
import json
import time
from enum import Enum
from pathlib import Path
from typing import Dict, Optional, Tuple

Expand All @@ -23,6 +24,19 @@
]


class ModelType(Enum):
BASE = "base"
PRO = "pro"
ULTRA = "ultra"


class ProcessType(Enum):
FILE = "file"
TABLE = "table"
FILE_REFINED = "file_refined"
FILE_REFINED_QUICK = "file_refined_quick"


class AnyParser:
"""AnyParser RT: Real-time parser for any data format."""

Expand All @@ -37,6 +51,7 @@ def __init__(self, api_key: str, base_url: str = PUBLIC_SHARED_BASE_URL) -> None
None
"""
self._sync_url = f"{base_url}/extract"
self._sync_refined_url = f"{base_url}/refined_extract"
self._async_upload_url = f"{base_url}/async/upload"
self._async_fetch_url = f"{base_url}/async/fetch"
self._api_key = api_key
Expand All @@ -46,7 +61,10 @@ def __init__(self, api_key: str, base_url: str = PUBLIC_SHARED_BASE_URL) -> None
}

def extract(
self, file_path: str, extract_args: Optional[Dict] = None
self,
file_path: str,
model: ModelType = ModelType.BASE,
extract_args: Optional[Dict] = None,
) -> Tuple[str, str]:
"""Extract data in real-time.
Expand All @@ -70,6 +88,8 @@ def extract(
None,
)

self._check_model(model)

# Encode the file content in base64
with open(file_path, "rb") as file:
encoded_file = base64.b64encode(file.read()).decode("utf-8")
Expand All @@ -83,10 +103,17 @@ def extract(
if extract_args is not None and isinstance(extract_args, dict):
payload["extract_args"] = extract_args

if model == ModelType.BASE:
url = self._sync_url
elif model == ModelType.PRO or model == ModelType.ULTRA:
url = self._sync_refined_url
if model == ModelType.PRO:
payload["quick_mode"] = True

# Send the POST request
start_time = time.time()
response = requests.post(
self._sync_url,
url,
headers=self._headers,
data=json.dumps(payload),
timeout=TIMEOUT,
Expand All @@ -110,8 +137,13 @@ def extract(
else:
return f"Error: {response.status_code} {response.text}", None

def async_extract(self, file_path: str, extract_args: Optional[Dict] = None) -> str:
"""Extract data asyncronously.
def async_extract(
self,
file_path: str,
model: ModelType = ModelType.BASE,
extract_args: Optional[Dict] = None,
) -> str:
"""Extract data asynchronously.
Args:
file_path (str): The path to the file to be parsed.
Expand All @@ -130,10 +162,21 @@ def async_extract(self, file_path: str, extract_args: Optional[Dict] = None) ->
supported_types = ", ".join(SUPPORTED_FILE_EXTENSIONS)
return f"Error: Unsupported file type: {file_extension}. Supported file types include {supported_types}."

self._check_model(model)

file_name = Path(file_path).name

if model == ModelType.BASE:
process_type = ProcessType.FILE
elif model == ModelType.PRO:
process_type = ProcessType.FILE_REFINED_QUICK
elif model == ModelType.ULTRA:
process_type = ProcessType.FILE_REFINED

# Create the JSON payload
payload = {
"file_name": file_name,
"process_type": process_type.value,
}

if extract_args is not None and isinstance(extract_args, dict):
Expand Down Expand Up @@ -220,3 +263,10 @@ def async_fetch(
if response.status_code == 202:
return None
return f"Error: {response.status_code} {response.text}"

def _check_model(self, model: ModelType) -> None:
if model not in {ModelType.BASE, ModelType.PRO, ModelType.ULTRA}:
valid_models = ", ".join(["`" + model.value + "`" for model in ModelType])
raise ValueError(
f"Invalid model type: {model}. Supported `model` types include {valid_models}."
)
377 changes: 371 additions & 6 deletions examples/async_pdf_to_markdown.ipynb

Large diffs are not rendered by default.

222 changes: 211 additions & 11 deletions examples/pdf_to_markdown.ipynb

Large diffs are not rendered by default.

0 comments on commit 58c96ac

Please sign in to comment.