from fastapi import APIRouter, Request, WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse from pydantic import BaseModel from app.core.logger import logger from app.services.qa_service import QAService from app.middleware.error_handler import ErrorHandler from app.middleware.request_logger import RequestLogger from app.api.rag_api import router as rag_router from typing import Dict, Optional # 请求和响应模型 class ChatRequest(BaseModel): question: str chat_id: Optional[str] = None use_rag: Optional[bool] = True class ChatResponse(BaseModel): question: str answer: str chat_id: str class ErrorResponse(BaseModel): error: str # 存储活跃的聊天会话 active_chats: Dict[str, list] = {} def register_routes(app, qa_service: QAService): """注册所有路由""" # 创建路由器 router = APIRouter(prefix="/api/v1") @app.get("/", response_class=HTMLResponse) @RequestLogger.log_request @ErrorHandler.handle_error async def index(request: Request): """返回聊天页面""" logger.info("访问根路径") try: return app.state.templates.TemplateResponse("chat.html", {"request": request}) except Exception as e: logger.error(f"渲染模板时发生错误: {str(e)}", exc_info=True) return str(e), 500 @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() try: while True: data = await websocket.receive_json() message = data.get('message', '').strip() chat_id = data.get('chatId') use_rag = data.get('use_rag', True) # 默认使用 RAG if message: # 获取或创建聊天历史 if chat_id not in active_chats: active_chats[chat_id] = [] # 添加用户消息到历史 active_chats[chat_id].append({"role": "user", "content": message}) # 获取回答 answer = qa_service.get_answer(message, active_chats[chat_id], use_rag=use_rag) # 添加AI回答到历史 active_chats[chat_id].append({"role": "assistant", "content": answer}) # 发送响应 await websocket.send_json({ "question": message, "answer": answer, "chatId": chat_id }) except WebSocketDisconnect: logger.info("WebSocket 连接断开") except Exception as e: logger.error(f"WebSocket 处理错误: {str(e)}", exc_info=True) await websocket.close() @router.get("/health") @RequestLogger.log_request @ErrorHandler.handle_error async def health_check(request: Request): """健康检查接口""" logger.info("收到健康检查请求") return {"status": "healthy"} @router.post("/chat", response_model=ChatResponse, responses={ 200: {"model": ChatResponse}, 400: {"model": ErrorResponse}, 500: {"model": ErrorResponse} }) @RequestLogger.log_request @ErrorHandler.handle_validation_error @ErrorHandler.handle_error async def chat(request: Request, chat_request: ChatRequest): """聊天接口,接受问题并返回 AI 回答""" try: logger.info(f"收到问题: {chat_request.question}") # 获取或创建聊天历史 chat_id = chat_request.chat_id or str(len(active_chats) + 1) if chat_id not in active_chats: active_chats[chat_id] = [] # 添加用户消息到历史 active_chats[chat_id].append({"role": "user", "content": chat_request.question}) # 获取回答 answer = qa_service.get_answer( chat_request.question, active_chats[chat_id], use_rag=chat_request.use_rag ) # 添加AI回答到历史 active_chats[chat_id].append({"role": "assistant", "content": answer}) logger.info(f"问题处理完成: {chat_request.question}") return ChatResponse( question=chat_request.question, answer=answer, chat_id=chat_id ) except Exception as e: logger.error(f"处理请求时发生错误: {str(e)}", exc_info=True) return ErrorResponse(error=str(e)) # 注册路由器 app.include_router(router) app.include_router(rag_router) # 添加RAG路由