You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
287 lines
12 KiB
287 lines
12 KiB
"""
|
|
WebSocket客户端模块
|
|
只负责WebSocket连接和数据收发,不管理Channel
|
|
遵循单一职责原则
|
|
"""
|
|
import asyncio
|
|
import json
|
|
import uuid
|
|
from typing import Any, Optional, Dict, Callable
|
|
from enum import Enum
|
|
from datetime import datetime
|
|
import websockets
|
|
from app.utils.structured_log import get_structured_logger, LogLevel
|
|
|
|
logger = get_structured_logger(__name__, LogLevel.INFO)
|
|
|
|
class WebSocketClientState(Enum):
|
|
"""WebSocket客户端状态枚举"""
|
|
DISCONNECTED = "disconnected"
|
|
CONNECTING = "connecting"
|
|
CONNECTED = "connected"
|
|
RECONNECTING = "reconnecting"
|
|
ERROR = "error"
|
|
|
|
class WebSocketClient:
|
|
"""WebSocket客户端 - 只负责连接和数据收发
|
|
|
|
单一职责:
|
|
- 建立和维护WebSocket连接
|
|
- 发送原始数据到WebSocket服务器
|
|
- 接收原始数据从WebSocket服务器
|
|
- 管理连接状态和重连逻辑
|
|
|
|
不负责:
|
|
- Channel管理
|
|
- 心跳管理
|
|
- 数据路由
|
|
- 业务逻辑处理
|
|
"""
|
|
|
|
def __init__(self, url: str, name: str = "default"):
|
|
self.url = url
|
|
self.name = name
|
|
self._websocket: Optional[websockets.WebSocketServerProtocol] = None
|
|
self._state = WebSocketClientState.DISCONNECTED
|
|
self._reconnect_attempts = 0
|
|
self._max_reconnect_attempts = 5
|
|
self._reconnect_delay = 1.0
|
|
self._message_handlers: Dict[str, Callable] = {}
|
|
self._receive_task: Optional[asyncio.Task] = None
|
|
self._created_at = datetime.now()
|
|
self._last_message_at: Optional[datetime] = None
|
|
self.heartbeat_interval: Optional[int] = None # 心跳间隔配置
|
|
|
|
logger.info(f"创建WebSocket客户端: {name} -> {url}")
|
|
|
|
@property
|
|
def state(self) -> WebSocketClientState:
|
|
"""获取客户端状态"""
|
|
return self._state
|
|
|
|
@property
|
|
def is_connected(self) -> bool:
|
|
"""检查是否已连接"""
|
|
return self._state == WebSocketClientState.CONNECTED
|
|
|
|
async def connect(self) -> bool:
|
|
"""连接到WebSocket服务器"""
|
|
try:
|
|
if self.is_connected:
|
|
logger.warning(f"WebSocket客户端 {self.name} 已经连接")
|
|
return True
|
|
|
|
self._state = WebSocketClientState.CONNECTING
|
|
logger.info(f"WebSocket客户端 {self.name} 正在连接... URL: {self.url}")
|
|
|
|
# 验证URL格式
|
|
if not self.url.startswith(('ws://', 'wss://')):
|
|
raise ValueError(f"无效的WebSocket URL: {self.url},必须以ws://或wss://开头")
|
|
|
|
# 建立WebSocket连接
|
|
ssl_context = None
|
|
if self.url.startswith('wss://'):
|
|
from app.core.config.settings import config
|
|
import ssl
|
|
ssl_context = ssl.create_default_context()
|
|
|
|
if not config.websocket.ssl_verify_hostname:
|
|
ssl_context.check_hostname = False
|
|
|
|
if not config.websocket.ssl_verify_certificate:
|
|
ssl_context.verify_mode = ssl.CERT_NONE
|
|
|
|
# 添加连接超时配置
|
|
from app.core.config.settings import config
|
|
connection_timeout = config.websocket.connection_timeout
|
|
|
|
logger.info(f"WebSocket客户端 {self.name} 开始连接,超时时间: {connection_timeout}秒")
|
|
|
|
# 使用websockets.connect的超时参数
|
|
self._websocket = await asyncio.wait_for(
|
|
websockets.connect(
|
|
self.url,
|
|
ssl=ssl_context,
|
|
ping_interval=None, # 禁用自动ping,由外部管理心跳
|
|
ping_timeout=None, # 禁用自动ping超时
|
|
close_timeout=10 # 关闭超时
|
|
),
|
|
timeout=connection_timeout
|
|
)
|
|
self._state = WebSocketClientState.CONNECTED
|
|
self._reconnect_attempts = 0
|
|
|
|
# 启动接收任务
|
|
self._receive_task = asyncio.create_task(self._receive_messages())
|
|
|
|
logger.info(f"WebSocket客户端 {self.name} 连接成功")
|
|
return True
|
|
|
|
except asyncio.TimeoutError:
|
|
self._state = WebSocketClientState.ERROR
|
|
logger.error(f"WebSocket客户端 {self.name} 连接超时: {connection_timeout}秒,URL: {self.url}")
|
|
return False
|
|
except ValueError as e:
|
|
self._state = WebSocketClientState.ERROR
|
|
logger.error(f"WebSocket客户端 {self.name} URL格式错误: {e}")
|
|
return False
|
|
except Exception as e:
|
|
self._state = WebSocketClientState.ERROR
|
|
logger.error(f"WebSocket客户端 {self.name} 连接失败: {e},URL: {self.url}")
|
|
return False
|
|
|
|
async def disconnect(self):
|
|
"""断开WebSocket连接"""
|
|
try:
|
|
self._state = WebSocketClientState.DISCONNECTED
|
|
|
|
# 停止接收任务
|
|
if self._receive_task:
|
|
self._receive_task.cancel()
|
|
|
|
# 关闭WebSocket连接
|
|
if self._websocket:
|
|
await self._websocket.close()
|
|
self._websocket = None
|
|
|
|
# 清理消息处理器
|
|
self._message_handlers.clear()
|
|
|
|
logger.info(f"WebSocket客户端 {self.name} 已断开")
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket客户端 {self.name} 断开连接时出错: {e}")
|
|
|
|
async def send_raw(self, payload: Any) -> bool:
|
|
"""发送原始数据到WebSocket服务器
|
|
|
|
单一职责:只负责发送数据,不处理业务逻辑
|
|
"""
|
|
try:
|
|
if not self.is_connected:
|
|
logger.warning(f"WebSocket客户端 {self.name} 未连接,无法发送消息")
|
|
return False
|
|
|
|
# 直接发送预组装的数据
|
|
if isinstance(payload, (bytes, bytearray)):
|
|
await self._websocket.send(payload)
|
|
else:
|
|
# 其他类型(如str),按原样发送
|
|
await self._websocket.send(payload)
|
|
|
|
self._last_message_at = datetime.now()
|
|
logger.debug(f"WebSocket客户端 {self.name} 发送数据成功")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket客户端 {self.name} 发送数据失败: {e}")
|
|
return False
|
|
|
|
def register_message_handler(self, message_type: str, handler: Callable):
|
|
"""注册消息处理器 - 用于处理接收到的消息"""
|
|
self._message_handlers[message_type] = handler
|
|
logger.info(f"WebSocket客户端 {self.name} 注册消息处理器: {message_type}")
|
|
|
|
def unregister_message_handler(self, message_type: str):
|
|
"""取消注册消息处理器"""
|
|
if message_type in self._message_handlers:
|
|
del self._message_handlers[message_type]
|
|
logger.info(f"WebSocket客户端 {self.name} 取消注册消息处理器: {message_type}")
|
|
|
|
async def _receive_messages(self):
|
|
"""接收消息循环 - 从WebSocket服务器接收数据
|
|
|
|
单一职责:只负责接收数据并调用处理器,不处理业务逻辑
|
|
"""
|
|
logger.info(f"WebSocket客户端 {self.name} 开始接收消息循环")
|
|
try:
|
|
while self.is_connected:
|
|
try:
|
|
logger.debug(f"WebSocket客户端 {self.name} 等待接收消息...")
|
|
message = await self._websocket.recv()
|
|
self._last_message_at = datetime.now()
|
|
|
|
logger.info(f"WebSocket客户端 {self.name} 收到原始消息: {message}")
|
|
|
|
# 调用消息处理器
|
|
await self._handle_message(message)
|
|
|
|
except websockets.exceptions.ConnectionClosed:
|
|
logger.warning(f"WebSocket客户端 {self.name} 连接已关闭")
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"WebSocket客户端 {self.name} 接收消息时出错: {e}")
|
|
break
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info(f"WebSocket客户端 {self.name} 接收任务已取消")
|
|
except Exception as e:
|
|
logger.error(f"WebSocket客户端 {self.name} 接收任务异常: {e}")
|
|
finally:
|
|
logger.info(f"WebSocket客户端 {self.name} 接收消息循环结束")
|
|
|
|
async def _handle_message(self, message: Any):
|
|
"""处理接收到的消息
|
|
|
|
单一职责:只负责调用注册的处理器,不处理业务逻辑
|
|
"""
|
|
try:
|
|
logger.debug(f"WebSocket客户端 {self.name} 开始处理消息: {message}")
|
|
|
|
# 尝试解析JSON消息
|
|
if isinstance(message, str):
|
|
try:
|
|
parsed_message = json.loads(message)
|
|
message_type = parsed_message.get("type", "data")
|
|
logger.debug(f"WebSocket客户端 {self.name} 解析JSON消息成功: type={message_type}")
|
|
except json.JSONDecodeError:
|
|
message_type = "raw"
|
|
parsed_message = {"type": "raw", "data": message}
|
|
logger.debug(f"WebSocket客户端 {self.name} JSON解析失败,作为raw消息处理")
|
|
else:
|
|
message_type = "raw"
|
|
parsed_message = {"type": "raw", "data": message}
|
|
logger.debug(f"WebSocket客户端 {self.name} 非字符串消息,作为raw消息处理")
|
|
|
|
# 调用对应的处理器
|
|
handler = self._message_handlers.get(message_type)
|
|
if handler:
|
|
logger.debug(f"WebSocket客户端 {self.name} 调用消息处理器: {message_type}")
|
|
if asyncio.iscoroutinefunction(handler):
|
|
await handler(parsed_message)
|
|
else:
|
|
handler(parsed_message)
|
|
logger.debug(f"WebSocket客户端 {self.name} 消息处理器调用完成: {message_type}")
|
|
else:
|
|
# 调用通配符处理器
|
|
wildcard_handler = self._message_handlers.get("*")
|
|
if wildcard_handler:
|
|
logger.debug(f"WebSocket客户端 {self.name} 调用通配符处理器")
|
|
if asyncio.iscoroutinefunction(wildcard_handler):
|
|
await wildcard_handler(parsed_message)
|
|
else:
|
|
wildcard_handler(parsed_message)
|
|
logger.debug(f"WebSocket客户端 {self.name} 通配符处理器调用完成")
|
|
else:
|
|
logger.debug(f"WebSocket客户端 {self.name} 收到未处理的消息类型: {message_type}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket客户端 {self.name} 处理消息时出错: {e}")
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""获取客户端统计信息"""
|
|
return {
|
|
"name": self.name,
|
|
"url": self.url,
|
|
"state": self.state.value,
|
|
"is_connected": self.is_connected,
|
|
"reconnect_attempts": self._reconnect_attempts,
|
|
"max_reconnect_attempts": self._max_reconnect_attempts,
|
|
"reconnect_delay": self._reconnect_delay,
|
|
"message_handler_count": len(self._message_handlers),
|
|
"receive_task_running": self._receive_task and not self._receive_task.done(),
|
|
"created_at": self._created_at,
|
|
"last_message_at": self._last_message_at
|
|
}
|
|
|
|
# 移除不必要的全局变量声明
|