You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
99 lines
3.9 KiB
99 lines
3.9 KiB
from flask_socketio import SocketIO, emit
|
|
from app.core.logger import logger
|
|
from app.services.qa_service import QAService
|
|
from typing import Dict, Set
|
|
import threading
|
|
import queue
|
|
import time
|
|
|
|
class WebSocketService:
|
|
"""WebSocket服务类,处理WebSocket连接和消息"""
|
|
|
|
def __init__(self, socketio: SocketIO, qa_service: QAService):
|
|
"""
|
|
初始化WebSocket服务
|
|
|
|
Args:
|
|
socketio: SocketIO实例
|
|
qa_service: 问答服务实例
|
|
"""
|
|
self.socketio = socketio
|
|
self.qa_service = qa_service
|
|
self.active_connections: Dict[str, Set[str]] = {}
|
|
self.message_queues: Dict[str, queue.Queue] = {}
|
|
self.processing_threads: Dict[str, threading.Thread] = {}
|
|
self.lock = threading.Lock()
|
|
|
|
def register_connection(self, sid: str):
|
|
"""注册新的WebSocket连接"""
|
|
with self.lock:
|
|
if sid not in self.active_connections:
|
|
self.active_connections[sid] = set()
|
|
self.message_queues[sid] = queue.Queue()
|
|
self._start_processing_thread(sid)
|
|
logger.info(f"新连接注册: {sid}")
|
|
|
|
def unregister_connection(self, sid: str):
|
|
"""注销WebSocket连接"""
|
|
with self.lock:
|
|
if sid in self.active_connections:
|
|
del self.active_connections[sid]
|
|
del self.message_queues[sid]
|
|
if sid in self.processing_threads:
|
|
self.processing_threads[sid].join(timeout=1)
|
|
del self.processing_threads[sid]
|
|
logger.info(f"连接注销: {sid}")
|
|
|
|
def _start_processing_thread(self, sid: str):
|
|
"""启动消息处理线程"""
|
|
def process_messages():
|
|
while sid in self.active_connections:
|
|
try:
|
|
message = self.message_queues[sid].get(timeout=1)
|
|
if message is None: # 退出信号
|
|
break
|
|
|
|
try:
|
|
# 处理消息
|
|
answer = self.qa_service.get_answer(message)
|
|
self.socketio.emit('ai_response',
|
|
{'message': answer},
|
|
room=sid)
|
|
except Exception as e:
|
|
logger.error(f"处理消息时发生错误: {str(e)}", exc_info=True)
|
|
self.socketio.emit('ai_response',
|
|
{'message': f'处理消息时发生错误: {str(e)}'},
|
|
room=sid)
|
|
|
|
except queue.Empty:
|
|
continue
|
|
except Exception as e:
|
|
logger.error(f"消息处理线程错误: {str(e)}", exc_info=True)
|
|
break
|
|
|
|
thread = threading.Thread(target=process_messages)
|
|
thread.daemon = True
|
|
thread.start()
|
|
self.processing_threads[sid] = thread
|
|
|
|
def handle_message(self, sid: str, message: str):
|
|
"""处理接收到的消息"""
|
|
if not message:
|
|
self.socketio.emit('ai_response',
|
|
{'message': '请输入问题'},
|
|
room=sid)
|
|
return
|
|
|
|
logger.info(f"收到消息 - 会话ID: {sid}, 内容: {message}")
|
|
|
|
if sid in self.message_queues:
|
|
self.message_queues[sid].put(message)
|
|
else:
|
|
logger.error(f"会话 {sid} 的消息队列不存在")
|
|
self.socketio.emit('ai_response',
|
|
{'message': '服务器错误,请重新连接'},
|
|
room=sid)
|
|
|
|
def broadcast(self, message: str):
|
|
"""广播消息给所有连接的客户端"""
|
|
self.socketio.emit('broadcast', {'message': message})
|