You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

58 lines
2.3 KiB

from app.chains.qa_chain import QAChain
from app.chains.rag_chain import RAGChain
from app.services.vector_store_service import VectorStoreService
from app.core.logger import logger
from app.core.config import settings
from typing import List, Dict, Optional
class QAService:
"""问答服务类,处理问答相关的业务逻辑"""
def __init__(self, model_name: str = None, temperature: float = None):
"""
初始化问答服务
Args:
model_name (str, optional): 使用的模型名称,默认使用配置中的值
temperature (float, optional): 模型温度参数,默认使用配置中的值
"""
self.qa_chain = QAChain(
model_name=model_name or settings.DEFAULT_MODEL,
temperature=temperature or settings.TEMPERATURE
)
# 初始化 RAG 相关服务
self.vector_store_service = VectorStoreService()
self.rag_chain = RAGChain(self.vector_store_service)
def get_answer(self, question: str, chat_history: Optional[List[Dict[str, str]]] = None, use_rag: bool = True) -> str:
"""
获取问题的回答
Args:
question (str): 用户问题
chat_history (List[Dict[str, str]], optional): 聊天历史记录
use_rag (bool): 是否使用 RAG 增强回答,默认为 True
Returns:
str: AI回答
"""
try:
logger.info(f"处理问题: {question}")
if use_rag:
# 使用 RAG 获取回答
rag_result = self.rag_chain.query(question)
answer = rag_result["result"]
# 如果有相关文档,添加到回答中
if rag_result["source_documents"]:
sources = [doc.page_content for doc in rag_result["source_documents"]]
answer += "\n\n参考来源:\n" + "\n".join(sources)
else:
# 使用基础 QA 获取回答
answer = self.qa_chain.answer(question, chat_history)
logger.info(f"生成回答完成,长度: {len(answer)}")
return answer
except Exception as e:
logger.error(f"生成回答时发生错误: {str(e)}", exc_info=True)
raise