Skip to content

Commit

Permalink
[0.3.1]
Browse files Browse the repository at this point in the history
  • Loading branch information
kyegomez committed Jan 16, 2025
1 parent 4d7e04e commit af0b84c
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 23 deletions.
4 changes: 4 additions & 0 deletions swarm_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from swarm_models.base_llm import BaseLLM # noqa: E402
from swarm_models.base_multimodal_model import BaseMultiModalModel
from swarm_models.gpt4_vision_api import GPT4VisionAPI # noqa: E402

# from swarm_models.huggingface import HuggingfaceLLM # noqa: E402
# from swarm_models.layoutlm_document_qa import LayoutLMDocumentQA
from swarm_models.llama3_hosted import llama3Hosted

# from swarm_models.llava import LavaMultiModal # noqa: E402
# from swarm_models.nougat import Nougat # noqa: E402
from swarm_models.openai_embeddings import OpenAIEmbeddings
Expand Down Expand Up @@ -32,12 +34,14 @@
TextModality,
VideoModality,
)

# from swarm_models.vilt import Vilt # noqa: E402
from swarm_models.popular_llms import FireWorksAI
from swarm_models.openai_function_caller import OpenAIFunctionCaller
from swarm_models.ollama_model import OllamaModel
from swarm_models.sam_two import GroundedSAMTwo
from swarm_models.utils import * # NOQA

# from swarm_models.together_llm import TogetherLLM
# from swarm_models.lite_llm_model import LiteLLM
from swarm_models.tiktoken_wrapper import TikTokenizer
Expand Down
64 changes: 42 additions & 22 deletions swarm_models/gpt4_vision_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def __init__(
self.meta_prompt = meta_prompt
self.system_prompt = system_prompt


if self.logging_enabled:
logging.basicConfig(level=logging.DEBUG)
else:
Expand All @@ -94,7 +93,6 @@ def __init__(
if self.meta_prompt:
self.system_prompt = self.meta_prompt_init()


def encode_image(self, img: str):
"""Encode image to base64."""
if not os.path.exists(img):
Expand All @@ -113,7 +111,13 @@ def download_img_then_encode(self, img: str):
response = requests.get(img)
return base64.b64encode(response.content).decode("utf-8")

def compose_messages(self, task: str, img: str, img_list: list = None, context: list = None):
def compose_messages(
self,
task: str,
img: str,
img_list: list = None,
context: list = None,
):
"""Compose the payload for the GPT-4 Vision API, if illegal image paths are provided
, None is returned, means the payload is not valid
Expand All @@ -135,11 +139,12 @@ def compose_messages(self, task: str, img: str, img_list: list = None, context:
if None is returned, then the payload is not valid
"""


# Compose the messages
messages = []
# Add the system prompt to the messages
messages.append({"role": "system", "content": self.system_prompt})
messages.append(
{"role": "system", "content": self.system_prompt}
)
# Add the context to the messages
messages = messages + context if context else messages

Expand All @@ -155,17 +160,30 @@ def compose_messages(self, task: str, img: str, img_list: list = None, context:
if image:
if os.path.exists(image):
encoded_img = self.encode_image(image)
content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_img}"}})
content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{encoded_img}"
},
}
)
elif image.startswith("http"):
content.append({"type": "image_url", "image_url": {"url": f"{image}"}})
content.append(
{
"type": "image_url",
"image_url": {"url": f"{image}"},
}
)
else:
logger.error(f"Image file not found: {image} or not a valid URL")
print(f"Image file not found: {image} or not a valid URL")
logger.error(
f"Image file not found: {image} or not a valid URL"
)
print(
f"Image file not found: {image} or not a valid URL"
)
return None
content = {
"role": "user",
"content": content
}
content = {"role": "user", "content": content}
messages.append(content)
return messages
return None
Expand All @@ -188,22 +206,24 @@ def run(
"Authorization": f"Bearer {self.openai_api_key}",
}
if messages is None:
messages = self.compose_messages(task, img, multi_imgs, messages)
messages = self.compose_messages(
task, img, multi_imgs, messages
)

if messages is None:
raise ValueError("Image path is invalid, please check the image path")
raise ValueError(
"Image path is invalid, please check the image path"
)

payload = {
"model": self.model_name,
"messages": messages
}
payload = {"model": self.model_name, "messages": messages}

response = requests.post(self.openai_proxy, headers=headers, json=payload)
response = requests.post(
self.openai_proxy, headers=headers, json=payload
)

# Get the response as a JSON object
response_json = response.json()


# Return the JSON object if return_json is True
if return_json is True:
print(response_json)
Expand Down Expand Up @@ -422,4 +442,4 @@ def print_dashboard(self):
# example, instead of '1 - 4', list as '[1], [2], [3], [4]'. These labels could be
# numbers or letters and typically correspond to specific segments or parts of the image.
# """
# return META_PROMPT
# return META_PROMPT
4 changes: 3 additions & 1 deletion swarm_models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,9 @@ def run(self, task: str, *args, **kwargs):
- Generated text (str).
"""
try:
inputs = self.tokenizer.encode(task, return_tensors="pt").to(self.model.device)
inputs = self.tokenizer.encode(
task, return_tensors="pt"
).to(self.model.device)

if self.decoding:
with torch.no_grad():
Expand Down

0 comments on commit af0b84c

Please sign in to comment.