You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
284 lines
10 KiB
284 lines
10 KiB
"""
|
|
WebSocket Channel模块
|
|
实现类似.NET Channel的数据接收机制
|
|
"""
|
|
import asyncio
|
|
from typing import Any, Optional, Callable, Dict, List
|
|
from enum import Enum
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from app.utils.structured_log import get_structured_logger, LogLevel
|
|
|
|
logger = get_structured_logger(__name__, LogLevel.INFO)
|
|
|
|
class ChannelState(Enum):
|
|
"""Channel状态枚举"""
|
|
IDLE = "idle"
|
|
CONNECTING = "connecting"
|
|
CONNECTED = "connected"
|
|
DISCONNECTED = "disconnected"
|
|
ERROR = "error"
|
|
|
|
@dataclass
|
|
class ChannelMessage:
|
|
"""Channel消息数据类(最小化)"""
|
|
type: str
|
|
data: Any # 已组装的最终发送载荷(str/bytes/JSON字符串等)
|
|
|
|
class WebSocketChannel:
|
|
"""WebSocket Channel - 类似.NET Channel的数据接收机制"""
|
|
|
|
def __init__(self, name: str, max_size: int = 1000):
|
|
self.name = name
|
|
self.max_size = max_size
|
|
self._queue: asyncio.Queue = asyncio.Queue(maxsize=max_size)
|
|
self._state = ChannelState.IDLE
|
|
self._subscribers: List[Callable] = []
|
|
self._filters: Dict[str, Callable] = {}
|
|
self._created_at = datetime.now()
|
|
self._last_message_at: Optional[datetime] = None
|
|
self._connection_count = 0 # 连接次数统计
|
|
|
|
logger.info(f"创建WebSocket Channel: {name}")
|
|
|
|
@property
|
|
def state(self) -> ChannelState:
|
|
"""获取Channel状态"""
|
|
return self._state
|
|
|
|
@property
|
|
def is_connected(self) -> bool:
|
|
"""检查是否已连接"""
|
|
return self._state == ChannelState.CONNECTED
|
|
|
|
@property
|
|
def queue_size(self) -> int:
|
|
"""获取队列大小"""
|
|
return self._queue.qsize()
|
|
|
|
@property
|
|
def is_full(self) -> bool:
|
|
"""检查队列是否已满"""
|
|
return self._queue.full()
|
|
|
|
@property
|
|
def connection_count(self) -> int:
|
|
"""获取连接次数"""
|
|
return self._connection_count
|
|
|
|
async def connect(self) -> bool:
|
|
"""连接Channel"""
|
|
try:
|
|
# 检查当前状态
|
|
if self._state == ChannelState.CONNECTED:
|
|
logger.warning(f"Channel {self.name} 已经连接")
|
|
return True
|
|
|
|
if self._state == ChannelState.ERROR:
|
|
logger.warning(f"Channel {self.name} 处于错误状态,尝试重新连接")
|
|
|
|
self._state = ChannelState.CONNECTING
|
|
logger.info(f"Channel {self.name} 正在连接...")
|
|
|
|
# 这里可以添加实际的连接逻辑
|
|
await asyncio.sleep(0.1) # 模拟连接过程
|
|
|
|
self._state = ChannelState.CONNECTED
|
|
self._connection_count += 1
|
|
logger.info(f"Channel {self.name} 连接成功 (第{self._connection_count}次)")
|
|
return True
|
|
|
|
except Exception as e:
|
|
self._state = ChannelState.ERROR
|
|
logger.error(f"Channel {self.name} 连接失败: {e}")
|
|
return False
|
|
|
|
async def disconnect(self):
|
|
"""断开Channel连接"""
|
|
try:
|
|
if self._state == ChannelState.DISCONNECTED:
|
|
logger.warning(f"Channel {self.name} 已经断开")
|
|
return
|
|
|
|
self._state = ChannelState.DISCONNECTED
|
|
logger.info(f"Channel {self.name} 已断开")
|
|
|
|
# 清空队列
|
|
await self.clear()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Channel {self.name} 断开连接时出错: {e}")
|
|
|
|
async def reconnect(self) -> bool:
|
|
"""重新连接Channel"""
|
|
try:
|
|
logger.info(f"Channel {self.name} 尝试重新连接...")
|
|
await self.disconnect()
|
|
return await self.connect()
|
|
except Exception as e:
|
|
logger.error(f"Channel {self.name} 重新连接失败: {e}")
|
|
return False
|
|
|
|
async def send_message(self, message: ChannelMessage) -> bool:
|
|
"""发送消息到Channel"""
|
|
try:
|
|
if not self.is_connected:
|
|
logger.warning(f"Channel {self.name} 未连接,无法发送消息")
|
|
return False
|
|
|
|
# 检查队列是否已满
|
|
if self.is_full:
|
|
logger.warning(f"Channel {self.name} 队列已满,丢弃消息")
|
|
return False
|
|
|
|
await self._queue.put(message)
|
|
self._last_message_at = datetime.now()
|
|
|
|
# 通知订阅者
|
|
await self._notify_subscribers(message)
|
|
|
|
logger.debug(f"Channel {self.name} 发送消息成功: {message.type}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Channel {self.name} 发送消息失败: {e}")
|
|
return False
|
|
|
|
async def receive_message(self, timeout: Optional[float] = None) -> Optional[ChannelMessage]:
|
|
"""从Channel接收消息"""
|
|
try:
|
|
if not self.is_connected:
|
|
logger.warning(f"Channel {self.name} 未连接,无法接收消息")
|
|
return None
|
|
|
|
message = await asyncio.wait_for(self._queue.get(), timeout=timeout)
|
|
self._last_message_at = datetime.now()
|
|
|
|
logger.debug(f"Channel {self.name} 接收消息: {message.type}")
|
|
return message
|
|
|
|
except asyncio.TimeoutError:
|
|
logger.debug(f"Channel {self.name} 接收消息超时")
|
|
return None
|
|
|
|
def try_receive_message(self) -> Optional[ChannelMessage]:
|
|
"""非阻塞获取一条消息,如无数据立即返回None"""
|
|
try:
|
|
if not self.is_connected:
|
|
logger.warning(f"Channel {self.name} 未连接,无法接收消息")
|
|
return None
|
|
message = self._queue.get_nowait()
|
|
self._last_message_at = datetime.now()
|
|
logger.debug(f"Channel {self.name} 非阻塞接收消息: {message.type}")
|
|
return message
|
|
except asyncio.QueueEmpty:
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Channel {self.name} 非阻塞接收失败: {e}")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Channel {self.name} 接收消息失败: {e}")
|
|
return None
|
|
|
|
async def receive_messages(self, count: int = 1, timeout: Optional[float] = None) -> List[ChannelMessage]:
|
|
"""批量接收消息"""
|
|
messages = []
|
|
try:
|
|
for _ in range(count):
|
|
message = await self.receive_message(timeout)
|
|
if message:
|
|
messages.append(message)
|
|
else:
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Channel {self.name} 批量接收消息失败: {e}")
|
|
|
|
return messages
|
|
|
|
def subscribe(self, callback: Callable[[ChannelMessage], None]):
|
|
"""订阅Channel消息"""
|
|
if callback not in self._subscribers:
|
|
self._subscribers.append(callback)
|
|
logger.info(f"Channel {self.name} 添加订阅者: {callback.__name__}")
|
|
|
|
def unsubscribe(self, callback: Callable[[ChannelMessage], None]):
|
|
"""取消订阅"""
|
|
if callback in self._subscribers:
|
|
self._subscribers.remove(callback)
|
|
logger.info(f"Channel {self.name} 移除订阅者: {callback.__name__}")
|
|
|
|
def add_filter(self, filter_name: str, filter_func: Callable[[ChannelMessage], bool]):
|
|
"""添加消息过滤器"""
|
|
self._filters[filter_name] = filter_func
|
|
logger.info(f"Channel {self.name} 添加过滤器: {filter_name}")
|
|
|
|
def remove_filter(self, filter_name: str):
|
|
"""移除消息过滤器"""
|
|
if filter_name in self._filters:
|
|
del self._filters[filter_name]
|
|
logger.info(f"Channel {self.name} 移除过滤器: {filter_name}")
|
|
|
|
async def _notify_subscribers(self, message: ChannelMessage):
|
|
"""通知订阅者"""
|
|
# 应用过滤器
|
|
for filter_name, filter_func in self._filters.items():
|
|
if not filter_func(message):
|
|
logger.debug(f"Channel {self.name} 消息被过滤器 {filter_name} 过滤")
|
|
return
|
|
|
|
# 通知所有订阅者
|
|
for subscriber in self._subscribers:
|
|
try:
|
|
if asyncio.iscoroutinefunction(subscriber):
|
|
await subscriber(message)
|
|
else:
|
|
subscriber(message)
|
|
except Exception as e:
|
|
logger.error(f"Channel {self.name} 通知订阅者失败: {e}")
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""获取Channel统计信息"""
|
|
return {
|
|
"name": self.name,
|
|
"state": self.state.value,
|
|
"queue_size": self.queue_size,
|
|
"max_size": self.max_size,
|
|
"is_full": self.is_full,
|
|
"subscriber_count": len(self._subscribers),
|
|
"filter_count": len(self._filters),
|
|
"connection_count": self._connection_count,
|
|
"created_at": self._created_at,
|
|
"last_message_at": self._last_message_at
|
|
}
|
|
|
|
async def clear(self):
|
|
"""清空Channel队列"""
|
|
cleared_count = 0
|
|
while not self._queue.empty():
|
|
try:
|
|
self._queue.get_nowait()
|
|
cleared_count += 1
|
|
except asyncio.QueueEmpty:
|
|
break
|
|
logger.info(f"Channel {self.name} 队列已清空,清除了 {cleared_count} 条消息")
|
|
|
|
def reset(self):
|
|
"""重置Channel状态(用于重新初始化)"""
|
|
try:
|
|
self._state = ChannelState.IDLE
|
|
self._subscribers.clear()
|
|
self._filters.clear()
|
|
self._connection_count = 0
|
|
self._last_message_at = None
|
|
logger.info(f"Channel {self.name} 状态已重置")
|
|
except Exception as e:
|
|
logger.error(f"Channel {self.name} 重置状态失败: {e}")
|
|
|
|
async def destroy(self):
|
|
"""销毁Channel(完全清理资源)"""
|
|
try:
|
|
await self.disconnect()
|
|
self.reset()
|
|
logger.info(f"Channel {self.name} 已销毁")
|
|
except Exception as e:
|
|
logger.error(f"Channel {self.name} 销毁失败: {e}")
|