You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
267 lines
11 KiB
267 lines
11 KiB
"""
|
|
WebSocket客户端模块
|
|
只负责WebSocket连接和数据收发,不管理Channel
|
|
"""
|
|
import asyncio
|
|
import json
|
|
import uuid
|
|
from typing import Any, Optional, Dict, Callable
|
|
from enum import Enum
|
|
from datetime import datetime
|
|
import websockets
|
|
from app.utils.log import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
class WebSocketClientState(Enum):
|
|
"""WebSocket客户端状态枚举"""
|
|
DISCONNECTED = "disconnected"
|
|
CONNECTING = "connecting"
|
|
CONNECTED = "connected"
|
|
RECONNECTING = "reconnecting"
|
|
ERROR = "error"
|
|
|
|
class WebSocketClient:
|
|
"""WebSocket客户端 - 只负责连接和数据收发"""
|
|
|
|
def __init__(self, url: str, name: str = "default"):
|
|
self.url = url
|
|
self.name = name
|
|
self._websocket: Optional[websockets.WebSocketServerProtocol] = None
|
|
self._state = WebSocketClientState.DISCONNECTED
|
|
self._reconnect_attempts = 0
|
|
self._max_reconnect_attempts = 5
|
|
self._reconnect_delay = 1.0
|
|
self._message_handlers: Dict[str, Callable] = {}
|
|
self._connection_task: Optional[asyncio.Task] = None
|
|
self._receive_task: Optional[asyncio.Task] = None
|
|
self._heartbeat_task: Optional[asyncio.Task] = None
|
|
self._created_at = datetime.now()
|
|
self._last_heartbeat = None
|
|
|
|
logger.info(f"创建WebSocket客户端: {name} -> {url}")
|
|
|
|
@property
|
|
def state(self) -> WebSocketClientState:
|
|
"""获取客户端状态"""
|
|
return self._state
|
|
|
|
@property
|
|
def is_connected(self) -> bool:
|
|
"""检查是否已连接"""
|
|
return self._state == WebSocketClientState.CONNECTED
|
|
|
|
async def connect(self) -> bool:
|
|
"""连接到WebSocket服务器"""
|
|
try:
|
|
if self.is_connected:
|
|
logger.warning(f"WebSocket客户端 {self.name} 已经连接")
|
|
return True
|
|
|
|
self._state = WebSocketClientState.CONNECTING
|
|
logger.info(f"WebSocket客户端 {self.name} 正在连接... URL: {self.url}")
|
|
|
|
# 验证URL格式
|
|
if not self.url.startswith(('ws://', 'wss://')):
|
|
raise ValueError(f"无效的WebSocket URL: {self.url},必须以ws://或wss://开头")
|
|
|
|
# 建立WebSocket连接
|
|
# 根据配置决定是否跳过SSL证书验证
|
|
ssl_context = None
|
|
if self.url.startswith('wss://'):
|
|
from app.core.config.settings import config
|
|
import ssl
|
|
ssl_context = ssl.create_default_context()
|
|
|
|
# 根据配置决定是否验证证书和主机名
|
|
# 先设置check_hostname,再设置verify_mode
|
|
if not config.websocket.ssl_verify_hostname:
|
|
ssl_context.check_hostname = False
|
|
|
|
if not config.websocket.ssl_verify_certificate:
|
|
ssl_context.verify_mode = ssl.CERT_NONE
|
|
|
|
# 添加连接超时配置
|
|
from app.core.config.settings import config
|
|
connection_timeout = config.websocket.connection_timeout
|
|
|
|
logger.info(f"WebSocket客户端 {self.name} 开始连接,超时时间: {connection_timeout}秒")
|
|
|
|
# 使用websockets.connect的超时参数
|
|
self._websocket = await asyncio.wait_for(
|
|
websockets.connect(
|
|
self.url,
|
|
ssl=ssl_context,
|
|
ping_interval=None, # 禁用自动ping,由适配器管理心跳
|
|
ping_timeout=None, # 禁用自动ping超时
|
|
close_timeout=10 # 关闭超时
|
|
),
|
|
timeout=connection_timeout
|
|
)
|
|
self._state = WebSocketClientState.CONNECTED
|
|
self._reconnect_attempts = 0
|
|
|
|
# 启动接收任务
|
|
self._receive_task = asyncio.create_task(self._receive_messages())
|
|
|
|
# 心跳改由Adapter以优先级Channel方式发送,客户端不再直接发送心跳
|
|
|
|
logger.info(f"WebSocket客户端 {self.name} 连接成功")
|
|
return True
|
|
|
|
except asyncio.TimeoutError:
|
|
self._state = WebSocketClientState.ERROR
|
|
logger.error(f"WebSocket客户端 {self.name} 连接超时: {connection_timeout}秒,URL: {self.url}")
|
|
return False
|
|
except ValueError as e:
|
|
self._state = WebSocketClientState.ERROR
|
|
logger.error(f"WebSocket客户端 {self.name} URL格式错误: {e}")
|
|
return False
|
|
except Exception as e:
|
|
self._state = WebSocketClientState.ERROR
|
|
logger.error(f"WebSocket客户端 {self.name} 连接失败: {e},URL: {self.url}")
|
|
return False
|
|
|
|
async def disconnect(self):
|
|
"""断开WebSocket连接"""
|
|
try:
|
|
self._state = WebSocketClientState.DISCONNECTED
|
|
|
|
# 停止任务
|
|
if self._receive_task:
|
|
self._receive_task.cancel()
|
|
if self._heartbeat_task:
|
|
self._heartbeat_task.cancel()
|
|
|
|
# 关闭WebSocket连接
|
|
if self._websocket:
|
|
await self._websocket.close()
|
|
self._websocket = None
|
|
|
|
logger.info(f"WebSocket客户端 {self.name} 已断开")
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket客户端 {self.name} 断开连接时出错: {e}")
|
|
|
|
async def _send_raw(self, payload: Any) -> bool:
|
|
"""发送预组装载荷到WebSocket服务器(私有,供Adapter调用)"""
|
|
try:
|
|
if not self.is_connected:
|
|
logger.warning(f"WebSocket客户端 {self.name} 未连接,无法发送消息")
|
|
return False
|
|
|
|
# 直接发送预组装的字符串或二进制
|
|
if isinstance(payload, (bytes, bytearray)):
|
|
await self._websocket.send(payload)
|
|
else:
|
|
# 其他类型(如str),按原样发送
|
|
await self._websocket.send(payload)
|
|
logger.debug(f"WebSocket客户端 {self.name} 发送载荷成功")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket客户端 {self.name} 发送消息失败: {e}")
|
|
return False
|
|
|
|
def register_message_handler(self, message_type: str, handler: Callable):
|
|
"""注册消息处理器 - 用于处理接收到的消息"""
|
|
self._message_handlers[message_type] = handler
|
|
logger.info(f"WebSocket客户端 {self.name} 注册消息处理器: {message_type}")
|
|
|
|
def unregister_message_handler(self, message_type: str):
|
|
"""取消注册消息处理器"""
|
|
if message_type in self._message_handlers:
|
|
del self._message_handlers[message_type]
|
|
logger.info(f"WebSocket客户端 {self.name} 取消注册消息处理器: {message_type}")
|
|
|
|
async def _receive_messages(self):
|
|
"""接收消息循环 - 从WebSocket服务器接收数据"""
|
|
try:
|
|
while self.is_connected:
|
|
try:
|
|
message_data = await self._websocket.recv()
|
|
await self._handle_received_message(message_data)
|
|
except Exception as e:
|
|
logger.error(f"WebSocket客户端 {self.name} 接收消息时出错: {e}")
|
|
break
|
|
except asyncio.CancelledError:
|
|
logger.info(f"WebSocket客户端 {self.name} 接收任务已取消")
|
|
except Exception as e:
|
|
logger.error(f"WebSocket客户端 {self.name} 接收任务异常: {e}")
|
|
finally:
|
|
# 尝试重连
|
|
if self._state != WebSocketClientState.DISCONNECTED:
|
|
await self._try_reconnect()
|
|
|
|
async def _handle_received_message(self, message_data: str):
|
|
"""处理接收到的消息 - 调用注册的处理器"""
|
|
try:
|
|
message = json.loads(message_data)
|
|
message_type = message.get("type")
|
|
|
|
# 调用消息处理器
|
|
handler = None
|
|
if message_type in self._message_handlers:
|
|
handler = self._message_handlers[message_type]
|
|
elif "*" in self._message_handlers:
|
|
handler = self._message_handlers["*"]
|
|
|
|
if handler:
|
|
if asyncio.iscoroutinefunction(handler):
|
|
await handler(message)
|
|
else:
|
|
handler(message)
|
|
|
|
logger.debug(f"WebSocket客户端 {self.name} 处理消息: {message_type}")
|
|
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"WebSocket客户端 {self.name} 解析消息失败: {e}")
|
|
except Exception as e:
|
|
logger.error(f"WebSocket客户端 {self.name} 处理消息失败: {e}")
|
|
|
|
async def _heartbeat_loop(self):
|
|
"""心跳循环"""
|
|
try:
|
|
while self.is_connected:
|
|
try:
|
|
await self._send_message("heartbeat", {"timestamp": datetime.now().isoformat()})
|
|
self._last_heartbeat = datetime.now()
|
|
await asyncio.sleep(30) # 30秒发送一次心跳
|
|
except Exception as e:
|
|
logger.error(f"WebSocket客户端 {self.name} 心跳发送失败: {e}")
|
|
break
|
|
except asyncio.CancelledError:
|
|
logger.info(f"WebSocket客户端 {self.name} 心跳任务已取消")
|
|
except Exception as e:
|
|
logger.error(f"WebSocket客户端 {self.name} 心跳任务异常: {e}")
|
|
|
|
async def _try_reconnect(self):
|
|
"""尝试重连"""
|
|
if self._reconnect_attempts >= self._max_reconnect_attempts:
|
|
logger.error(f"WebSocket客户端 {self.name} 重连次数已达上限")
|
|
return
|
|
|
|
self._state = WebSocketClientState.RECONNECTING
|
|
self._reconnect_attempts += 1
|
|
|
|
logger.info(f"WebSocket客户端 {self.name} 尝试重连 ({self._reconnect_attempts}/{self._max_reconnect_attempts})")
|
|
|
|
await asyncio.sleep(self._reconnect_delay * self._reconnect_attempts)
|
|
|
|
if await self.connect():
|
|
logger.info(f"WebSocket客户端 {self.name} 重连成功")
|
|
else:
|
|
await self._try_reconnect()
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""获取客户端统计信息"""
|
|
return {
|
|
"name": self.name,
|
|
"url": self.url,
|
|
"state": self.state.value,
|
|
"is_connected": self.is_connected,
|
|
"reconnect_attempts": self._reconnect_attempts,
|
|
"handler_count": len(self._message_handlers),
|
|
"created_at": self._created_at,
|
|
"last_heartbeat": self._last_heartbeat
|
|
}
|