diff --git a/README.md b/README.md index 23c5d37..19ba4cc 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,7 @@ A reverse-engineered asynchronous python wrapper for [Google Gemini](https://gem - [Retrieve images in response](#retrieve-images-in-response) - [Generate images with ImageFx](#generate-images-with-imagefx) - [Save images to local files](#save-images-to-local-files) + - [Specify language model version](#specify-language-model-version) - [Generate contents with Gemini extensions](#generate-contents-with-gemini-extensions) - [Check and switch to other reply candidates](#check-and-switch-to-other-reply-candidates) - [Control log level](#control-log-level) @@ -92,9 +93,9 @@ pip install -U browser-cookie3 ```yaml services: - main: - volumes: - - ./gemini_cookies:/usr/local/lib/python3.12/site-packages/gemini_webapi/utils/temp + main: + volumes: + - ./gemini_cookies:/usr/local/lib/python3.12/site-packages/gemini_webapi/utils/temp ``` > [!NOTE] @@ -255,6 +256,33 @@ async def main(): asyncio.run(main()) ``` +### Specify language model version + +You can choose a specified language model version by passing `model` argument to `GeminiClient.generate_content` or `GeminiClient.start_chat`. The default value is `unspecified`. + +Currently available models (as of Dec 21, 2024): + +- `unspecified` - Default model (Gemini 1.5 Flash) +- `gemini-1.5-flash` - Gemini 1.5 Flash +- `gemini-2.0-flash-exp` - Gemini 2.0 Flash Experimental + +```python +from gemini_webapi.constants import Model + +async def main(): + response1 = await client.generate_content( + "What's you language model version? Reply version number only.", + model="gemini-1.5-flash", + ) + print(f"Model version (gemini-1.5-flash): {response1.text}") + + chat = client.start_chat(model=Model.G_2_0_FLASH_EXP) + response2 = await chat.send_message("What's you language model version? Reply version number only.") + print(f"Model version ({Model.G_2_0_FLASH_EXP.model_name}): {response2.text}") + +asyncio.run(main()) +``` + ### Generate contents with Gemini extensions > [!IMPORTANT] diff --git a/src/gemini_webapi/client.py b/src/gemini_webapi/client.py index 7131602..1da64df 100644 --- a/src/gemini_webapi/client.py +++ b/src/gemini_webapi/client.py @@ -8,7 +8,7 @@ from httpx import AsyncClient, ReadTimeout -from .constants import Endpoint, Headers +from .constants import Endpoint, Headers, Model from .exceptions import AuthError, APIError, TimeoutError, GeminiError from .types import WebImage, GeneratedImage, Candidate, ModelOutput from .utils import ( @@ -79,6 +79,9 @@ class GeminiClient: __Secure-1PSIDTS cookie value, some google accounts don't require this value, provide only if it's in the cookie list. proxy: `str`, optional Proxy URL. + kwargs: `dict`, optional + Additional arguments which will be passed to the http client. + Refer to `httpx.AsyncClient` for more information. Raises ------ @@ -98,6 +101,7 @@ class GeminiClient: "close_task", "auto_refresh", "refresh_interval", + "kwargs", ] def __init__( @@ -105,6 +109,7 @@ def __init__( secure_1psid: str | None = None, secure_1psidts: str | None = None, proxy: str | None = None, + **kwargs, ): self.cookies = {} self.proxy = proxy @@ -117,6 +122,7 @@ def __init__( self.close_task: Task | None = None self.auto_refresh: bool = True self.refresh_interval: float = 540 + self.kwargs = kwargs # Validate cookies if secure_1psid: @@ -173,6 +179,7 @@ async def init( follow_redirects=True, headers=Headers.GEMINI.value, cookies=valid_cookies, + **self.kwargs, ) self.access_token = access_token self.cookies = valid_cookies @@ -256,7 +263,9 @@ async def generate_content( self, prompt: str, images: list[bytes | str | Path] | None = None, + model: Model | str = Model.UNSPECIFIED, chat: Optional["ChatSession"] = None, + **kwargs, ) -> ModelOutput: """ Generates contents with prompt. @@ -267,8 +276,14 @@ async def generate_content( Prompt provided by user. images: `list[bytes | str | Path]`, optional List of image file paths or file data in bytes. + model: `Model` | `str`, optional + Specify the model to use for generation. + Pass either a `gemini_webapi.constants.Model` enum or a model name string. chat: `ChatSession`, optional Chat data to retrieve conversation history. If None, will automatically generate a new chat id when sending post request. + kwargs: `dict`, optional + Additional arguments which will be passed to the post request. + Refer to `httpx.AsyncClient.request` for more information. Returns ------- @@ -291,12 +306,16 @@ async def generate_content( assert prompt, "Prompt cannot be empty." + if not isinstance(model, Model): + model = Model.from_name(model) + if self.auto_close: await self.reset_close_task() try: response = await self.client.post( Endpoint.GENERATE.value, + headers=model.model_header, data={ "at": self.access_token, "f.req": json.dumps( @@ -325,6 +344,7 @@ async def generate_content( ] ), }, + **kwargs, ) except ReadTimeout: raise TimeoutError( @@ -431,12 +451,13 @@ def start_chat(self, **kwargs) -> "ChatSession": Parameters ---------- kwargs: `dict`, optional - Other arguments to pass to `ChatSession.__init__`. + Additional arguments which will be passed to the chat session. + Refer to `gemini_webapi.ChatSession` for more information. Returns ------- :class:`ChatSession` - Empty chat object for retrieving conversation history. + Empty chat session object for retrieving conversation history. """ return ChatSession(geminiclient=self, **kwargs) @@ -458,9 +479,17 @@ class ChatSession: Reply id, if provided together with metadata, will override the second value in it. rcid: `str`, optional Reply candidate id, if provided together with metadata, will override the third value in it. + model: `Model` | `str`, optional + Specify the model to use for generation. + Pass either a `gemini_webapi.constants.Model` enum or a model name string. """ - __slots__ = ["__metadata", "geminiclient", "last_output"] + __slots__ = [ + "__metadata", + "geminiclient", + "last_output", + "model", + ] def __init__( self, @@ -469,10 +498,12 @@ def __init__( cid: str | None = None, # chat id rid: str | None = None, # reply id rcid: str | None = None, # reply candidate id + model: Model | str = Model.UNSPECIFIED, ): self.__metadata: list[str | None] = [None, None, None] self.geminiclient: GeminiClient = geminiclient self.last_output: ModelOutput | None = None + self.model = model if metadata: self.metadata = metadata @@ -499,6 +530,7 @@ async def send_message( self, prompt: str, images: list[bytes | str | Path] | None = None, + **kwargs, ) -> ModelOutput: """ Generates contents with prompt. @@ -510,6 +542,9 @@ async def send_message( Prompt provided by user. images: `list[bytes | str | Path]`, optional List of image file paths or file data in bytes. + kwargs: `dict`, optional + Additional arguments which will be passed to the post request. + Refer to `httpx.AsyncClient.request` for more information. Returns ------- @@ -531,7 +566,7 @@ async def send_message( """ return await self.geminiclient.generate_content( - prompt=prompt, images=images, chat=self + prompt=prompt, images=images, model=self.model, chat=self, **kwargs ) def choose_candidate(self, index: int) -> ModelOutput: diff --git a/src/gemini_webapi/constants.py b/src/gemini_webapi/constants.py index ec2b35c..5a16e4a 100644 --- a/src/gemini_webapi/constants.py +++ b/src/gemini_webapi/constants.py @@ -21,3 +21,28 @@ class Headers(Enum): "Content-Type": "application/json", } UPLOAD = {"Push-ID": "feeds/mcudyrk2a4khkz"} + + +class Model(Enum): + UNSPECIFIED = ("unspecified", {}) + G_1_5_FLASH = ( + "gemini-1.5-flash", + {"x-goog-ext-525001261-jspb": '[null,null,null,null,"7daceb7ef88130f5"]'}, + ) + G_2_0_FLASH_EXP = ( + "gemini-2.0-flash-exp", + {"x-goog-ext-525001261-jspb": '[null,null,null,null,"948b866104ccf484"]'}, + ) + + def __init__(self, name, header): + self.model_name = name + self.model_header = header + + @classmethod + def from_name(cls, name: str): + for model in cls: + if model.model_name == name: + return model + raise ValueError( + f"Unknown model name: {name}. Available models: {', '.join([model.model_name for model in cls])}" + ) diff --git a/tests/test_client_features.py b/tests/test_client_features.py index 758e1e0..05c5f38 100644 --- a/tests/test_client_features.py +++ b/tests/test_client_features.py @@ -6,6 +6,7 @@ from loguru import logger from gemini_webapi import GeminiClient, AuthError, set_log_level +from gemini_webapi.constants import Model logging.getLogger("asyncio").setLevel(logging.ERROR) set_log_level("DEBUG") @@ -27,10 +28,20 @@ async def test_successful_request(self): response = await self.geminiclient.generate_content("Hello World!") self.assertTrue(response.text) + @logger.catch(reraise=True) + async def test_switch_model(self): + for model in Model: + response = await self.geminiclient.generate_content( + "What's you language model version? Reply version number only.", + model=model, + ) + logger.debug(f"Model version ({model.model_name}): {response.text}") + @logger.catch(reraise=True) async def test_upload_image(self): response = await self.geminiclient.generate_content( - "Describe these images", images=[Path("assets/banner.png"), "assets/favicon.png"] + "Describe these images", + images=[Path("assets/banner.png"), "assets/favicon.png"], ) logger.debug(response.text) @@ -86,9 +97,7 @@ async def test_ai_image_generation(self): @logger.catch(reraise=True) async def test_card_content(self): - response = await self.geminiclient.generate_content( - "How is today's weather?" - ) + response = await self.geminiclient.generate_content("How is today's weather?") logger.debug(response.text) @logger.catch(reraise=True)