Skip to content

Commit

Permalink
Feat/models manage (#86)
Browse files Browse the repository at this point in the history
* feat/add model management

* fix:model add and update

* feat:add model provider qwen
  • Loading branch information
Onelevenvy authored Oct 14, 2024
1 parent 11a1883 commit 7674333
Show file tree
Hide file tree
Showing 10 changed files with 314 additions and 112 deletions.
122 changes: 82 additions & 40 deletions backend/app/core/db.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os

from sqlalchemy import text
import importlib
from sqlalchemy import text, func
from sqlmodel import Session, create_engine, select

from app.core.config import settings
from app.core.tools import managed_tools
from app.curd import users
from app.models import Skill, User, UserCreate
from app.models import Skill, User, UserCreate, ModelProvider, Models
from app.core.model_providers.model_provider_manager import model_provider_manager


def get_url():
Expand All @@ -25,7 +26,20 @@ def get_url():
# make sure all SQLModel models are imported (app.models) before initializing DB
# otherwise, SQLModel might fail to initialize relationships properly
# for more details: https://github.com/tiangolo/full-stack-fastapi-template/issues/28

def print_skills_info(session: Session) -> None:
print("\nSkills Information:")
skills = session.exec(select(Skill).order_by(Skill.id)).all()
for skill in skills:
print(f"Skill: {skill.name} (ID: {skill.id})")
print(f" Display Name: {skill.display_name}")
print(f" Description: {skill.description}")
print(f" Managed: {'Yes' if skill.managed else 'No'}")
print(f" Owner ID: {skill.owner_id}")
if skill.input_parameters:
print(" Input Parameters:")
for param, param_type in skill.input_parameters.items():
print(f" - {param}: {param_type}")
print()

def init_db(session: Session) -> None:
# Tables should be created with Alembic migrations
Expand Down Expand Up @@ -61,12 +75,15 @@ def init_db(session: Session) -> None:
if (
existing_skill.description != skill_info.description
or existing_skill.display_name != skill_info.display_name
or existing_skill.input_parameters != skill_info.input_parameters # 检查输入参数是否变化
or existing_skill.input_parameters
!= skill_info.input_parameters # 检查输入参数是否变化
):
# Update the existing skill's description and input parameters
existing_skill.description = skill_info.description
existing_skill.display_name = skill_info.display_name
existing_skill.input_parameters = skill_info.input_parameters # 更新输入参数
existing_skill.input_parameters = (
skill_info.input_parameters
) # 更新输入参数
session.add(existing_skill) # Mark the modified object for saving
else:
new_skill = Skill(
Expand All @@ -90,39 +107,64 @@ def init_db(session: Session) -> None:

session.commit()

# 打印 skills 信息
print_skills_info(session)



def init_modelprovider_model_db(session: Session) -> None:
# Insert or update ModelProvider data
model_provider_sql = """
INSERT INTO ModelProvider (id, provider_name, base_url, api_key, icon, description)
VALUES
(4, 'openai', 'fakeurl', 'fakeapikey', 'string', 'open ai'),
(1, 'Ollama', 'fakeurl', 'fakeapikey', 'string', 'string fake'),
(2, 'Siliconflow', 'fakeurl', 'fakeapikey', 'string', 'siliconflow'),
(3, 'zhipuai', 'https://open.bigmodel.cn/api/paas/v4', 'fakeapikey', 'zhipuai', '智谱AI')
ON CONFLICT (id) DO NOTHING;
"""

# Insert Models data
models_sql = """
INSERT INTO Models (id, ai_model_name, provider_id)
VALUES
(1, 'gpt4', 4),
(2, 'gpt4o', 4),
(3, 'gpt4o-mini', 4),
(4, 'llama3.1:8b', 1),
(5, 'Qwen/Qwen2-7B-Instruct', 2),
(6, 'glm-4-alltools', 3),
(7, 'glm-4-flash', 3),
(8, 'glm-4-0520', 3),
(9, 'glm-4-plus', 3),
(10, 'glm-4v-plus', 3),
(11, 'glm-4', 3),
(12, 'glm-4v', 3)
ON CONFLICT (id) DO NOTHING;
"""

# Execute the SQL statements
session.exec(text(model_provider_sql))
session.exec(text(models_sql))
session.commit()
# 获取所有提供商配置
providers = model_provider_manager.get_all_providers()

# 按照提供商名称排序,确保处理顺序一致
for provider_name in sorted(providers.keys()):
provider_data = providers[provider_name]

# 查找现有的提供商记录
db_provider = session.exec(select(ModelProvider).where(ModelProvider.provider_name == provider_data['provider_name'])).first()

if db_provider:
# 更新提供商信息,但保留现有的 API 密钥和基础 URL
db_provider.icon = provider_data['icon']
db_provider.description = provider_data['description']
# 注意:我们不更新 api_key 和 base_url,因为它们可能已被用户修改
else:
# 如果提供商不存在,创建新记录
db_provider = ModelProvider(
provider_name=provider_data['provider_name'],
base_url=provider_data['base_url'],
api_key=provider_data['api_key'],
icon=provider_data['icon'],
description=provider_data['description']
)
session.add(db_provider)

session.flush() # 确保 provider_id 已生成

# 获取该提供商支持的模型
supported_models = set(model_provider_manager.get_supported_models(provider_name))

# 获取数据库中该提供商现有的模型
existing_models = set(model.ai_model_name for model in session.exec(select(Models).where(Models.provider_id == db_provider.id)))

# 添加新模型
for model_name in sorted(supported_models - existing_models):
new_model = Models(ai_model_name=model_name, provider_id=db_provider.id)
session.add(new_model)

# 删除不再支持的模型
for model_name in sorted(existing_models - supported_models):
session.exec(select(Models).where(Models.ai_model_name == model_name, Models.provider_id == db_provider.id)).delete()

session.commit()

# 打印当前数据库状态,用于验证
providers = session.exec(select(ModelProvider).order_by(ModelProvider.id)).all()
for provider in providers:
print(f"Provider: {provider.provider_name} (ID: {provider.id})")
print(f" Base URL: {provider.base_url}")
print(f" API Key: {'*' * len(provider.api_key)}") # 出于安全考虑,不打印实际的 API 密钥
models = session.exec(select(Models).where(Models.provider_id == provider.id).order_by(Models.id)).all()
for model in models:
print(f" - Model: {model.ai_model_name} (ID: {model.id})")

8 changes: 8 additions & 0 deletions backend/app/core/model_providers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import importlib

def get_provider_config(provider_name: str):
try:
module = importlib.import_module(f'app.core.model_providers.{provider_name}.config')
return module.PROVIDER_CONFIG
except ImportError:
raise ValueError(f"No configuration found for provider: {provider_name}")
48 changes: 48 additions & 0 deletions backend/app/core/model_providers/model_provider_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
import importlib
from typing import Dict, Any, List, Callable

class ModelProviderManager:
def __init__(self):
self.providers: Dict[str, Dict[str, Any]] = {}
self.models: Dict[str, List[str]] = {}
self.init_functions: Dict[str, Callable] = {}
self.load_providers()

def load_providers(self):
providers_dir = os.path.dirname(os.path.abspath(__file__))
for item in os.listdir(providers_dir):
if os.path.isdir(os.path.join(providers_dir, item)) and not item.startswith("__"):
try:
module = importlib.import_module(f".{item}.config", package="app.core.model_providers")
provider_config = getattr(module, 'PROVIDER_CONFIG', None)
supported_models = getattr(module, 'SUPPORTED_MODELS', [])
init_function = getattr(module, 'init_model', None)

if provider_config and init_function:
self.providers[item] = provider_config
self.models[item] = supported_models
self.init_functions[item] = init_function
except ImportError as e:
print(f"Failed to load provider config for {item}: {e}")

def get_provider_config(self, provider_name: str) -> Dict[str, Any]:
return self.providers.get(provider_name, {})

def get_supported_models(self, provider_name: str) -> List[str]:
return self.models.get(provider_name, [])

def get_all_providers(self) -> Dict[str, Dict[str, Any]]:
return self.providers

def get_all_models(self) -> Dict[str, List[str]]:
return self.models

def init_model(self, provider_name: str, model: str, temperature: float, **kwargs):
init_function = self.init_functions.get(provider_name)
if init_function:
return init_function(model, temperature, **kwargs)
else:
raise ValueError(f"No initialization function found for provider: {provider_name}")

model_provider_manager = ModelProviderManager()
21 changes: 21 additions & 0 deletions backend/app/core/model_providers/ollama/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from langchain_ollama import ChatOllama

PROVIDER_CONFIG = {
'provider_name': 'Ollama',
'base_url': 'http://host.docker.internal:11434',
'api_key': 'fakeapikey',
'icon': 'ollama_icon',
'description': 'Ollama API provider'
}

SUPPORTED_MODELS = [
'llama3.1:8b',
]

def init_model(model: str, temperature: float, **kwargs):
return ChatOllama(
model=model,
temperature=temperature,
base_url=PROVIDER_CONFIG['base_url'],
**kwargs
)
29 changes: 29 additions & 0 deletions backend/app/core/model_providers/openai/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from langchain_openai import ChatOpenAI

PROVIDER_CONFIG = {
"provider_name": "openai",
"base_url": "https://api.openai.com/v1",
"api_key": "your_api_key_here",
"icon": "openai_icon",
"description": "OpenAI API provider",
}

SUPPORTED_MODELS = [
"gpt-4",
"gpt-4-0314",
"gpt-4-32k",
"gpt-4-32k-0314",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-4o-mini",
]


def init_model(model: str, temperature: float, **kwargs):
return ChatOpenAI(
model=model,
temperature=temperature,
openai_api_key=PROVIDER_CONFIG["api_key"],
openai_api_base=PROVIDER_CONFIG["base_url"],
**kwargs
)
24 changes: 24 additions & 0 deletions backend/app/core/model_providers/qwen/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from langchain_openai import ChatOpenAI

PROVIDER_CONFIG = {
"provider_name": "Qwen",
"base_url": "fakeurl",
"api_key": "fakeapikey",
"icon": "qwen_icon",
"description": "qwen API provider",
}

SUPPORTED_MODELS = [
"Qwen2-7B-Instruct",
"Qwen2.5-70B",
]


def init_model(model: str, temperature: float, **kwargs):
return ChatOpenAI(
model=model,
temperature=temperature,
openai_api_key=PROVIDER_CONFIG["api_key"],
openai_api_base=PROVIDER_CONFIG["base_url"],
**kwargs
)
22 changes: 22 additions & 0 deletions backend/app/core/model_providers/siliconflow/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from langchain_openai import ChatOpenAI

PROVIDER_CONFIG = {
'provider_name': 'Siliconflow',
'base_url': 'fakeurl',
'api_key': 'fakeapikey',
'icon': 'siliconflow_icon',
'description': 'Siliconflow API provider'
}

SUPPORTED_MODELS = [
'Qwen/Qwen2-7B-Instruct',
]

def init_model(model: str, temperature: float, **kwargs):
return ChatOpenAI(
model=model,
temperature=temperature,
openai_api_key=PROVIDER_CONFIG['api_key'],
openai_api_base=PROVIDER_CONFIG['base_url'],
**kwargs
)
29 changes: 29 additions & 0 deletions backend/app/core/model_providers/zhipuai/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from langchain_openai import ChatOpenAI

PROVIDER_CONFIG = {
"provider_name": "zhipuai",
"base_url": "https://open.bigmodel.cn/api/paas/v4",
"api_key": "5e867fc4396cff20bc3431d39e8c240f.d5Y8YBqIawDigP46",
"icon": "zhipuai_icon",
"description": "智谱AI",
}

SUPPORTED_MODELS = [
"glm-4-alltools",
"glm-4-flash",
"glm-4-0520",
"glm-4-plus",
"glm-4v-plus",
"glm-4",
"glm-4v",
]


def init_model(model: str, temperature: float, **kwargs):
return ChatOpenAI(
model=model,
temperature=temperature,
openai_api_key=PROVIDER_CONFIG["api_key"],
openai_api_base=PROVIDER_CONFIG["base_url"],
**kwargs
)
Loading

0 comments on commit 7674333

Please sign in to comment.