You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

99 lines
3.9 KiB

from flask_socketio import SocketIO, emit
from app.core.logger import logger
from app.services.qa_service import QAService
from typing import Dict, Set
import threading
import queue
import time
class WebSocketService:
"""WebSocket服务类,处理WebSocket连接和消息"""
def __init__(self, socketio: SocketIO, qa_service: QAService):
"""
初始化WebSocket服务
Args:
socketio: SocketIO实例
qa_service: 问答服务实例
"""
self.socketio = socketio
self.qa_service = qa_service
self.active_connections: Dict[str, Set[str]] = {}
self.message_queues: Dict[str, queue.Queue] = {}
self.processing_threads: Dict[str, threading.Thread] = {}
self.lock = threading.Lock()
def register_connection(self, sid: str):
"""注册新的WebSocket连接"""
with self.lock:
if sid not in self.active_connections:
self.active_connections[sid] = set()
self.message_queues[sid] = queue.Queue()
self._start_processing_thread(sid)
logger.info(f"新连接注册: {sid}")
def unregister_connection(self, sid: str):
"""注销WebSocket连接"""
with self.lock:
if sid in self.active_connections:
del self.active_connections[sid]
del self.message_queues[sid]
if sid in self.processing_threads:
self.processing_threads[sid].join(timeout=1)
del self.processing_threads[sid]
logger.info(f"连接注销: {sid}")
def _start_processing_thread(self, sid: str):
"""启动消息处理线程"""
def process_messages():
while sid in self.active_connections:
try:
message = self.message_queues[sid].get(timeout=1)
if message is None: # 退出信号
break
try:
# 处理消息
answer = self.qa_service.get_answer(message)
self.socketio.emit('ai_response',
{'message': answer},
room=sid)
except Exception as e:
logger.error(f"处理消息时发生错误: {str(e)}", exc_info=True)
self.socketio.emit('ai_response',
{'message': f'处理消息时发生错误: {str(e)}'},
room=sid)
except queue.Empty:
continue
except Exception as e:
logger.error(f"消息处理线程错误: {str(e)}", exc_info=True)
break
thread = threading.Thread(target=process_messages)
thread.daemon = True
thread.start()
self.processing_threads[sid] = thread
def handle_message(self, sid: str, message: str):
"""处理接收到的消息"""
if not message:
self.socketio.emit('ai_response',
{'message': '请输入问题'},
room=sid)
return
logger.info(f"收到消息 - 会话ID: {sid}, 内容: {message}")
if sid in self.message_queues:
self.message_queues[sid].put(message)
else:
logger.error(f"会话 {sid} 的消息队列不存在")
self.socketio.emit('ai_response',
{'message': '服务器错误,请重新连接'},
room=sid)
def broadcast(self, message: str):
"""广播消息给所有连接的客户端"""
self.socketio.emit('broadcast', {'message': message})