From d8d46a5c879f0a0861e2e055a9b0c18fb4e1abbc Mon Sep 17 00:00:00 2001 From: Shane Krueger Date: Sun, 27 Oct 2024 10:25:30 -0400 Subject: [PATCH] Add KeepAliveMode and SupportedWebSocketSubProtocols options (#80) --- README.md | 57 ++++++++++- .../GraphQLHttpMiddleware.cs | 2 +- .../WebSockets/BaseSubscriptionServer.cs | 70 +++++++++++-- .../WebSockets/GraphQLWebSocketOptions.cs | 19 ++++ .../WebSockets/GraphQLWs/PingPayload.cs | 12 +++ .../GraphQLWs/SubscriptionServer.cs | 97 ++++++++++++++++++- .../WebSockets/KeepAliveMode.cs | 36 +++++++ .../SubscriptionServer.cs | 22 +++++ .../GraphQL.AspNetCore3.approved.txt | 27 ++++++ .../WebSockets/NewSubscriptionServerTests.cs | 79 +++++++++++++-- .../WebSockets/OldSubscriptionServerTests.cs | 8 ++ .../WebSockets/TestBaseSubscriptionServer.cs | 2 + .../WebSockets/TestNewSubscriptionServer.cs | 5 + 13 files changed, 416 insertions(+), 20 deletions(-) create mode 100644 src/GraphQL.AspNetCore3/WebSockets/GraphQLWs/PingPayload.cs create mode 100644 src/GraphQL.AspNetCore3/WebSockets/KeepAliveMode.cs diff --git a/README.md b/README.md index 6347f61..4ec3f36 100644 --- a/README.md +++ b/README.md @@ -631,10 +631,12 @@ endpoint. | Property | Description | Default value | |-----------------------------|----------------------|---------------| | `ConnectionInitWaitTimeout` | The amount of time to wait for a GraphQL initialization packet before the connection is closed. | 10 seconds | -| `KeepAliveTimeout` | The amount of time to wait between sending keep-alive packets. | 30 seconds | | `DisconnectionTimeout` | The amount of time to wait to attempt a graceful teardown of the WebSockets protocol. | 10 seconds | | `DisconnectAfterErrorEvent` | Disconnects a subscription from the client if the subscription source dispatches an `OnError` event. | True | | `DisconnectAfterAnyError` | Disconnects a subscription from the client there are any GraphQL errors during a subscription. | False | +| `KeepAliveMode` | The mode to use for sending keep-alive packets. | protocol-dependent | +| `KeepAliveTimeout` | The amount of time to wait between sending keep-alive packets. | disabled | +| `SupportedWebSocketSubProtocols` | A list of supported WebSocket sub-protocols. | `graphql-ws`, `graphql-transport-ws` | ### Multi-schema configuration @@ -699,6 +701,59 @@ public class MySchema : Schema } ``` +### Keep-alive configuration + +By default, the middleware will not send keep-alive packets to the client. As the underlying +operating system may not detect a disconnected client until a message is sent, you may wish to +enable keep-alive packets to be sent periodically. The default mode for keep-alive packets +differs depending on whether the client connected with the `graphql-ws` or `graphql-transport-ws` +sub-protocol. The `graphql-ws` sub-protocol will send a unidirectional keep-alive packet to the +client on a fixed schedule, while the `graphql-transport-ws` sub-protocol will only send +unidirectional keep-alive packets when the client has not sent a message within a certain time. +The differing behavior is due to the default implementation of the `graphql-ws` sub-protocol +client, which after receiving a single keep-alive packet, expects additional keep-alive packets +to be sent sooner than every 20 seconds, regardless of the client's activity. + +To configure keep-alive packets, set the `KeepAliveMode` and `KeepAliveTimeout` properties +within the `GraphQLWebSocketOptions` object. Set the `KeepAliveTimeout` property to +enable keep-alive packets, or use `TimeSpan.Zero` or `Timeout.InfiniteTimeSpan` to disable it. + +The `KeepAliveMode` property is only applicable to the `graphql-transport-ws` sub-protocol and +can be set to the options listed below: + +| Keep-alive mode | Description | +|-----------------|-------------| +| `Default` | Same as `Timeout`. | +| `Timeout` | Sends a unidirectional keep-alive message when no message has been received within the specified timeout period. | +| `Interval` | Sends a unidirectional keep-alive message at a fixed interval, regardless of message activity. | +| `TimeoutWithPayload` | Sends a bidirectional keep-alive message with a payload on a fixed interval, and validates the payload matches in the response. | + +The `TimeoutWithPayload` model is particularly useful when the server may send messages to the +client at a faster pace than the client can process them. In this case queued messages will be +limited to double the timeout period, as the keep-alive message is queued along with other +packets sent from the server to the client. The client will need to respond to process queued +messages and respond to the keep-alive message within the timeout period or the server will +disconnect the client. When the server forcibly disconnects the client, no graceful teardown +of the WebSocket protocol occurs, and any queued messages are discarded. + +When using the `TimeoutWithPayload` keep-alive mode, you may wish to enforce that the +`graphql-transport-ws` sub-protocol is in use by the client, as the `graphql-ws` sub-protocol +does not support bidirectional keep-alive packets. This can be done by setting the +`SupportedWebSocketSubProtocols` property to only include the `graphql-transport-ws` sub-protocol. + +```csharp +app.UseGraphQL("/graphql", options => +{ + // configure keep-alive packets + options.WebSockets.KeepAliveTimeout = TimeSpan.FromSeconds(10); + options.WebSockets.KeepAliveMode = KeepAliveMode.TimeoutWithPayload; + // set the supported sub-protocols to only include the graphql-transport-ws sub-protocol + options.WebSockets.SupportedWebSocketSubProtocols = [GraphQLWs.SubscriptionServer.SubProtocol]; +}); +``` + +Please note that the included UI packages are configured to use the `graphql-ws` sub-protocol. + ### Customizing middleware behavior GET/POST requests are handled directly by the `GraphQLHttpMiddleware`. diff --git a/src/GraphQL.AspNetCore3/GraphQLHttpMiddleware.cs b/src/GraphQL.AspNetCore3/GraphQLHttpMiddleware.cs index 69512e0..35b6747 100644 --- a/src/GraphQL.AspNetCore3/GraphQLHttpMiddleware.cs +++ b/src/GraphQL.AspNetCore3/GraphQLHttpMiddleware.cs @@ -686,7 +686,7 @@ protected virtual Task WriteJsonResponseAsync(HttpContext context, Http /// /// Gets a list of WebSocket sub-protocols supported. /// - protected virtual IEnumerable SupportedWebSocketSubProtocols => _supportedSubProtocols; + protected virtual IEnumerable SupportedWebSocketSubProtocols => _options.WebSockets.SupportedWebSocketSubProtocols; /// /// Creates an , a WebSocket message pump. diff --git a/src/GraphQL.AspNetCore3/WebSockets/BaseSubscriptionServer.cs b/src/GraphQL.AspNetCore3/WebSockets/BaseSubscriptionServer.cs index fa97b18..f6940cd 100644 --- a/src/GraphQL.AspNetCore3/WebSockets/BaseSubscriptionServer.cs +++ b/src/GraphQL.AspNetCore3/WebSockets/BaseSubscriptionServer.cs @@ -256,10 +256,32 @@ protected virtual Task OnNotAuthorizedPolicyAsync(OperationMessage message, Auth ///

/// Otherwise, the connection is acknowledged via , /// is called to indicate that this WebSocket connection is ready to accept requests, - /// and keep-alive messages are sent via if configured to do so. - /// Keep-alive messages are only sent if no messages have been sent over the WebSockets connection for the - /// length of time configured in . + /// and is called to start sending keep-alive messages if configured to do so. ///
+ protected virtual async Task OnConnectionInitAsync(OperationMessage message) + { + if (!await AuthorizeAsync(message)) { + return; + } + await OnConnectionAcknowledgeAsync(message); + if (!TryInitialize()) + return; + + _ = OnKeepAliveLoopAsync(); + } + + /// + /// Executes when the client is attempting to initialize the connection. + ///

+ /// By default, this first checks to validate that the + /// request has passed authentication. If validation fails, the connection is closed with an Access + /// Denied message. + ///

+ /// Otherwise, the connection is acknowledged via , + /// is called to indicate that this WebSocket connection is ready to accept requests, + /// and is called to start sending keep-alive messages if configured to do so. + ///
+ [Obsolete($"Please use the {nameof(OnConnectionInitAsync)}(message) and {nameof(OnKeepAliveLoopAsync)} methods instead. This method will be removed in a future version of this library.")] protected virtual async Task OnConnectionInitAsync(OperationMessage message, bool smartKeepAlive) { if (!await AuthorizeAsync(message)) { @@ -272,12 +294,48 @@ protected virtual async Task OnConnectionInitAsync(OperationMessage message, boo var keepAliveTimeout = _options.KeepAliveTimeout ?? DefaultKeepAliveTimeout; if (keepAliveTimeout > TimeSpan.Zero) { if (smartKeepAlive) - _ = StartSmartKeepAliveLoopAsync(); + _ = OnKeepAliveLoopAsync(keepAliveTimeout, KeepAliveMode.Timeout); else - _ = StartKeepAliveLoopAsync(); + _ = OnKeepAliveLoopAsync(keepAliveTimeout, KeepAliveMode.Interval); + } + } + + /// + /// Starts sending keep-alive messages if configured to do so. Inspects the configured + /// and passes control to + /// if keep-alive messages are enabled. + /// + protected virtual Task OnKeepAliveLoopAsync() + { + return OnKeepAliveLoopAsync( + _options.KeepAliveTimeout ?? DefaultKeepAliveTimeout, + _options.KeepAliveMode); + } + + /// + /// Sends keep-alive messages according to the specified timeout period and method. + /// See for implementation details for each supported mode. + /// + protected virtual async Task OnKeepAliveLoopAsync(TimeSpan keepAliveTimeout, KeepAliveMode keepAliveMode) + { + if (keepAliveTimeout <= TimeSpan.Zero) + return; + + switch (keepAliveMode) { + case KeepAliveMode.Default: + case KeepAliveMode.Timeout: + await StartSmartKeepAliveLoopAsync(); + break; + case KeepAliveMode.Interval: + await StartDumbKeepAliveLoopAsync(); + break; + case KeepAliveMode.TimeoutWithPayload: + throw new NotImplementedException($"{nameof(KeepAliveMode.TimeoutWithPayload)} is not implemented within the {nameof(BaseSubscriptionServer)} class."); + default: + throw new ArgumentOutOfRangeException(nameof(keepAliveMode)); } - async Task StartKeepAliveLoopAsync() + async Task StartDumbKeepAliveLoopAsync() { while (!CancellationToken.IsCancellationRequested) { await Task.Delay(keepAliveTimeout, CancellationToken); diff --git a/src/GraphQL.AspNetCore3/WebSockets/GraphQLWebSocketOptions.cs b/src/GraphQL.AspNetCore3/WebSockets/GraphQLWebSocketOptions.cs index a671bd4..50afb65 100644 --- a/src/GraphQL.AspNetCore3/WebSockets/GraphQLWebSocketOptions.cs +++ b/src/GraphQL.AspNetCore3/WebSockets/GraphQLWebSocketOptions.cs @@ -22,6 +22,12 @@ public class GraphQLWebSocketOptions /// public TimeSpan? KeepAliveTimeout { get; set; } + /// + /// Gets or sets the keep-alive mode used for websocket subscriptions. + /// This property is only applicable when using the GraphQLWs protocol. + /// + public KeepAliveMode KeepAliveMode { get; set; } = KeepAliveMode.Default; + /// /// The amount of time to wait to attempt a graceful teardown of the WebSockets protocol. /// The default is 10 seconds. @@ -38,4 +44,17 @@ public class GraphQLWebSocketOptions /// Disconnects a subscription from the client there are any GraphQL errors during a subscription. /// public bool DisconnectAfterAnyError { get; set; } + + /// + /// The list of supported WebSocket sub-protocols. + /// Defaults to and . + /// Adding other sub-protocols require the method + /// to be overridden to handle the new sub-protocol. + /// + /// + /// When the is set to , you may wish to remove + /// from this list to prevent clients from using + /// protocols which do not support the keep-alive mode. + /// + public List SupportedWebSocketSubProtocols { get; set; } = [GraphQLWs.SubscriptionServer.SubProtocol, SubscriptionsTransportWs.SubscriptionServer.SubProtocol]; } diff --git a/src/GraphQL.AspNetCore3/WebSockets/GraphQLWs/PingPayload.cs b/src/GraphQL.AspNetCore3/WebSockets/GraphQLWs/PingPayload.cs new file mode 100644 index 0000000..aed7cb2 --- /dev/null +++ b/src/GraphQL.AspNetCore3/WebSockets/GraphQLWs/PingPayload.cs @@ -0,0 +1,12 @@ +namespace GraphQL.AspNetCore3.WebSockets.GraphQLWs; + +/// +/// The payload of the ping message. +/// +public class PingPayload +{ + /// + /// The unique identifier of the ping message. + /// + public string? id { get; set; } +} diff --git a/src/GraphQL.AspNetCore3/WebSockets/GraphQLWs/SubscriptionServer.cs b/src/GraphQL.AspNetCore3/WebSockets/GraphQLWs/SubscriptionServer.cs index 47e6924..47a6427 100644 --- a/src/GraphQL.AspNetCore3/WebSockets/GraphQLWs/SubscriptionServer.cs +++ b/src/GraphQL.AspNetCore3/WebSockets/GraphQLWs/SubscriptionServer.cs @@ -6,6 +6,11 @@ namespace GraphQL.AspNetCore3.WebSockets.GraphQLWs; public class SubscriptionServer : BaseSubscriptionServer { private readonly IWebSocketAuthenticationService? _authenticationService; + private readonly IGraphQLSerializer _serializer; + private readonly GraphQLWebSocketOptions _options; + private DateTime _lastPongReceivedUtc; + private string? _lastPingId; + private readonly object _lastPingLock = new(); /// /// The WebSocket sub-protocol used for this protocol. @@ -69,6 +74,8 @@ public SubscriptionServer( UserContextBuilder = userContextBuilder ?? throw new ArgumentNullException(nameof(userContextBuilder)); Serializer = serializer ?? throw new ArgumentNullException(nameof(serializer)); _authenticationService = authenticationService; + _serializer = serializer; + _options = options; } /// @@ -84,7 +91,9 @@ public override async Task OnMessageReceivedAsync(OperationMessage message) if (Initialized) { await ErrorTooManyInitializationRequestsAsync(message); } else { +#pragma warning disable CS0618 // Type or member is obsolete await OnConnectionInitAsync(message, true); +#pragma warning restore CS0618 // Type or member is obsolete } return; } @@ -105,6 +114,64 @@ public override async Task OnMessageReceivedAsync(OperationMessage message) } } + /// + [Obsolete($"Please use the {nameof(OnConnectionInitAsync)} and {nameof(OnKeepAliveLoopAsync)} methods instead. This method will be removed in a future version of this library.")] + protected override Task OnConnectionInitAsync(OperationMessage message, bool smartKeepAlive) + { + if (smartKeepAlive) + return OnConnectionInitAsync(message); + else + return base.OnConnectionInitAsync(message, smartKeepAlive); + } + + /// + protected override Task OnKeepAliveLoopAsync(TimeSpan keepAliveTimeout, KeepAliveMode keepAliveMode) + { + if (keepAliveMode == KeepAliveMode.TimeoutWithPayload) { + if (keepAliveTimeout <= TimeSpan.Zero) + return Task.CompletedTask; + return SecureKeepAliveLoopAsync(keepAliveTimeout, keepAliveTimeout); + } + return base.OnKeepAliveLoopAsync(keepAliveTimeout, keepAliveMode); + + // pingInterval is the time since the last pong was received before sending a new ping + // pongInterval is the time to wait for a pong after a ping was sent before forcibly closing the connection + async Task SecureKeepAliveLoopAsync(TimeSpan pingInterval, TimeSpan pongInterval) + { + lock (_lastPingLock) + _lastPongReceivedUtc = DateTime.UtcNow; + while (!CancellationToken.IsCancellationRequested) { + // Wait for the next ping interval + TimeSpan interval; + var now = DateTime.UtcNow; + DateTime lastPongReceivedUtc; + lock (_lastPingLock) { + lastPongReceivedUtc = _lastPongReceivedUtc; + } + var nextPing = lastPongReceivedUtc.Add(pingInterval); + interval = nextPing.Subtract(now); + if (interval > TimeSpan.Zero) // could easily be zero or less, if pongInterval is equal or greater than pingInterval + await Task.Delay(interval, CancellationToken); + + // Send a new ping message + await OnSendKeepAliveAsync(); + + // Wait for the pong response + await Task.Delay(pongInterval, CancellationToken); + bool abort; + lock (_lastPingLock) { + abort = _lastPongReceivedUtc == lastPongReceivedUtc; + } + if (abort) { + // Forcibly close the connection if the client has not responded to the keep-alive message. + // Do not send a close message to the client or wait for a response. + Connection.HttpContext.Abort(); + return; + } + } + } + } + /// /// Pong is a required response to a ping, and also a unidirectional keep-alive packet, /// whereas ping is a bidirectional keep-alive packet. @@ -123,11 +190,37 @@ protected virtual Task OnPingAsync(OperationMessage message) /// Executes when a pong message is received. /// protected virtual Task OnPongAsync(OperationMessage message) - => Task.CompletedTask; + { + if (_options.KeepAliveMode == KeepAliveMode.TimeoutWithPayload) { + try { + var pingId = _serializer.ReadNode(message.Payload)?.id; + lock (_lastPingLock) { + if (_lastPingId == pingId) + _lastPongReceivedUtc = DateTime.UtcNow; + } + } catch { } // ignore deserialization errors in case the pong message does not match the expected format + } + return Task.CompletedTask; + } /// protected override Task OnSendKeepAliveAsync() - => Connection.SendMessageAsync(_pongMessage); + { + if (_options.KeepAliveMode == KeepAliveMode.TimeoutWithPayload) { + var lastPingId = Guid.NewGuid().ToString("N"); + lock (_lastPingLock) { + _lastPingId = lastPingId; + } + return Connection.SendMessageAsync( + new() { + Type = MessageType.Ping, + Payload = new PingPayload { id = lastPingId } + } + ); + } else { + return Connection.SendMessageAsync(_pongMessage); + } + } private static readonly OperationMessage _connectionAckMessage = new() { Type = MessageType.ConnectionAck }; /// diff --git a/src/GraphQL.AspNetCore3/WebSockets/KeepAliveMode.cs b/src/GraphQL.AspNetCore3/WebSockets/KeepAliveMode.cs new file mode 100644 index 0000000..3bc39dd --- /dev/null +++ b/src/GraphQL.AspNetCore3/WebSockets/KeepAliveMode.cs @@ -0,0 +1,36 @@ +namespace GraphQL.AspNetCore3.WebSockets; + +/// +/// Specifies the mode of keep-alive behavior. +/// +public enum KeepAliveMode +{ + /// + /// Same as : Sends a unidirectional keep-alive message when no message has been received within the specified timeout period. + /// + Default = 0, + + /// + /// Sends a unidirectional keep-alive message when no message has been received within the specified timeout period. + /// + Timeout = 1, + + /// + /// Sends a unidirectional keep-alive message at a fixed interval, regardless of message activity. + /// + Interval = 2, + + /// + /// Sends a Ping message with a payload after the specified timeout from the last received Pong, + /// and waits for a corresponding Pong response. Requires that the client reflects the payload + /// in the response. Forcibly disconnects the client if the client does not respond with a Pong + /// message within the specified timeout. This means that a dead connection will be closed after + /// a maximum of double the period. + /// + /// + /// This mode is particularly useful when backpressure causes subscription messages to be delayed + /// due to a slow or unresponsive client connection. The server can detect that the client is not + /// processing messages in a timely manner and disconnect the client to free up resources. + /// + TimeoutWithPayload = 3, +} diff --git a/src/GraphQL.AspNetCore3/WebSockets/SubscriptionsTransportWs/SubscriptionServer.cs b/src/GraphQL.AspNetCore3/WebSockets/SubscriptionsTransportWs/SubscriptionServer.cs index c52f53b..0f39428 100644 --- a/src/GraphQL.AspNetCore3/WebSockets/SubscriptionsTransportWs/SubscriptionServer.cs +++ b/src/GraphQL.AspNetCore3/WebSockets/SubscriptionsTransportWs/SubscriptionServer.cs @@ -81,7 +81,9 @@ public override async Task OnMessageReceivedAsync(OperationMessage message) if (Initialized) { await ErrorTooManyInitializationRequestsAsync(message); } else { +#pragma warning disable CS0618 // Type or member is obsolete await OnConnectionInitAsync(message, false); +#pragma warning restore CS0618 // Type or member is obsolete } return; } @@ -102,6 +104,26 @@ public override async Task OnMessageReceivedAsync(OperationMessage message) } } + /// + [Obsolete($"Please use the {nameof(OnConnectionInitAsync)} and {nameof(OnKeepAliveLoopAsync)} methods instead. This method will be removed in a future version of this library.")] + protected override Task OnConnectionInitAsync(OperationMessage message, bool smartKeepAlive) + { + if (!smartKeepAlive) + return OnConnectionInitAsync(message); + else + return base.OnConnectionInitAsync(message, smartKeepAlive); + } + + /// + /// + /// This implementation overrides to + /// as this protocol does not support the other modes. Override this method to support your own implementation. + /// + protected override Task OnKeepAliveLoopAsync(TimeSpan keepAliveTimeout, KeepAliveMode keepAliveMode) + => base.OnKeepAliveLoopAsync( + keepAliveTimeout, + KeepAliveMode.Interval); + private static readonly OperationMessage _keepAliveMessage = new() { Type = MessageType.GQL_CONNECTION_KEEP_ALIVE }; /// protected override Task OnSendKeepAliveAsync() diff --git a/src/Tests.ApiApprovals/GraphQL.AspNetCore3.approved.txt b/src/Tests.ApiApprovals/GraphQL.AspNetCore3.approved.txt index f7934fa..fbd14bf 100644 --- a/src/Tests.ApiApprovals/GraphQL.AspNetCore3.approved.txt +++ b/src/Tests.ApiApprovals/GraphQL.AspNetCore3.approved.txt @@ -297,8 +297,13 @@ namespace GraphQL.AspNetCore3.WebSockets public virtual System.Threading.Tasks.Task InitializeConnectionAsync() { } protected virtual System.Threading.Tasks.Task OnCloseConnectionAsync() { } protected abstract System.Threading.Tasks.Task OnConnectionAcknowledgeAsync(GraphQL.Transport.OperationMessage message); + protected virtual System.Threading.Tasks.Task OnConnectionInitAsync(GraphQL.Transport.OperationMessage message) { } + [System.Obsolete("Please use the OnConnectionInitAsync(message) and OnKeepAliveLoopAsync methods in" + + "stead. This method will be removed in a future version of this library.")] protected virtual System.Threading.Tasks.Task OnConnectionInitAsync(GraphQL.Transport.OperationMessage message, bool smartKeepAlive) { } protected virtual System.Threading.Tasks.Task OnConnectionInitWaitTimeoutAsync() { } + protected virtual System.Threading.Tasks.Task OnKeepAliveLoopAsync() { } + protected virtual System.Threading.Tasks.Task OnKeepAliveLoopAsync(System.TimeSpan keepAliveTimeout, GraphQL.AspNetCore3.WebSockets.KeepAliveMode keepAliveMode) { } public abstract System.Threading.Tasks.Task OnMessageReceivedAsync(GraphQL.Transport.OperationMessage message); protected virtual System.Threading.Tasks.Task OnNotAuthenticatedAsync(GraphQL.Transport.OperationMessage message) { } protected virtual System.Threading.Tasks.Task OnNotAuthorizedPolicyAsync(GraphQL.Transport.OperationMessage message, Microsoft.AspNetCore.Authorization.AuthorizationResult result) { } @@ -322,7 +327,9 @@ namespace GraphQL.AspNetCore3.WebSockets public bool DisconnectAfterAnyError { get; set; } public bool DisconnectAfterErrorEvent { get; set; } public System.TimeSpan? DisconnectionTimeout { get; set; } + public GraphQL.AspNetCore3.WebSockets.KeepAliveMode KeepAliveMode { get; set; } public System.TimeSpan? KeepAliveTimeout { get; set; } + public System.Collections.Generic.List SupportedWebSocketSubProtocols { get; set; } } public interface IOperationMessageProcessor : System.IDisposable { @@ -343,6 +350,13 @@ namespace GraphQL.AspNetCore3.WebSockets System.Threading.Tasks.Task ExecuteAsync(GraphQL.AspNetCore3.WebSockets.IOperationMessageProcessor operationMessageProcessor); System.Threading.Tasks.Task SendMessageAsync(GraphQL.Transport.OperationMessage message); } + public enum KeepAliveMode + { + Default = 0, + Timeout = 1, + Interval = 2, + TimeoutWithPayload = 3, + } public sealed class SubscriptionList : System.IDisposable { public SubscriptionList() { } @@ -386,6 +400,11 @@ namespace GraphQL.AspNetCore3.WebSockets.GraphQLWs public const string Pong = "pong"; public const string Subscribe = "subscribe"; } + public class PingPayload + { + public PingPayload() { } + public string? id { get; set; } + } public class SubscriptionServer : GraphQL.AspNetCore3.WebSockets.BaseSubscriptionServer { public SubscriptionServer(GraphQL.AspNetCore3.WebSockets.IWebSocketConnection connection, GraphQL.AspNetCore3.WebSockets.GraphQLWebSocketOptions options, GraphQL.AspNetCore3.IAuthorizationOptions authorizationOptions, GraphQL.IDocumentExecuter executer, GraphQL.IGraphQLSerializer serializer, Microsoft.Extensions.DependencyInjection.IServiceScopeFactory serviceScopeFactory, GraphQL.AspNetCore3.IUserContextBuilder userContextBuilder, GraphQL.AspNetCore3.WebSockets.IWebSocketAuthenticationService? authenticationService = null) { } @@ -399,6 +418,10 @@ namespace GraphQL.AspNetCore3.WebSockets.GraphQLWs protected override System.Threading.Tasks.Task ExecuteRequestAsync(GraphQL.Transport.OperationMessage message) { } protected virtual System.Threading.Tasks.Task OnCompleteAsync(GraphQL.Transport.OperationMessage message) { } protected override System.Threading.Tasks.Task OnConnectionAcknowledgeAsync(GraphQL.Transport.OperationMessage message) { } + [System.Obsolete("Please use the OnConnectionInitAsync and OnKeepAliveLoopAsync methods instead. Th" + + "is method will be removed in a future version of this library.")] + protected override System.Threading.Tasks.Task OnConnectionInitAsync(GraphQL.Transport.OperationMessage message, bool smartKeepAlive) { } + protected override System.Threading.Tasks.Task OnKeepAliveLoopAsync(System.TimeSpan keepAliveTimeout, GraphQL.AspNetCore3.WebSockets.KeepAliveMode keepAliveMode) { } public override System.Threading.Tasks.Task OnMessageReceivedAsync(GraphQL.Transport.OperationMessage message) { } protected virtual System.Threading.Tasks.Task OnPingAsync(GraphQL.Transport.OperationMessage message) { } protected virtual System.Threading.Tasks.Task OnPongAsync(GraphQL.Transport.OperationMessage message) { } @@ -437,6 +460,10 @@ namespace GraphQL.AspNetCore3.WebSockets.SubscriptionsTransportWs protected override System.Threading.Tasks.Task ErrorAccessDeniedAsync() { } protected override System.Threading.Tasks.Task ExecuteRequestAsync(GraphQL.Transport.OperationMessage message) { } protected override System.Threading.Tasks.Task OnConnectionAcknowledgeAsync(GraphQL.Transport.OperationMessage message) { } + [System.Obsolete("Please use the OnConnectionInitAsync and OnKeepAliveLoopAsync methods instead. Th" + + "is method will be removed in a future version of this library.")] + protected override System.Threading.Tasks.Task OnConnectionInitAsync(GraphQL.Transport.OperationMessage message, bool smartKeepAlive) { } + protected override System.Threading.Tasks.Task OnKeepAliveLoopAsync(System.TimeSpan keepAliveTimeout, GraphQL.AspNetCore3.WebSockets.KeepAliveMode keepAliveMode) { } public override System.Threading.Tasks.Task OnMessageReceivedAsync(GraphQL.Transport.OperationMessage message) { } protected override System.Threading.Tasks.Task OnSendKeepAliveAsync() { } protected virtual System.Threading.Tasks.Task OnStartAsync(GraphQL.Transport.OperationMessage message) { } diff --git a/src/Tests/WebSockets/NewSubscriptionServerTests.cs b/src/Tests/WebSockets/NewSubscriptionServerTests.cs index 86d749b..786ef23 100644 --- a/src/Tests/WebSockets/NewSubscriptionServerTests.cs +++ b/src/Tests/WebSockets/NewSubscriptionServerTests.cs @@ -1,4 +1,5 @@ using System.Security.Claims; +using GraphQL.AspNetCore3.WebSockets.GraphQLWs; namespace Tests.WebSockets; @@ -63,6 +64,14 @@ public async Task Message_Initialize(bool initialized) .Returns(Task.CompletedTask).Verifiable(); } else { _mockServer.Protected().Setup("OnConnectionInitAsync", message, true) + .CallBase().Verifiable(); + _mockServer.Protected().Setup("OnConnectionInitAsync", message) + .CallBase().Verifiable(); + _mockServer.Protected().Setup>("AuthorizeAsync", message) + .Returns(new ValueTask(true)).Verifiable(); + _mockServer.Protected().Setup("OnConnectionAcknowledgeAsync", message) + .Returns(Task.CompletedTask).Verifiable(); + _mockServer.Protected().Setup("OnKeepAliveLoopAsync") .Returns(Task.CompletedTask).Verifiable(); } _mockServer.Setup(x => x.OnMessageReceivedAsync(message)).CallBase().Verifiable(); @@ -115,11 +124,15 @@ public async Task Message_Ping(bool initialized, bool withPayload) } [Theory] - [InlineData(false)] - [InlineData(true)] - public async Task Message_Pong(bool initialized) + [InlineData(false, false)] + [InlineData(false, true)] + [InlineData(true, false)] + [InlineData(true, true)] + public async Task Message_Pong(bool initialized, bool withPayload) { var message = new OperationMessage { Type = "pong" }; + if (withPayload) + message.Payload = new PingPayload { id = Guid.NewGuid().ToString("N") }; _mockServer.Protected().Setup("OnPongAsync", message) .Returns(Task.CompletedTask).Verifiable(); if (initialized) { @@ -177,11 +190,27 @@ public async Task Message_Unknown(string? messageType) _mockServer.VerifyNoOtherCalls(); } - [Fact] - public async Task OnSendKeepAliveAsync() + [Theory] + [InlineData(KeepAliveMode.Default)] + [InlineData(KeepAliveMode.Interval)] + [InlineData(KeepAliveMode.Timeout)] + [InlineData(KeepAliveMode.TimeoutWithPayload)] + public async Task OnSendKeepAliveAsync(KeepAliveMode keepAliveMode) { + _options.WebSockets.KeepAliveMode = keepAliveMode; _mockStream.Setup(x => x.SendMessageAsync(It.IsAny())) - .Returns(o => o.Type == "pong" ? Task.CompletedTask : Task.FromException(new Exception())) + .Returns(async o => { + o.Id.ShouldBeNull(); + if (keepAliveMode == KeepAliveMode.TimeoutWithPayload) { + o.Type.ShouldBe("ping"); + var payload = o.Payload.ShouldBeOfType(); + var guid = Guid.ParseExact(payload.id.ShouldNotBeNull(), "N"); + guid.ShouldNotBe(Guid.Empty); + } else { + o.Type.ShouldBe("pong"); + o.Payload.ShouldBeNull(); + } + }) .Verifiable(); _mockServer.Protected().Setup("OnSendKeepAliveAsync").CallBase().Verifiable(); await _server.Do_OnSendKeepAliveAsync(); @@ -189,6 +218,25 @@ public async Task OnSendKeepAliveAsync() _mockServer.VerifyNoOtherCalls(); } + [Theory] + [InlineData(KeepAliveMode.Default)] + [InlineData(KeepAliveMode.Interval)] + [InlineData(KeepAliveMode.Timeout)] + [InlineData(KeepAliveMode.TimeoutWithPayload)] + public async Task OnKeepAliveLoopAsync(KeepAliveMode keepAliveMode) + { + _options.WebSockets.KeepAliveMode = keepAliveMode; + var defaultKeepAliveTimeout = TimeSpan.FromSeconds(10); + _mockServer.Protected().SetupGet("DefaultKeepAliveTimeout") + .Returns(defaultKeepAliveTimeout).Verifiable(); + _mockServer.Protected().Setup("OnKeepAliveLoopAsync").CallBase().Verifiable(); + _mockServer.Protected().Setup("OnKeepAliveLoopAsync", defaultKeepAliveTimeout, keepAliveMode) + .Returns(Task.CompletedTask).Verifiable(); + await _server.Do_OnKeepAliveLoopAsync(); + _mockServer.Verify(); + _mockServer.VerifyNoOtherCalls(); + } + [Fact] public async Task OnConnectionAcknowledgeAsync() { @@ -202,13 +250,24 @@ public async Task OnConnectionAcknowledgeAsync() _mockServer.VerifyNoOtherCalls(); } - [Fact] - public async Task OnPingAsync() + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task OnPingAsync(bool withPayload) { + var payload = new { id = Guid.NewGuid().ToString("N") }; _mockStream.Setup(x => x.SendMessageAsync(It.IsAny())) - .Returns(o => o.Type == "pong" ? Task.CompletedTask : Task.FromException(new Exception())) + .Returns(async o => { + o.Id.ShouldBeNull(); + o.Type.ShouldBe("pong"); + if (withPayload) { + o.Payload.ShouldBe(payload); + } else { + o.Payload.ShouldBeNull(); + } + }) .Verifiable(); - var message = new OperationMessage(); + var message = new OperationMessage() { Payload = withPayload ? payload : null }; _mockServer.Protected().Setup("OnPingAsync", message).CallBase().Verifiable(); await _server.Do_OnPingAsync(message); _mockServer.Verify(); diff --git a/src/Tests/WebSockets/OldSubscriptionServerTests.cs b/src/Tests/WebSockets/OldSubscriptionServerTests.cs index 83e91a1..7b9ecd1 100644 --- a/src/Tests/WebSockets/OldSubscriptionServerTests.cs +++ b/src/Tests/WebSockets/OldSubscriptionServerTests.cs @@ -81,6 +81,14 @@ public async Task Message_Initialize(bool initialized) .Returns(Task.CompletedTask).Verifiable(); } else { _mockServer.Protected().Setup("OnConnectionInitAsync", message, false) + .CallBase().Verifiable(); + _mockServer.Protected().Setup("OnConnectionInitAsync", message) + .CallBase().Verifiable(); + _mockServer.Protected().Setup>("AuthorizeAsync", message) + .Returns(new ValueTask(true)).Verifiable(); + _mockServer.Protected().Setup("OnConnectionAcknowledgeAsync", message) + .Returns(Task.CompletedTask).Verifiable(); + _mockServer.Protected().Setup("OnKeepAliveLoopAsync") .Returns(Task.CompletedTask).Verifiable(); } _mockServer.Setup(x => x.OnMessageReceivedAsync(message)).CallBase().Verifiable(); diff --git a/src/Tests/WebSockets/TestBaseSubscriptionServer.cs b/src/Tests/WebSockets/TestBaseSubscriptionServer.cs index 15e6df1..c5749d5 100644 --- a/src/Tests/WebSockets/TestBaseSubscriptionServer.cs +++ b/src/Tests/WebSockets/TestBaseSubscriptionServer.cs @@ -49,7 +49,9 @@ public Task Do_ErrorAccessDeniedAsync() => ErrorAccessDeniedAsync(); public Task Do_OnConnectionInitAsync(OperationMessage message, bool smartKeepAlive) +#pragma warning disable CS0618 // Type or member is obsolete => OnConnectionInitAsync(message, smartKeepAlive); +#pragma warning restore CS0618 // Type or member is obsolete public Task Do_SubscribeAsync(OperationMessage message, bool overwrite) => SubscribeAsync(message, overwrite); diff --git a/src/Tests/WebSockets/TestNewSubscriptionServer.cs b/src/Tests/WebSockets/TestNewSubscriptionServer.cs index 52de220..cd7d1e7 100644 --- a/src/Tests/WebSockets/TestNewSubscriptionServer.cs +++ b/src/Tests/WebSockets/TestNewSubscriptionServer.cs @@ -42,6 +42,9 @@ public Task Do_SendCompletedAsync(string id) public Task Do_ExecuteRequestAsync(OperationMessage message) => ExecuteRequestAsync(message); + public Task Do_OnKeepAliveLoopAsync() + => OnKeepAliveLoopAsync(); + public SubscriptionList Get_Subscriptions => Subscriptions; @@ -56,4 +59,6 @@ public SubscriptionList Get_Subscriptions public IDocumentExecuter Get_DocumentExecuter => DocumentExecuter; public IServiceScopeFactory Get_ServiceScopeFactory => ServiceScopeFactory; + + public TimeSpan Get_DefaultKeepAliveTimeout => DefaultKeepAliveTimeout; }