You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
58 lines
2.3 KiB
58 lines
2.3 KiB
1 month ago
|
from app.chains.qa_chain import QAChain
|
||
|
from app.chains.rag_chain import RAGChain
|
||
|
from app.services.vector_store_service import VectorStoreService
|
||
|
from app.core.logger import logger
|
||
|
from app.core.config import settings
|
||
|
from typing import List, Dict, Optional
|
||
|
|
||
|
class QAService:
|
||
|
"""问答服务类,处理问答相关的业务逻辑"""
|
||
|
|
||
|
def __init__(self, model_name: str = None, temperature: float = None):
|
||
|
"""
|
||
|
初始化问答服务
|
||
|
|
||
|
Args:
|
||
|
model_name (str, optional): 使用的模型名称,默认使用配置中的值
|
||
|
temperature (float, optional): 模型温度参数,默认使用配置中的值
|
||
|
"""
|
||
|
self.qa_chain = QAChain(
|
||
|
model_name=model_name or settings.DEFAULT_MODEL,
|
||
|
temperature=temperature or settings.TEMPERATURE
|
||
|
)
|
||
|
# 初始化 RAG 相关服务
|
||
|
self.vector_store_service = VectorStoreService()
|
||
|
self.rag_chain = RAGChain(self.vector_store_service)
|
||
|
|
||
|
def get_answer(self, question: str, chat_history: Optional[List[Dict[str, str]]] = None, use_rag: bool = True) -> str:
|
||
|
"""
|
||
|
获取问题的回答
|
||
|
|
||
|
Args:
|
||
|
question (str): 用户问题
|
||
|
chat_history (List[Dict[str, str]], optional): 聊天历史记录
|
||
|
use_rag (bool): 是否使用 RAG 增强回答,默认为 True
|
||
|
|
||
|
Returns:
|
||
|
str: AI回答
|
||
|
"""
|
||
|
try:
|
||
|
logger.info(f"处理问题: {question}")
|
||
|
|
||
|
if use_rag:
|
||
|
# 使用 RAG 获取回答
|
||
|
rag_result = self.rag_chain.query(question)
|
||
|
answer = rag_result["result"]
|
||
|
# 如果有相关文档,添加到回答中
|
||
|
if rag_result["source_documents"]:
|
||
|
sources = [doc.page_content for doc in rag_result["source_documents"]]
|
||
|
answer += "\n\n参考来源:\n" + "\n".join(sources)
|
||
|
else:
|
||
|
# 使用基础 QA 获取回答
|
||
|
answer = self.qa_chain.answer(question, chat_history)
|
||
|
|
||
|
logger.info(f"生成回答完成,长度: {len(answer)}")
|
||
|
return answer
|
||
|
except Exception as e:
|
||
|
logger.error(f"生成回答时发生错误: {str(e)}", exc_info=True)
|
||
|
raise
|