from flask_socketio import SocketIO, emit from app.core.logger import logger from app.services.qa_service import QAService from typing import Dict, Set import threading import queue import time class WebSocketService: """WebSocket服务类,处理WebSocket连接和消息""" def __init__(self, socketio: SocketIO, qa_service: QAService): """ 初始化WebSocket服务 Args: socketio: SocketIO实例 qa_service: 问答服务实例 """ self.socketio = socketio self.qa_service = qa_service self.active_connections: Dict[str, Set[str]] = {} self.message_queues: Dict[str, queue.Queue] = {} self.processing_threads: Dict[str, threading.Thread] = {} self.lock = threading.Lock() def register_connection(self, sid: str): """注册新的WebSocket连接""" with self.lock: if sid not in self.active_connections: self.active_connections[sid] = set() self.message_queues[sid] = queue.Queue() self._start_processing_thread(sid) logger.info(f"新连接注册: {sid}") def unregister_connection(self, sid: str): """注销WebSocket连接""" with self.lock: if sid in self.active_connections: del self.active_connections[sid] del self.message_queues[sid] if sid in self.processing_threads: self.processing_threads[sid].join(timeout=1) del self.processing_threads[sid] logger.info(f"连接注销: {sid}") def _start_processing_thread(self, sid: str): """启动消息处理线程""" def process_messages(): while sid in self.active_connections: try: message = self.message_queues[sid].get(timeout=1) if message is None: # 退出信号 break try: # 处理消息 answer = self.qa_service.get_answer(message) self.socketio.emit('ai_response', {'message': answer}, room=sid) except Exception as e: logger.error(f"处理消息时发生错误: {str(e)}", exc_info=True) self.socketio.emit('ai_response', {'message': f'处理消息时发生错误: {str(e)}'}, room=sid) except queue.Empty: continue except Exception as e: logger.error(f"消息处理线程错误: {str(e)}", exc_info=True) break thread = threading.Thread(target=process_messages) thread.daemon = True thread.start() self.processing_threads[sid] = thread def handle_message(self, sid: str, message: str): """处理接收到的消息""" if not message: self.socketio.emit('ai_response', {'message': '请输入问题'}, room=sid) return logger.info(f"收到消息 - 会话ID: {sid}, 内容: {message}") if sid in self.message_queues: self.message_queues[sid].put(message) else: logger.error(f"会话 {sid} 的消息队列不存在") self.socketio.emit('ai_response', {'message': '服务器错误,请重新连接'}, room=sid) def broadcast(self, message: str): """广播消息给所有连接的客户端""" self.socketio.emit('broadcast', {'message': message})