You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

257 lines
10 KiB

"""
WebSocket客户端模块
只负责WebSocket连接和数据收发,不管理Channel
"""
import asyncio
import json
import uuid
from typing import Any, Optional, Dict, Callable
from enum import Enum
from datetime import datetime
import websockets
from app.utils.log import get_logger
logger = get_logger(__name__)
class WebSocketClientState(Enum):
"""WebSocket客户端状态枚举"""
DISCONNECTED = "disconnected"
CONNECTING = "connecting"
CONNECTED = "connected"
RECONNECTING = "reconnecting"
ERROR = "error"
class WebSocketClient:
"""WebSocket客户端 - 只负责连接和数据收发"""
def __init__(self, url: str, name: str = "default"):
self.url = url
self.name = name
self._websocket: Optional[websockets.WebSocketServerProtocol] = None
self._state = WebSocketClientState.DISCONNECTED
self._reconnect_attempts = 0
self._max_reconnect_attempts = 5
self._reconnect_delay = 1.0
self._message_handlers: Dict[str, Callable] = {}
self._connection_task: Optional[asyncio.Task] = None
self._receive_task: Optional[asyncio.Task] = None
self._heartbeat_task: Optional[asyncio.Task] = None
self._created_at = datetime.now()
self._last_heartbeat = None
logger.info(f"创建WebSocket客户端: {name} -> {url}")
@property
def state(self) -> WebSocketClientState:
"""获取客户端状态"""
return self._state
@property
def is_connected(self) -> bool:
"""检查是否已连接"""
return self._state == WebSocketClientState.CONNECTED
async def connect(self) -> bool:
"""连接到WebSocket服务器"""
try:
if self.is_connected:
logger.warning(f"WebSocket客户端 {self.name} 已经连接")
return True
self._state = WebSocketClientState.CONNECTING
logger.info(f"WebSocket客户端 {self.name} 正在连接...")
# 建立WebSocket连接
# 根据配置决定是否跳过SSL证书验证
ssl_context = None
if self.url.startswith('wss://'):
from app.core.config.settings import config
import ssl
ssl_context = ssl.create_default_context()
# 根据配置决定是否验证证书和主机名
# 先设置check_hostname,再设置verify_mode
if not config.websocket.ssl_verify_hostname:
ssl_context.check_hostname = False
if not config.websocket.ssl_verify_certificate:
ssl_context.verify_mode = ssl.CERT_NONE
# 添加连接超时配置
from app.core.config.settings import config
connection_timeout = config.websocket.connection_timeout
# 使用websockets.connect的超时参数
self._websocket = await asyncio.wait_for(
websockets.connect(
self.url,
ssl=ssl_context,
ping_interval=None, # 禁用自动ping,由适配器管理心跳
ping_timeout=None, # 禁用自动ping超时
close_timeout=10 # 关闭超时
),
timeout=connection_timeout
)
self._state = WebSocketClientState.CONNECTED
self._reconnect_attempts = 0
# 启动接收任务
self._receive_task = asyncio.create_task(self._receive_messages())
# 心跳改由Adapter以优先级Channel方式发送,客户端不再直接发送心跳
logger.info(f"WebSocket客户端 {self.name} 连接成功")
return True
except asyncio.TimeoutError:
self._state = WebSocketClientState.ERROR
logger.error(f"WebSocket客户端 {self.name} 连接超时: {connection_timeout}")
return False
except Exception as e:
self._state = WebSocketClientState.ERROR
logger.error(f"WebSocket客户端 {self.name} 连接失败: {e}")
return False
async def disconnect(self):
"""断开WebSocket连接"""
try:
self._state = WebSocketClientState.DISCONNECTED
# 停止任务
if self._receive_task:
self._receive_task.cancel()
if self._heartbeat_task:
self._heartbeat_task.cancel()
# 关闭WebSocket连接
if self._websocket:
await self._websocket.close()
self._websocket = None
logger.info(f"WebSocket客户端 {self.name} 已断开")
except Exception as e:
logger.error(f"WebSocket客户端 {self.name} 断开连接时出错: {e}")
async def _send_raw(self, payload: Any) -> bool:
"""发送预组装载荷到WebSocket服务器(私有,供Adapter调用)"""
try:
if not self.is_connected:
logger.warning(f"WebSocket客户端 {self.name} 未连接,无法发送消息")
return False
# 直接发送预组装的字符串或二进制
if isinstance(payload, (bytes, bytearray)):
await self._websocket.send(payload)
else:
# 其他类型(如str),按原样发送
await self._websocket.send(payload)
logger.debug(f"WebSocket客户端 {self.name} 发送载荷成功")
return True
except Exception as e:
logger.error(f"WebSocket客户端 {self.name} 发送消息失败: {e}")
return False
def register_message_handler(self, message_type: str, handler: Callable):
"""注册消息处理器 - 用于处理接收到的消息"""
self._message_handlers[message_type] = handler
logger.info(f"WebSocket客户端 {self.name} 注册消息处理器: {message_type}")
def unregister_message_handler(self, message_type: str):
"""取消注册消息处理器"""
if message_type in self._message_handlers:
del self._message_handlers[message_type]
logger.info(f"WebSocket客户端 {self.name} 取消注册消息处理器: {message_type}")
async def _receive_messages(self):
"""接收消息循环 - 从WebSocket服务器接收数据"""
try:
while self.is_connected:
try:
message_data = await self._websocket.recv()
await self._handle_received_message(message_data)
except Exception as e:
logger.error(f"WebSocket客户端 {self.name} 接收消息时出错: {e}")
break
except asyncio.CancelledError:
logger.info(f"WebSocket客户端 {self.name} 接收任务已取消")
except Exception as e:
logger.error(f"WebSocket客户端 {self.name} 接收任务异常: {e}")
finally:
# 尝试重连
if self._state != WebSocketClientState.DISCONNECTED:
await self._try_reconnect()
async def _handle_received_message(self, message_data: str):
"""处理接收到的消息 - 调用注册的处理器"""
try:
message = json.loads(message_data)
message_type = message.get("type")
# 调用消息处理器
handler = None
if message_type in self._message_handlers:
handler = self._message_handlers[message_type]
elif "*" in self._message_handlers:
handler = self._message_handlers["*"]
if handler:
if asyncio.iscoroutinefunction(handler):
await handler(message)
else:
handler(message)
logger.debug(f"WebSocket客户端 {self.name} 处理消息: {message_type}")
except json.JSONDecodeError as e:
logger.error(f"WebSocket客户端 {self.name} 解析消息失败: {e}")
except Exception as e:
logger.error(f"WebSocket客户端 {self.name} 处理消息失败: {e}")
async def _heartbeat_loop(self):
"""心跳循环"""
try:
while self.is_connected:
try:
await self._send_message("heartbeat", {"timestamp": datetime.now().isoformat()})
self._last_heartbeat = datetime.now()
await asyncio.sleep(30) # 30秒发送一次心跳
except Exception as e:
logger.error(f"WebSocket客户端 {self.name} 心跳发送失败: {e}")
break
except asyncio.CancelledError:
logger.info(f"WebSocket客户端 {self.name} 心跳任务已取消")
except Exception as e:
logger.error(f"WebSocket客户端 {self.name} 心跳任务异常: {e}")
async def _try_reconnect(self):
"""尝试重连"""
if self._reconnect_attempts >= self._max_reconnect_attempts:
logger.error(f"WebSocket客户端 {self.name} 重连次数已达上限")
return
self._state = WebSocketClientState.RECONNECTING
self._reconnect_attempts += 1
logger.info(f"WebSocket客户端 {self.name} 尝试重连 ({self._reconnect_attempts}/{self._max_reconnect_attempts})")
await asyncio.sleep(self._reconnect_delay * self._reconnect_attempts)
if await self.connect():
logger.info(f"WebSocket客户端 {self.name} 重连成功")
else:
await self._try_reconnect()
def get_stats(self) -> Dict[str, Any]:
"""获取客户端统计信息"""
return {
"name": self.name,
"url": self.url,
"state": self.state.value,
"is_connected": self.is_connected,
"reconnect_attempts": self._reconnect_attempts,
"handler_count": len(self._message_handlers),
"created_at": self._created_at,
"last_heartbeat": self._last_heartbeat
}