Skip to content

Commit

Permalink
Merge pull request #1667 from hlohaus/phind2
Browse files Browse the repository at this point in the history
Expire cache, Fix multiple websocket conversations in OpenaiChat
  • Loading branch information
hlohaus authored Mar 9, 2024
2 parents d1a8164 + 74a33f1 commit b3d19c5
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 230 deletions.
18 changes: 8 additions & 10 deletions g4f/Provider/GeminiPro.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,38 +26,35 @@ async def create_async_generator(
stream: bool = False,
proxy: str = None,
api_key: str = None,
api_base: str = None,
use_auth_header: bool = True,
api_base: str = "https://generativelanguage.googleapis.com/v1beta",
use_auth_header: bool = False,
image: ImageType = None,
connector: BaseConnector = None,
**kwargs
) -> AsyncResult:
model = "gemini-pro-vision" if not model and image else model
model = "gemini-pro-vision" if not model and image is not None else model
model = cls.get_model(model)

if not api_key:
raise MissingAuthError('Missing "api_key"')

headers = params = None
if api_base and use_auth_header:
if use_auth_header:
headers = {"Authorization": f"Bearer {api_key}"}
else:
params = {"key": api_key}

if not api_base:
api_base = f"https://generativelanguage.googleapis.com/v1beta"

