You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

295 lines
11 KiB

using System.Net.WebSockets;
using System.Text;
using System.Text.Json;
using CellularManagement.Application.Services;
using CellularManagement.Domain.Entities;
using CellularManagement.Infrastructure.Monitoring;
using CellularManagement.Infrastructure.Pooling;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.Caching.Memory;
namespace CellularManagement.Infrastructure.WebSocket;
public class WebSocketService : IWebSocketService
{
private readonly ILogger<WebSocketService> _logger;
private readonly WebSocketMetrics _metrics;
private readonly WebSocketMessagePool _messagePool;
private readonly IDistributedCache _distributedCache;
private readonly ICacheService _cacheService;
private readonly string _nodeId;
private const string CONNECTION_PREFIX = "ws_connection_";
private const string WEBSOCKET_PREFIX = "ws_socket_";
private const string USER_PREFIX = "ws_user_";
private const string NODE_PREFIX = "ws_node_";
public WebSocketService(
ILogger<WebSocketService> logger,
WebSocketMetrics metrics,
WebSocketMessagePool messagePool,
IDistributedCache distributedCache,
ICacheService cacheService)
{
_logger = logger;
_metrics = metrics;
_messagePool = messagePool;
_distributedCache = distributedCache;
_cacheService = cacheService;
_nodeId = Guid.NewGuid().ToString();
}
public async Task<string> AcceptConnectionAsync(System.Net.WebSockets.WebSocket webSocket)
{
var connectionId = Guid.NewGuid().ToString();
var connection = WebSocketConnection.Create(connectionId);
var connectionKey = $"{CONNECTION_PREFIX}{connectionId}";
var webSocketKey = $"{WEBSOCKET_PREFIX}{connectionId}";
var nodeKey = $"{NODE_PREFIX}{_nodeId}";
await _distributedCache.SetStringAsync(
connectionKey,
JsonSerializer.Serialize(connection),
new DistributedCacheEntryOptions
{
SlidingExpiration = TimeSpan.FromMinutes(30)
});
_cacheService.Set(webSocketKey, webSocket, new MemoryCacheEntryOptions
{
SlidingExpiration = TimeSpan.FromMinutes(30)
});
await AddConnectionToNodeAsync(connectionId);
_metrics.ConnectionEstablished();
_logger.LogInformation("WebSocket connection accepted: {ConnectionId} on node {NodeId}",
connectionId, _nodeId);
return connectionId;
}
public async Task<bool> CloseConnectionAsync(string connectionId)
{
var connectionKey = $"{CONNECTION_PREFIX}{connectionId}";
var webSocketKey = $"{WEBSOCKET_PREFIX}{connectionId}";
var userKey = $"{USER_PREFIX}{connectionId}";
var connectionJson = await _distributedCache.GetStringAsync(connectionKey);
if (_cacheService.TryGetValue(webSocketKey, out System.Net.WebSockets.WebSocket? webSocket))
{
try
{
if (connectionJson != null)
{
var connection = JsonSerializer.Deserialize<WebSocketConnection>(connectionJson);
if (connection != null && webSocket != null)
{
try
{
if (webSocket.State == WebSocketState.Open || webSocket.State == WebSocketState.CloseReceived)
{
await webSocket.CloseAsync(
WebSocketCloseStatus.NormalClosure,
"Connection closed by server",
CancellationToken.None);
}
}
catch (Exception ex)
{
_logger.LogError(ex, "Error closing WebSocket connection: {ConnectionId}", connectionId);
}
connection.Close();
await _distributedCache.SetStringAsync(connectionKey, JsonSerializer.Serialize(connection));
}
}
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Failed to deserialize WebSocketConnection for connection: {ConnectionId}, continuing with cleanup", connectionId);
}
}
await RemoveConnectionFromNodeAsync(connectionId);
await _distributedCache.RemoveAsync(connectionKey);
_cacheService.Remove(webSocketKey);
await _distributedCache.RemoveAsync(userKey);
_metrics.ConnectionClosed();
_logger.LogInformation("WebSocket connection closed: {ConnectionId}", connectionId);
return true;
}
public async Task<bool> SendMessageAsync(string connectionId, byte[] message)
{
var connectionKey = $"{CONNECTION_PREFIX}{connectionId}";
var webSocketKey = $"{WEBSOCKET_PREFIX}{connectionId}";
var connectionJson = await _distributedCache.GetStringAsync(connectionKey);
if (_cacheService.TryGetValue(webSocketKey, out System.Net.WebSockets.WebSocket? webSocket))
{
try
{
if (connectionJson != null)
{
var connection = JsonSerializer.Deserialize<WebSocketConnection>(connectionJson);
if (connection?.State == WebSocketState.Open && webSocket != null)
{
try
{
await webSocket.SendAsync(
new ArraySegment<byte>(message),
WebSocketMessageType.Text,
true,
CancellationToken.None);
_metrics.MessageProcessed(TimeSpan.Zero);
return true;
}
catch (Exception ex)
{
_logger.LogError(ex, "Error sending message to connection: {ConnectionId}", connectionId);
_metrics.ErrorOccurred("SendMessage");
}
}
}
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Failed to deserialize WebSocketConnection for connection: {ConnectionId}, skipping message send", connectionId);
}
}
return false;
}
public async Task<bool> BroadcastMessageAsync(byte[] message)
{
var nodes = await GetAllNodesAsync();
var success = true;
foreach (var node in nodes)
{
var nodeKey = $"{NODE_PREFIX}{node}";
var connectionsJson = await _distributedCache.GetStringAsync(nodeKey);
if (connectionsJson != null)
{
var connections = JsonSerializer.Deserialize<List<string>>(connectionsJson);
foreach (var connectionId in connections)
{
if (!await SendMessageAsync(connectionId, message))
{
success = false;
}
}
}
}
return success;
}
public async Task<bool> SendMessageToUserAsync(string userId, byte[] message)
{
var userConnections = await GetUserConnectionsAsync(userId);
var success = true;
foreach (var connection in userConnections)
{
if (!await SendMessageAsync(connection.ConnectionId, message))
{
success = false;
}
}
return success;
}
public async Task AssociateUserAsync(string connectionId, string userId)
{
var connectionKey = $"{CONNECTION_PREFIX}{connectionId}";
var userKey = $"{USER_PREFIX}{connectionId}";
var connectionJson = await _distributedCache.GetStringAsync(connectionKey);
if (connectionJson != null)
{
var connection = JsonSerializer.Deserialize<WebSocketConnection>(connectionJson);
if (connection != null)
{
connection.AssociateUser(userId);
await _distributedCache.SetStringAsync(connectionKey, JsonSerializer.Serialize(connection));
await _distributedCache.SetStringAsync(userKey, userId);
}
}
}
public async Task<WebSocketConnection?> GetConnectionAsync(string connectionId)
{
var connectionKey = $"{CONNECTION_PREFIX}{connectionId}";
var connectionJson = await _distributedCache.GetStringAsync(connectionKey);
return connectionJson != null
? JsonSerializer.Deserialize<WebSocketConnection>(connectionJson)
: null;
}
public async Task<IEnumerable<WebSocketConnection>> GetUserConnectionsAsync(string userId)
{
var connections = new List<WebSocketConnection>();
var nodes = await GetAllNodesAsync();
foreach (var node in nodes)
{
var nodeKey = $"{NODE_PREFIX}{node}";
var connectionsJson = await _distributedCache.GetStringAsync(nodeKey);
if (connectionsJson != null)
{
var connectionIds = JsonSerializer.Deserialize<List<string>>(connectionsJson);
foreach (var connectionId in connectionIds)
{
var connection = await GetConnectionAsync(connectionId);
if (connection?.UserId == userId)
{
connections.Add(connection);
}
}
}
}
return connections;
}
private async Task AddConnectionToNodeAsync(string connectionId)
{
var nodeKey = $"{NODE_PREFIX}{_nodeId}";
var connections = await GetNodeConnectionsAsync();
connections.Add(connectionId);
await _distributedCache.SetStringAsync(
nodeKey,
JsonSerializer.Serialize(connections));
}
private async Task RemoveConnectionFromNodeAsync(string connectionId)
{
var nodeKey = $"{NODE_PREFIX}{_nodeId}";
var connections = await GetNodeConnectionsAsync();
connections.Remove(connectionId);
await _distributedCache.SetStringAsync(
nodeKey,
JsonSerializer.Serialize(connections));
}
private async Task<List<string>> GetNodeConnectionsAsync()
{
var nodeKey = $"{NODE_PREFIX}{_nodeId}";
var connectionsJson = await _distributedCache.GetStringAsync(nodeKey);
return connectionsJson != null
? JsonSerializer.Deserialize<List<string>>(connectionsJson) ?? new List<string>()
: new List<string>();
}
private async Task<List<string>> GetAllNodesAsync()
{
// 这里需要实现服务发现机制
// 可以使用Redis的Pub/Sub或其他服务发现机制
return new List<string> { _nodeId };
}
}