Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add evaluation service module for RAG and Agent #2070

Merged
merged 10 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions assets/schema/dbgpt.sql
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ CREATE TABLE IF NOT EXISTS `knowledge_document`
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
`doc_name` varchar(100) NOT NULL COMMENT 'document path name',
`doc_type` varchar(50) NOT NULL COMMENT 'doc type',
`doc_token` varchar(100) NOT NULL COMMENT 'doc token',
`doc_token` varchar(100) NULL COMMENT 'doc token',
`space` varchar(50) NOT NULL COMMENT 'knowledge space',
`chunk_size` int NOT NULL COMMENT 'chunk size',
`last_sync` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'last sync time',
Expand All @@ -56,7 +56,7 @@ CREATE TABLE IF NOT EXISTS `document_chunk`
`document_id` int NOT NULL COMMENT 'document parent id',
`content` longtext NOT NULL COMMENT 'chunk content',
`questions` text NULL COMMENT 'chunk related questions',
`meta_info` varchar(200) NOT NULL COMMENT 'metadata info',
`meta_info` text NOT NULL COMMENT 'metadata info',
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
PRIMARY KEY (`id`),
Expand Down
8 changes: 8 additions & 0 deletions dbgpt/app/initialization/serve_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,11 @@ def register_serve_apps(system_app: SystemApp, cfg: Config, webserver_port: int)
system_app.register(FileServe)

# ################################ File Serve Register End ########################################

# ################################ Evaluate Serve Register Begin #######################################
from dbgpt.serve.evaluate.serve import Serve as EvaluateServe

# Register serve Evaluate
system_app.register(EvaluateServe)

# ################################ Evaluate Serve Register End ########################################
2 changes: 1 addition & 1 deletion dbgpt/app/knowledge/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def chunk_list(
"doc_type": query_request.doc_type,
"content": query_request.content,
}
chunk_res = service.get_chunk_list(
chunk_res = service.get_chunk_list_page(
query, query_request.page, query_request.page_size
)
res = ChunkQueryResponse(
Expand Down
28 changes: 28 additions & 0 deletions dbgpt/client/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Evaluation."""
from typing import List

from dbgpt.core.schema.api import Result

from ..core.interface.evaluation import EvaluationResult
from ..serve.evaluate.api.schemas import EvaluateServeRequest
from .client import Client, ClientException


async def run_evaluation(
client: Client, request: EvaluateServeRequest
) -> List[EvaluationResult]:
"""Run evaluation.

Args:
client (Client): The dbgpt client.
request (EvaluateServeRequest): The Evaluate Request.
"""
try:
res = await client.post("/evaluate/evaluation", request.dict())
result: Result = res.json()
if result["success"]:
return list(result["data"])
else:
raise ClientException(status=result["err_code"], reason=result)
except Exception as e:
raise ClientException(f"Failed to run evaluation: {e}")
4 changes: 2 additions & 2 deletions dbgpt/core/interface/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def __init__(self):

def register_metric(self, cls: Type[EvaluationMetric]):
"""Register metric."""
self.metrics[cls.name] = cls
self.metrics[cls.name()] = cls

def get_by_name(self, name: str) -> Type[EvaluationMetric]:
"""Get by name."""
Expand All @@ -308,4 +308,4 @@ def all_metric_infos(self):
return result


metric_mange = MetricManage()
metric_manage = MetricManage()
2 changes: 1 addition & 1 deletion dbgpt/rag/evaluation/answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ async def _do_evaluation(
contexts=contexts,
passing=result.passing,
raw_dataset=raw_dataset,
metric_name=metric.name,
metric_name=metric.name(),
feedback=result.feedback,
)
)
Expand Down
19 changes: 15 additions & 4 deletions dbgpt/rag/index/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ async def aload_document_with_limit(
max_threads,
)

@abstractmethod
def similar_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
Expand All @@ -196,16 +197,26 @@ def similar_search(
Return:
List[Chunk]: The similar documents.
"""
return self.similar_search_with_scores(text, topk, 0.0, filters)

async def asimilar_search(
self,
query: str,
topk: int,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Async similar_search in vector database."""
return await blocking_func_to_async_no_executor(
self.similar_search, query, topk, filters
)

async def asimilar_search_with_scores(
self,
doc: str,
query: str,
topk: int,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Aynsc similar_search_with_score in vector database."""
"""Async similar_search_with_score in vector database."""
return await blocking_func_to_async_no_executor(
self.similar_search_with_scores, doc, topk, score_threshold, filters
self.similar_search_with_scores, query, topk, score_threshold, filters
)
2 changes: 1 addition & 1 deletion dbgpt/rag/operators/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async def _do_evaluation(
contexts=contexts,
passing=result.passing,
raw_dataset=raw_dataset,
metric_name=metric.name,
metric_name=metric.name(),
)
)
return results
5 changes: 1 addition & 4 deletions dbgpt/rag/retriever/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.chat_util import run_async_tasks
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
from dbgpt.util.tracer import root_tracer


Expand Down Expand Up @@ -241,9 +240,7 @@ async def _similarity_search(
"query": query,
},
):
return await blocking_func_to_async_no_executor(
self._index_store.similar_search, query, self._top_k, filters
)
return await self._index_store.asimilar_search(query, self._top_k, filters)

async def _run_async_tasks(self, tasks) -> List[Chunk]:
"""Run async tasks."""
Expand Down
23 changes: 23 additions & 0 deletions dbgpt/serve/agent/agents/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
DefaultAWELLayoutManager,
GptsMemory,
LLMConfig,
ResourceType,
ShortTermMemory,
UserProxyAgent,
get_agent_manager,
Expand All @@ -43,6 +44,7 @@
from dbgpt.util.json_utils import serialize
from dbgpt.util.tracer import TracerManager

from ...rag.retriever.knowledge_space import KnowledgeSpaceRetriever
from ..db import GptsMessagesDao
from ..db.gpts_app import GptsApp, GptsAppDao, GptsAppQuery
from ..db.gpts_conversations_db import GptsConversationsDao, GptsConversationsEntity
Expand Down Expand Up @@ -602,5 +604,26 @@ async def topic_terminate(
last_gpts_conversation.conv_id, Status.COMPLETE.value
)

async def get_knowledge_resources(self, app_code: str, question: str):
"""Get the knowledge resources."""
context = []
app: GptsApp = self.get_app(app_code)
if app and app.details and len(app.details) > 0:
for detail in app.details:
if detail and detail.resources and len(detail.resources) > 0:
for resource in detail.resources:
if resource.type == ResourceType.Knowledge:
retriever = KnowledgeSpaceRetriever(
space_id=str(resource.value),
top_k=CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
)
chunks = await retriever.aretrieve_with_scores(
question, score_threshold=0.3
)
context.extend([chunk.content for chunk in chunks])
else:
continue
return context


multi_agents = MultiAgents(system_app)
3 changes: 2 additions & 1 deletion dbgpt/serve/agent/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,9 @@ async def _do_evaluation(
contexts=contexts,
passing=result.passing,
raw_dataset=raw_dataset,
metric_name=metric.name,
metric_name=metric.name(),
prediction_cost=prediction_cost,
feedback=result.feedback,
)
)
return results
Expand Down
14 changes: 11 additions & 3 deletions dbgpt/serve/agent/evaluation/evaluation_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
from dbgpt.core.interface.evaluation import (
BaseEvaluationResult,
EvaluationMetric,
metric_mange,
metric_manage,
)
from dbgpt.rag.evaluation.answer import AnswerRelevancyMetric
from dbgpt.rag.evaluation.retriever import (
RetrieverHitRateMetric,
RetrieverMRRMetric,
RetrieverSimilarityMetric,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -116,5 +122,7 @@ def sync_compute(
)


metric_mange.register_metric(IntentMetric)
metric_mange.register_metric(AppLinkMetric)
metric_manage.register_metric(RetrieverHitRateMetric)
metric_manage.register_metric(RetrieverMRRMetric)
metric_manage.register_metric(RetrieverSimilarityMetric)
metric_manage.register_metric(AnswerRelevancyMetric)
Empty file.
Empty file.
155 changes: 155 additions & 0 deletions dbgpt/serve/evaluate/api/endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import logging
from functools import cache
from typing import List, Optional

from fastapi import APIRouter, Depends, HTTPException
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer

from dbgpt.component import ComponentType, SystemApp
from dbgpt.core.interface.evaluation import metric_manage
from dbgpt.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
from dbgpt.rag.evaluation.answer import AnswerRelevancyMetric
from dbgpt.serve.core import Result
from dbgpt.serve.evaluate.api.schemas import EvaluateServeRequest, EvaluateServeResponse
from dbgpt.serve.evaluate.config import SERVE_SERVICE_COMPONENT_NAME
from dbgpt.serve.evaluate.service.service import Service

from ...prompt.service.service import Service as PromptService

router = APIRouter()

# Add your API endpoints here

global_system_app: Optional[SystemApp] = None
logger = logging.getLogger(__name__)


def get_service() -> Service:
"""Get the service instance"""
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)