method = "streamGenerateContent" if stream else "generateContent"
url = f"{api_base.rstrip('/')}/models/{model}:{method}"
async with ClientSession(headers=headers, connector=get_connector(connector, proxy)) as session:
contents = [
{
"role": "model" if message["role"] == "assistant" else message["role"],
"role": "model" if message["role"] == "assistant" else "user",
"parts": [{"text": message["content"]}]
}
for message in messages
]
if image:
if image is not None:
image = to_bytes(image)
contents[-1]["parts"].append({
"inline_data": {
Expand Down Expand Up @@ -87,7 +84,8 @@ async def create_async_generator(
lines = [b"{\n"]
elif chunk == b",\r\n" or chunk == b"]":
try:
data = json.loads(b"".join(lines))
data = b"".join(lines)
data = json.loads(data)
yield data["candidates"][0]["content"]["parts"][0]["text"]
except:
data = data.decode() if isinstance(data, bytes) else data
Expand Down
2 changes: 1 addition & 1 deletion g4f/Provider/bing/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async def create_conversation(session: ClientSession, proxy: str = None) -> Conv
}
for k, v in headers.items():
session.headers[k] = v
url = 'https://www.bing.com/turing/conversation/create?bundleVersion=1.1579.2'
url = 'https://www.bing.com/turing/conversation/create?bundleVersion=1.1626.1'
async with session.get(url, headers=headers, proxy=proxy) as response:
try:
data = await response.json()
Expand Down
103 changes: 55 additions & 48 deletions g4f/Provider/needs_auth/OpenaiChat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
import json
import os
import base64
import time
from aiohttp import ClientWebSocketResponse

try:
from py_arkose_generator.arkose import get_values_for_request
from async_property import async_cached_property
has_requirements = True
has_arkose_generator = True
except ImportError:
async_cached_property = property
has_requirements = False
has_arkose_generator = False

try:
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
Expand All @@ -33,7 +33,7 @@

class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
"""A class for creating and managing conversations with OpenAI chat service"""

url = "https://chat.openai.com"
working = True
needs_auth = True
Expand All @@ -47,7 +47,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
_api_key: str = None
_headers: dict = None
_cookies: Cookies = None
_last_message: int = 0
_expires: int = None

@classmethod
async def create(
Expand Down Expand Up @@ -80,7 +80,7 @@ async def create(
A Response object that contains the generator, action, messages, and options
"""
# Add the user input to the messages list
if prompt:
if prompt is not None:
messages.append({
"role": "user",
"content": prompt
Expand All @@ -102,7 +102,7 @@ async def create(
messages,
kwargs
)

@classmethod
async def upload_image(
cls,
Expand Down Expand Up @@ -162,7 +162,7 @@ async def upload_image(
response.raise_for_status()
image_data["download_url"] = (await response.json())["download_url"]
return ImageRequest(image_data)

@classmethod
async def get_default_model(cls, session: StreamSession, headers: dict):
"""
Expand All @@ -185,7 +185,7 @@ async def get_default_model(cls, session: StreamSession, headers: dict):
return cls.default_model
raise RuntimeError(f"Response: {data}")
return cls.default_model

@classmethod
def create_messages(cls, messages: Messages, image_request: ImageRequest = None):
"""
Expand Down Expand Up @@ -334,9 +334,7 @@ async def create_async_generator(
Raises:
RuntimeError: If an error occurs during processing.
"""
if not has_requirements:
raise MissingRequirementsError('Install "py-arkose-generator" and "async_property" package')
if not parent_id:
if parent_id is None:
parent_id = str(uuid.uuid4())

# Read api_key from arguments
Expand All @@ -348,7 +346,7 @@ async def create_async_generator(
timeout=timeout
) as session:
# Read api_key and cookies from cache / browser config
if cls._headers is None:
if cls._headers is None or cls._expires is None or time.time() > cls._expires:
if api_key is None:
# Read api_key from cookies
cookies = get_cookies("chat.openai.com", False) if cookies is None else cookies
Expand All @@ -357,8 +355,8 @@ async def create_async_generator(
else:
api_key = cls._api_key if api_key is None else api_key
# Read api_key with session cookies
if api_key is None and cookies:
api_key = await cls.fetch_access_token(session, cls._headers)
#if api_key is None and cookies:
# api_key = await cls.fetch_access_token(session, cls._headers)
# Load default model
if cls.default_model is None and api_key is not None:
try:
Expand All @@ -384,6 +382,19 @@ async def create_async_generator(
else:
cls._set_api_key(api_key)

async with session.post(
f"{cls.url}/backend-api/sentinel/chat-requirements",
json={"conversation_mode_kind": "primary_assistant"},
headers=cls._headers
) as response:
response.raise_for_status()
data = await response.json()
need_arkose = data["arkose"]["required"]
chat_token = data["token"]

if need_arkose and not has_arkose_generator:
raise MissingRequirementsError('Install "py-arkose-generator" package')

try:
image_request = await cls.upload_image(session, cls._headers, image, image_name) if image else None
except Exception as e:
Expand All @@ -394,12 +405,10 @@ async def create_async_generator(
model = cls.get_model(model).replace("gpt-3.5-turbo", "text-davinci-002-render-sha")
fields = ResponseFields()
while fields.finish_reason is None:
arkose_token = await cls.get_arkose_token(session)
conversation_id = conversation_id if fields.conversation_id is None else fields.conversation_id
parent_id = parent_id if fields.message_id is None else fields.message_id
data = {
"action": action,
"arkose_token": arkose_token,
"conversation_mode": {"kind": "primary_assistant"},
"force_paragen": False,
"force_rate_limit": False,
Expand All @@ -417,7 +426,8 @@ async def create_async_generator(
json=data,
headers={
"Accept": "text/event-stream",
"OpenAI-Sentinel-Arkose-Token": arkose_token,
**({"OpenAI-Sentinel-Arkose-Token": await cls.get_arkose_token(session)} if need_arkose else {}),
"OpenAI-Sentinel-Chat-Requirements-Token": chat_token,
**cls._headers
}
) as response:
Expand All @@ -437,17 +447,20 @@ async def create_async_generator(
await cls.delete_conversation(session, cls._headers, fields.conversation_id)

@staticmethod
async def iter_messages_ws(ws: ClientWebSocketResponse) -> AsyncIterator:
async def iter_messages_ws(ws: ClientWebSocketResponse, conversation_id: str) -> AsyncIterator:
while True:
yield base64.b64decode((await ws.receive_json())["body"])
message = await ws.receive_json()
if message["conversation_id"] == conversation_id:
yield base64.b64decode(message["body"])

@classmethod
async def iter_messages_chunk(cls, messages: AsyncIterator, session: StreamSession, fields: ResponseFields) -> AsyncIterator:
last_message: int = 0
async for message in messages:
if message.startswith(b'{"wss_url":'):
async with session.ws_connect(json.loads(message)["wss_url"]) as ws:
async for chunk in cls.iter_messages_chunk(cls.iter_messages_ws(ws), session, fields):
message = json.loads(message)
async with session.ws_connect(message["wss_url"]) as ws:
async for chunk in cls.iter_messages_chunk(cls.iter_messages_ws(ws, message["conversation_id"]), session, fields):
yield chunk
break
async for chunk in cls.iter_messages_line(session, message, fields):
Expand All @@ -467,6 +480,8 @@ async def iter_messages_line(cls, session: StreamSession, line: bytes, fields: R
if not line.startswith(b"data: "):
return
elif line.startswith(b"data: [DONE]"):
if fields.finish_reason is None:
fields.finish_reason = "error"
return
try:
line = json.loads(line[6:])
Expand Down Expand Up @@ -589,22 +604,13 @@ def _update_request_args(cls, session: StreamSession):
@classmethod
def _set_api_key(cls, api_key: str):
cls._api_key = api_key
cls._expires = int(time.time()) + 60 * 60 * 4
cls._headers["Authorization"] = f"Bearer {api_key}"

@classmethod
def _update_cookie_header(cls):
cls._headers["Cookie"] = cls._format_cookies(cls._cookies)

class EndTurn:
"""
Class to represent the end of a conversation turn.
"""
def __init__(self):
self.is_end = False

def end(self):
self.is_end = True

class ResponseFields:
"""
Class to encapsulate response fields.
Expand Down Expand Up @@ -633,8 +639,8 @@ def __init__(
self._options = options
self._fields = None

async def generator(self):
if self._generator:
async def generator(self) -> AsyncIterator:
if self._generator is not None:
self._generator = None
chunks = []
async for chunk in self._generator:
Expand All @@ -644,27 +650,29 @@ async def generator(self):
yield chunk
chunks.append(str(chunk))
self._message = "".join(chunks)
if not self._fields:
if self._fields is None:
raise RuntimeError("Missing response fields")
self.is_end = self._fields.end_turn
self.is_end = self._fields.finish_reason == "stop"

def __aiter__(self):
return self.generator()

@async_cached_property
async def message(self) -> str:
async def get_message(self) -> str:
await self.generator()
return self._message

async def get_fields(self):
async def get_fields(self) -> dict:
await self.generator()
return {"conversation_id": self._fields.conversation_id, "parent_id": self._fields.message_id}
return {
"conversation_id": self._fields.conversation_id,
"parent_id": self._fields.message_id
}

async def next(self, prompt: str, **kwargs) -> Response:
async def create_next(self, prompt: str, **kwargs) -> Response:
return await OpenaiChat.create(
**self._options,
prompt=prompt,
messages=await self.messages,
messages=await self.get_messages(),
action="next",
**await self.get_fields(),
**kwargs
Expand All @@ -676,13 +684,13 @@ async def do_continue(self, **kwargs) -> Response:
raise RuntimeError("Can't continue message. Message already finished.")
return await OpenaiChat.create(
**self._options,
messages=await self.messages,
messages=await self.get_messages(),
action="continue",
**fields,
**kwargs
)

async def variant(self, **kwargs) -> Response:
async def create_variant(self, **kwargs) -> Response:
if self.action != "next":
raise RuntimeError("Can't create variant from continue or variant request.")
return await OpenaiChat.create(
Expand All @@ -693,8 +701,7 @@ async def variant(self, **kwargs) -> Response:
**kwargs
)

@async_cached_property
async def messages(self):
async def get_messages(self) -> list:
messages = self._messages
messages.append({"role": "assistant", "content": await self.message})
messages.append({"role": "assistant", "content": await self.message()})
return messages
Loading

0 comments on commit b3d19c5

Please sign in to comment.