from app.chains.qa_chain import QAChain from app.chains.rag_chain import RAGChain from app.services.vector_store_service import VectorStoreService from app.core.logger import logger from app.core.config import settings from typing import List, Dict, Optional class QAService: """问答服务类,处理问答相关的业务逻辑""" def __init__(self, model_name: str = None, temperature: float = None): """ 初始化问答服务 Args: model_name (str, optional): 使用的模型名称,默认使用配置中的值 temperature (float, optional): 模型温度参数,默认使用配置中的值 """ self.qa_chain = QAChain( model_name=model_name or settings.DEFAULT_MODEL, temperature=temperature or settings.TEMPERATURE ) # 初始化 RAG 相关服务 self.vector_store_service = VectorStoreService() self.rag_chain = RAGChain(self.vector_store_service) def get_answer(self, question: str, chat_history: Optional[List[Dict[str, str]]] = None, use_rag: bool = True) -> str: """ 获取问题的回答 Args: question (str): 用户问题 chat_history (List[Dict[str, str]], optional): 聊天历史记录 use_rag (bool): 是否使用 RAG 增强回答,默认为 True Returns: str: AI回答 """ try: logger.info(f"处理问题: {question}") if use_rag: # 使用 RAG 获取回答 rag_result = self.rag_chain.query(question) answer = rag_result["result"] # 如果有相关文档,添加到回答中 if rag_result["source_documents"]: sources = [doc.page_content for doc in rag_result["source_documents"]] answer += "\n\n参考来源:\n" + "\n".join(sources) else: # 使用基础 QA 获取回答 answer = self.qa_chain.answer(question, chat_history) logger.info(f"生成回答完成,长度: {len(answer)}") return answer except Exception as e: logger.error(f"生成回答时发生错误: {str(e)}", exc_info=True) raise