You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

66 lines
2.2 KiB

from app.chains.base_chain import BaseChain
from app.core.logger import logger
from app.core.config import settings
from typing import List, Dict, Optional
class QAChain(BaseChain):
"""
问答链类,继承自基础链,专门用于处理问答任务
该类使用预定义的提示模板来格式化问题,并通过基础链获取回答
"""
def __init__(self, model_name: str = None, temperature: float = None):
"""
初始化问答链
Args:
model_name (str, optional): 使用的模型名称,默认使用配置中的值
temperature (float, optional): 模型温度参数,控制输出的随机性,默认使用配置中的值
"""
super().__init__(
model_name=model_name or settings.DEFAULT_MODEL,
temperature=temperature or settings.TEMPERATURE
)
logger.info("初始化问答链,创建提示模板")
self.chain = self.create_chain(
template="""你是一个有帮助的AI助手。请根据以下对话历史回答问题:
历史对话:
{chat_history}
当前问题: {question}
回答:""",
input_variables=["question", "chat_history"]
)
def answer(self, question: str, chat_history: Optional[List[Dict[str, str]]] = None) -> str:
"""
获取问题的回答
Args:
question (str): 用户的问题
chat_history (List[Dict[str, str]], optional): 聊天历史记录
Returns:
str: 模型的回答
"""
logger.info(f"处理问题: {question}")
# 格式化聊天历史
formatted_history = ""
if chat_history:
formatted_history = "\n".join([
f"{'用户' if msg['role'] == 'user' else 'AI'}: {msg['content']}"
for msg in chat_history
])
# 运行链
answer = self.run(self.chain, {
"question": question,
"chat_history": formatted_history
})
logger.info(f"生成回答完成,长度: {len(answer)}")
return answer