You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
422 lines
17 KiB
422 lines
17 KiB
"""
|
|
WebSocket管理器模块
|
|
统一管理WebSocket客户端和Channel,支持一个客户端3个Channel架构
|
|
遵循单一职责原则
|
|
"""
|
|
import asyncio
|
|
from typing import Dict, List, Optional, Any
|
|
from datetime import datetime
|
|
from app.core.websocket.client import WebSocketClient
|
|
from app.core.websocket.channel import WebSocketChannel
|
|
from app.core.websocket.adapter import WebSocketAdapter
|
|
from app.utils.structured_log import get_structured_logger, LogLevel
|
|
|
|
logger = get_structured_logger(__name__, LogLevel.INFO)
|
|
|
|
class WebSocketManager:
|
|
"""WebSocket管理器 - 统一管理WebSocket客户端和Channel
|
|
|
|
单一职责:
|
|
- 管理WebSocket客户端的生命周期
|
|
- 管理Channel的创建和销毁
|
|
- 管理适配器的创建和销毁
|
|
- 提供统一的API接口
|
|
|
|
架构设计:
|
|
- 一个客户端只需要3个Channel:心跳、发送、接收
|
|
- 心跳Channel:高优先级,用于心跳消息
|
|
- 发送Channel:正常优先级,用于业务数据发送
|
|
- 接收Channel:接收所有数据
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._clients: Dict[str, WebSocketClient] = {}
|
|
self._channels: Dict[str, WebSocketChannel] = {}
|
|
self._adapters: Dict[str, WebSocketAdapter] = {}
|
|
self._heartbeat_tasks: Dict[str, asyncio.Task] = {} # 心跳任务
|
|
self._created_at = datetime.now()
|
|
self._lock = asyncio.Lock()
|
|
|
|
logger.info("WebSocket管理器初始化完成")
|
|
|
|
async def create_client(self, name: str, url: str, heartbeat_interval: int = 120) -> WebSocketClient:
|
|
"""创建WebSocket客户端并自动创建3个Channel
|
|
|
|
单一职责:只负责创建客户端和Channel,不处理业务逻辑
|
|
"""
|
|
async with self._lock:
|
|
if name in self._clients:
|
|
logger.warning(f"WebSocket客户端 {name} 已存在")
|
|
return self._clients[name]
|
|
|
|
try:
|
|
# 创建客户端
|
|
client = WebSocketClient(url, name)
|
|
self._clients[name] = client
|
|
|
|
# 创建3个Channel
|
|
await self._create_client_channels(name)
|
|
|
|
# 启动心跳任务 - 确保即使其他步骤失败也要尝试启动心跳
|
|
try:
|
|
await self._start_heartbeat_task(name, heartbeat_interval)
|
|
except Exception as e:
|
|
logger.error(f"心跳任务启动失败,但继续创建客户端: {name} - {e}")
|
|
|
|
logger.info(f"WebSocket管理器创建客户端: {name} -> {url}")
|
|
return client
|
|
|
|
except Exception as e:
|
|
# 如果创建过程中出现异常,清理已创建的资源
|
|
if name in self._clients:
|
|
del self._clients[name]
|
|
logger.error(f"WebSocket管理器创建客户端失败: {name} - {e}")
|
|
raise
|
|
|
|
async def _create_client_channels(self, client_name: str):
|
|
"""为客户端创建3个Channel
|
|
|
|
单一职责:只负责创建Channel,不处理业务逻辑
|
|
"""
|
|
try:
|
|
# 创建3个Channel
|
|
channels = [
|
|
(f"{client_name}_heartbeat", 100), # 心跳Channel,小队列
|
|
(f"{client_name}_send", 1000), # 发送Channel,大队列
|
|
(f"{client_name}_receive", 1000) # 接收Channel,大队列
|
|
]
|
|
|
|
for channel_name, max_size in channels:
|
|
if channel_name not in self._channels:
|
|
channel = WebSocketChannel(channel_name, max_size)
|
|
self._channels[channel_name] = channel
|
|
await channel.connect()
|
|
logger.info(f"WebSocket管理器创建Channel: {channel_name}")
|
|
|
|
# 创建适配器
|
|
await self._create_client_adapters(client_name)
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket管理器创建客户端Channel失败: {client_name} - {e}")
|
|
raise
|
|
|
|
async def _create_client_adapters(self, client_name: str):
|
|
"""为客户端创建适配器
|
|
|
|
单一职责:只负责创建适配器,不处理业务逻辑
|
|
"""
|
|
try:
|
|
client = self._clients[client_name]
|
|
|
|
# 获取Channel
|
|
heartbeat_channel = self._channels[f"{client_name}_heartbeat"]
|
|
send_channel = self._channels[f"{client_name}_send"]
|
|
receive_channel = self._channels[f"{client_name}_receive"]
|
|
|
|
# 创建适配器
|
|
adapters = [
|
|
# 心跳适配器:心跳Channel -> WebSocket
|
|
(f"{client_name}:heartbeat", client, heartbeat_channel, heartbeat_channel),
|
|
# 发送适配器:发送Channel -> WebSocket
|
|
(f"{client_name}:send", client, send_channel, receive_channel),
|
|
# 接收适配器:WebSocket -> 接收Channel
|
|
(f"{client_name}:receive", client, receive_channel, receive_channel)
|
|
]
|
|
|
|
for adapter_key, client, outbound_channel, inbound_channel in adapters:
|
|
if adapter_key not in self._adapters:
|
|
adapter = WebSocketAdapter(client, outbound_channel, inbound_channel)
|
|
self._adapters[adapter_key] = adapter
|
|
await adapter.start()
|
|
logger.info(f"WebSocket管理器创建适配器: {adapter_key}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket管理器创建客户端适配器失败: {client_name} - {e}")
|
|
raise
|
|
|
|
async def _start_heartbeat_task(self, client_name: str, heartbeat_interval: int):
|
|
"""启动心跳任务
|
|
|
|
单一职责:只负责心跳任务管理,不处理业务逻辑
|
|
"""
|
|
try:
|
|
# 停止已存在的心跳任务
|
|
if client_name in self._heartbeat_tasks:
|
|
self._heartbeat_tasks[client_name].cancel()
|
|
|
|
# 创建新的心跳任务
|
|
heartbeat_task = asyncio.create_task(self._heartbeat_loop(client_name, heartbeat_interval))
|
|
self._heartbeat_tasks[client_name] = heartbeat_task
|
|
|
|
logger.info(f"WebSocket管理器启动心跳任务: {client_name} 间隔:{heartbeat_interval}秒")
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket管理器启动心跳任务失败: {client_name} - {e}")
|
|
raise
|
|
|
|
async def _heartbeat_loop(self, client_name: str, heartbeat_interval: int):
|
|
"""心跳循环
|
|
|
|
单一职责:只负责心跳数据生成,不处理业务逻辑
|
|
"""
|
|
logger.info(f"心跳循环开始: {client_name} 间隔:{heartbeat_interval}秒")
|
|
try:
|
|
heartbeat_channel = self._channels.get(f"{client_name}_heartbeat")
|
|
if not heartbeat_channel:
|
|
logger.error(f"心跳Channel不存在: {client_name}_heartbeat")
|
|
# 等待一段时间后重试,而不是直接返回
|
|
await asyncio.sleep(5)
|
|
return
|
|
|
|
logger.info(f"心跳循环启动成功: {client_name} -> {heartbeat_channel.name}")
|
|
|
|
while client_name in self._clients and self._clients[client_name].is_connected:
|
|
try:
|
|
# 创建心跳消息
|
|
from app.core.websocket.channel import ChannelMessage
|
|
heartbeat_data = {"Message": "ping"}
|
|
heartbeat_message = ChannelMessage(
|
|
type="heartbeat",
|
|
data=heartbeat_data,
|
|
priority=1 # 高优先级
|
|
)
|
|
|
|
# 发送到心跳Channel
|
|
success = await heartbeat_channel.send_message(heartbeat_message)
|
|
if success:
|
|
logger.debug(f"心跳消息已发送到Channel: {client_name}_heartbeat")
|
|
else:
|
|
logger.warning(f"心跳消息发送失败: {client_name}_heartbeat")
|
|
|
|
# 等待下次心跳
|
|
await asyncio.sleep(heartbeat_interval)
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info(f"心跳任务被取消: {client_name}")
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"心跳循环异常: {client_name} - {e}")
|
|
await asyncio.sleep(5) # 异常时等待5秒后重试
|
|
|
|
except Exception as e:
|
|
logger.error(f"心跳任务异常: {client_name} - {e}")
|
|
finally:
|
|
logger.info(f"心跳循环结束: {client_name}")
|
|
|
|
async def remove_client(self, name: str) -> bool:
|
|
"""移除WebSocket客户端
|
|
|
|
单一职责:只负责移除客户端和相关资源,不处理业务逻辑
|
|
"""
|
|
async with self._lock:
|
|
if name not in self._clients:
|
|
return False
|
|
|
|
# 停止心跳任务
|
|
await self._stop_heartbeat_task(name)
|
|
|
|
# 停止相关的适配器
|
|
await self._stop_client_adapters(name)
|
|
|
|
# 移除相关的Channel
|
|
await self._remove_client_channels(name)
|
|
|
|
# 移除客户端
|
|
client = self._clients[name]
|
|
await client.disconnect()
|
|
del self._clients[name]
|
|
|
|
logger.info(f"WebSocket管理器移除客户端: {name}")
|
|
return True
|
|
|
|
async def _stop_heartbeat_task(self, client_name: str):
|
|
"""停止心跳任务
|
|
|
|
单一职责:只负责停止心跳任务,不处理业务逻辑
|
|
"""
|
|
if client_name in self._heartbeat_tasks:
|
|
self._heartbeat_tasks[client_name].cancel()
|
|
del self._heartbeat_tasks[client_name]
|
|
logger.info(f"WebSocket管理器停止心跳任务: {client_name}")
|
|
|
|
async def _stop_client_adapters(self, client_name: str):
|
|
"""停止指定客户端的所有适配器
|
|
|
|
单一职责:只负责停止适配器,不处理业务逻辑
|
|
"""
|
|
adapters_to_remove = []
|
|
for adapter_key, adapter in self._adapters.items():
|
|
if adapter_key.startswith(f"{client_name}:"):
|
|
await adapter.stop()
|
|
adapters_to_remove.append(adapter_key)
|
|
|
|
for adapter_key in adapters_to_remove:
|
|
del self._adapters[adapter_key]
|
|
|
|
async def _remove_client_channels(self, client_name: str):
|
|
"""移除指定客户端的所有Channel
|
|
|
|
单一职责:只负责移除Channel,不处理业务逻辑
|
|
"""
|
|
channels_to_remove = []
|
|
for channel_name, channel in self._channels.items():
|
|
if channel_name.startswith(f"{client_name}_"):
|
|
await channel.disconnect()
|
|
channels_to_remove.append(channel_name)
|
|
|
|
for channel_name in channels_to_remove:
|
|
del self._channels[channel_name]
|
|
|
|
def get_client(self, name: str) -> Optional[WebSocketClient]:
|
|
"""获取WebSocket客户端"""
|
|
return self._clients.get(name)
|
|
|
|
def get_all_clients(self) -> Dict[str, WebSocketClient]:
|
|
"""获取所有客户端"""
|
|
return self._clients.copy()
|
|
|
|
async def connect_client(self, name: str) -> bool:
|
|
"""连接指定客户端"""
|
|
client = self.get_client(name)
|
|
if not client:
|
|
logger.error(f"WebSocket客户端 {name} 不存在")
|
|
return False
|
|
|
|
return await client.connect()
|
|
|
|
async def disconnect_client(self, name: str) -> bool:
|
|
"""断开指定客户端"""
|
|
client = self.get_client(name)
|
|
if not client:
|
|
logger.error(f"WebSocket客户端 {name} 不存在")
|
|
return False
|
|
|
|
await client.disconnect()
|
|
return True
|
|
|
|
def get_channel(self, name: str) -> Optional[WebSocketChannel]:
|
|
"""获取Channel"""
|
|
return self._channels.get(name)
|
|
|
|
def get_client_channels(self, client_name: str) -> Dict[str, WebSocketChannel]:
|
|
"""获取指定客户端的所有Channel"""
|
|
client_channels = {}
|
|
for channel_name, channel in self._channels.items():
|
|
if channel_name.startswith(f"{client_name}_"):
|
|
client_channels[channel_name] = channel
|
|
return client_channels
|
|
|
|
def get_all_channels(self) -> Dict[str, WebSocketChannel]:
|
|
"""获取所有Channel"""
|
|
return self._channels.copy()
|
|
|
|
def get_adapter(self, client_name: str, channel_type: str) -> Optional[WebSocketAdapter]:
|
|
"""获取适配器"""
|
|
adapter_key = f"{client_name}:{channel_type}"
|
|
return self._adapters.get(adapter_key)
|
|
|
|
def get_client_adapters(self, client_name: str) -> Dict[str, WebSocketAdapter]:
|
|
"""获取指定客户端的所有适配器"""
|
|
client_adapters = {}
|
|
for adapter_key, adapter in self._adapters.items():
|
|
if adapter_key.startswith(f"{client_name}:"):
|
|
client_adapters[adapter_key] = adapter
|
|
return client_adapters
|
|
|
|
def get_all_adapters(self) -> Dict[str, WebSocketAdapter]:
|
|
"""获取所有适配器"""
|
|
return self._adapters.copy()
|
|
|
|
async def send_message(self, client_name: str, message_type: str, data: Any, priority: int = 0) -> bool:
|
|
"""发送消息到指定客户端
|
|
|
|
单一职责:只负责消息路由,不处理业务逻辑
|
|
"""
|
|
try:
|
|
# 根据消息类型选择Channel
|
|
if message_type == "heartbeat":
|
|
channel_name = f"{client_name}_heartbeat"
|
|
else:
|
|
channel_name = f"{client_name}_send"
|
|
|
|
channel = self._channels.get(channel_name)
|
|
if not channel:
|
|
logger.error(f"Channel {channel_name} 不存在")
|
|
return False
|
|
|
|
# 创建消息
|
|
from app.core.websocket.channel import ChannelMessage
|
|
message = ChannelMessage(
|
|
type=message_type,
|
|
data=data,
|
|
priority=priority
|
|
)
|
|
|
|
# 发送到Channel
|
|
success = await channel.send_message(message)
|
|
if success:
|
|
logger.debug(f"WebSocket管理器发送消息成功: {client_name} -> {message_type}")
|
|
else:
|
|
logger.warning(f"WebSocket管理器发送消息失败: {client_name} -> {message_type}")
|
|
|
|
return success
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket管理器发送消息异常: {client_name} -> {message_type} - {e}")
|
|
return False
|
|
|
|
async def send_heartbeat(self, client_name: str) -> bool:
|
|
"""发送心跳消息
|
|
|
|
单一职责:只负责心跳消息发送,不处理业务逻辑
|
|
"""
|
|
heartbeat_data = {"Message": "ping"}
|
|
return await self.send_message(client_name, "heartbeat", heartbeat_data, priority=1)
|
|
|
|
async def cleanup(self):
|
|
"""清理所有资源
|
|
|
|
单一职责:只负责资源清理,不处理业务逻辑
|
|
"""
|
|
try:
|
|
# 停止所有心跳任务
|
|
for client_name in list(self._heartbeat_tasks.keys()):
|
|
await self._stop_heartbeat_task(client_name)
|
|
|
|
# 先停止所有适配器(停止数据转发)
|
|
for adapter in self._adapters.values():
|
|
await adapter.stop()
|
|
|
|
# 再断开所有客户端(清理连接和消息处理器)
|
|
for client in self._clients.values():
|
|
await client.disconnect()
|
|
|
|
# 最后断开所有Channel(清理队列)
|
|
for channel in self._channels.values():
|
|
await channel.disconnect()
|
|
|
|
# 清空所有集合
|
|
self._adapters.clear()
|
|
self._clients.clear()
|
|
self._channels.clear()
|
|
|
|
logger.info("WebSocket管理器清理完成")
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket管理器清理失败: {e}")
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""获取管理器统计信息"""
|
|
return {
|
|
"created_at": self._created_at,
|
|
"client_count": len(self._clients),
|
|
"channel_count": len(self._channels),
|
|
"adapter_count": len(self._adapters),
|
|
"heartbeat_task_count": len(self._heartbeat_tasks),
|
|
"clients": {name: client.get_stats() for name, client in self._clients.items()},
|
|
"channels": {name: channel.get_stats() for name, channel in self._channels.items()},
|
|
"adapters": {key: adapter.get_stats() for key, adapter in self._adapters.items()}
|
|
}
|
|
|
|
# 全局WebSocket管理器实例
|
|
websocket_manager = WebSocketManager()
|