Skip to content
This repository has been archived by the owner on Jun 12, 2024. It is now read-only.

Commit

Permalink
feat: add docstring and type hint for all funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
dsdanielpark committed Mar 11, 2024
1 parent 7d591d3 commit 07a39ff
Show file tree
Hide file tree
Showing 7 changed files with 260 additions and 40 deletions.
20 changes: 20 additions & 0 deletions gemini/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Gemini:
auto_cookies (bool): Automatically manage cookies if True.
timeout (int): Timeout for requests, defaults to 30 seconds.
proxies (Optional[dict]): Proxy configuration for requests, if any.
rcid (str): Response candidate ID.
"""

def __init__(
Expand Down Expand Up @@ -179,6 +180,7 @@ def _construct_payload(
Parameters:
prompt (str): The user prompt to send.
image (Union[bytes, str]): The image data as bytes or file path. Supported formats: webp, jpeg, png.
nonce (str): A one-time token used for request verification.
Returns:
Expand Down Expand Up @@ -261,6 +263,15 @@ def generate_content(
return self._create_model_output(parsed_response)

def _create_model_output(self, parsed_response: dict) -> GeminiModelOutput:
"""
Creates model output from parsed response.
Args:
parsed_response (dict): The parsed response data.
Returns:
GeminiModelOutput: The model output containing metadata, candidates, and response dictionary.
"""
candidates = self.collect_candidates(parsed_response)
metadata = parsed_response.get("metadata", [])
try:
Expand All @@ -277,6 +288,15 @@ def _create_model_output(self, parsed_response: dict) -> GeminiModelOutput:

@staticmethod
def collect_candidates(data):
"""
Collects candidate data from parsed response.
Args:
data: The parsed response data.
Returns:
List: A list of GeminiCandidate objects.
"""
collected = []
stack = [data]

Expand Down
2 changes: 2 additions & 0 deletions gemini/src/misc/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class Tool(Enum):
YOUTUBE = ["youtube_tool"]


IMAGE_PUSH_ID = "feeds/mcudyrk2a4khkz"

DEFAULT_LANGUAGE = "en"
POST_ENDPOINT = "https://gemini.google.com/_/BardChatUi/data/assistant.lamda.BardFrontendService/StreamGenerate"
HOST = "https://gemini.google.com"
Expand Down
50 changes: 26 additions & 24 deletions gemini/src/misc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import requests
from typing import Union
from .constants import REPLIT_SUPPORT_PROGRAM_LANGUAGES
from .constants import REPLIT_SUPPORT_PROGRAM_LANGUAGES, IMAGE_PUSH_ID


def extract_code(text: str, language: str) -> str:
Expand Down Expand Up @@ -68,7 +68,7 @@ def upload_image(file: Union[bytes, str]) -> str:
response = requests.post(
url="https://content-push.googleapis.com/upload/",
headers={
"Push-ID": "feeds/mcudyrk2a4khkz",
"Push-ID": IMAGE_PUSH_ID,
"Content-Type": "application/octet-stream",
},
data=file_data,
Expand All @@ -79,6 +79,30 @@ def upload_image(file: Union[bytes, str]) -> str:
return response.text


def build_replit_structure(instructions: str, code: str, filename: str) -> list:
"""
Creates and returns the input image data structure based on provided parameters.
Args:
instructions (str): The instruction text.
code (str): The code.
filename (str): The filename.
Returns:
list: The input image data structure.
"""
return [
[
[
"qACoKe",
json.dumps([instructions, 5, code, [[filename, code]]]),
None,
"generic",
]
]
]


def max_token(text: str, n: int) -> str:
"""
Return the first 'n' tokens (words) of the given text.
Expand Down Expand Up @@ -122,25 +146,3 @@ def max_sentence(text: str, n: int):
if sentence_count == n:
result = "".join(sentences).strip()
return result


def build_replit_data(instructions: str, code: str, filename: str) -> list:
"""
Creates and returns the input_image_data_struct based on provided parameters.
:param instructions: The instruction text.
:param code: The code.
:param filename: The filename.
:return: The input_image_data_struct.
"""
return [
[
[
"qACoKe",
json.dumps([instructions, 5, code, [[filename, code]]]),
None,
"generic",
]
]
]
53 changes: 48 additions & 5 deletions gemini/src/model/decorator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
import functools
import time
import requests
from typing import Callable


def retry(attempts=3, delay=2, backoff=2):
def retry_decorator(func):
def retry(attempts: int = 3, delay: int = 2, backoff: int = 2) -> Callable:
"""
Retries a function call with exponential backoff.
Args:
attempts (int): The maximum number of attempts. Defaults to 3.
delay (int): The initial delay in seconds between retries. Defaults to 2.
backoff (int): The backoff factor. Defaults to 2.
Returns:
Callable: Decorator function.
"""

def retry_decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
_attempts, _delay = attempts, delay
Expand All @@ -23,7 +36,17 @@ def wrapper(*args, **kwargs):
return retry_decorator


