You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

669 lines
25 KiB

using System.Buffers;
using System.Net.WebSockets;
using System.Text;
using Microsoft.Extensions.Logging;
using CoreAgent.WebSocketTransport.Interfaces;
using CoreAgent.WebSocketTransport.Models;
using CoreAgent.WebSocketTransport.Middleware;
using System.Text.Json.Serialization;
using System.Text.Json;
namespace CoreAgent.WebSocketTransport.Services;
/// <summary>
/// WebSocket 传输实现
/// 单一职责:连接管理和自动数据流转
/// </summary>
public class WebSocketTransport : IWebSocketTransport
{
private readonly ILogger<WebSocketTransport> _logger;
private readonly IWebSocketConnection _connection;
private readonly IMessageSerializer _serializer;
private readonly IEnumerable<IMessageMiddleware> _middlewares;
private readonly WebSocketConfig _config;
private readonly IMessageChannelManager _channelManager;
private readonly CancellationTokenSource _cancellationTokenSource;
private readonly SemaphoreSlim _connectionSemaphore;
// 连接状态管理
private volatile bool _isConnected;
private DateTime? _lastHeartbeat;
private readonly object _heartbeatLock = new object();
private Task? _sendTask;
private Task? _receiveTask;
private Task? _heartbeatTask;
private Task? _reconnectTask;
private int _reconnectAttempts;
private readonly object _reconnectLock = new object();
public bool IsConnected => _isConnected;
public DateTime? LastHeartbeat
{
get
{
lock (_heartbeatLock)
{
return _lastHeartbeat;
}
}
}
public WebSocketTransport(
ILogger<WebSocketTransport> logger,
IWebSocketConnection connection,
IMessageSerializer serializer,
IEnumerable<IMessageMiddleware> middlewares,
WebSocketConfig config,
IMessageChannelManager channelManager)
{
_logger = logger;
_connection = connection;
_serializer = serializer;
_middlewares = middlewares;
_config = config;
_channelManager = channelManager;
_cancellationTokenSource = new CancellationTokenSource();
_connectionSemaphore = new SemaphoreSlim(1, 1);
}
/// <summary>
/// 异步连接 WebSocket 服务器
/// </summary>
public async Task ConnectAsync(CancellationToken cancellationToken = default)
{
_logger.LogInformation("开始连接 WebSocket 服务器: {Url}", _config.Url);
await _connectionSemaphore.WaitAsync(cancellationToken);
try
{
if (_isConnected)
{
_logger.LogInformation("WebSocket 已连接,跳过重复连接");
return;
}
await ConnectInternalAsync(cancellationToken);
}
catch (Exception ex)
{
_logger.LogError(ex, "连接 WebSocket 服务器失败: {Url}", _config.Url);
throw;
}
finally
{
_connectionSemaphore.Release();
_logger.LogDebug("已释放连接信号量");
}
}
/// <summary>
/// 异步关闭连接
/// </summary>
public async Task CloseAsync(CancellationToken cancellationToken = default)
{
_logger.LogInformation("开始关闭 WebSocket 连接");
await _connectionSemaphore.WaitAsync(cancellationToken);
try
{
if (!_isConnected)
{
_logger.LogInformation("WebSocket 未连接,无需关闭");
return;
}
_logger.LogInformation("正在关闭 WebSocket 连接: {Url}", _config.Url);
// 停止后台任务
_cancellationTokenSource.Cancel();
_logger.LogDebug("已取消所有操作令牌");
// 等待任务完成
var closeTasks = new List<Task>();
if (_sendTask != null && !_sendTask.IsCompleted)
{
closeTasks.Add(_sendTask);
_logger.LogDebug("等待发送任务完成");
}
if (_receiveTask != null && !_receiveTask.IsCompleted)
{
closeTasks.Add(_receiveTask);
_logger.LogDebug("等待接收任务完成");
}
if (_heartbeatTask != null && !_heartbeatTask.IsCompleted)
{
closeTasks.Add(_heartbeatTask);
_logger.LogDebug("等待心跳任务完成");
}
if (_reconnectTask != null && !_reconnectTask.IsCompleted)
{
closeTasks.Add(_reconnectTask);
_logger.LogDebug("等待重连任务完成");
}
if (closeTasks.Count > 0)
{
_logger.LogInformation("等待 {TaskCount} 个后台任务完成", closeTasks.Count);
await Task.WhenAll(closeTasks);
_logger.LogInformation("所有后台任务已完成");
}
// 关闭连接
try
{
await _connection.CloseAsync(WebSocketCloseStatus.NormalClosure, "正常关闭", cancellationToken);
_logger.LogDebug("WebSocket 连接已正常关闭");
}
catch (Exception ex)
{
_logger.LogWarning(ex, "正常关闭连接失败,强制关闭连接");
_connection.ForceClose();
}
_isConnected = false;
_logger.LogInformation("WebSocket 连接关闭完成");
}
catch (Exception ex)
{
_logger.LogError(ex, "关闭 WebSocket 连接时发生异常");
throw;
}
finally
{
_connectionSemaphore.Release();
_logger.LogDebug("已释放连接信号量");
}
}
/// <summary>
/// 内部连接方法
/// </summary>
private async Task ConnectInternalAsync(CancellationToken cancellationToken)
{
_logger.LogInformation("正在连接 WebSocket 服务器: {Url}, 超时时间: {TimeoutMs}ms", _config.Url, _config.TimeoutMs);
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
timeoutCts.CancelAfter(_config.TimeoutMs);
try
{
await _connection.ConnectAsync(new Uri(_config.Url), timeoutCts.Token);
_logger.LogDebug("WebSocket 连接建立成功");
}
catch (Exception ex)
{
_logger.LogError(ex, "WebSocket 连接建立失败: {Url}", _config.Url);
// 连接失败时,确保状态正确
_isConnected = false;
_connection.ForceClose();
throw;
}
_isConnected = true;
_reconnectAttempts = 0;
UpdateHeartbeat();
_logger.LogDebug("连接状态已更新,重连次数重置为 0");
// 启动后台任务
_logger.LogDebug("启动后台任务");
_sendTask = Task.Run(() => SendLoopAsync(_cancellationTokenSource.Token));
_receiveTask = Task.Run(() => ReceiveLoopAsync(_cancellationTokenSource.Token));
_heartbeatTask = Task.Run(() => HeartbeatLoopAsync(_cancellationTokenSource.Token));
_logger.LogDebug("后台任务启动完成: 发送={SendTaskId}, 接收={ReceiveTaskId}, 心跳={HeartbeatTaskId}",
_sendTask?.Id, _receiveTask?.Id, _heartbeatTask?.Id);
_logger.LogInformation("WebSocket 连接成功建立,所有后台任务已启动");
}
/// <summary>
/// 触发重连
/// </summary>
private void TriggerReconnect()
{
lock (_reconnectLock)
{
if (_reconnectTask != null && !_reconnectTask.IsCompleted)
{
_logger.LogDebug("重连任务已在运行,跳过重复触发");
return; // 重连任务已在运行
}
_logger.LogInformation("启动重连任务");
_reconnectTask = Task.Run(() => ReconnectLoopAsync(_cancellationTokenSource.Token));
_logger.LogDebug("重连任务已启动: {TaskId}", _reconnectTask.Id);
}
}
/// <summary>
/// 重连循环
/// </summary>
private async Task ReconnectLoopAsync(CancellationToken cancellationToken)
{
_logger.LogInformation("重连循环开始,当前重连次数: {Attempts}", _reconnectAttempts);
_isConnected = false;
while (_reconnectAttempts < _config.MaxReconnectAttempts && !cancellationToken.IsCancellationRequested)
{
_reconnectAttempts++;
var delaySeconds = Math.Min(Math.Pow(2, _reconnectAttempts - 1), 30);
var delay = TimeSpan.FromSeconds(delaySeconds);
_logger.LogWarning("WebSocket 连接断开,{DelaySeconds}秒后进行第{Attempt}次重连",
delaySeconds, _reconnectAttempts);
await Task.Delay(delay, cancellationToken);
try
{
_logger.LogInformation("开始第{Attempt}次重连尝试", _reconnectAttempts);
await ConnectInternalAsync(cancellationToken);
_logger.LogInformation("WebSocket 重连成功,重连次数: {Attempts}", _reconnectAttempts);
return;
}
catch (Exception ex)
{
_logger.LogError(ex, "WebSocket 重连失败,尝试次数: {Attempt}", _reconnectAttempts);
// 重连失败时,确保连接状态正确
_isConnected = false;
_connection.ForceClose();
}
}
_logger.LogError("WebSocket 重连失败,已达到最大尝试次数: {MaxAttempts}", _config.MaxReconnectAttempts);
_isConnected = false;
}
/// <summary>
/// 心跳循环
/// </summary>
private async Task HeartbeatLoopAsync(CancellationToken cancellationToken)
{
try
{
while (!cancellationToken.IsCancellationRequested)
{
if (_isConnected && _connection.IsConnected)
{
_channelManager.PriorityChannel.TryWrite(new HeartbeatMessage());
UpdateHeartbeat();
}
await Task.Delay(TimeSpan.FromSeconds(120), cancellationToken);
}
}
catch (Exception ex)
{
_logger.LogError(ex, "心跳循环异常");
}
}
/// <summary>
/// 更新心跳时间
/// </summary>
private void UpdateHeartbeat()
{
lock (_heartbeatLock)
{
_lastHeartbeat = DateTime.UtcNow;
}
}
/// <summary>
/// 发送循环 - 自动从发送通道读取数据并发送
/// </summary>
private async Task SendLoopAsync(CancellationToken cancellationToken)
{
_logger.LogDebug("发送循环开始运行");
var messageCount = 0;
var priorityMessageCount = 0;
try
{
while (!cancellationToken.IsCancellationRequested && !_channelManager.SendChannel.IsCompleted)
{
// 优先处理优先级消息
while (_channelManager.PriorityChannel.TryRead(out var priorityMessage))
{
priorityMessageCount++;
if (priorityMessage is null) continue;
_logger.LogTrace("处理优先级消息 #{PriorityCount}: {MessageType}",
priorityMessageCount, priorityMessage?.GetType().Name ?? "null");
await SendSingleMessageAsync(priorityMessage, cancellationToken);
}
// 处理普通消息
if (_channelManager.SendChannel.TryRead(out var message))
{
messageCount++;
_logger.LogTrace("处理普通消息 #{Count}: {MessageType}",
messageCount, message?.GetType().Name ?? "null");
await SendSingleMessageAsync(message, cancellationToken);
}
else
{
// 如果没有消息,等待一小段时间
await Task.Delay(10, cancellationToken);
}
}
_logger.LogInformation("发送循环正常结束,共处理 {MessageCount} 条普通消息,{PriorityCount} 条优先级消息",
messageCount, priorityMessageCount);
}
catch (OperationCanceledException)
{
_logger.LogInformation("发送循环被取消,共处理 {MessageCount} 条普通消息,{PriorityCount} 条优先级消息",
messageCount, priorityMessageCount);
}
catch (Exception ex)
{
_logger.LogError(ex, "发送循环异常,共处理 {MessageCount} 条普通消息,{PriorityCount} 条优先级消息",
messageCount, priorityMessageCount);
TriggerReconnect();
}
}
/// <summary>
/// 发送单个消息
/// </summary>
private async Task SendSingleMessageAsync(object message, CancellationToken cancellationToken)
{
var stopwatch = System.Diagnostics.Stopwatch.StartNew();
var messageType = message?.GetType().Name ?? "null";
try
{
_logger.LogTrace("开始处理发送消息: {MessageType}", messageType);
// 通过中间件处理消息
var processedMessage = message;
var middlewareCount = 0;
foreach (var middleware in _middlewares)
{
middlewareCount++;
var middlewareName = middleware.GetType().Name;
_logger.LogTrace("通过中间件 {MiddlewareName} 处理消息", middlewareName);
processedMessage = await middleware.ProcessSendAsync(processedMessage, cancellationToken);
if (processedMessage == null)
{
_logger.LogDebug("消息被中间件 {MiddlewareName} 跳过", middlewareName);
return;
}
}
_logger.LogTrace("消息通过 {MiddlewareCount} 个中间件处理完成", middlewareCount);
if (processedMessage is byte[] binaryData)
{
// 处理二进制消息
_logger.LogTrace("发送二进制消息,大小: {Size} bytes", binaryData.Length);
await _connection.SendAsync(new ArraySegment<byte>(binaryData), WebSocketMessageType.Binary, true, cancellationToken);
}
else
{
// 处理文本消息
var data = _serializer.Serialize(processedMessage);
_logger.LogTrace("发送文本消息,大小: {Size} bytes", data.Length);
await _connection.SendAsync(new ArraySegment<byte>(data), WebSocketMessageType.Text, true, cancellationToken);
}
stopwatch.Stop();
_logger.LogDebug("消息发送成功: {MessageType}, 耗时: {ElapsedMs}ms", messageType, stopwatch.ElapsedMilliseconds);
}
catch (Exception ex)
{
stopwatch.Stop();
_logger.LogError(ex, "发送消息失败: {MessageType}, 耗时: {ElapsedMs}ms", messageType, stopwatch.ElapsedMilliseconds);
throw;
}
}
/// <summary>
/// 接收循环 - 自动接收数据并推送到接收通道
/// </summary>
private async Task ReceiveLoopAsync(CancellationToken cancellationToken)
{
_logger.LogDebug("接收循环开始运行");
var buffer = ArrayPool<byte>.Shared.Rent(4096);
var messageBuilder = new StringBuilder();
var binaryStream = new MemoryStream();
var messageCount = 0;
var binaryMessageCount = 0;
try
{
while (!cancellationToken.IsCancellationRequested &&
(_connection.State == WebSocketState.Open || _connection.State == WebSocketState.CloseReceived))
{
WebSocketReceiveResult result;
try
{
result = await _connection.ReceiveAsync(new ArraySegment<byte>(buffer), cancellationToken);
}
catch (InvalidOperationException ex) when (ex.Message.Contains("WebSocket 未连接"))
{
_logger.LogWarning("WebSocket 连接状态异常,准备触发重连: {State}", _connection.State);
TriggerReconnect();
break;
}
_logger.LogTrace("接收到 WebSocket 消息: 类型={MessageType}, 大小={Count}, 结束={EndOfMessage}",
result.MessageType, result.Count, result.EndOfMessage);
if (result.MessageType == WebSocketMessageType.Text)
{
var text = Encoding.UTF8.GetString(buffer, 0, result.Count);
messageBuilder.Append(text);
if (result.EndOfMessage)
{
var message = messageBuilder.ToString();
messageCount++;
_logger.LogTrace("接收到完整文本消息 #{Count}: 长度={Length}", messageCount, message.Length);
await ProcessReceivedMessageAsync(message, cancellationToken);
messageBuilder.Clear();
}
}
else if (result.MessageType == WebSocketMessageType.Binary)
{
// 处理二进制消息
binaryStream.Write(buffer, 0, result.Count);
if (result.EndOfMessage)
{
var binaryData = binaryStream.ToArray();
binaryMessageCount++;
_logger.LogTrace("接收到完整二进制消息 #{Count}: 大小={Size} bytes", binaryMessageCount, binaryData.Length);
await ProcessReceivedMessageAsync(binaryData, cancellationToken);
binaryStream.SetLength(0);
}
}
else if (result.MessageType == WebSocketMessageType.Close)
{
_logger.LogInformation("收到 WebSocket 关闭消息,准备处理连接关闭");
// 收到关闭消息时,需要强制关闭连接并重新创建WebSocket实例
// 这样可以确保后续能够重新连接
_connection.ForceClose();
_isConnected = false;
_logger.LogInformation("WebSocket 连接已强制关闭,准备触发重连");
// 确保在触发重连之前,当前接收循环能够正常退出
// 重连任务会在后台启动,不会阻塞当前循环的退出
TriggerReconnect();
break;
}
}
_logger.LogInformation("接收循环正常结束,共接收 {TextCount} 条文本消息,{BinaryCount} 条二进制消息",
messageCount, binaryMessageCount);
}
catch (OperationCanceledException)
{
_logger.LogInformation("接收循环被取消,共接收 {TextCount} 条文本消息,{BinaryCount} 条二进制消息",
messageCount, binaryMessageCount);
}
catch (Exception ex)
{
_logger.LogError(ex, "接收循环异常,共接收 {TextCount} 条文本消息,{BinaryCount} 条二进制消息",
messageCount, binaryMessageCount);
TriggerReconnect();
}
finally
{
ArrayPool<byte>.Shared.Return(buffer);
binaryStream.Dispose();
_logger.LogDebug("接收循环资源已清理");
}
}
/// <summary>
/// 处理接收到的消息
/// </summary>
private async Task ProcessReceivedMessageAsync(object message, CancellationToken cancellationToken)
{
var stopwatch = System.Diagnostics.Stopwatch.StartNew();
var messageType = message?.GetType().Name ?? "null";
try
{
_logger.LogTrace("开始处理接收消息: {MessageType}", messageType);
// 通过中间件处理消息
var processedMessage = message;
var middlewareCount = 0;
foreach (var middleware in _middlewares)
{
middlewareCount++;
var middlewareName = middleware.GetType().Name;
_logger.LogTrace("通过中间件 {MiddlewareName} 处理接收消息", middlewareName);
processedMessage = await middleware.ProcessReceiveAsync(processedMessage, cancellationToken);
if (processedMessage == null)
{
_logger.LogDebug("接收消息被中间件 {MiddlewareName} 跳过", middlewareName);
return;
}
}
_logger.LogTrace("接收消息通过 {MiddlewareCount} 个中间件处理完成", middlewareCount);
// 推送到接收通道
if (processedMessage != null)
{
await _channelManager.ReceiveChannel.WriteAsync(processedMessage, cancellationToken);
_logger.LogTrace("接收消息已推送到接收通道");
}
else
{
_logger.LogTrace("接收消息为空,跳过推送到通道");
}
stopwatch.Stop();
_logger.LogDebug("接收消息处理完成: {MessageType}, 耗时: {ElapsedMs}ms", messageType, stopwatch.ElapsedMilliseconds);
}
catch (Exception ex)
{
stopwatch.Stop();
_logger.LogError(ex, "处理接收消息异常: {MessageType}, 耗时: {ElapsedMs}ms", messageType, stopwatch.ElapsedMilliseconds);
}
}
private bool _disposed = false;
/// <summary>
/// 释放资源
/// </summary>
public void Dispose()
{
if (_disposed) return;
_logger?.LogInformation("开始释放 WebSocket 传输资源");
try
{
// 取消所有操作
_cancellationTokenSource?.Cancel();
_logger?.LogDebug("已取消所有操作令牌");
// 等待信号量释放
if (_connectionSemaphore != null)
{
try
{
_connectionSemaphore.Wait(TimeSpan.FromSeconds(5));
_logger?.LogDebug("已获取连接信号量");
}
catch (TimeoutException)
{
_logger?.LogWarning("等待连接信号量超时");
}
}
// 收集需要等待的任务
var tasks = new List<Task>();
if (_sendTask != null && !_sendTask.IsCompleted)
{
tasks.Add(_sendTask);
_logger?.LogDebug("添加发送任务到等待列表");
}
if (_receiveTask != null && !_receiveTask.IsCompleted)
{
tasks.Add(_receiveTask);
_logger?.LogDebug("添加接收任务到等待列表");
}
if (_heartbeatTask != null && !_heartbeatTask.IsCompleted)
{
tasks.Add(_heartbeatTask);
_logger?.LogDebug("添加心跳任务到等待列表");
}
if (_reconnectTask != null && !_reconnectTask.IsCompleted)
{
tasks.Add(_reconnectTask);
_logger?.LogDebug("添加重连任务到等待列表");
}
// 等待所有任务完成
if (tasks.Count > 0)
{
_logger?.LogInformation("等待 {TaskCount} 个后台任务完成", tasks.Count);
var waitResult = Task.WaitAll(tasks.ToArray(), TimeSpan.FromSeconds(5));
if (waitResult)
{
_logger?.LogInformation("所有后台任务已成功完成");
}
else
{
_logger?.LogWarning("部分后台任务未在超时时间内完成");
}
}
else
{
_logger?.LogDebug("没有需要等待的后台任务");
}
}
catch (Exception ex)
{
_logger?.LogError(ex, "释放资源过程中发生异常");
}
finally
{
try
{
// 释放托管资源
_cancellationTokenSource?.Dispose();
_connectionSemaphore?.Dispose();
_logger?.LogDebug("已释放托管资源");
}
catch (Exception ex)
{
_logger?.LogError(ex, "释放托管资源时发生异常");
}
_disposed = true;
GC.SuppressFinalize(this);
_logger?.LogInformation("WebSocket 传输资源释放完成");
}
}
}