-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat/add model management * fix:model add and update * feat:add model provider qwen
- Loading branch information
1 parent
11a1883
commit 7674333
Showing
10 changed files
with
314 additions
and
112 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
import importlib | ||
|
||
def get_provider_config(provider_name: str): | ||
try: | ||
module = importlib.import_module(f'app.core.model_providers.{provider_name}.config') | ||
return module.PROVIDER_CONFIG | ||
except ImportError: | ||
raise ValueError(f"No configuration found for provider: {provider_name}") |
48 changes: 48 additions & 0 deletions
48
backend/app/core/model_providers/model_provider_manager.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import os | ||
import importlib | ||
from typing import Dict, Any, List, Callable | ||
|
||
class ModelProviderManager: | ||
def __init__(self): | ||
self.providers: Dict[str, Dict[str, Any]] = {} | ||
self.models: Dict[str, List[str]] = {} | ||
self.init_functions: Dict[str, Callable] = {} | ||
self.load_providers() | ||
|
||
def load_providers(self): | ||
providers_dir = os.path.dirname(os.path.abspath(__file__)) | ||
for item in os.listdir(providers_dir): | ||
if os.path.isdir(os.path.join(providers_dir, item)) and not item.startswith("__"): | ||
try: | ||
module = importlib.import_module(f".{item}.config", package="app.core.model_providers") | ||
provider_config = getattr(module, 'PROVIDER_CONFIG', None) | ||
supported_models = getattr(module, 'SUPPORTED_MODELS', []) | ||
init_function = getattr(module, 'init_model', None) | ||
|
||
if provider_config and init_function: | ||
self.providers[item] = provider_config | ||
self.models[item] = supported_models | ||
self.init_functions[item] = init_function | ||
except ImportError as e: | ||
print(f"Failed to load provider config for {item}: {e}") | ||
|
||
def get_provider_config(self, provider_name: str) -> Dict[str, Any]: | ||
return self.providers.get(provider_name, {}) | ||
|
||
def get_supported_models(self, provider_name: str) -> List[str]: | ||
return self.models.get(provider_name, []) | ||
|
||
def get_all_providers(self) -> Dict[str, Dict[str, Any]]: | ||
return self.providers | ||
|
||
def get_all_models(self) -> Dict[str, List[str]]: | ||
return self.models | ||
|
||
def init_model(self, provider_name: str, model: str, temperature: float, **kwargs): | ||
init_function = self.init_functions.get(provider_name) | ||
if init_function: | ||
return init_function(model, temperature, **kwargs) | ||
else: | ||
raise ValueError(f"No initialization function found for provider: {provider_name}") | ||
|
||
model_provider_manager = ModelProviderManager() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from langchain_ollama import ChatOllama | ||
|
||
PROVIDER_CONFIG = { | ||
'provider_name': 'Ollama', | ||
'base_url': 'http://host.docker.internal:11434', | ||
'api_key': 'fakeapikey', | ||
'icon': 'ollama_icon', | ||
'description': 'Ollama API provider' | ||
} | ||
|
||
SUPPORTED_MODELS = [ | ||
'llama3.1:8b', | ||
] | ||
|
||
def init_model(model: str, temperature: float, **kwargs): | ||
return ChatOllama( | ||
model=model, | ||
temperature=temperature, | ||
base_url=PROVIDER_CONFIG['base_url'], | ||
**kwargs | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from langchain_openai import ChatOpenAI | ||
|
||
PROVIDER_CONFIG = { | ||
"provider_name": "openai", | ||
"base_url": "https://api.openai.com/v1", | ||
"api_key": "your_api_key_here", | ||
"icon": "openai_icon", | ||
"description": "OpenAI API provider", | ||
} | ||
|
||
SUPPORTED_MODELS = [ | ||
"gpt-4", | ||
"gpt-4-0314", | ||
"gpt-4-32k", | ||
"gpt-4-32k-0314", | ||
"gpt-3.5-turbo", | ||
"gpt-3.5-turbo-16k", | ||
"gpt-4o-mini", | ||
] | ||
|
||
|
||
def init_model(model: str, temperature: float, **kwargs): | ||
return ChatOpenAI( | ||
model=model, | ||
temperature=temperature, | ||
openai_api_key=PROVIDER_CONFIG["api_key"], | ||
openai_api_base=PROVIDER_CONFIG["base_url"], | ||
**kwargs | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from langchain_openai import ChatOpenAI | ||
|
||
PROVIDER_CONFIG = { | ||
"provider_name": "Qwen", | ||
"base_url": "fakeurl", | ||
"api_key": "fakeapikey", | ||
"icon": "qwen_icon", | ||
"description": "qwen API provider", | ||
} | ||
|
||
SUPPORTED_MODELS = [ | ||
"Qwen2-7B-Instruct", | ||
"Qwen2.5-70B", | ||
] | ||
|
||
|
||
def init_model(model: str, temperature: float, **kwargs): | ||
return ChatOpenAI( | ||
model=model, | ||
temperature=temperature, | ||
openai_api_key=PROVIDER_CONFIG["api_key"], | ||
openai_api_base=PROVIDER_CONFIG["base_url"], | ||
**kwargs | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from langchain_openai import ChatOpenAI | ||
|
||
PROVIDER_CONFIG = { | ||
'provider_name': 'Siliconflow', | ||
'base_url': 'fakeurl', | ||
'api_key': 'fakeapikey', | ||
'icon': 'siliconflow_icon', | ||
'description': 'Siliconflow API provider' | ||
} | ||
|
||
SUPPORTED_MODELS = [ | ||
'Qwen/Qwen2-7B-Instruct', | ||
] | ||
|
||
def init_model(model: str, temperature: float, **kwargs): | ||
return ChatOpenAI( | ||
model=model, | ||
temperature=temperature, | ||
openai_api_key=PROVIDER_CONFIG['api_key'], | ||
openai_api_base=PROVIDER_CONFIG['base_url'], | ||
**kwargs | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from langchain_openai import ChatOpenAI | ||
|
||
PROVIDER_CONFIG = { | ||
"provider_name": "zhipuai", | ||
"base_url": "https://open.bigmodel.cn/api/paas/v4", | ||
"api_key": "5e867fc4396cff20bc3431d39e8c240f.d5Y8YBqIawDigP46", | ||
"icon": "zhipuai_icon", | ||
"description": "智谱AI", | ||
} | ||
|
||
SUPPORTED_MODELS = [ | ||
"glm-4-alltools", | ||
"glm-4-flash", | ||
"glm-4-0520", | ||
"glm-4-plus", | ||
"glm-4v-plus", | ||
"glm-4", | ||
"glm-4v", | ||
] | ||
|
||
|
||
def init_model(model: str, temperature: float, **kwargs): | ||
return ChatOpenAI( | ||
model=model, | ||
temperature=temperature, | ||
openai_api_key=PROVIDER_CONFIG["api_key"], | ||
openai_api_base=PROVIDER_CONFIG["base_url"], | ||
**kwargs | ||
) |
Oops, something went wrong.