commit
546b2f5045
56 changed files with 1664 additions and 0 deletions
@ -0,0 +1,194 @@ |
|||
# LangChain Ollama API |
|||
|
|||
基于 Ollama 本地大模型和 LangChain 的 AI 服务 API。 |
|||
|
|||
## 功能特点 |
|||
|
|||
- 使用 Ollama 本地大模型进行推理 |
|||
- 基于 LangChain 框架构建提示词链 |
|||
- Flask API 服务端点 |
|||
- 支持自定义提示词模板 |
|||
- 完全本地化部署,无需外部 API |
|||
- 完整的 API 文档(Swagger UI) |
|||
- JSON 格式的日志记录 |
|||
|
|||
## 环境要求 |
|||
|
|||
- Python 3.8+ |
|||
- Conda |
|||
- Ollama 服务(本地运行) |
|||
- Qwen2.5 模型(通过 Ollama 安装) |
|||
|
|||
## 安装步骤 |
|||
|
|||
1. 克隆项目并进入项目目录: |
|||
```bash |
|||
git clone <repository-url> |
|||
cd <project-directory> |
|||
``` |
|||
|
|||
2. 创建并激活 Conda 环境: |
|||
```bash |
|||
# 创建环境 |
|||
conda create -n langchain-ollama python=3.8 |
|||
# 激活环境 |
|||
conda activate langchain-ollama |
|||
``` |
|||
|
|||
3. 安装依赖: |
|||
```bash |
|||
pip install -r requirements.txt |
|||
``` |
|||
|
|||
4. 安装 Qwen2.5 模型(如果尚未安装): |
|||
```bash |
|||
ollama pull qwen2.5:latest |
|||
``` |
|||
|
|||
5. 配置环境变量: |
|||
创建 `.env` 文件并设置以下变量(可选,有默认值): |
|||
```env |
|||
OLLAMA_BASE_URL=http://localhost:11434 |
|||
DEFAULT_MODEL=qwen2.5:latest |
|||
FLASK_HOST=0.0.0.0 |
|||
FLASK_PORT=5000 |
|||
FLASK_DEBUG=False |
|||
MAX_TOKENS=2048 |
|||
TEMPERATURE=0.7 |
|||
``` |
|||
|
|||
## 运行服务 |
|||
|
|||
### 开发环境 |
|||
|
|||
1. 确保已激活 Conda 环境: |
|||
```bash |
|||
conda activate langchain-ollama |
|||
``` |
|||
|
|||
2. 确保 Ollama 服务已启动并运行在本地 |
|||
|
|||
3. 启动开发服务器: |
|||
```bash |
|||
python app.py |
|||
``` |
|||
|
|||
4. 访问 API 文档: |
|||
- Swagger UI: http://localhost:5000/docs |
|||
- ReDoc: http://localhost:5000/redoc |
|||
|
|||
### 生产环境 |
|||
|
|||
1. 安装生产服务器: |
|||
```bash |
|||
pip install gunicorn |
|||
``` |
|||
|
|||
2. 使用 Gunicorn 启动服务: |
|||
```bash |
|||
# 基本启动 |
|||
gunicorn -w 4 -b 0.0.0.0:5000 app:app |
|||
|
|||
# 使用配置文件启动(推荐) |
|||
gunicorn -c gunicorn.conf.py app:app |
|||
``` |
|||
|
|||
3. 创建 Gunicorn 配置文件 `gunicorn.conf.py`: |
|||
```python |
|||
# 工作进程数 |
|||
workers = 4 |
|||
# 工作模式 |
|||
worker_class = 'sync' |
|||
# 绑定地址 |
|||
bind = '0.0.0.0:5000' |
|||
# 超时时间 |
|||
timeout = 120 |
|||
# 最大请求数 |
|||
max_requests = 1000 |
|||
# 最大请求抖动 |
|||
max_requests_jitter = 50 |
|||
# 访问日志 |
|||
accesslog = 'access.log' |
|||
# 错误日志 |
|||
errorlog = 'error.log' |
|||
# 日志级别 |
|||
loglevel = 'info' |
|||
``` |
|||
|
|||
4. 使用 systemd 管理服务(Linux): |
|||
```ini |
|||
# /etc/systemd/system/langchain-ollama.service |
|||
[Unit] |
|||
Description=LangChain Ollama API Service |
|||
After=network.target |
|||
|
|||
[Service] |
|||
User=your_user |
|||
Group=your_group |
|||
WorkingDirectory=/path/to/your/app |
|||
Environment="PATH=/path/to/your/conda/env/bin" |
|||
ExecStart=/path/to/your/conda/env/bin/gunicorn -c gunicorn.conf.py app:app |
|||
Restart=always |
|||
|
|||
[Install] |
|||
WantedBy=multi-user.target |
|||
``` |
|||
|
|||
## API 端点 |
|||
|
|||
### 健康检查 |
|||
- GET `/api/v1/health` |
|||
- 返回服务状态 |
|||
|
|||
### 聊天接口 |
|||
- POST `/api/v1/chat` |
|||
- 请求体: |
|||
```json |
|||
{ |
|||
"question": "你的问题" |
|||
} |
|||
``` |
|||
- 响应: |
|||
```json |
|||
{ |
|||
"question": "原始问题", |
|||
"answer": "AI回答" |
|||
} |
|||
``` |
|||
|
|||
## 日志 |
|||
|
|||
服务使用 JSON 格式记录日志,包含以下信息: |
|||
- 时间戳 |
|||
- 日志级别 |
|||
- 文件名和行号 |
|||
- 函数名 |
|||
- 日志消息 |
|||
|
|||
## 注意事项 |
|||
|
|||
1. 确保 Ollama 服务已正确安装并运行 |
|||
2. 默认使用 qwen2.5:latest 模型,可以通过环境变量更改 |
|||
3. 建议在生产环境中设置适当的温度参数和最大 token 限制 |
|||
4. 使用 Conda 环境时,确保每次运行前都已激活环境 |
|||
5. 开发环境仅用于测试,生产环境请使用 Gunicorn 部署 |
|||
|
|||
## 自定义提示词链 |
|||
|
|||
可以通过继承 `BaseChain` 类来创建自定义的提示词链。示例: |
|||
|
|||
```python |
|||
from chains.base_chain import BaseChain |
|||
|
|||
class CustomChain(BaseChain): |
|||
def __init__(self, model_name="qwen2.5:latest", temperature=0.7): |
|||
super().__init__(model_name, temperature) |
|||
self.chain = self.create_chain( |
|||
template="你的提示词模板", |
|||
input_variables=["你的输入变量"] |
|||
) |
|||
``` |
|||
|
|||
## 许可证 |
|||
|
|||
MIT |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,49 @@ |
|||
import os |
|||
from fastapi import FastAPI |
|||
from fastapi.templating import Jinja2Templates |
|||
from fastapi.middleware.cors import CORSMiddleware |
|||
from app.core.config import settings |
|||
from app.api.routes import register_routes |
|||
from app.services.qa_service import QAService |
|||
from app.core.logger import logger |
|||
|
|||
def create_app() -> FastAPI: |
|||
"""创建并配置FastAPI应用实例""" |
|||
logger.info("Starting application initialization...") |
|||
|
|||
app = FastAPI( |
|||
title=settings.APP_NAME, |
|||
description=settings.APP_DESCRIPTION, |
|||
version=settings.APP_VERSION, |
|||
docs_url="/docs", |
|||
redoc_url="/redoc" |
|||
) |
|||
|
|||
logger.info("Configuring CORS middleware...") |
|||
# 配置CORS |
|||
app.add_middleware( |
|||
CORSMiddleware, |
|||
allow_origins=settings.CORS_ORIGINS, |
|||
allow_credentials=settings.CORS_CREDENTIALS, |
|||
allow_methods=settings.CORS_METHODS, |
|||
allow_headers=settings.CORS_HEADERS, |
|||
) |
|||
|
|||
logger.info("Setting up templates...") |
|||
# 配置模板 |
|||
templates = Jinja2Templates(directory=os.path.join(os.path.dirname(__file__), 'templates')) |
|||
app.state.templates = templates |
|||
|
|||
logger.info(f"Initializing QA service with model: {settings.DEFAULT_MODEL}") |
|||
# 初始化服务 |
|||
qa_service = QAService( |
|||
model_name=settings.DEFAULT_MODEL, |
|||
temperature=settings.TEMPERATURE |
|||
) |
|||
|
|||
logger.info("Registering routes...") |
|||
# 注册路由 |
|||
register_routes(app, qa_service) |
|||
|
|||
logger.info("Application initialization completed successfully!") |
|||
return app |
Binary file not shown.
@ -0,0 +1 @@ |
|||
# API package initialization |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,72 @@ |
|||
from fastapi import APIRouter, HTTPException |
|||
from pydantic import BaseModel, validator |
|||
from app.services.document_service import DocumentService |
|||
from app.services.vector_store_service import VectorStoreService |
|||
from app.chains.rag_chain import RAGChain |
|||
from app.core.logger import logger |
|||
import os |
|||
from pathlib import Path |
|||
|
|||
router = APIRouter(prefix="/rag", tags=["RAG"]) |
|||
|
|||
# 初始化服务 |
|||
document_service = DocumentService() |
|||
vector_store_service = VectorStoreService() |
|||
rag_chain = RAGChain(vector_store_service) |
|||
|
|||
class DirectoryRequest(BaseModel): |
|||
directory: str |
|||
|
|||
@validator('directory') |
|||
def validate_directory(cls, v): |
|||
# 统一路径分隔符 |
|||
path = Path(v) |
|||
if not path.exists(): |
|||
raise ValueError(f"路径不存在: {v}") |
|||
if not path.is_file() and not path.is_dir(): |
|||
raise ValueError(f"无效的路径: {v}") |
|||
return str(path.absolute()) |
|||
|
|||
class QuestionRequest(BaseModel): |
|||
question: str |
|||
|
|||
@router.post("/upload") |
|||
async def upload_documents(request: DirectoryRequest): |
|||
"""上传并处理文档""" |
|||
try: |
|||
logger.info(f"开始处理文档: {request.directory}") |
|||
# 处理文档 |
|||
documents = document_service.process_documents(request.directory) |
|||
if not documents: |
|||
return {"message": "没有新的文档需要处理"} |
|||
# 添加到向量存储 |
|||
vector_store_service.add_documents(documents) |
|||
logger.info(f"成功处理 {len(documents)} 个文档块") |
|||
return {"message": f"成功处理 {len(documents)} 个文档块"} |
|||
except Exception as e: |
|||
logger.error(f"处理文档时发生错误: {str(e)}", exc_info=True) |
|||
raise HTTPException(status_code=500, detail=str(e)) |
|||
|
|||
@router.post("/query") |
|||
async def query(request: QuestionRequest): |
|||
"""查询问题""" |
|||
try: |
|||
result = rag_chain.query(request.question) |
|||
return { |
|||
"answer": result["result"], |
|||
"sources": [doc.page_content for doc in result["source_documents"]] |
|||
} |
|||
except Exception as e: |
|||
logger.error(f"查询问题时发生错误: {str(e)}", exc_info=True) |
|||
raise HTTPException(status_code=500, detail=str(e)) |
|||
|
|||
@router.post("/clear") |
|||
async def clear_vector_store(): |
|||
"""清空向量存储""" |
|||
try: |
|||
vector_store_service.clear() |
|||
document_service.clear_processed_files() # 同时清空文件处理记录 |
|||
return {"message": "向量存储和文件处理记录已清空"} |
|||
except Exception as e: |
|||
logger.error(f"清空向量存储时发生错误: {str(e)}", exc_info=True) |
|||
raise HTTPException(status_code=500, detail=str(e)) |
@ -0,0 +1,135 @@ |
|||
from fastapi import APIRouter, Request, WebSocket, WebSocketDisconnect |
|||
from fastapi.responses import HTMLResponse |
|||
from pydantic import BaseModel |
|||
from app.core.logger import logger |
|||
from app.services.qa_service import QAService |
|||
from app.middleware.error_handler import ErrorHandler |
|||
from app.middleware.request_logger import RequestLogger |
|||
from app.api.rag_api import router as rag_router |
|||
from typing import Dict, Optional |
|||
|
|||
# 请求和响应模型 |
|||
class ChatRequest(BaseModel): |
|||
question: str |
|||
chat_id: Optional[str] = None |
|||
use_rag: Optional[bool] = True |
|||
|
|||
class ChatResponse(BaseModel): |
|||
question: str |
|||
answer: str |
|||
chat_id: str |
|||
|
|||
class ErrorResponse(BaseModel): |
|||
error: str |
|||
|
|||
# 存储活跃的聊天会话 |
|||
active_chats: Dict[str, list] = {} |
|||
|
|||
def register_routes(app, qa_service: QAService): |
|||
"""注册所有路由""" |
|||
|
|||
# 创建路由器 |
|||
router = APIRouter(prefix="/api/v1") |
|||
|
|||
@app.get("/", response_class=HTMLResponse) |
|||
@RequestLogger.log_request |
|||
@ErrorHandler.handle_error |
|||
async def index(request: Request): |
|||
"""返回聊天页面""" |
|||
logger.info("访问根路径") |
|||
try: |
|||
return app.state.templates.TemplateResponse("chat.html", {"request": request}) |
|||
except Exception as e: |
|||
logger.error(f"渲染模板时发生错误: {str(e)}", exc_info=True) |
|||
return str(e), 500 |
|||
|
|||
@app.websocket("/ws") |
|||
async def websocket_endpoint(websocket: WebSocket): |
|||
await websocket.accept() |
|||
try: |
|||
while True: |
|||
data = await websocket.receive_json() |
|||
message = data.get('message', '').strip() |
|||
chat_id = data.get('chatId') |
|||
use_rag = data.get('use_rag', True) # 默认使用 RAG |
|||
|
|||
if message: |
|||
# 获取或创建聊天历史 |
|||
if chat_id not in active_chats: |
|||
active_chats[chat_id] = [] |
|||
|
|||
# 添加用户消息到历史 |
|||
active_chats[chat_id].append({"role": "user", "content": message}) |
|||
|
|||
# 获取回答 |
|||
answer = qa_service.get_answer(message, active_chats[chat_id], use_rag=use_rag) |
|||
|
|||
# 添加AI回答到历史 |
|||
active_chats[chat_id].append({"role": "assistant", "content": answer}) |
|||
|
|||
# 发送响应 |
|||
await websocket.send_json({ |
|||
"question": message, |
|||
"answer": answer, |
|||
"chatId": chat_id |
|||
}) |
|||
except WebSocketDisconnect: |
|||
logger.info("WebSocket 连接断开") |
|||
except Exception as e: |
|||
logger.error(f"WebSocket 处理错误: {str(e)}", exc_info=True) |
|||
await websocket.close() |
|||
|
|||
@router.get("/health") |
|||
@RequestLogger.log_request |
|||
@ErrorHandler.handle_error |
|||
async def health_check(request: Request): |
|||
"""健康检查接口""" |
|||
logger.info("收到健康检查请求") |
|||
return {"status": "healthy"} |
|||
|
|||
@router.post("/chat", response_model=ChatResponse, responses={ |
|||
200: {"model": ChatResponse}, |
|||
400: {"model": ErrorResponse}, |
|||
500: {"model": ErrorResponse} |
|||
}) |
|||
@RequestLogger.log_request |
|||
@ErrorHandler.handle_validation_error |
|||
@ErrorHandler.handle_error |
|||
async def chat(request: Request, chat_request: ChatRequest): |
|||
"""聊天接口,接受问题并返回 AI 回答""" |
|||
try: |
|||
logger.info(f"收到问题: {chat_request.question}") |
|||
|
|||
# 获取或创建聊天历史 |
|||
chat_id = chat_request.chat_id or str(len(active_chats) + 1) |
|||
if chat_id not in active_chats: |
|||
active_chats[chat_id] = [] |
|||
|
|||
# 添加用户消息到历史 |
|||
active_chats[chat_id].append({"role": "user", "content": chat_request.question}) |
|||
|
|||
# 获取回答 |
|||
answer = qa_service.get_answer( |
|||
chat_request.question, |
|||
active_chats[chat_id], |
|||
use_rag=chat_request.use_rag |
|||
) |
|||
|
|||
# 添加AI回答到历史 |
|||
active_chats[chat_id].append({"role": "assistant", "content": answer}) |
|||
|
|||
logger.info(f"问题处理完成: {chat_request.question}") |
|||
|
|||
return ChatResponse( |
|||
question=chat_request.question, |
|||
answer=answer, |
|||
chat_id=chat_id |
|||
) |
|||
|
|||
except Exception as e: |
|||
logger.error(f"处理请求时发生错误: {str(e)}", exc_info=True) |
|||
return ErrorResponse(error=str(e)) |
|||
|
|||
# 注册路由器 |
|||
app.include_router(router) |
|||
app.include_router(rag_router) # 添加RAG路由 |
@ -0,0 +1 @@ |
|||
# Chains package initialization |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,86 @@ |
|||
from typing import Dict, Any |
|||
from langchain.prompts import PromptTemplate |
|||
from langchain_ollama import OllamaLLM |
|||
from langchain.schema.runnable import RunnablePassthrough |
|||
from app.core.logger import logger |
|||
import requests |
|||
from app.core.config import settings |
|||
|
|||
class BaseChain: |
|||
""" |
|||
基础链类,提供与 Ollama 模型交互的基本功能 |
|||
|
|||
Attributes: |
|||
llm: Ollama 语言模型实例 |
|||
""" |
|||
|
|||
def __init__(self, model_name: str = "llama2", temperature: float = 0.7): |
|||
""" |
|||
初始化基础链 |
|||
|
|||
Args: |
|||
model_name (str): 使用的模型名称,默认为 llama2 |
|||
temperature (float): 模型温度参数,控制输出的随机性,默认为 0.7 |
|||
""" |
|||
logger.info(f"初始化基础链,使用模型: {model_name}, 温度: {temperature}") |
|||
|
|||
# 检查 Ollama 服务是否可用 |
|||
try: |
|||
response = requests.get(f"{settings.OLLAMA_BASE_URL}/api/tags") |
|||
if response.status_code != 200: |
|||
raise ConnectionError(f"Ollama 服务响应异常: {response.status_code}") |
|||
logger.info("Ollama 服务连接正常") |
|||
except requests.exceptions.ConnectionError as e: |
|||
logger.error(f"无法连接到 Ollama 服务: {str(e)}") |
|||
raise ConnectionError("请确保 Ollama 服务正在运行,并且可以通过 http://localhost:11434 访问") |
|||
|
|||
# 检查模型是否已下载 |
|||
try: |
|||
response = requests.get(f"{settings.OLLAMA_BASE_URL}/api/show", params={"name": model_name}) |
|||
if response.status_code != 200: |
|||
logger.warning(f"模型 {model_name} 可能未下载,将尝试使用") |
|||
except Exception as e: |
|||
logger.warning(f"检查模型状态时出错: {str(e)}") |
|||
|
|||
self.llm = OllamaLLM( |
|||
model=model_name, |
|||
temperature=temperature, |
|||
base_url=settings.OLLAMA_BASE_URL |
|||
) |
|||
|
|||
def create_chain(self, template: str, input_variables: list): |
|||
""" |
|||
创建新的链实例 |
|||
|
|||
Args: |
|||
template (str): 提示模板 |
|||
input_variables (list): 输入变量列表 |
|||
|
|||
Returns: |
|||
RunnableSequence: 创建的链实例 |
|||
""" |
|||
logger.debug(f"创建新链,模板变量: {input_variables}") |
|||
prompt = PromptTemplate( |
|||
template=template, |
|||
input_variables=input_variables |
|||
) |
|||
# 使用新的链式调用方式 |
|||
return prompt | self.llm |
|||
|
|||
def run(self, chain, inputs: Dict[str, Any]) -> str: |
|||
""" |
|||
运行链并获取结果 |
|||
|
|||
Args: |
|||
chain: 要运行的链实例 |
|||
inputs (Dict[str, Any]): 输入参数字典 |
|||
|
|||
Returns: |
|||
str: 模型生成的回答 |
|||
""" |
|||
try: |
|||
logger.debug(f"运行链,输入参数: {inputs}") |
|||
return chain.invoke(inputs) |
|||
except Exception as e: |
|||
logger.error(f"运行链时发生错误: {str(e)}", exc_info=True) |
|||
raise |
@ -0,0 +1,66 @@ |
|||
from app.chains.base_chain import BaseChain |
|||
from app.core.logger import logger |
|||
from app.core.config import settings |
|||
from typing import List, Dict, Optional |
|||
|
|||
class QAChain(BaseChain): |
|||
""" |
|||
问答链类,继承自基础链,专门用于处理问答任务 |
|||
|
|||
该类使用预定义的提示模板来格式化问题,并通过基础链获取回答 |
|||
""" |
|||
|
|||
def __init__(self, model_name: str = None, temperature: float = None): |
|||
""" |
|||
初始化问答链 |
|||
|
|||
Args: |
|||
model_name (str, optional): 使用的模型名称,默认使用配置中的值 |
|||
temperature (float, optional): 模型温度参数,控制输出的随机性,默认使用配置中的值 |
|||
""" |
|||
super().__init__( |
|||
model_name=model_name or settings.DEFAULT_MODEL, |
|||
temperature=temperature or settings.TEMPERATURE |
|||
) |
|||
logger.info("初始化问答链,创建提示模板") |
|||
self.chain = self.create_chain( |
|||
template="""你是一个有帮助的AI助手。请根据以下对话历史回答问题: |
|||
|
|||
历史对话: |
|||
{chat_history} |
|||
|
|||
当前问题: {question} |
|||
|
|||
回答:""", |
|||
input_variables=["question", "chat_history"] |
|||
) |
|||
|
|||
def answer(self, question: str, chat_history: Optional[List[Dict[str, str]]] = None) -> str: |
|||
""" |
|||
获取问题的回答 |
|||
|
|||
Args: |
|||
question (str): 用户的问题 |
|||
chat_history (List[Dict[str, str]], optional): 聊天历史记录 |
|||
|
|||
Returns: |
|||
str: 模型的回答 |
|||
""" |
|||
logger.info(f"处理问题: {question}") |
|||
|
|||
# 格式化聊天历史 |
|||
formatted_history = "" |
|||
if chat_history: |
|||
formatted_history = "\n".join([ |
|||
f"{'用户' if msg['role'] == 'user' else 'AI'}: {msg['content']}" |
|||
for msg in chat_history |
|||
]) |
|||
|
|||
# 运行链 |
|||
answer = self.run(self.chain, { |
|||
"question": question, |
|||
"chat_history": formatted_history |
|||
}) |
|||
|
|||
logger.info(f"生成回答完成,长度: {len(answer)}") |
|||
return answer |
@ -0,0 +1,48 @@ |
|||
from typing import List |
|||
from langchain.chains import RetrievalQA |
|||
from langchain_ollama import OllamaLLM |
|||
from langchain.prompts import PromptTemplate |
|||
from app.services.vector_store_service import VectorStoreService |
|||
|
|||
class RAGChain: |
|||
def __init__(self, vector_store_service: VectorStoreService): |
|||
self.vector_store = vector_store_service |
|||
self.llm = OllamaLLM(model="qwen2.5:latest") |
|||
self.qa_chain = self._create_qa_chain() |
|||
|
|||
def _create_qa_chain(self) -> RetrievalQA: |
|||
"""创建问答链""" |
|||
prompt_template = """你是一个专业的问答助手。请基于以下上下文信息来回答问题。如果上下文中没有足够的信息来回答问题,请明确说明"根据提供的上下文,我无法回答这个问题"。 |
|||
|
|||
请遵循以下规则: |
|||
1. 只使用提供的上下文信息来回答问题 |
|||
2. 如果上下文信息不足,不要编造答案 |
|||
3. 如果上下文信息有冲突,请指出这一点 |
|||
4. 回答要简洁、准确、专业 |
|||
|
|||
上下文信息: |
|||
{context} |
|||
|
|||
问题: {question} |
|||
|
|||
请提供回答:""" |
|||
|
|||
PROMPT = PromptTemplate( |
|||
template=prompt_template, input_variables=["context", "question"] |
|||
) |
|||
|
|||
chain_type_kwargs = {"prompt": PROMPT} |
|||
|
|||
return RetrievalQA.from_chain_type( |
|||
llm=self.llm, |
|||
chain_type="stuff", |
|||
retriever=self.vector_store.vector_store.as_retriever( |
|||
search_kwargs={"k": 6} |
|||
), |
|||
chain_type_kwargs=chain_type_kwargs, |
|||
return_source_documents=True |
|||
) |
|||
|
|||
def query(self, question: str) -> dict: |
|||
"""查询问题""" |
|||
return self.qa_chain({"query": question}) |
@ -0,0 +1 @@ |
|||
# Core package initialization |
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,40 @@ |
|||
import os |
|||
from dotenv import load_dotenv |
|||
from typing import Optional |
|||
|
|||
# 加载环境变量 |
|||
load_dotenv() |
|||
|
|||
class Settings: |
|||
# 应用配置 |
|||
APP_NAME: str = "LangChain Ollama API" |
|||
APP_VERSION: str = "1.0.0" |
|||
APP_DESCRIPTION: str = "基于 Ollama 本地大模型和 LangChain 的 AI 服务 API" |
|||
API_PREFIX: str = os.getenv("API_PREFIX", "/api/v1") |
|||
|
|||
# 服务器配置 |
|||
HOST: str = os.getenv("HOST", "0.0.0.0") |
|||
PORT: int = int(os.getenv("PORT", "5000")) |
|||
DEBUG: bool = os.getenv("DEBUG", "False").lower() == "true" |
|||
|
|||
# Ollama 配置 |
|||
OLLAMA_BASE_URL: str = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434") |
|||
DEFAULT_MODEL: str = os.getenv("DEFAULT_MODEL", "qwen2.5:latest") |
|||
TEMPERATURE: float = float(os.getenv("TEMPERATURE", "0.7")) |
|||
MAX_TOKENS: int = int(os.getenv("MAX_TOKENS", "2048")) |
|||
|
|||
# 安全配置 |
|||
CORS_ORIGINS: list = ["*"] |
|||
CORS_CREDENTIALS: bool = True |
|||
CORS_METHODS: list = ["*"] |
|||
CORS_HEADERS: list = ["*"] |
|||
|
|||
# 日志配置 |
|||
LOG_LEVEL: str = "INFO" if DEBUG else "WARNING" |
|||
LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
|||
|
|||
class Config: |
|||
case_sensitive = True |
|||
|
|||
# 创建全局配置实例 |
|||
settings = Settings() |
@ -0,0 +1,55 @@ |
|||
from pydantic import BaseModel, Field, validator |
|||
from typing import Optional |
|||
import os |
|||
|
|||
class OllamaConfig(BaseModel): |
|||
"""Ollama配置验证""" |
|||
base_url: str = Field(default="http://localhost:11434") |
|||
model_name: str = Field(default="qwen2.5:latest") |
|||
temperature: float = Field(default=0.7, ge=0.0, le=1.0) |
|||
max_tokens: int = Field(default=2048, gt=0) |
|||
|
|||
class FlaskConfig(BaseModel): |
|||
"""Flask配置验证""" |
|||
host: str = Field(default="0.0.0.0") |
|||
port: int = Field(default=5000, gt=0, lt=65536) |
|||
debug: bool = Field(default=False) |
|||
api_prefix: str = Field(default="/api/v1") |
|||
|
|||
class Config(BaseModel): |
|||
"""应用配置验证""" |
|||
ollama: OllamaConfig = Field(default_factory=OllamaConfig) |
|||
flask: FlaskConfig = Field(default_factory=FlaskConfig) |
|||
|
|||
@validator('ollama.base_url') |
|||
def validate_base_url(cls, v): |
|||
if not v.startswith(('http://', 'https://')): |
|||
raise ValueError('base_url must start with http:// or https://') |
|||
return v |
|||
|
|||
@validator('flask.api_prefix') |
|||
def validate_api_prefix(cls, v): |
|||
if not v.startswith('/'): |
|||
raise ValueError('api_prefix must start with /') |
|||
return v |
|||
|
|||
def validate_config() -> Config: |
|||
"""验证并返回配置""" |
|||
try: |
|||
config = Config( |
|||
ollama=OllamaConfig( |
|||
base_url=os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"), |
|||
model_name=os.getenv("DEFAULT_MODEL", "qwen2.5:latest"), |
|||
temperature=float(os.getenv("TEMPERATURE", "0.7")), |
|||
max_tokens=int(os.getenv("MAX_TOKENS", "2048")) |
|||
), |
|||
flask=FlaskConfig( |
|||
host=os.getenv("FLASK_HOST", "0.0.0.0"), |
|||
port=int(os.getenv("FLASK_PORT", "5000")), |
|||
debug=os.getenv("FLASK_DEBUG", "False").lower() == "true", |
|||
api_prefix=os.getenv("API_PREFIX", "/api/v1") |
|||
) |
|||
) |
|||
return config |
|||
except Exception as e: |
|||
raise ValueError(f"配置验证失败: {str(e)}") |
@ -0,0 +1,41 @@ |
|||
import logging |
|||
import sys |
|||
from pythonjsonlogger import jsonlogger |
|||
from datetime import datetime |
|||
from app.core.config import settings |
|||
|
|||
def setup_logger(name: str = __name__) -> logging.Logger: |
|||
""" |
|||
设置和配置日志记录器,使用JSON格式输出 |
|||
|
|||
Args: |
|||
name (str): 日志记录器名称,默认为模块名 |
|||
|
|||
Returns: |
|||
logging.Logger: 配置好的日志记录器实例 |
|||
""" |
|||
# 创建日志记录器 |
|||
logger = logging.getLogger(name) |
|||
logger.setLevel(getattr(logging, settings.LOG_LEVEL)) |
|||
|
|||
# 创建控制台处理器 |
|||
console_handler = logging.StreamHandler(sys.stdout) |
|||
console_handler.setLevel(getattr(logging, settings.LOG_LEVEL)) |
|||
|
|||
# 创建JSON格式化器,包含更多详细信息 |
|||
formatter = jsonlogger.JsonFormatter( |
|||
settings.LOG_FORMAT, |
|||
datefmt='%Y-%m-%d %H:%M:%S' |
|||
) |
|||
console_handler.setFormatter(formatter) |
|||
|
|||
# 添加处理器到日志记录器 |
|||
logger.addHandler(console_handler) |
|||
|
|||
# 防止日志重复 |
|||
logger.propagate = False |
|||
|
|||
return logger |
|||
|
|||
# 创建默认日志记录器实例 |
|||
logger = setup_logger('langchain-ollama') |
Binary file not shown.
Binary file not shown.
@ -0,0 +1,43 @@ |
|||
from functools import wraps |
|||
from fastapi import HTTPException |
|||
from app.core.logger import logger |
|||
|
|||
class ErrorHandler: |
|||
"""错误处理中间件""" |
|||
|
|||
@staticmethod |
|||
def handle_error(f): |
|||
@wraps(f) |
|||
async def decorated_function(*args, **kwargs): |
|||
try: |
|||
return await f(*args, **kwargs) |
|||
except HTTPException as e: |
|||
logger.error(f"HTTP错误: {str(e)}") |
|||
raise e |
|||
except Exception as e: |
|||
logger.error(f"处理请求时发生错误: {str(e)}", exc_info=True) |
|||
raise HTTPException( |
|||
status_code=500, |
|||
detail={ |
|||
"error": str(e), |
|||
"message": "服务器内部错误" |
|||
} |
|||
) |
|||
return decorated_function |
|||
|
|||
@staticmethod |
|||
def handle_validation_error(f): |
|||
@wraps(f) |
|||
async def decorated_function(*args, **kwargs): |
|||
try: |
|||
return await f(*args, **kwargs) |
|||
except ValueError as e: |
|||
logger.error(f"请求参数验证失败: {str(e)}") |
|||
raise HTTPException( |
|||
status_code=400, |
|||
detail={ |
|||
"error": str(e), |
|||
"message": "请求参数无效" |
|||
} |
|||
) |
|||
return decorated_function |
@ -0,0 +1,33 @@ |
|||
from functools import wraps |
|||
from fastapi import Request |
|||
from app.core.logger import logger |
|||
import time |
|||
|
|||
class RequestLogger: |
|||
"""请求日志中间件""" |
|||
|
|||
@staticmethod |
|||
def log_request(f): |
|||
@wraps(f) |
|||
async def decorated_function(request: Request, *args, **kwargs): |
|||
start_time = time.time() |
|||
|
|||
# 记录请求信息 |
|||
logger.info(f"收到请求 - 方法: {request.method}, 路径: {request.url.path}") |
|||
|
|||
# 尝试获取请求体 |
|||
try: |
|||
body = await request.json() |
|||
logger.debug(f"请求体: {body}") |
|||
except: |
|||
pass |
|||
|
|||
# 执行请求处理 |
|||
response = await f(request, *args, **kwargs) |
|||
|
|||
# 记录响应时间 |
|||
duration = time.time() - start_time |
|||
logger.info(f"请求处理完成 - 耗时: {duration:.2f}秒") |
|||
|
|||
return response |
|||
return decorated_function |
@ -0,0 +1 @@ |
|||
# Services package initialization |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,140 @@ |
|||
from typing import List, Union, Set |
|||
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|||
from langchain_community.document_loaders import DirectoryLoader, TextLoader, PyPDFLoader, JSONLoader |
|||
from langchain.schema import Document |
|||
import os |
|||
import json |
|||
import hashlib |
|||
from app.core.logger import logger |
|||
|
|||
class DocumentService: |
|||
def __init__(self): |
|||
self.text_splitter = RecursiveCharacterTextSplitter( |
|||
chunk_size=1000, |
|||
chunk_overlap=200, |
|||
length_function=len, |
|||
) |
|||
# 用于存储已处理文件的哈希值 |
|||
self.processed_files: Set[str] = set() |
|||
|
|||
def _calculate_file_hash(self, file_path: str) -> str: |
|||
"""计算文件的哈希值""" |
|||
hash_md5 = hashlib.md5() |
|||
with open(file_path, "rb") as f: |
|||
for chunk in iter(lambda: f.read(4096), b""): |
|||
hash_md5.update(chunk) |
|||
return hash_md5.hexdigest() |
|||
|
|||
def _is_file_processed(self, file_path: str) -> bool: |
|||
"""检查文件是否已经处理过""" |
|||
file_hash = self._calculate_file_hash(file_path) |
|||
if file_hash in self.processed_files: |
|||
logger.info(f"文件已处理过,跳过: {file_path}") |
|||
return True |
|||
self.processed_files.add(file_hash) |
|||
return False |
|||
|
|||
def load_documents(self, path: str) -> List[Document]: |
|||
"""加载文档,支持单个文件或目录""" |
|||
if os.path.isfile(path): |
|||
if self._is_file_processed(path): |
|||
return [] |
|||
return self._load_single_file(path) |
|||
else: |
|||
return self._load_directory(path) |
|||
|
|||
def _load_single_file(self, file_path: str) -> List[Document]: |
|||
"""加载单个文件""" |
|||
if self._is_file_processed(file_path): |
|||
return [] |
|||
|
|||
file_extension = os.path.splitext(file_path)[1].lower() |
|||
|
|||
try: |
|||
if file_extension == '.pdf': |
|||
loader = PyPDFLoader(file_path) |
|||
elif file_extension == '.txt': |
|||
loader = TextLoader(file_path, encoding='utf-8') |
|||
elif file_extension == '.json': |
|||
# 读取 JSON 文件内容 |
|||
with open(file_path, 'r', encoding='utf-8') as f: |
|||
json_content = json.load(f) |
|||
|
|||
# 将 JSON 转换为文本格式 |
|||
text_content = self._json_to_text(json_content) |
|||
|
|||
# 创建文档 |
|||
return [Document(page_content=text_content, metadata={"source": file_path})] |
|||
else: |
|||
logger.warning(f"不支持的文件类型: {file_extension}") |
|||
return [] |
|||
|
|||
return loader.load() |
|||
except Exception as e: |
|||
logger.error(f"处理文件时出错 {file_path}: {str(e)}") |
|||
return [] |
|||
|
|||
def _json_to_text(self, json_content: Union[dict, list], indent: int = 0) -> str: |
|||
"""将 JSON 内容转换为易读的文本格式""" |
|||
if isinstance(json_content, dict): |
|||
text = [] |
|||
for key, value in json_content.items(): |
|||
if isinstance(value, (dict, list)): |
|||
text.append(f"{' ' * indent}{key}:") |
|||
text.append(self._json_to_text(value, indent + 1)) |
|||
else: |
|||
text.append(f"{' ' * indent}{key}: {value}") |
|||
return "\n".join(text) |
|||
elif isinstance(json_content, list): |
|||
text = [] |
|||
for i, item in enumerate(json_content): |
|||
text.append(f"{' ' * indent}Item {i + 1}:") |
|||
text.append(self._json_to_text(item, indent + 1)) |
|||
return "\n".join(text) |
|||
else: |
|||
return str(json_content) |
|||
|
|||
def _load_directory(self, directory_path: str) -> List[Document]: |
|||
"""加载指定目录下的所有文档""" |
|||
documents = [] |
|||
|
|||
# 加载 PDF 文件 |
|||
pdf_loader = DirectoryLoader( |
|||
directory_path, |
|||
glob="**/*.pdf", |
|||
loader_cls=PyPDFLoader |
|||
) |
|||
documents.extend(pdf_loader.load()) |
|||
|
|||
# 加载文本文件 |
|||
txt_loader = DirectoryLoader( |
|||
directory_path, |
|||
glob="**/*.txt", |
|||
loader_cls=TextLoader, |
|||
loader_kwargs={'encoding': 'utf-8'} |
|||
) |
|||
documents.extend(txt_loader.load()) |
|||
|
|||
# 加载 JSON 文件 |
|||
for root, _, files in os.walk(directory_path): |
|||
for file in files: |
|||
if file.lower().endswith('.json'): |
|||
file_path = os.path.join(root, file) |
|||
if not self._is_file_processed(file_path): |
|||
documents.extend(self._load_single_file(file_path)) |
|||
|
|||
return documents |
|||
|
|||
def clear_processed_files(self): |
|||
"""清空已处理文件记录""" |
|||
self.processed_files.clear() |
|||
logger.info("已清空文件处理记录") |
|||
|
|||
def split_documents(self, documents: List[Document]) -> List[Document]: |
|||
"""将文档分割成小块""" |
|||
return self.text_splitter.split_documents(documents) |
|||
|
|||
def process_documents(self, path: str) -> List[Document]: |
|||
"""处理文档:加载并分割,支持单个文件或目录""" |
|||
documents = self.load_documents(path) |
|||
return self.split_documents(documents) |
@ -0,0 +1,58 @@ |
|||
from app.chains.qa_chain import QAChain |
|||
from app.chains.rag_chain import RAGChain |
|||
from app.services.vector_store_service import VectorStoreService |
|||
from app.core.logger import logger |
|||
from app.core.config import settings |
|||
from typing import List, Dict, Optional |
|||
|
|||
class QAService: |
|||
"""问答服务类,处理问答相关的业务逻辑""" |
|||
|
|||
def __init__(self, model_name: str = None, temperature: float = None): |
|||
""" |
|||
初始化问答服务 |
|||
|
|||
Args: |
|||
model_name (str, optional): 使用的模型名称,默认使用配置中的值 |
|||
temperature (float, optional): 模型温度参数,默认使用配置中的值 |
|||
""" |
|||
self.qa_chain = QAChain( |
|||
model_name=model_name or settings.DEFAULT_MODEL, |
|||
temperature=temperature or settings.TEMPERATURE |
|||
) |
|||
# 初始化 RAG 相关服务 |
|||
self.vector_store_service = VectorStoreService() |
|||
self.rag_chain = RAGChain(self.vector_store_service) |
|||
|
|||
def get_answer(self, question: str, chat_history: Optional[List[Dict[str, str]]] = None, use_rag: bool = True) -> str: |
|||
""" |
|||
获取问题的回答 |
|||
|
|||
Args: |
|||
question (str): 用户问题 |
|||
chat_history (List[Dict[str, str]], optional): 聊天历史记录 |
|||
use_rag (bool): 是否使用 RAG 增强回答,默认为 True |
|||
|
|||
Returns: |
|||
str: AI回答 |
|||
""" |
|||
try: |
|||
logger.info(f"处理问题: {question}") |
|||
|
|||
if use_rag: |
|||
# 使用 RAG 获取回答 |
|||
rag_result = self.rag_chain.query(question) |
|||
answer = rag_result["result"] |
|||
# 如果有相关文档,添加到回答中 |
|||
if rag_result["source_documents"]: |
|||
sources = [doc.page_content for doc in rag_result["source_documents"]] |
|||
answer += "\n\n参考来源:\n" + "\n".join(sources) |
|||
else: |
|||
# 使用基础 QA 获取回答 |
|||
answer = self.qa_chain.answer(question, chat_history) |
|||
|
|||
logger.info(f"生成回答完成,长度: {len(answer)}") |
|||
return answer |
|||
except Exception as e: |
|||
logger.error(f"生成回答时发生错误: {str(e)}", exc_info=True) |
|||
raise |
@ -0,0 +1,93 @@ |
|||
from typing import List |
|||
from langchain.schema import Document |
|||
from langchain_chroma import Chroma |
|||
from langchain_ollama import OllamaEmbeddings |
|||
import chromadb |
|||
import os |
|||
import shutil |
|||
from app.core.logger import logger |
|||
|
|||
class VectorStoreService: |
|||
def __init__(self, persist_directory: str = "data/chroma"): |
|||
self.persist_directory = persist_directory |
|||
self.embeddings = OllamaEmbeddings(model="qwen2.5:latest") |
|||
self.vector_store = None |
|||
self._initialize_vector_store() |
|||
|
|||
def _initialize_vector_store(self): |
|||
"""初始化向量存储""" |
|||
try: |
|||
# 创建新目录 |
|||
os.makedirs(self.persist_directory, exist_ok=True) |
|||
|
|||
# 创建Chroma客户端 |
|||
client = chromadb.PersistentClient( |
|||
path=self.persist_directory, |
|||
settings=chromadb.Settings( |
|||
anonymized_telemetry=False, |
|||
allow_reset=True |
|||
) |
|||
) |
|||
|
|||
# 确保集合存在 |
|||
try: |
|||
client.get_collection("langchain") |
|||
except Exception: |
|||
# 如果集合不存在,创建新集合 |
|||
client.create_collection("langchain") |
|||
|
|||
# 初始化向量存储 |
|||
self.vector_store = Chroma( |
|||
client=client, |
|||
collection_name="langchain", |
|||
embedding_function=self.embeddings |
|||
) |
|||
logger.info("向量存储初始化成功") |
|||
except Exception as e: |
|||
logger.error(f"初始化向量存储时出错: {str(e)}", exc_info=True) |
|||
raise |
|||
|
|||
def add_documents(self, documents: List[Document]): |
|||
"""添加文档到向量存储""" |
|||
if not documents: |
|||
return |
|||
try: |
|||
self.vector_store.add_documents(documents) |
|||
logger.info(f"成功添加 {len(documents)} 个文档到向量存储") |
|||
except Exception as e: |
|||
logger.error(f"添加文档到向量存储时出错: {str(e)}", exc_info=True) |
|||
raise |
|||
|
|||
def similarity_search(self, query: str, k: int = 6, score_threshold: float = 0.7) -> List[Document]: |
|||
"""相似度搜索 |
|||
|
|||
Args: |
|||
query (str): 查询文本 |
|||
k (int): 返回的文档数量 |
|||
score_threshold (float): 相似度阈值,只返回相似度大于此值的文档 |
|||
""" |
|||
try: |
|||
# 获取带分数的搜索结果 |
|||
results = self.vector_store.similarity_search_with_score(query, k=k) |
|||
# 过滤相似度低于阈值的文档 |
|||
filtered_results = [doc for doc, score in results if score >= score_threshold] |
|||
# 如果没有文档通过阈值过滤,返回相似度最高的文档 |
|||
if not filtered_results and results: |
|||
return [results[0][0]] |
|||
return filtered_results |
|||
except Exception as e: |
|||
logger.error(f"执行相似度搜索时出错: {str(e)}", exc_info=True) |
|||
raise |
|||
|
|||
def clear(self): |
|||
"""清空向量存储""" |
|||
try: |
|||
if self.vector_store: |
|||
# 删除集合 |
|||
self.vector_store.delete_collection() |
|||
# 重新初始化向量存储 |
|||
self._initialize_vector_store() |
|||
logger.info("向量存储已清空并重新初始化") |
|||
except Exception as e: |
|||
logger.error(f"清空向量存储时出错: {str(e)}", exc_info=True) |
|||
raise |
@ -0,0 +1,99 @@ |
|||
from flask_socketio import SocketIO, emit |
|||
from app.core.logger import logger |
|||
from app.services.qa_service import QAService |
|||
from typing import Dict, Set |
|||
import threading |
|||
import queue |
|||
import time |
|||
|
|||
class WebSocketService: |
|||
"""WebSocket服务类,处理WebSocket连接和消息""" |
|||
|
|||
def __init__(self, socketio: SocketIO, qa_service: QAService): |
|||
""" |
|||
初始化WebSocket服务 |
|||
|
|||
Args: |
|||
socketio: SocketIO实例 |
|||
qa_service: 问答服务实例 |
|||
""" |
|||
self.socketio = socketio |
|||
self.qa_service = qa_service |
|||
self.active_connections: Dict[str, Set[str]] = {} |
|||
self.message_queues: Dict[str, queue.Queue] = {} |
|||
self.processing_threads: Dict[str, threading.Thread] = {} |
|||
self.lock = threading.Lock() |
|||
|
|||
def register_connection(self, sid: str): |
|||
"""注册新的WebSocket连接""" |
|||
with self.lock: |
|||
if sid not in self.active_connections: |
|||
self.active_connections[sid] = set() |
|||
self.message_queues[sid] = queue.Queue() |
|||
self._start_processing_thread(sid) |
|||
logger.info(f"新连接注册: {sid}") |
|||
|
|||
def unregister_connection(self, sid: str): |
|||
"""注销WebSocket连接""" |
|||
with self.lock: |
|||
if sid in self.active_connections: |
|||
del self.active_connections[sid] |
|||
del self.message_queues[sid] |
|||
if sid in self.processing_threads: |
|||
self.processing_threads[sid].join(timeout=1) |
|||
del self.processing_threads[sid] |
|||
logger.info(f"连接注销: {sid}") |
|||
|
|||
def _start_processing_thread(self, sid: str): |
|||
"""启动消息处理线程""" |
|||
def process_messages(): |
|||
while sid in self.active_connections: |
|||
try: |
|||
message = self.message_queues[sid].get(timeout=1) |
|||
if message is None: # 退出信号 |
|||
break |
|||
|
|||
try: |
|||
# 处理消息 |
|||
answer = self.qa_service.get_answer(message) |
|||
self.socketio.emit('ai_response', |
|||
{'message': answer}, |
|||
room=sid) |
|||
except Exception as e: |
|||
logger.error(f"处理消息时发生错误: {str(e)}", exc_info=True) |
|||
self.socketio.emit('ai_response', |
|||
{'message': f'处理消息时发生错误: {str(e)}'}, |
|||
room=sid) |
|||
|
|||
except queue.Empty: |
|||
continue |
|||
except Exception as e: |
|||
logger.error(f"消息处理线程错误: {str(e)}", exc_info=True) |
|||
break |
|||
|
|||
thread = threading.Thread(target=process_messages) |
|||
thread.daemon = True |
|||
thread.start() |
|||
self.processing_threads[sid] = thread |
|||
|
|||
def handle_message(self, sid: str, message: str): |
|||
"""处理接收到的消息""" |
|||
if not message: |
|||
self.socketio.emit('ai_response', |
|||
{'message': '请输入问题'}, |
|||
room=sid) |
|||
return |
|||
|
|||
logger.info(f"收到消息 - 会话ID: {sid}, 内容: {message}") |
|||
|
|||
if sid in self.message_queues: |
|||
self.message_queues[sid].put(message) |
|||
else: |
|||
logger.error(f"会话 {sid} 的消息队列不存在") |
|||
self.socketio.emit('ai_response', |
|||
{'message': '服务器错误,请重新连接'}, |
|||
room=sid) |
|||
|
|||
def broadcast(self, message: str): |
|||
"""广播消息给所有连接的客户端""" |
|||
self.socketio.emit('broadcast', {'message': message}) |
@ -0,0 +1,340 @@ |
|||
<!DOCTYPE html> |
|||
<html lang="zh"> |
|||
<head> |
|||
<meta charset="UTF-8"> |
|||
<meta name="viewport" content="width=device-width, initial-scale=1.0"> |
|||
<title>AI 聊天</title> |
|||
<script src="https://cdn.tailwindcss.com"></script> |
|||
<link href="https://cdn.jsdelivr.net/npm/@heroicons/react@2.0.18/outline.min.css" rel="stylesheet"> |
|||
<script> |
|||
tailwind.config = { |
|||
theme: { |
|||
extend: { |
|||
colors: { |
|||
primary: '#3B82F6', |
|||
secondary: '#1E40AF', |
|||
dark: '#1F2937', |
|||
} |
|||
} |
|||
} |
|||
} |
|||
</script> |
|||
</head> |
|||
<body class="bg-gray-100 min-h-screen"> |
|||
<div class="flex h-screen"> |
|||
<!-- 侧边栏 --> |
|||
<div class="w-64 bg-dark text-white p-4 flex flex-col"> |
|||
<div class="flex items-center justify-between mb-8"> |
|||
<h1 class="text-xl font-bold">AI 助手</h1> |
|||
<button onclick="newChat()" class="p-2 hover:bg-gray-700 rounded-lg"> |
|||
<svg xmlns="http://www.w3.org/2000/svg" class="h-5 w-5" viewBox="0 0 20 20" fill="currentColor"> |
|||
<path fill-rule="evenodd" d="M10 3a1 1 0 011 1v5h5a1 1 0 110 2h-5v5a1 1 0 11-2 0v-5H4a1 1 0 110-2h5V4a1 1 0 011-1z" clip-rule="evenodd" /> |
|||
</svg> |
|||
</button> |
|||
</div> |
|||
|
|||
<!-- 会话列表 --> |
|||
<div id="chat-list" class="flex-1 overflow-y-auto space-y-2"> |
|||
<!-- 会话将在这里动态添加 --> |
|||
</div> |
|||
|
|||
<!-- 底部状态 --> |
|||
<div class="mt-4 pt-4 border-t border-gray-700"> |
|||
<div id="status" class="text-sm text-gray-400">正在连接...</div> |
|||
</div> |
|||
</div> |
|||
|
|||
<!-- 主聊天区域 --> |
|||
<div class="flex-1 flex flex-col"> |
|||
<!-- 聊天头部 --> |
|||
<div class="bg-white border-b p-4 flex items-center justify-between"> |
|||
<h2 id="current-chat-title" class="text-lg font-semibold text-gray-800">新会话</h2> |
|||
<div class="flex items-center space-x-4"> |
|||
<!-- RAG 开关 --> |
|||
<div class="flex items-center space-x-2"> |
|||
<label class="relative inline-flex items-center cursor-pointer"> |
|||
<input type="checkbox" id="use-rag" class="sr-only peer" checked> |
|||
<div class="w-11 h-6 bg-gray-200 peer-focus:outline-none peer-focus:ring-4 peer-focus:ring-primary/20 rounded-full peer peer-checked:after:translate-x-full peer-checked:after:border-white after:content-[''] after:absolute after:top-[2px] after:left-[2px] after:bg-white after:border-gray-300 after:border after:rounded-full after:h-5 after:w-5 after:transition-all peer-checked:bg-primary"></div> |
|||
<span class="ml-2 text-sm font-medium text-gray-700">RAG</span> |
|||
</label> |
|||
</div> |
|||
<button onclick="clearChat()" class="p-2 text-gray-600 hover:bg-gray-100 rounded-lg"> |
|||
<svg xmlns="http://www.w3.org/2000/svg" class="h-5 w-5" viewBox="0 0 20 20" fill="currentColor"> |
|||
<path fill-rule="evenodd" d="M9 2a1 1 0 00-.894.553L7.382 4H4a1 1 0 000 2v10a2 2 0 002 2h8a2 2 0 002-2V6a1 1 0 100-2h-3.382l-.724-1.447A1 1 0 0011 2H9zM7 8a1 1 0 012 0v6a1 1 0 11-2 0V8zm5-1a1 1 0 00-1 1v6a1 1 0 102 0V8a1 1 0 00-1-1z" clip-rule="evenodd" /> |
|||
</svg> |
|||
</button> |
|||
</div> |
|||
</div> |
|||
|
|||
<!-- 聊天容器 --> |
|||
<div id="chat-container" class="flex-1 overflow-y-auto p-4 space-y-4 bg-gray-50"> |
|||
<!-- 消息将在这里动态添加 --> |
|||
</div> |
|||
|
|||
<!-- 输入区域 --> |
|||
<div class="border-t border-gray-200 p-4 bg-white"> |
|||
<div class="flex space-x-4"> |
|||
<div class="flex-1 relative"> |
|||
<textarea |
|||
id="message-input" |
|||
class="w-full rounded-lg border border-gray-300 px-4 py-2 focus:outline-none focus:ring-2 focus:ring-primary focus:border-transparent resize-none" |
|||
placeholder="输入您的问题..." |
|||
rows="1" |
|||
onkeydown="handleKeyPress(event)" |
|||
></textarea> |
|||
<div class="absolute right-2 bottom-2 text-sm text-gray-500"> |
|||
<span id="char-count">0</span>/1000 |
|||
</div> |
|||
</div> |
|||
<button |
|||
onclick="sendMessage()" |
|||
class="bg-primary hover:bg-secondary text-white font-medium px-6 py-2 rounded-lg transition-colors duration-200 flex items-center space-x-2" |
|||
> |
|||
<span>发送</span> |
|||
<svg xmlns="http://www.w3.org/2000/svg" class="h-5 w-5" viewBox="0 0 20 20" fill="currentColor"> |
|||
<path d="M10.894 2.553a1 1 0 00-1.788 0l-7 14a1 1 0 001.169 1.409l5-1.429A1 1 0 009 15.571V11a1 1 0 112 0v4.571a1 1 0 00.725.962l5 1.428a1 1 0 001.17-1.408l-7-14z" /> |
|||
</svg> |
|||
</button> |
|||
</div> |
|||
</div> |
|||
</div> |
|||
</div> |
|||
|
|||
<script> |
|||
// 全局变量 |
|||
let ws = null; |
|||
let currentChatId = null; |
|||
let chats = new Map(); |
|||
|
|||
// 初始化 |
|||
function init() { |
|||
connect(); |
|||
loadChats(); |
|||
setupMessageInput(); |
|||
} |
|||
|
|||
// WebSocket 连接 |
|||
function connect() { |
|||
ws = new WebSocket(`ws://${window.location.host}/ws`); |
|||
|
|||
ws.onopen = () => { |
|||
updateStatus('已连接', 'text-green-400'); |
|||
}; |
|||
|
|||
ws.onclose = () => { |
|||
updateStatus('已断开连接', 'text-red-400'); |
|||
setTimeout(connect, 3000); |
|||
}; |
|||
|
|||
ws.onerror = (error) => { |
|||
console.error('WebSocket error:', error); |
|||
updateStatus('连接错误', 'text-red-400'); |
|||
}; |
|||
|
|||
ws.onmessage = (event) => { |
|||
const data = JSON.parse(event.data); |
|||
addMessage('AI', data.answer, 'ai'); |
|||
saveChat(currentChatId); |
|||
}; |
|||
} |
|||
|
|||
// 更新状态显示 |
|||
function updateStatus(text, colorClass) { |
|||
const statusDiv = document.getElementById('status'); |
|||
statusDiv.textContent = text; |
|||
statusDiv.className = `text-sm ${colorClass}`; |
|||
} |
|||
|
|||
// 发送消息 |
|||
function sendMessage() { |
|||
const messageInput = document.getElementById('message-input'); |
|||
const message = messageInput.value.trim(); |
|||
const useRag = document.getElementById('use-rag').checked; |
|||
|
|||
if (message && ws?.readyState === WebSocket.OPEN) { |
|||
if (!currentChatId) { |
|||
currentChatId = Date.now().toString(); |
|||
createNewChat(currentChatId, message); |
|||
} |
|||
|
|||
ws.send(JSON.stringify({ |
|||
message, |
|||
chatId: currentChatId, |
|||
use_rag: useRag |
|||
})); |
|||
|
|||
addMessage('我', message, 'user'); |
|||
messageInput.value = ''; |
|||
updateCharCount(); |
|||
saveChat(currentChatId); |
|||
} |
|||
} |
|||
|
|||
// 添加消息到聊天界面 |
|||
function addMessage(sender, text, type) { |
|||
const chatContainer = document.getElementById('chat-container'); |
|||
|
|||
const messageDiv = document.createElement('div'); |
|||
messageDiv.className = `flex ${type === 'user' ? 'justify-end' : 'justify-start'}`; |
|||
|
|||
const messageContent = document.createElement('div'); |
|||
messageContent.className = `max-w-[70%] rounded-lg px-4 py-2 ${ |
|||
type === 'user' |
|||
? 'bg-primary text-white rounded-br-none' |
|||
: 'bg-white text-gray-800 rounded-bl-none shadow-sm' |
|||
}`; |
|||
|
|||
const senderSpan = document.createElement('div'); |
|||
senderSpan.className = 'text-xs font-medium mb-1'; |
|||
senderSpan.textContent = sender; |
|||
|
|||
const textDiv = document.createElement('div'); |
|||
textDiv.className = 'text-sm whitespace-pre-wrap'; |
|||
textDiv.textContent = text; |
|||
|
|||
messageContent.appendChild(senderSpan); |
|||
messageContent.appendChild(textDiv); |
|||
messageDiv.appendChild(messageContent); |
|||
|
|||
chatContainer.appendChild(messageDiv); |
|||
chatContainer.scrollTop = chatContainer.scrollHeight; |
|||
} |
|||
|
|||
// 创建新会话 |
|||
function newChat() { |
|||
currentChatId = null; |
|||
document.getElementById('chat-container').innerHTML = ''; |
|||
document.getElementById('current-chat-title').textContent = '新会话'; |
|||
document.getElementById('message-input').value = ''; |
|||
updateCharCount(); |
|||
} |
|||
|
|||
// 清除当前会话 |
|||
function clearChat() { |
|||
if (currentChatId) { |
|||
chats.delete(currentChatId); |
|||
saveChats(); |
|||
newChat(); |
|||
} |
|||
} |
|||
|
|||
// 创建新会话项 |
|||
function createNewChat(id, firstMessage) { |
|||
const chatList = document.getElementById('chat-list'); |
|||
const chatItem = document.createElement('div'); |
|||
chatItem.className = 'p-2 hover:bg-gray-700 rounded-lg cursor-pointer'; |
|||
chatItem.onclick = () => loadChat(id); |
|||
|
|||
const title = firstMessage.length > 20 ? firstMessage.substring(0, 20) + '...' : firstMessage; |
|||
chatItem.innerHTML = ` |
|||
<div class="text-sm font-medium">${title}</div> |
|||
<div class="text-xs text-gray-400">${new Date().toLocaleString()}</div> |
|||
`; |
|||
|
|||
chatList.insertBefore(chatItem, chatList.firstChild); |
|||
|
|||
chats.set(id, { |
|||
title: title, |
|||
messages: [], |
|||
timestamp: Date.now() |
|||
}); |
|||
|
|||
document.getElementById('current-chat-title').textContent = title; |
|||
} |
|||
|
|||
// 加载会话 |
|||
function loadChat(id) { |
|||
const chat = chats.get(id); |
|||
if (chat) { |
|||
currentChatId = id; |
|||
document.getElementById('chat-container').innerHTML = ''; |
|||
document.getElementById('current-chat-title').textContent = chat.title; |
|||
|
|||
chat.messages.forEach(msg => { |
|||
addMessage(msg.sender, msg.text, msg.type); |
|||
}); |
|||
} |
|||
} |
|||
|
|||
// 保存会话 |
|||
function saveChat(id) { |
|||
const chat = chats.get(id); |
|||
if (chat) { |
|||
const messages = []; |
|||
document.querySelectorAll('#chat-container > div').forEach(div => { |
|||
const content = div.querySelector('div:last-child'); |
|||
const sender = div.querySelector('div:first-child').textContent; |
|||
const type = div.classList.contains('justify-end') ? 'user' : 'ai'; |
|||
messages.push({ |
|||
sender, |
|||
text: content.textContent, |
|||
type |
|||
}); |
|||
}); |
|||
|
|||
chat.messages = messages; |
|||
saveChats(); |
|||
} |
|||
} |
|||
|
|||
// 保存所有会话到本地存储 |
|||
function saveChats() { |
|||
localStorage.setItem('chats', JSON.stringify(Array.from(chats.entries()))); |
|||
} |
|||
|
|||
// 从本地存储加载会话 |
|||
function loadChats() { |
|||
const savedChats = localStorage.getItem('chats'); |
|||
if (savedChats) { |
|||
chats = new Map(JSON.parse(savedChats)); |
|||
const chatList = document.getElementById('chat-list'); |
|||
chatList.innerHTML = ''; |
|||
|
|||
Array.from(chats.entries()) |
|||
.sort((a, b) => b[1].timestamp - a[1].timestamp) |
|||
.forEach(([id, chat]) => { |
|||
const chatItem = document.createElement('div'); |
|||
chatItem.className = 'p-2 hover:bg-gray-700 rounded-lg cursor-pointer'; |
|||
chatItem.onclick = () => loadChat(id); |
|||
chatItem.innerHTML = ` |
|||
<div class="text-sm font-medium">${chat.title}</div> |
|||
<div class="text-xs text-gray-400">${new Date(chat.timestamp).toLocaleString()}</div> |
|||
`; |
|||
chatList.appendChild(chatItem); |
|||
}); |
|||
} |
|||
} |
|||
|
|||
// 设置消息输入框 |
|||
function setupMessageInput() { |
|||
const messageInput = document.getElementById('message-input'); |
|||
messageInput.addEventListener('input', updateCharCount); |
|||
} |
|||
|
|||
// 更新字符计数 |
|||
function updateCharCount() { |
|||
const messageInput = document.getElementById('message-input'); |
|||
const charCount = document.getElementById('char-count'); |
|||
const count = messageInput.value.length; |
|||
charCount.textContent = count; |
|||
|
|||
if (count > 1000) { |
|||
charCount.className = 'text-red-500'; |
|||
} else { |
|||
charCount.className = 'text-gray-500'; |
|||
} |
|||
} |
|||
|
|||
// 处理按键事件 |
|||
function handleKeyPress(event) { |
|||
if (event.key === 'Enter' && !event.shiftKey) { |
|||
event.preventDefault(); |
|||
sendMessage(); |
|||
} |
|||
} |
|||
|
|||
// 初始化应用 |
|||
init(); |
|||
</script> |
|||
</body> |
|||
</html> |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,16 @@ |
|||
name: langchain-ollama |
|||
channels: |
|||
- conda-forge |
|||
- defaults |
|||
dependencies: |
|||
- python=3.10 |
|||
- pip |
|||
- pip: |
|||
- langchain>=0.1.0 |
|||
- langchain-community>=0.0.10 |
|||
- flask>=2.0.0 |
|||
- flask-restx>=1.1.0 |
|||
- python-dotenv>=0.19.0 |
|||
- requests>=2.31.0 |
|||
- pydantic>=2.0.0 |
|||
- python-json-logger>=2.0.7 |
@ -0,0 +1,33 @@ |
|||
# Web Framework |
|||
flask>=2.0.0 |
|||
flask-restx>=1.0.0 |
|||
flask-socketio>=5.3.0 |
|||
eventlet>=0.33.0 |
|||
gevent>=22.10.2 |
|||
fastapi>=0.104.0 |
|||
uvicorn>=0.24.0 |
|||
|
|||
# LangChain |
|||
langchain>=0.0.350 |
|||
langchain-ollama>=0.0.1 |
|||
langchain-community>=0.0.10 |
|||
langchain-core>=0.1.0 |
|||
langchain-chroma>=0.0.1 |
|||
|
|||
# Vector Store |
|||
chromadb>=0.4.22 |
|||
|
|||
# Environment Variables |
|||
python-dotenv>=1.0.0 |
|||
|
|||
# Logging |
|||
python-json-logger>=2.0.0 |
|||
|
|||
# Ollama |
|||
ollama>=0.1.0 |
|||
|
|||
# Document Processing |
|||
unstructured>=0.10.30 |
|||
python-magic>=0.4.27 |
|||
python-magic-bin>=0.4.14; sys_platform == 'win32' |
|||
pypdf>=3.17.0 |
@ -0,0 +1,19 @@ |
|||
import uvicorn |
|||
from app import create_app |
|||
from app.core.config import settings |
|||
from app.core.logger import logger |
|||
|
|||
# 创建应用实例 |
|||
logger.info("Creating application instance...") |
|||
app = create_app() |
|||
|
|||
if __name__ == "__main__": |
|||
logger.info(f"Starting server on {settings.HOST}:{settings.PORT}") |
|||
# 启动服务器 |
|||
uvicorn.run( |
|||
"run:app", |
|||
host=settings.HOST, |
|||
port=settings.PORT, |
|||
reload=settings.DEBUG, |
|||
log_level=settings.LOG_LEVEL.lower() |
|||
) |
Loading…
Reference in new issue