You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

86 lines
3.0 KiB

1 month ago
from typing import Dict, Any
from langchain.prompts import PromptTemplate
from langchain_ollama import OllamaLLM
from langchain.schema.runnable import RunnablePassthrough
from app.core.logger import logger
import requests
from app.core.config import settings
class BaseChain:
"""
基础链类提供与 Ollama 模型交互的基本功能
Attributes:
llm: Ollama 语言模型实例
"""
def __init__(self, model_name: str = "llama2", temperature: float = 0.7):
"""
初始化基础链
Args:
model_name (str): 使用的模型名称默认为 llama2
temperature (float): 模型温度参数控制输出的随机性默认为 0.7
"""
logger.info(f"初始化基础链,使用模型: {model_name}, 温度: {temperature}")
# 检查 Ollama 服务是否可用
try:
response = requests.get(f"{settings.OLLAMA_BASE_URL}/api/tags")
if response.status_code != 200:
raise ConnectionError(f"Ollama 服务响应异常: {response.status_code}")
logger.info("Ollama 服务连接正常")
except requests.exceptions.ConnectionError as e:
logger.error(f"无法连接到 Ollama 服务: {str(e)}")
raise ConnectionError("请确保 Ollama 服务正在运行,并且可以通过 http://localhost:11434 访问")
# 检查模型是否已下载
try:
response = requests.get(f"{settings.OLLAMA_BASE_URL}/api/show", params={"name": model_name})
if response.status_code != 200:
logger.warning(f"模型 {model_name} 可能未下载,将尝试使用")
except Exception as e:
logger.warning(f"检查模型状态时出错: {str(e)}")
self.llm = OllamaLLM(
model=model_name,
temperature=temperature,
base_url=settings.OLLAMA_BASE_URL
)
def create_chain(self, template: str, input_variables: list):
"""
创建新的链实例
Args:
template (str): 提示模板
input_variables (list): 输入变量列表
Returns:
RunnableSequence: 创建的链实例
"""
logger.debug(f"创建新链,模板变量: {input_variables}")
prompt = PromptTemplate(
template=template,
input_variables=input_variables
)
# 使用新的链式调用方式
return prompt | self.llm
def run(self, chain, inputs: Dict[str, Any]) -> str:
"""
运行链并获取结果
Args:
chain: 要运行的链实例
inputs (Dict[str, Any]): 输入参数字典
Returns:
str: 模型生成的回答
"""
try:
logger.debug(f"运行链,输入参数: {inputs}")
return chain.invoke(inputs)
except Exception as e:
logger.error(f"运行链时发生错误: {str(e)}", exc_info=True)
raise