You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

135 lines
4.8 KiB

1 month ago
from fastapi import APIRouter, Request, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from app.core.logger import logger
from app.services.qa_service import QAService
from app.middleware.error_handler import ErrorHandler
from app.middleware.request_logger import RequestLogger
from app.api.rag_api import router as rag_router
from typing import Dict, Optional
# 请求和响应模型
class ChatRequest(BaseModel):
question: str
chat_id: Optional[str] = None
use_rag: Optional[bool] = True
class ChatResponse(BaseModel):
question: str
answer: str
chat_id: str
class ErrorResponse(BaseModel):
error: str
# 存储活跃的聊天会话
active_chats: Dict[str, list] = {}
def register_routes(app, qa_service: QAService):
"""注册所有路由"""
# 创建路由器
router = APIRouter(prefix="/api/v1")
@app.get("/", response_class=HTMLResponse)
@RequestLogger.log_request
@ErrorHandler.handle_error
async def index(request: Request):
"""返回聊天页面"""
logger.info("访问根路径")
try:
return app.state.templates.TemplateResponse("chat.html", {"request": request})
except Exception as e:
logger.error(f"渲染模板时发生错误: {str(e)}", exc_info=True)
return str(e), 500
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
while True:
data = await websocket.receive_json()
message = data.get('message', '').strip()
chat_id = data.get('chatId')
use_rag = data.get('use_rag', True) # 默认使用 RAG
if message:
# 获取或创建聊天历史
if chat_id not in active_chats:
active_chats[chat_id] = []
# 添加用户消息到历史
active_chats[chat_id].append({"role": "user", "content": message})
# 获取回答
answer = qa_service.get_answer(message, active_chats[chat_id], use_rag=use_rag)
# 添加AI回答到历史
active_chats[chat_id].append({"role": "assistant", "content": answer})
# 发送响应
await websocket.send_json({
"question": message,
"answer": answer,
"chatId": chat_id
})
except WebSocketDisconnect:
logger.info("WebSocket 连接断开")
except Exception as e:
logger.error(f"WebSocket 处理错误: {str(e)}", exc_info=True)
await websocket.close()
@router.get("/health")
@RequestLogger.log_request
@ErrorHandler.handle_error
async def health_check(request: Request):
"""健康检查接口"""
logger.info("收到健康检查请求")
return {"status": "healthy"}
@router.post("/chat", response_model=ChatResponse, responses={
200: {"model": ChatResponse},
400: {"model": ErrorResponse},
500: {"model": ErrorResponse}
})
@RequestLogger.log_request
@ErrorHandler.handle_validation_error
@ErrorHandler.handle_error
async def chat(request: Request, chat_request: ChatRequest):
"""聊天接口,接受问题并返回 AI 回答"""
try:
logger.info(f"收到问题: {chat_request.question}")
# 获取或创建聊天历史
chat_id = chat_request.chat_id or str(len(active_chats) + 1)
if chat_id not in active_chats:
active_chats[chat_id] = []
# 添加用户消息到历史
active_chats[chat_id].append({"role": "user", "content": chat_request.question})
# 获取回答
answer = qa_service.get_answer(
chat_request.question,
active_chats[chat_id],
use_rag=chat_request.use_rag
)
# 添加AI回答到历史
active_chats[chat_id].append({"role": "assistant", "content": answer})
logger.info(f"问题处理完成: {chat_request.question}")
return ChatResponse(
question=chat_request.question,
answer=answer,
chat_id=chat_id
)
except Exception as e:
logger.error(f"处理请求时发生错误: {str(e)}", exc_info=True)
return ErrorResponse(error=str(e))
# 注册路由器
app.include_router(router)
app.include_router(rag_router) # 添加RAG路由