You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
135 lines
4.8 KiB
135 lines
4.8 KiB
from fastapi import APIRouter, Request, WebSocket, WebSocketDisconnect
|
|
from fastapi.responses import HTMLResponse
|
|
from pydantic import BaseModel
|
|
from app.core.logger import logger
|
|
from app.services.qa_service import QAService
|
|
from app.middleware.error_handler import ErrorHandler
|
|
from app.middleware.request_logger import RequestLogger
|
|
from app.api.rag_api import router as rag_router
|
|
from typing import Dict, Optional
|
|
|
|
# 请求和响应模型
|
|
class ChatRequest(BaseModel):
|
|
question: str
|
|
chat_id: Optional[str] = None
|
|
use_rag: Optional[bool] = True
|
|
|
|
class ChatResponse(BaseModel):
|
|
question: str
|
|
answer: str
|
|
chat_id: str
|
|
|
|
class ErrorResponse(BaseModel):
|
|
error: str
|
|
|
|
# 存储活跃的聊天会话
|
|
active_chats: Dict[str, list] = {}
|
|
|
|
def register_routes(app, qa_service: QAService):
|
|
"""注册所有路由"""
|
|
|
|
# 创建路由器
|
|
router = APIRouter(prefix="/api/v1")
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
@RequestLogger.log_request
|
|
@ErrorHandler.handle_error
|
|
async def index(request: Request):
|
|
"""返回聊天页面"""
|
|
logger.info("访问根路径")
|
|
try:
|
|
return app.state.templates.TemplateResponse("chat.html", {"request": request})
|
|
except Exception as e:
|
|
logger.error(f"渲染模板时发生错误: {str(e)}", exc_info=True)
|
|
return str(e), 500
|
|
|
|
@app.websocket("/ws")
|
|
async def websocket_endpoint(websocket: WebSocket):
|
|
await websocket.accept()
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_json()
|
|
message = data.get('message', '').strip()
|
|
chat_id = data.get('chatId')
|
|
use_rag = data.get('use_rag', True) # 默认使用 RAG
|
|
|
|
if message:
|
|
# 获取或创建聊天历史
|
|
if chat_id not in active_chats:
|
|
active_chats[chat_id] = []
|
|
|
|
# 添加用户消息到历史
|
|
active_chats[chat_id].append({"role": "user", "content": message})
|
|
|
|
# 获取回答
|
|
answer = qa_service.get_answer(message, active_chats[chat_id], use_rag=use_rag)
|
|
|
|
# 添加AI回答到历史
|
|
active_chats[chat_id].append({"role": "assistant", "content": answer})
|
|
|
|
# 发送响应
|
|
await websocket.send_json({
|
|
"question": message,
|
|
"answer": answer,
|
|
"chatId": chat_id
|
|
})
|
|
except WebSocketDisconnect:
|
|
logger.info("WebSocket 连接断开")
|
|
except Exception as e:
|
|
logger.error(f"WebSocket 处理错误: {str(e)}", exc_info=True)
|
|
await websocket.close()
|
|
|
|
@router.get("/health")
|
|
@RequestLogger.log_request
|
|
@ErrorHandler.handle_error
|
|
async def health_check(request: Request):
|
|
"""健康检查接口"""
|
|
logger.info("收到健康检查请求")
|
|
return {"status": "healthy"}
|
|
|
|
@router.post("/chat", response_model=ChatResponse, responses={
|
|
200: {"model": ChatResponse},
|
|
400: {"model": ErrorResponse},
|
|
500: {"model": ErrorResponse}
|
|
})
|
|
@RequestLogger.log_request
|
|
@ErrorHandler.handle_validation_error
|
|
@ErrorHandler.handle_error
|
|
async def chat(request: Request, chat_request: ChatRequest):
|
|
"""聊天接口,接受问题并返回 AI 回答"""
|
|
try:
|
|
logger.info(f"收到问题: {chat_request.question}")
|
|
|
|
# 获取或创建聊天历史
|
|
chat_id = chat_request.chat_id or str(len(active_chats) + 1)
|
|
if chat_id not in active_chats:
|
|
active_chats[chat_id] = []
|
|
|
|
# 添加用户消息到历史
|
|
active_chats[chat_id].append({"role": "user", "content": chat_request.question})
|
|
|
|
# 获取回答
|
|
answer = qa_service.get_answer(
|
|
chat_request.question,
|
|
active_chats[chat_id],
|
|
use_rag=chat_request.use_rag
|
|
)
|
|
|
|
# 添加AI回答到历史
|
|
active_chats[chat_id].append({"role": "assistant", "content": answer})
|
|
|
|
logger.info(f"问题处理完成: {chat_request.question}")
|
|
|
|
return ChatResponse(
|
|
question=chat_request.question,
|
|
answer=answer,
|
|
chat_id=chat_id
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"处理请求时发生错误: {str(e)}", exc_info=True)
|
|
return ErrorResponse(error=str(e))
|
|
|
|
# 注册路由器
|
|
app.include_router(router)
|
|
app.include_router(rag_router) # 添加RAG路由
|