def log_method(func):
def log_method(func: Callable) -> Callable:
"""
Logs method entry and exit.
Args:
func (Callable): The function to decorate.
Returns:
Callable: Decorated function.
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
className = args[0].__class__.__name__
Expand All @@ -39,7 +62,17 @@ def wrapper(*args, **kwargs):
return wrapper


def time_execution(func):
def time_execution(func: Callable) -> Callable:
"""
Measures the execution time of a function.
Args:
func (Callable): The function to decorate.
Returns:
Callable: Decorated function.
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
Expand All @@ -51,7 +84,17 @@ def wrapper(*args, **kwargs):
return wrapper


def handle_errors(func):
def handle_errors(func: Callable) -> Callable:
"""
Handles errors that occur during function execution.
Args:
func (Callable): The function to decorate.
Returns:
Callable: Decorated function.
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
Expand Down
21 changes: 18 additions & 3 deletions gemini/src/model/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,31 @@
class PackageError(Exception):
"""Invalid credentials/cookies."""
"""
Exception raised when encountering invalid credentials or cookies.
This exception indicates that the provided credentials or cookies are invalid or insufficient for the operation.
"""

pass


class GeminiAPIError(Exception):
"""Unhandled server error."""
"""
Exception raised for unhandled errors from the Gemini server.
This exception indicates that an unexpected error occurred on the Gemini server side, which was not properly handled by the client.
"""

pass


class TimeoutError(GeminiAPIError):
"""Request timeout."""
"""
Exception raised when a request times out.
This exception is a subclass of GeminiAPIError and is raised when a request to the Gemini server exceeds the specified timeout period without receiving a response.
"""

pass
67 changes: 67 additions & 0 deletions gemini/src/model/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,41 @@


class GeminiImage(BaseModel):
"""
Represents an image with URL, title, and alt text.
Attributes:
url (HttpUrl): The URL of the image.
title (str): The title of the image. Defaults to "[Image]".
alt (str): The alt text of the image. Defaults to "".
Methods:
validate_images(cls, images): Validates the input images list.
save(cls, images: List["GeminiImage"], save_path: str = "cached", cookies: Optional[dict] = None) -> Optional[Path]:
Downloads and saves images asynchronously.
fetch_bytes(url: HttpUrl, cookies: Optional[dict] = None) -> Optional[bytes]:
Fetches bytes of an image asynchronously.
fetch_images_dict(cls, images: List["GeminiImage"], cookies: Optional[dict] = None) -> Dict[str, bytes]:
Fetches images asynchronously and returns a dictionary of image data.
save_images(cls, image_data: Dict[str, bytes], save_path: str = "cached"):
Saves images locally.
"""

url: HttpUrl
title: str = "[Image]"
alt: str = ""

@classmethod
def validate_images(cls, images):
"""
Validates the input images list.
Args:
images: The list of GeminiImage objects.
Raises:
ValueError: If the input images list is empty.
"""
if not images:
raise ValueError(
"Input is empty. Please provide images infomation to proceed."
Expand All @@ -29,6 +58,17 @@ async def save(
save_path: str = "cached",
cookies: Optional[dict] = None,
) -> Optional[Path]:
"""
Downloads and saves images asynchronously.
Args:
images (List["GeminiImage"]): The list of GeminiImage objects to download.
save_path (str): The directory path to save the images. Defaults to "cached".
cookies (Optional[dict]): Cookies to be used for downloading images. Defaults to None.
Returns:
Optional[Path]: The path to the directory where the images are saved, or None if saving fails.
"""
cls.validate_images(images)
image_data = await cls.fetch_images_dict(images, cookies)
await cls.save_images(image_data, save_path)
Expand All @@ -37,6 +77,16 @@ async def save(
async def fetch_bytes(
url: HttpUrl, cookies: Optional[dict] = None
) -> Optional[bytes]:
"""
Fetches bytes of an image asynchronously.
Args:
url (HttpUrl): The URL of the image.
cookies (Optional[dict]): Cookies to be used for fetching the image. Defaults to None.
Returns:
Optional[bytes]: The bytes of the image, or None if fetching fails.
"""
try:
async with httpx.AsyncClient(follow_redirects=True) as client:
response = await client.get(str(url), cookies=cookies)
Expand All @@ -50,13 +100,30 @@ async def fetch_bytes(
async def fetch_images_dict(
cls, images: List["GeminiImage"], cookies: Optional[dict] = None
) -> Dict[str, bytes]:
"""
Fetches images asynchronously and returns a dictionary of image data.
Args:
images (List["GeminiImage"]): The list of GeminiImage objects to fetch.
cookies (Optional[dict]): Cookies to be used for fetching the images. Defaults to None.
Returns:
Dict[str, bytes]: A dictionary containing image titles as keys and image bytes as values.
"""
cls.validate_images(images)
tasks = [cls.fetch_bytes(image.url, cookies=cookies) for image in images]
results = await asyncio.gather(*tasks)
return {image.title: result for image, result in zip(images, results) if result}

@staticmethod
async def save_images(image_data: Dict[str, bytes], save_path: str = "cached"):
"""
Saves images locally.
Args:
image_data (Dict[str, bytes]): A dictionary containing image titles as keys and image bytes as values.
save_path (str): The directory path to save the images. Defaults to "cached".
"""
os.makedirs(save_path, exist_ok=True)
for title, data in image_data.items():
now = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")
Expand Down
Loading

0 comments on commit 07a39ff

Please sign in to comment.