You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
86 lines
3.0 KiB
86 lines
3.0 KiB
from typing import Dict, Any
|
|
from langchain.prompts import PromptTemplate
|
|
from langchain_ollama import OllamaLLM
|
|
from langchain.schema.runnable import RunnablePassthrough
|
|
from app.core.logger import logger
|
|
import requests
|
|
from app.core.config import settings
|
|
|
|
class BaseChain:
|
|
"""
|
|
基础链类,提供与 Ollama 模型交互的基本功能
|
|
|
|
Attributes:
|
|
llm: Ollama 语言模型实例
|
|
"""
|
|
|
|
def __init__(self, model_name: str = "llama2", temperature: float = 0.7):
|
|
"""
|
|
初始化基础链
|
|
|
|
Args:
|
|
model_name (str): 使用的模型名称,默认为 llama2
|
|
temperature (float): 模型温度参数,控制输出的随机性,默认为 0.7
|
|
"""
|
|
logger.info(f"初始化基础链,使用模型: {model_name}, 温度: {temperature}")
|
|
|
|
# 检查 Ollama 服务是否可用
|
|
try:
|
|
response = requests.get(f"{settings.OLLAMA_BASE_URL}/api/tags")
|
|
if response.status_code != 200:
|
|
raise ConnectionError(f"Ollama 服务响应异常: {response.status_code}")
|
|
logger.info("Ollama 服务连接正常")
|
|
except requests.exceptions.ConnectionError as e:
|
|
logger.error(f"无法连接到 Ollama 服务: {str(e)}")
|
|
raise ConnectionError("请确保 Ollama 服务正在运行,并且可以通过 http://localhost:11434 访问")
|
|
|
|
# 检查模型是否已下载
|
|
try:
|
|
response = requests.get(f"{settings.OLLAMA_BASE_URL}/api/show", params={"name": model_name})
|
|
if response.status_code != 200:
|
|
logger.warning(f"模型 {model_name} 可能未下载,将尝试使用")
|
|
except Exception as e:
|
|
logger.warning(f"检查模型状态时出错: {str(e)}")
|
|
|
|
self.llm = OllamaLLM(
|
|
model=model_name,
|
|
temperature=temperature,
|
|
base_url=settings.OLLAMA_BASE_URL
|
|
)
|
|
|
|
def create_chain(self, template: str, input_variables: list):
|
|
"""
|
|
创建新的链实例
|
|
|
|
Args:
|
|
template (str): 提示模板
|
|
input_variables (list): 输入变量列表
|
|
|
|
Returns:
|
|
RunnableSequence: 创建的链实例
|
|
"""
|
|
logger.debug(f"创建新链,模板变量: {input_variables}")
|
|
prompt = PromptTemplate(
|
|
template=template,
|
|
input_variables=input_variables
|
|
)
|
|
# 使用新的链式调用方式
|
|
return prompt | self.llm
|
|
|
|
def run(self, chain, inputs: Dict[str, Any]) -> str:
|
|
"""
|
|
运行链并获取结果
|
|
|
|
Args:
|
|
chain: 要运行的链实例
|
|
inputs (Dict[str, Any]): 输入参数字典
|
|
|
|
Returns:
|
|
str: 模型生成的回答
|
|
"""
|
|
try:
|
|
logger.debug(f"运行链,输入参数: {inputs}")
|
|
return chain.invoke(inputs)
|
|
except Exception as e:
|
|
logger.error(f"运行链时发生错误: {str(e)}", exc_info=True)
|
|
raise
|