using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Microsoft.Extensions.Caching.Memory; using CoreAgent.WebSocketTransport.Interfaces; using CoreAgent.WebSocketTransport.Services; using CoreAgent.WebSocketTransport.Middleware; using CoreAgent.WebSocketTransport.Models; namespace CoreAgent.WebSocketTransport.Extensions { /// /// WebSocket 传输服务扩展 /// public static class WebSocketTransportExtensions { /// /// 添加 WebSocket 传输服务 /// /// 服务集合 /// 配置 /// 配置节名称 /// 服务集合 public static IServiceCollection AddWebSocketTransport( this IServiceCollection services, IConfiguration configuration, string configSection = "WebSocket") { if (services == null) throw new ArgumentNullException(nameof(services)); if (configuration == null) throw new ArgumentNullException(nameof(configuration)); // 注册配置 services.Configure(options => { configuration.GetSection(configSection).Bind(options); }); // 注册默认中间件(在核心服务之前) RegisterDefaultMiddleware(services); // 注册核心服务 RegisterCoreServices(services); return services; } /// /// 添加 WebSocket 中间件 /// /// 中间件类型 /// 服务集合 /// 服务集合 public static IServiceCollection AddWebSocketMiddleware(this IServiceCollection services) where T : class, IMessageMiddleware { if (services == null) throw new ArgumentNullException(nameof(services)); services.AddTransient(); return services; } /// /// 注册核心服务 /// /// 服务集合 private static void RegisterCoreServices(IServiceCollection services) { // 注册核心组件 services.AddSingleton(); services.AddSingleton(); // 注册消息通道管理器 services.AddSingleton(provider => { var logger = provider.GetRequiredService>(); var config = provider.GetRequiredService>().Value; return new MessageChannelManager( logger, config.SendChannelCapacity, config.ReceiveChannelCapacity, config.PriorityChannelCapacity); }); // 注册 WebSocket 传输 services.AddSingleton(provider => { var logger = provider.GetRequiredService>(); var connection = provider.GetRequiredService(); var serializer = provider.GetRequiredService(); var middlewares = provider.GetServices(); var config = provider.GetRequiredService>().Value; var channelManager = provider.GetRequiredService(); return new CoreAgent.WebSocketTransport.Services.WebSocketTransport( logger, connection, serializer, middlewares, config, channelManager); }); } /// /// 注册默认中间件 /// /// 服务集合 private static void RegisterDefaultMiddleware(IServiceCollection services) { // 注册日志中间件 services.AddWebSocketMiddleware(); // 注册缓存中间件(使用正确的注册方式) services.AddWebSocketMiddleware(); } } }