def get_prompt_service() -> PromptService:
return global_system_app.get_component("dbgpt_serve_prompt_service", PromptService)


def get_worker_manager() -> WorkerManager:
worker_manager = global_system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
return worker_manager


def get_model_controller() -> BaseModelController:
controller = global_system_app.get_component(
ComponentType.MODEL_CONTROLLER, BaseModelController
)
return controller


get_bearer_token = HTTPBearer(auto_error=False)


@cache
def _parse_api_keys(api_keys: str) -> List[str]:
"""Parse the string api keys to a list

Args:
api_keys (str): The string api keys

Returns:
List[str]: The list of api keys
"""
if not api_keys:
return []
return [key.strip() for key in api_keys.split(",")]


async def check_api_key(
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
service: Service = Depends(get_service),
) -> Optional[str]:
"""Check the api key

If the api key is not set, allow all.

Your can pass the token in you request header like this:

.. code-block:: python

import requests

client_api_key = "your_api_key"
headers = {"Authorization": "Bearer " + client_api_key}
res = requests.get("http://test/hello", headers=headers)
assert res.status_code == 200

"""
if service.config.api_keys:
api_keys = _parse_api_keys(service.config.api_keys)
if auth is None or (token := auth.credentials) not in api_keys:
raise HTTPException(
status_code=401,
detail={
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
},
)
return token
else:
# api_keys not set; allow all
return None


@router.get("/health", dependencies=[Depends(check_api_key)])
async def health():
"""Health check endpoint"""
return {"status": "ok"}


@router.get("/test_auth", dependencies=[Depends(check_api_key)])
async def test_auth():
"""Test auth endpoint"""
return {"status": "ok"}


@router.get("/scenes")
async def get_scenes():
scene_list = [{"recall": "召回评测"}, {"app": "应用评测"}]

return Result.succ(scene_list)


@router.post("/evaluation")
async def evaluation(
request: EvaluateServeRequest,
service: Service = Depends(get_service),
) -> Result:
"""Evaluate results by the scene

Args:
request (EvaluateServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(
await service.run_evaluation(
request.scene_key,
request.scene_value,
request.datasets,
request.context,
request.evaluate_metrics,
)
)


def init_endpoints(system_app: SystemApp) -> None:
"""Initialize the endpoints"""
global global_system_app
system_app.register(Service)
global_system_app = system_app
Loading
Loading