You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

250 lines
9.3 KiB

"""
WebSocket管理器模块
统一管理多个WebSocket客户端和Channel的连接
"""
import asyncio
from typing import Dict, List, Optional, Any
from datetime import datetime
from app.core.websocket.client import WebSocketClient
from app.core.websocket.channel import WebSocketChannel
from app.core.websocket.adapter import WebSocketAdapter
from app.utils.log import get_logger
logger = get_logger(__name__)
class WebSocketManager:
"""WebSocket管理器 - 统一管理多个WebSocket客户端和Channel的连接"""
def __init__(self):
self._clients: Dict[str, WebSocketClient] = {}
self._channels: Dict[str, WebSocketChannel] = {}
self._adapters: Dict[str, WebSocketAdapter] = {}
self._created_at = datetime.now()
self._lock = asyncio.Lock()
logger.info("WebSocket管理器初始化完成")
async def create_client(self, name: str, url: str) -> WebSocketClient:
"""创建WebSocket客户端"""
async with self._lock:
if name in self._clients:
logger.warning(f"WebSocket客户端 {name} 已存在")
return self._clients[name]
client = WebSocketClient(url, name)
self._clients[name] = client
logger.info(f"WebSocket管理器创建客户端: {name} -> {url}")
return client
async def remove_client(self, name: str) -> bool:
"""移除WebSocket客户端"""
async with self._lock:
if name not in self._clients:
return False
# 停止相关的适配器
await self._stop_client_adapters(name)
client = self._clients[name]
await client.disconnect()
del self._clients[name]
logger.info(f"WebSocket管理器移除客户端: {name}")
return True
def get_client(self, name: str) -> Optional[WebSocketClient]:
"""获取WebSocket客户端"""
return self._clients.get(name)
def get_all_clients(self) -> Dict[str, WebSocketClient]:
"""获取所有客户端"""
return self._clients.copy()
async def connect_client(self, name: str) -> bool:
"""连接指定客户端"""
client = self.get_client(name)
if not client:
logger.error(f"WebSocket客户端 {name} 不存在")
return False
return await client.connect()
async def disconnect_client(self, name: str) -> bool:
"""断开指定客户端"""
client = self.get_client(name)
if not client:
logger.error(f"WebSocket客户端 {name} 不存在")
return False
await client.disconnect()
return True
async def connect_all_clients(self) -> Dict[str, bool]:
"""连接所有客户端"""
results = {}
for name, client in self._clients.items():
results[name] = await client.connect()
return results
async def disconnect_all_clients(self):
"""断开所有客户端"""
for name, client in self._clients.items():
await client.disconnect()
logger.info("WebSocket管理器断开所有客户端")
async def create_channel(self, name: str, max_size: int = 1000) -> WebSocketChannel:
"""创建Channel"""
async with self._lock:
if name in self._channels:
logger.warning(f"Channel {name} 已存在")
return self._channels[name]
channel = WebSocketChannel(name, max_size)
self._channels[name] = channel
logger.info(f"WebSocket管理器创建Channel: {name}")
return channel
async def remove_channel(self, name: str) -> bool:
"""移除Channel"""
async with self._lock:
if name not in self._channels:
return False
# 停止相关的适配器
await self._stop_channel_adapters(name)
channel = self._channels[name]
await channel.disconnect()
del self._channels[name]
logger.info(f"WebSocket管理器移除Channel: {name}")
return True
def get_channel(self, name: str) -> Optional[WebSocketChannel]:
"""获取Channel"""
return self._channels.get(name)
def get_all_channels(self) -> Dict[str, WebSocketChannel]:
"""获取所有Channel"""
return self._channels.copy()
async def create_adapter(self, client_name: str, channel_name: str, heartbeat_interval: int = 120) -> Optional[WebSocketAdapter]:
"""创建适配器,连接客户端和Channel"""
async with self._lock:
client = self.get_client(client_name)
channel = self.get_channel(channel_name)
if not client:
logger.error(f"WebSocket客户端 {client_name} 不存在")
return None
if not channel:
logger.error(f"Channel {channel_name} 不存在")
return None
adapter_key = f"{client_name}:{channel_name}"
# 若已存在,确保运行;若未运行则重启
if adapter_key in self._adapters:
adapter = self._adapters[adapter_key]
# 确保Channel已连接
await channel.connect()
# 如果发送任务未运行,则重新启动
stats = adapter.get_stats()
if not stats.get("send_task_running"):
await adapter.start()
logger.info(f"WebSocket管理器重新启动适配器: {adapter_key}")
else:
logger.info(f"WebSocket管理器适配器已存在且运行中: {adapter_key}")
return adapter
# 新建并启动
adapter = WebSocketAdapter(client, channel, channel, heartbeat_interval)
self._adapters[adapter_key] = adapter
await channel.connect()
await adapter.start()
logger.info(f"WebSocket管理器创建适配器: {adapter_key}")
return adapter
async def remove_adapter(self, client_name: str, channel_name: str) -> bool:
"""移除适配器"""
async with self._lock:
adapter_key = f"{client_name}:{channel_name}"
if adapter_key not in self._adapters:
return False
adapter = self._adapters[adapter_key]
await adapter.stop()
del self._adapters[adapter_key]
logger.info(f"WebSocket管理器移除适配器: {adapter_key}")
return True
def get_adapter(self, client_name: str, channel_name: str) -> Optional[WebSocketAdapter]:
"""获取适配器"""
adapter_key = f"{client_name}:{channel_name}"
return self._adapters.get(adapter_key)
def get_all_adapters(self) -> Dict[str, WebSocketAdapter]:
"""获取所有适配器"""
return self._adapters.copy()
async def _stop_client_adapters(self, client_name: str):
"""停止指定客户端的所有适配器"""
adapters_to_remove = []
for adapter_key, adapter in self._adapters.items():
if adapter_key.startswith(f"{client_name}:"):
await adapter.stop()
adapters_to_remove.append(adapter_key)
for adapter_key in adapters_to_remove:
del self._adapters[adapter_key]
async def _stop_channel_adapters(self, channel_name: str):
"""停止指定Channel的所有适配器"""
adapters_to_remove = []
for adapter_key, adapter in self._adapters.items():
if adapter_key.endswith(f":{channel_name}"):
await adapter.stop()
adapters_to_remove.append(adapter_key)
for adapter_key in adapters_to_remove:
del self._adapters[adapter_key]
# 严格架构:不再提供直接发送能力,统一走Channel
# 严格架构:不再提供直接广播能力,统一走Channel
def get_stats(self) -> Dict[str, Any]:
"""获取管理器统计信息"""
return {
"created_at": self._created_at,
"client_count": len(self._clients),
"channel_count": len(self._channels),
"adapter_count": len(self._adapters),
"clients": {name: client.get_stats() for name, client in self._clients.items()},
"channels": {name: channel.get_stats() for name, channel in self._channels.items()},
"adapters": {key: adapter.get_stats() for key, adapter in self._adapters.items()}
}
async def cleanup(self):
"""清理资源"""
# 停止所有适配器
for adapter in self._adapters.values():
await adapter.stop()
# 断开所有客户端
await self.disconnect_all_clients()
# 断开所有Channel
for channel in self._channels.values():
await channel.disconnect()
self._clients.clear()
self._channels.clear()
self._adapters.clear()
logger.info("WebSocket管理器清理完成")
# 全局WebSocket管理器实例
websocket_manager = WebSocketManager()