Skip to content

Commit

Permalink
Merge pull request #2569 from hlohaus/12Jan
Browse files Browse the repository at this point in the history
Read FinishReason and Usage from Gemini API
  • Loading branch information
hlohaus authored Jan 14, 2025
2 parents 03813fb + 0e5c9ed commit 3f39890
Show file tree
Hide file tree
Showing 12 changed files with 88 additions and 60 deletions.
24 changes: 13 additions & 11 deletions g4f/Provider/Jmuz.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class Jmuz(OpenaiAPI):
label = "Jmuz"
url = "https://jmuz.me"
url = "https://discord.gg/qXfu24JmsB"
login_url = None
api_base = "https://jmuz.me/gpt/api/v2"
api_key = "prod"
Expand All @@ -15,7 +15,7 @@ class Jmuz(OpenaiAPI):
supports_stream = True
supports_system_message = False

default_model = 'gpt-4o'
default_model = "gpt-4o"
model_aliases = {
"gemini": "gemini-exp",
"deepseek-chat": "deepseek-2.5",
Expand All @@ -29,13 +29,7 @@ def get_models(cls):
return cls.models

@classmethod
def get_model(cls, model: str, **kwargs) -> str:
if model in cls.get_models():
return model
return cls.default_model

@classmethod
def create_async_generator(
async def create_async_generator(
cls,
model: str,
messages: Messages,
Expand All @@ -52,12 +46,20 @@ def create_async_generator(
"cache-control": "no-cache",
"user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36"
}
return super().create_async_generator(
started = False
async for chunk in super().create_async_generator(
model=model,
messages=messages,
api_base=cls.api_base,
api_key=cls.api_key,
stream=cls.supports_stream,
headers=headers,
**kwargs
)
):
if isinstance(chunk, str) and cls.url in chunk:
continue
if isinstance(chunk, str) and not started:
chunk = chunk.lstrip()
if chunk:
started = True
yield chunk
4 changes: 2 additions & 2 deletions g4f/Provider/Pizzagpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class Pizzagpt(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://www.pizzagpt.it"
api_endpoint = "/api/chatx-completion"
working = True
working = False
default_model = 'gpt-4o-mini'

@classmethod
Expand Down Expand Up @@ -46,6 +46,6 @@ async def create_async_generator(
response_json = await response.json()
content = response_json.get("answer", response_json).get("content")
if content:
if "misuse detected. please get in touch" in content:
if "Misuse detected. please get in touch" in content:
raise ValueError(content)
yield content
7 changes: 4 additions & 3 deletions g4f/Provider/needs_auth/Custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from .OpenaiAPI import OpenaiAPI

class Custom(OpenaiAPI):
label = "Custom"
label = "Custom Provider"
url = None
login_url = "http://localhost:8080"
login_url = None
working = True
api_base = "http://localhost:8080/v1"
needs_auth = False
needs_auth = False
sort_models = False
17 changes: 16 additions & 1 deletion g4f/Provider/needs_auth/GeminiPro.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import base64
import json
import requests
from typing import Optional
from aiohttp import ClientSession, BaseConnector

from ...typing import AsyncResult, Messages, ImagesType
from ...image import to_bytes, is_accepted_format
from ...errors import MissingAuthError
from ...requests.raise_for_status import raise_for_status
from ...providers.response import Usage, FinishReason
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import get_connector
from ... import debug
Expand Down Expand Up @@ -62,6 +64,7 @@ async def create_async_generator(
api_base: str = api_base,
use_auth_header: bool = False,
images: ImagesType = None,
tools: Optional[list] = None,
connector: BaseConnector = None,
**kwargs
) -> AsyncResult:
Expand Down Expand Up @@ -104,7 +107,10 @@ async def create_async_generator(
"maxOutputTokens": kwargs.get("max_tokens"),
"topP": kwargs.get("top_p"),
"topK": kwargs.get("top_k"),
}
},
"tools": [{
"functionDeclarations": tools
}] if tools else None
}
system_prompt = "\n".join(
message["content"]
Expand All @@ -128,6 +134,15 @@ async def create_async_generator(
data = b"".join(lines)
data = json.loads(data)
yield data["candidates"][0]["content"]["parts"][0]["text"]
if "finishReason" in data["candidates"][0]:
yield FinishReason(data["candidates"][0]["finishReason"].lower())
usage = data.get("usageMetadata")
if usage:
yield Usage(
prompt_tokens=usage.get("promptTokenCount"),
completion_tokens=usage.get("candidatesTokenCount"),
total_tokens=usage.get("totalTokenCount")
)
except:
data = data.decode(errors="ignore") if isinstance(data, bytes) else data
raise RuntimeError(f"Read chunk failed: {data}")
Expand Down
10 changes: 7 additions & 3 deletions g4f/Provider/needs_auth/HuggingFace.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,25 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
default_model = HuggingChat.default_model
default_image_model = HuggingChat.default_image_model
model_aliases = HuggingChat.model_aliases
extra_models = [
"meta-llama/Llama-3.2-11B-Vision-Instruct",
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
"NousResearch/Hermes-3-Llama-3.1-8B",
]

@classmethod
def get_models(cls) -> list[str]:
if not cls.models:
url = "https://huggingface.co/api/models?inference=warm&pipeline_tag=text-generation"
models = [model["id"] for model in requests.get(url).json()]
models.append("meta-llama/Llama-3.2-11B-Vision-Instruct")
models.append("nvidia/Llama-3.1-Nemotron-70B-Instruct-HF")
models.extend(cls.extra_models)
models.sort()
if not cls.image_models:
url = "https://huggingface.co/api/models?pipeline_tag=text-to-image"
cls.image_models = [model["id"] for model in requests.get(url).json() if model["trendingScore"] >= 20]
cls.image_models.sort()
models.extend(cls.image_models)
cls.models = models
cls.models = list(set(models))
return cls.models

@classmethod
Expand Down
8 changes: 6 additions & 2 deletions g4f/Provider/needs_auth/OpenaiAPI.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
supports_system_message = True
default_model = ""
fallback_models = []
sort_models = True

@classmethod
def get_models(cls, api_key: str = None, api_base: str = None) -> list[str]:
Expand All @@ -36,8 +37,11 @@ def get_models(cls, api_key: str = None, api_base: str = None) -> list[str]:
response = requests.get(f"{api_base}/models", headers=headers)
raise_for_status(response)
data = response.json()
cls.models = [model.get("id") for model in (data.get("data") if isinstance(data, dict) else data)]
cls.models.sort()
data = data.get("data") if isinstance(data, dict) else data
cls.image_models = [model.get("id") for model in data if model.get("image")]
cls.models = [model.get("id") for model in data]
if cls.sort_models:
cls.models.sort()
except Exception as e:
debug.log(e)
cls.models = cls.fallback_models
Expand Down
16 changes: 10 additions & 6 deletions g4f/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,16 @@ async def read_root_v1():
HTTP_200_OK: {"model": List[ModelResponseModel]},
})
async def models():
return [{
'id': model_id,
'object': 'model',
'created': 0,
'owned_by': model.base_provider
} for model_id, model in g4f.models.ModelUtils.convert.items()]
return {
"object": "list",
"data": [{
"id": model_id,
"object": "model",
"created": 0,
"owned_by": model.base_provider,
"image": isinstance(model, g4f.models.ImageModel),
} for model_id, model in g4f.models.ModelUtils.convert.items()]
}

@self.app.get("/v1/models/{model_name}", responses={
HTTP_200_OK: {"model": ModelResponseModel},
Expand Down
13 changes: 9 additions & 4 deletions g4f/gui/client/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ <h3>Settings</h3>
<label for="refine" class="toogle" title=""></label>
</div>
<div class="field box">
<label for="systemPrompt" class="label" title="">Default for System prompt</label>
<label for="systemPrompt" class="label" title="">System prompt</label>
<textarea id="systemPrompt" placeholder="You are a helpful assistant."></textarea>
</div>
<div class="field box">
Expand All @@ -157,6 +157,14 @@ <h3>Settings</h3>
document.getElementById('recognition-language').placeholder = navigator.language;
</script>
</div>
<div class="field box">
<label for="Custom-api_base" class="label" title="">Custom Provider (Base Url):</label>
<input type="text" id="Custom-api_base" name="Custom[api_base]" placeholder="http://localhost:8080/v1"/>
</div>
<div class="field box hidden">
<label for="Custom-api_key" class="label" title="">Custom Provider:</label>
<input type="text" id="Custom-api_key" name="Custom[api_key]" placeholder="api_key"/>
</div>
<div class="field box hidden">
<label for="BingCreateImages-api_key" class="label" title="">Microsoft Designer in Bing:</label>
<input type="text" id="BingCreateImages-api_key" name="BingCreateImages[api_key]" placeholder="&quot;_U&quot; cookie"/>
Expand Down Expand Up @@ -254,10 +262,7 @@ <h3>Settings</h3>
<option value="gpt-4o">gpt-4o</option>
<option value="gpt-4o-mini">gpt-4o-mini</option>
<option value="llama-3.1-70b">llama-3.1-70b</option>
<option value="llama-3.1-405b">llama-3.1-405b</option>
<option value="mixtral-8x7b">mixtral-8x7b</option>
<option value="gemini-pro">gemini-pro</option>
<option value="gemini-flash">gemini-flash</option>
<option value="claude-3.5-sonnet">claude-3.5-sonnet</option>
<option value="flux">flux (Image Generation)</option>
<option value="dall-e-3">dall-e-3 (Image Generation)</option>
Expand Down
35 changes: 13 additions & 22 deletions g4f/gui/client/static/js/chat.v1.js
Original file line number Diff line number Diff line change
Expand Up @@ -351,11 +351,6 @@ const handle_ask = async () => {
await count_input()
await add_conversation(window.conversation_id);

if ("text" in fileInput.dataset) {
message += '\n```' + fileInput.dataset.type + '\n';
message += fileInput.dataset.text;
message += '\n```'
}
let message_index = await add_message(window.conversation_id, "user", message);
let message_id = get_message_id();

Expand Down Expand Up @@ -799,6 +794,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
const files = input && input.files.length > 0 ? input.files : null;
const download_images = document.getElementById("download_images")?.checked;
const api_key = get_api_key_by_provider(provider);
const api_base = provider == "Custom" ? document.getElementById(`${provider}-api_base`).value : null;
const ignored = Array.from(settings.querySelectorAll("input.provider:not(:checked)")).map((el)=>el.value);
await api("conversation", {
id: message_id,
Expand All @@ -811,6 +807,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
action: action,
download_images: download_images,
api_key: api_key,
api_base: api_base,
ignored: ignored,
}, files, message_id, scroll);
content_map.update_timeouts.forEach((timeoutId)=>clearTimeout(timeoutId));
Expand Down Expand Up @@ -1066,7 +1063,7 @@ const load_conversation = async (conversation_id, scroll=true) => {
}
buffer = buffer.replace(/ \[aborted\]$/g, "").replace(/ \[error\]$/g, "");
new_content = item.content.replace(/ \[aborted\]$/g, "").replace(/ \[error\]$/g, "");
buffer += merge_messages(buffer, new_content);
buffer = merge_messages(buffer, new_content);
last_model = item.provider?.model;
providers.push(item.provider?.name);
let next_i = parseInt(i) + 1;
Expand Down Expand Up @@ -1176,9 +1173,6 @@ const load_conversation = async (conversation_id, scroll=true) => {
if (window.GPTTokenizer_cl100k_base) {
const filtered = prepare_messages(messages, null, true, false);
if (filtered.length > 0) {
if (GPTTokenizer_o200k_base && last_model?.startsWith("gpt-4o") || last_model?.startsWith("o1")) {
return GPTTokenizer_o200k_base?.encodeChat(filtered, last_model).length;
}
last_model = last_model?.startsWith("gpt-3") ? "gpt-3.5-turbo" : "gpt-4"
let count_total = GPTTokenizer_cl100k_base?.encodeChat(filtered, last_model).length
if (count_total > 0) {
Expand Down Expand Up @@ -1890,7 +1884,6 @@ setTimeout(load_version, 100);

fileInput.addEventListener('click', async (event) => {
fileInput.value = '';
delete fileInput.dataset.text;
});

async function upload_cookies() {
Expand Down Expand Up @@ -1920,7 +1913,6 @@ function formatFileSize(bytes) {
async function upload_files(fileInput) {
const paperclip = document.querySelector(".user-input .fa-paperclip");
const bucket_id = uuid();
delete fileInput.dataset.text;
paperclip.classList.add("blink");

const formData = new FormData();
Expand Down Expand Up @@ -1980,8 +1972,7 @@ fileInput.addEventListener('change', async (event) => {
if (type == "json") {
const reader = new FileReader();
reader.addEventListener('load', async (event) => {
fileInput.dataset.text = event.target.result;
const data = JSON.parse(fileInput.dataset.text);
const data = JSON.parse(event.target.result);
if (data.options && "g4f" in data.options) {
let count = 0;
Object.keys(data).forEach(key => {
Expand All @@ -1990,7 +1981,6 @@ fileInput.addEventListener('change', async (event) => {
count += 1;
}
});
delete fileInput.dataset.text;
await load_conversations();
fileInput.value = "";
inputCount.innerText = `${count} Conversations were imported successfully`;
Expand All @@ -2012,8 +2002,6 @@ fileInput.addEventListener('change', async (event) => {
});
reader.readAsText(fileInput.files[0]);
}
} else {
delete fileInput.dataset.text;
}
});

Expand All @@ -2033,16 +2021,19 @@ function get_selected_model() {
}

async function api(ressource, args=null, files=null, message_id=null, scroll=true) {
let api_key;
const headers = {};
if (ressource == "models" && args) {
api_key = get_api_key_by_provider(args);
if (api_key) {
headers.x_api_key = api_key;
}
api_base = args == "Custom" ? document.getElementById(`${args}-api_base`).value : null;
if (api_base) {
headers.x_api_base = api_base;
}
ressource = `${ressource}/${args}`;
}
const url = `/backend-api/v2/${ressource}`;
const headers = {};
if (api_key) {
headers.x_api_key = api_key;
}
if (ressource == "conversation") {
let body = JSON.stringify(args);
headers.accept = 'text/event-stream';
Expand Down Expand Up @@ -2224,7 +2215,7 @@ if (SpeechRecognition) {
};
recognition.onend = function() {
messageInput.value = `${startValue ? startValue + "\n" : ""}${buffer}`;
if (!microLabel.classList.contains("recognition")) {
if (microLabel.classList.contains("recognition")) {
recognition.start();
} else {
messageInput.readOnly = false;
Expand Down
7 changes: 5 additions & 2 deletions g4f/gui/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ def get_models():
for model, providers in models.__models__.values()]

@staticmethod
def get_provider_models(provider: str, api_key: str = None):
def get_provider_models(provider: str, api_key: str = None, api_base: str = None):
if provider in ProviderUtils.convert:
provider = ProviderUtils.convert[provider]
if issubclass(provider, ProviderModelMixin):
if api_key is not None and "api_key" in signature(provider.get_models).parameters:
models = provider.get_models(api_key=api_key)
models = provider.get_models(api_key=api_key, api_base=api_base)
else:
models = provider.get_models()
return [
Expand Down Expand Up @@ -90,6 +90,9 @@ def _prepare_conversation_kwargs(self, json_data: dict, kwargs: dict):
api_key = json_data.get("api_key")
if api_key is not None:
kwargs["api_key"] = api_key
api_base = json_data.get("api_base")
if api_base is not None:
kwargs["api_base"] = api_base
kwargs["tool_calls"] = [{
"function": {
"name": "bucket_tool"
Expand Down
Loading

0 comments on commit 3f39890

Please sign in to comment.