You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

272 lines
11 KiB

"""
WebSocket客户端模块
只负责WebSocket连接和数据收发,不管理Channel
遵循单一职责原则
"""
import asyncio
import json
import uuid
from typing import Any, Optional, Dict, Callable
from enum import Enum
from datetime import datetime
import websockets
from app.utils.structured_log import get_structured_logger, LogLevel
logger = get_structured_logger(__name__, LogLevel.INFO)
class WebSocketClientState(Enum):
"""WebSocket客户端状态枚举"""
DISCONNECTED = "disconnected"
CONNECTING = "connecting"
CONNECTED = "connected"
RECONNECTING = "reconnecting"
ERROR = "error"
class WebSocketClient:
"""WebSocket客户端 - 只负责连接和数据收发
单一职责:
- 建立和维护WebSocket连接
- 发送原始数据到WebSocket服务器
- 接收原始数据从WebSocket服务器
- 管理连接状态和重连逻辑
不负责:
- Channel管理
- 心跳管理
- 数据路由
- 业务逻辑处理
"""
def __init__(self, url: str, name: str = "default"):
self.url = url
self.name = name
self._websocket: Optional[websockets.WebSocketServerProtocol] = None
self._state = WebSocketClientState.DISCONNECTED
self._reconnect_attempts = 0
self._max_reconnect_attempts = 5
self._reconnect_delay = 1.0
self._message_handlers: Dict[str, Callable] = {}
self._receive_task: Optional[asyncio.Task] = None
self._created_at = datetime.now()
self._last_message_at: Optional[datetime] = None
self.heartbeat_interval: Optional[int] = None # 心跳间隔配置
logger.info(f"创建WebSocket客户端: {name} -> {url}")
@property
def state(self) -> WebSocketClientState:
"""获取客户端状态"""
return self._state
@property
def is_connected(self) -> bool:
"""检查是否已连接"""
return self._state == WebSocketClientState.CONNECTED
async def connect(self) -> bool:
"""连接到WebSocket服务器"""
try:
if self.is_connected:
logger.warning(f"WebSocket客户端 {self.name} 已经连接")
return True
self._state = WebSocketClientState.CONNECTING
logger.info(f"WebSocket客户端 {self.name} 正在连接... URL: {self.url}")
# 验证URL格式
if not self.url.startswith(('ws://', 'wss://')):
raise ValueError(f"无效的WebSocket URL: {self.url},必须以ws://或wss://开头")
# 建立WebSocket连接
ssl_context = None
if self.url.startswith('wss://'):
from app.core.config.settings import config
import ssl
ssl_context = ssl.create_default_context()
if not config.websocket.ssl_verify_hostname:
ssl_context.check_hostname = False
if not config.websocket.ssl_verify_certificate:
ssl_context.verify_mode = ssl.CERT_NONE
# 添加连接超时配置
from app.core.config.settings import config
connection_timeout = config.websocket.connection_timeout
logger.info(f"WebSocket客户端 {self.name} 开始连接,超时时间: {connection_timeout}")
# 使用websockets.connect的超时参数
self._websocket = await asyncio.wait_for(
websockets.connect(
self.url,
ssl=ssl_context,
ping_interval=None, # 禁用自动ping,由外部管理心跳
ping_timeout=None, # 禁用自动ping超时
close_timeout=10 # 关闭超时
),
timeout=connection_timeout
)
self._state = WebSocketClientState.CONNECTED
self._reconnect_attempts = 0
# 启动接收任务
self._receive_task = asyncio.create_task(self._receive_messages())
logger.info(f"WebSocket客户端 {self.name} 连接成功")
return True
except asyncio.TimeoutError:
self._state = WebSocketClientState.ERROR
logger.error(f"WebSocket客户端 {self.name} 连接超时: {connection_timeout}秒,URL: {self.url}")
return False
except ValueError as e:
self._state = WebSocketClientState.ERROR
logger.error(f"WebSocket客户端 {self.name} URL格式错误: {e}")
return False
except Exception as e:
self._state = WebSocketClientState.ERROR
logger.error(f"WebSocket客户端 {self.name} 连接失败: {e},URL: {self.url}")
return False
async def disconnect(self):
"""断开WebSocket连接"""
try:
self._state = WebSocketClientState.DISCONNECTED
# 停止接收任务
if self._receive_task:
self._receive_task.cancel()
# 关闭WebSocket连接
if self._websocket:
await self._websocket.close()
self._websocket = None
# 清理消息处理器
self._message_handlers.clear()
logger.info(f"WebSocket客户端 {self.name} 已断开")
except Exception as e:
logger.error(f"WebSocket客户端 {self.name} 断开连接时出错: {e}")
async def send_raw(self, payload: Any) -> bool:
"""发送原始数据到WebSocket服务器
单一职责:只负责发送数据,不处理业务逻辑
"""
try:
if not self.is_connected:
logger.warning(f"WebSocket客户端 {self.name} 未连接,无法发送消息")
return False
# 直接发送预组装的数据
if isinstance(payload, (bytes, bytearray)):
await self._websocket.send(payload)
else:
# 其他类型(如str),按原样发送
await self._websocket.send(payload)
self._last_message_at = datetime.now()
logger.debug(f"WebSocket客户端 {self.name} 发送数据成功")
return True
except Exception as e:
logger.error(f"WebSocket客户端 {self.name} 发送数据失败: {e}")
return False
def register_message_handler(self, message_type: str, handler: Callable):
"""注册消息处理器 - 用于处理接收到的消息"""
self._message_handlers[message_type] = handler
logger.info(f"WebSocket客户端 {self.name} 注册消息处理器: {message_type}")
def unregister_message_handler(self, message_type: str):
"""取消注册消息处理器"""
if message_type in self._message_handlers:
del self._message_handlers[message_type]
logger.info(f"WebSocket客户端 {self.name} 取消注册消息处理器: {message_type}")
async def _receive_messages(self):
"""接收消息循环 - 从WebSocket服务器接收数据
单一职责:只负责接收数据并调用处理器,不处理业务逻辑
"""
try:
while self.is_connected:
try:
message = await self._websocket.recv()
self._last_message_at = datetime.now()
# 调用消息处理器
await self._handle_message(message)
except websockets.exceptions.ConnectionClosed:
logger.warning(f"WebSocket客户端 {self.name} 连接已关闭")
break
except Exception as e:
logger.error(f"WebSocket客户端 {self.name} 接收消息时出错: {e}")
break
except asyncio.CancelledError:
logger.info(f"WebSocket客户端 {self.name} 接收任务已取消")
except Exception as e:
logger.error(f"WebSocket客户端 {self.name} 接收任务异常: {e}")
async def _handle_message(self, message: Any):
"""处理接收到的消息
单一职责:只负责调用注册的处理器,不处理业务逻辑
"""
try:
# 尝试解析JSON消息
if isinstance(message, str):
try:
parsed_message = json.loads(message)
message_type = parsed_message.get("type", "data")
except json.JSONDecodeError:
message_type = "raw"
parsed_message = {"type": "raw", "data": message}
else:
message_type = "raw"
parsed_message = {"type": "raw", "data": message}
# 调用对应的处理器
handler = self._message_handlers.get(message_type)
if handler:
if asyncio.iscoroutinefunction(handler):
await handler(parsed_message)
else:
handler(parsed_message)
else:
# 调用通配符处理器
wildcard_handler = self._message_handlers.get("*")
if wildcard_handler:
if asyncio.iscoroutinefunction(wildcard_handler):
await wildcard_handler(parsed_message)
else:
wildcard_handler(parsed_message)
else:
logger.debug(f"WebSocket客户端 {self.name} 收到未处理的消息类型: {message_type}")
except Exception as e:
logger.error(f"WebSocket客户端 {self.name} 处理消息时出错: {e}")
def get_stats(self) -> Dict[str, Any]:
"""获取客户端统计信息"""
return {
"name": self.name,
"url": self.url,
"state": self.state.value,
"is_connected": self.is_connected,
"reconnect_attempts": self._reconnect_attempts,
"max_reconnect_attempts": self._max_reconnect_attempts,
"reconnect_delay": self._reconnect_delay,
"message_handler_count": len(self._message_handlers),
"receive_task_running": self._receive_task and not self._receive_task.done(),
"created_at": self._created_at,
"last_message_at": self._last_message_at
}
# 移除不必要的全局变量声明