You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

135 lines
4.8 KiB

from fastapi import APIRouter, Request, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from app.core.logger import logger
from app.services.qa_service import QAService
from app.middleware.error_handler import ErrorHandler
from app.middleware.request_logger import RequestLogger
from app.api.rag_api import router as rag_router
from typing import Dict, Optional
# 请求和响应模型
class ChatRequest(BaseModel):
question: str
chat_id: Optional[str] = None
use_rag: Optional[bool] = True
class ChatResponse(BaseModel):
question: str
answer: str
chat_id: str
class ErrorResponse(BaseModel):
error: str
# 存储活跃的聊天会话
active_chats: Dict[str, list] = {}
def register_routes(app, qa_service: QAService):
"""注册所有路由"""
# 创建路由器
router = APIRouter(prefix="/api/v1")
@app.get("/", response_class=HTMLResponse)
@RequestLogger.log_request
@ErrorHandler.handle_error
async def index(request: Request):
"""返回聊天页面"""
logger.info("访问根路径")
try:
return app.state.templates.TemplateResponse("chat.html", {"request": request})
except Exception as e:
logger.error(f"渲染模板时发生错误: {str(e)}", exc_info=True)
return str(e), 500
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
while True:
data = await websocket.receive_json()
message = data.get('message', '').strip()
chat_id = data.get('chatId')
use_rag = data.get('use_rag', True) # 默认使用 RAG
if message:
# 获取或创建聊天历史
if chat_id not in active_chats:
active_chats[chat_id] = []
# 添加用户消息到历史
active_chats[chat_id].append({"role": "user", "content": message})
# 获取回答
answer = qa_service.get_answer(message, active_chats[chat_id], use_rag=use_rag)
# 添加AI回答到历史
active_chats[chat_id].append({"role": "assistant", "content": answer})
# 发送响应
await websocket.send_json({
"question": message,
"answer": answer,
"chatId": chat_id
})
except WebSocketDisconnect:
logger.info("WebSocket 连接断开")
except Exception as e:
logger.error(f"WebSocket 处理错误: {str(e)}", exc_info=True)
await websocket.close()
@router.get("/health")
@RequestLogger.log_request
@ErrorHandler.handle_error
async def health_check(request: Request):
"""健康检查接口"""
logger.info("收到健康检查请求")
return {"status": "healthy"}
@router.post("/chat", response_model=ChatResponse, responses={
200: {"model": ChatResponse},
400: {"model": ErrorResponse},
500: {"model": ErrorResponse}
})
@RequestLogger.log_request
@ErrorHandler.handle_validation_error
@ErrorHandler.handle_error
async def chat(request: Request, chat_request: ChatRequest):
"""聊天接口,接受问题并返回 AI 回答"""
try:
logger.info(f"收到问题: {chat_request.question}")
# 获取或创建聊天历史
chat_id = chat_request.chat_id or str(len(active_chats) + 1)
if chat_id not in active_chats:
active_chats[chat_id] = []
# 添加用户消息到历史
active_chats[chat_id].append({"role": "user", "content": chat_request.question})
# 获取回答
answer = qa_service.get_answer(
chat_request.question,
active_chats[chat_id],
use_rag=chat_request.use_rag
)
# 添加AI回答到历史
active_chats[chat_id].append({"role": "assistant", "content": answer})
logger.info(f"问题处理完成: {chat_request.question}")
return ChatResponse(
question=chat_request.question,
answer=answer,
chat_id=chat_id
)
except Exception as e:
logger.error(f"处理请求时发生错误: {str(e)}", exc_info=True)
return ErrorResponse(error=str(e))
# 注册路由器
app.include_router(router)
app.include_router(rag_router) # 添加RAG路由