from app.chains.base_chain import BaseChain from app.core.logger import logger from app.core.config import settings from typing import List, Dict, Optional class QAChain(BaseChain): """ 问答链类,继承自基础链,专门用于处理问答任务 该类使用预定义的提示模板来格式化问题,并通过基础链获取回答 """ def __init__(self, model_name: str = None, temperature: float = None): """ 初始化问答链 Args: model_name (str, optional): 使用的模型名称,默认使用配置中的值 temperature (float, optional): 模型温度参数,控制输出的随机性,默认使用配置中的值 """ super().__init__( model_name=model_name or settings.DEFAULT_MODEL, temperature=temperature or settings.TEMPERATURE ) logger.info("初始化问答链,创建提示模板") self.chain = self.create_chain( template="""你是一个有帮助的AI助手。请根据以下对话历史回答问题: 历史对话: {chat_history} 当前问题: {question} 回答:""", input_variables=["question", "chat_history"] ) def answer(self, question: str, chat_history: Optional[List[Dict[str, str]]] = None) -> str: """ 获取问题的回答 Args: question (str): 用户的问题 chat_history (List[Dict[str, str]], optional): 聊天历史记录 Returns: str: 模型的回答 """ logger.info(f"处理问题: {question}") # 格式化聊天历史 formatted_history = "" if chat_history: formatted_history = "\n".join([ f"{'用户' if msg['role'] == 'user' else 'AI'}: {msg['content']}" for msg in chat_history ]) # 运行链 answer = self.run(self.chain, { "question": question, "chat_history": formatted_history }) logger.info(f"生成回答完成,长度: {len(answer)}") return answer