You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
66 lines
2.2 KiB
66 lines
2.2 KiB
1 month ago
|
from app.chains.base_chain import BaseChain
|
||
|
from app.core.logger import logger
|
||
|
from app.core.config import settings
|
||
|
from typing import List, Dict, Optional
|
||
|
|
||
|
class QAChain(BaseChain):
|
||
|
"""
|
||
|
问答链类,继承自基础链,专门用于处理问答任务
|
||
|
|
||
|
该类使用预定义的提示模板来格式化问题,并通过基础链获取回答
|
||
|
"""
|
||
|
|
||
|
def __init__(self, model_name: str = None, temperature: float = None):
|
||
|
"""
|
||
|
初始化问答链
|
||
|
|
||
|
Args:
|
||
|
model_name (str, optional): 使用的模型名称,默认使用配置中的值
|
||
|
temperature (float, optional): 模型温度参数,控制输出的随机性,默认使用配置中的值
|
||
|
"""
|
||
|
super().__init__(
|
||
|
model_name=model_name or settings.DEFAULT_MODEL,
|
||
|
temperature=temperature or settings.TEMPERATURE
|
||
|
)
|
||
|
logger.info("初始化问答链,创建提示模板")
|
||
|
self.chain = self.create_chain(
|
||
|
template="""你是一个有帮助的AI助手。请根据以下对话历史回答问题:
|
||
|
|
||
|
历史对话:
|
||
|
{chat_history}
|
||
|
|
||
|
当前问题: {question}
|
||
|
|
||
|
回答:""",
|
||
|
input_variables=["question", "chat_history"]
|
||
|
)
|
||
|
|
||
|
def answer(self, question: str, chat_history: Optional[List[Dict[str, str]]] = None) -> str:
|
||
|
"""
|
||
|
获取问题的回答
|
||
|
|
||
|
Args:
|
||
|
question (str): 用户的问题
|
||
|
chat_history (List[Dict[str, str]], optional): 聊天历史记录
|
||
|
|
||
|
Returns:
|
||
|
str: 模型的回答
|
||
|
"""
|
||
|
logger.info(f"处理问题: {question}")
|
||
|
|
||
|
# 格式化聊天历史
|
||
|
formatted_history = ""
|
||
|
if chat_history:
|
||
|
formatted_history = "\n".join([
|
||
|
f"{'用户' if msg['role'] == 'user' else 'AI'}: {msg['content']}"
|
||
|
for msg in chat_history
|
||
|
])
|
||
|
|
||
|
# 运行链
|
||
|
answer = self.run(self.chain, {
|
||
|
"question": question,
|
||
|
"chat_history": formatted_history
|
||
|
})
|
||
|
|
||
|
logger.info(f"生成回答完成,长度: {len(answer)}")
|
||
|
return answer
|