diff --git a/README.md b/README.md
index 6c92a52..3b9f202 100644
--- a/README.md
+++ b/README.md
@@ -564,6 +564,8 @@ endpoint.
| `AuthorizationRequired` | Requires `HttpContext.User` to represent an authenticated user. | False |
| `AuthorizedPolicy` | If set, requires `HttpContext.User` to pass authorization of the specified policy. | |
| `AuthorizedRoles` | If set, requires `HttpContext.User` to be a member of any one of a list of roles. | |
+| `CsrfProtectionEnabled` | Enables cross-site request forgery (CSRF) protection for both GET and POST requests. | True |
+| `CsrfProtectionHeaders` | Sets the headers used for CSRF protection when necessary. | `GraphQL-Require-Preflight` |
| `EnableBatchedRequests` | Enables handling of batched GraphQL requests for POST requests when formatted as JSON. | True |
| `ExecuteBatchedRequestsInParallel` | Enables parallel execution of batched GraphQL requests. | True |
| `HandleGet` | Enables handling of GET requests. | True |
diff --git a/src/GraphQL.AspNetCore3/Errors/CsrfProtectionError.cs b/src/GraphQL.AspNetCore3/Errors/CsrfProtectionError.cs
new file mode 100644
index 0000000..068868e
--- /dev/null
+++ b/src/GraphQL.AspNetCore3/Errors/CsrfProtectionError.cs
@@ -0,0 +1,16 @@
+namespace GraphQL.AspNetCore3.Errors;
+
+///
+/// Represents an error indicating that the request may not have triggered a CORS preflight request.
+///
+public class CsrfProtectionError : RequestError
+{
+ ///
+ public CsrfProtectionError(IEnumerable headersRequired) : base($"This request requires a non-empty header from the following list: {FormatHeaders(headersRequired)}.") { }
+
+ ///
+ public CsrfProtectionError(IEnumerable headersRequired, Exception innerException) : base($"This request requires a non-empty header from the following list: {FormatHeaders(headersRequired)}. {innerException.Message}") { }
+
+ private static string FormatHeaders(IEnumerable headersRequired)
+ => string.Join(", ", headersRequired.Select(x => $"'{x}'"));
+}
diff --git a/src/GraphQL.AspNetCore3/GraphQLHttpMiddleware.cs b/src/GraphQL.AspNetCore3/GraphQLHttpMiddleware.cs
index cf79386..7695cc8 100644
--- a/src/GraphQL.AspNetCore3/GraphQLHttpMiddleware.cs
+++ b/src/GraphQL.AspNetCore3/GraphQLHttpMiddleware.cs
@@ -5,6 +5,7 @@
using System.Security.Claims;
using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Authorization;
+using static System.Net.Mime.MediaTypeNames;
namespace GraphQL.AspNetCore3;
@@ -125,6 +126,10 @@ public virtual async Task InvokeAsync(HttpContext context)
return;
}
+ // Perform CSRF protection if necessary
+ if (await HandleCsrfProtectionAsync(context, _next))
+ return;
+
// Authenticate request if necessary
if (await HandleAuthorizeAsync(context, _next))
return;
@@ -423,6 +428,32 @@ static void ApplyFileToRequest(IFormFile file, string target, GraphQLRequest? re
}
}
+ ///
+ /// Performs CSRF protection, if required, and returns if the
+ /// request was handled (typically by returning an error message). If
+ /// is returned, the request is processed normally.
+ ///
+ protected virtual async ValueTask HandleCsrfProtectionAsync(HttpContext context, RequestDelegate next)
+ {
+ if (!_options.CsrfProtectionEnabled)
+ return false;
+ if (context.Request.Headers.TryGetValue("Content-Type", out var contentTypes) && contentTypes.Count > 0 && contentTypes[0] != null) {
+ var contentType = contentTypes[0]!;
+ if (contentType.IndexOf(';') > 0) {
+ contentType = contentType.Substring(0, contentType.IndexOf(';'));
+ }
+ contentType = contentType.Trim().ToLowerInvariant();
+ if (!(contentType == "text/plain" || contentType == "application/x-www-form-urlencoded" || contentType == "multipart/form-data"))
+ return false;
+ }
+ foreach (var header in _options.CsrfProtectionHeaders) {
+ if (context.Request.Headers.TryGetValue(header, out var values) && values.Count > 0 && values[0]?.Length > 0)
+ return false;
+ }
+ await HandleCsrfProtectionErrorAsync(context, next);
+ return true;
+ }
+
///
/// Perform authentication, if required, and return if the
/// request was handled (typically by returning an error message). If
@@ -769,21 +800,29 @@ protected virtual Task HandleNotAuthorizedPolicyAsync(HttpContext context, Reque
///
protected virtual async ValueTask HandleDeserializationErrorAsync(HttpContext context, RequestDelegate next, Exception exception)
{
- await WriteErrorResponseAsync(context, HttpStatusCode.BadRequest, new JsonInvalidError(exception));
+ await WriteErrorResponseAsync(context, new JsonInvalidError(exception));
return true;
}
+ ///
+ /// Writes a '.' message to the output.
+ ///
+ protected virtual async Task HandleCsrfProtectionErrorAsync(HttpContext context, RequestDelegate next)
+ {
+ await WriteErrorResponseAsync(context, new CsrfProtectionError(_options.CsrfProtectionHeaders));
+ }
+
///
/// Writes a '400 Batched requests are not supported.' message to the output.
///
protected virtual Task HandleBatchedRequestsNotSupportedAsync(HttpContext context, RequestDelegate next)
- => WriteErrorResponseAsync(context, HttpStatusCode.BadRequest, new BatchedRequestsNotSupportedError());
+ => WriteErrorResponseAsync(context, new BatchedRequestsNotSupportedError());
///
/// Writes a '400 Invalid requested WebSocket sub-protocol(s).' message to the output.
///
protected virtual Task HandleWebSocketSubProtocolNotSupportedAsync(HttpContext context, RequestDelegate next)
- => WriteErrorResponseAsync(context, HttpStatusCode.BadRequest, new WebSocketSubProtocolNotSupportedError(context.WebSockets.WebSocketRequestedProtocols));
+ => WriteErrorResponseAsync(context, new WebSocketSubProtocolNotSupportedError(context.WebSockets.WebSocketRequestedProtocols));
///
/// Writes a '415 Invalid Content-Type header: could not be parsed.' message to the output.
@@ -814,6 +853,12 @@ protected virtual Task HandleInvalidHttpMethodErrorAsync(HttpContext context, Re
return next(context);
}
+ ///
+ /// Writes the specified error as a JSON-formatted GraphQL response.
+ ///
+ protected virtual Task WriteErrorResponseAsync(HttpContext context, ExecutionError executionError)
+ => WriteErrorResponseAsync(context, executionError is IHasPreferredStatusCode withCode ? withCode.PreferredStatusCode : HttpStatusCode.BadRequest, executionError);
+
///
/// Writes the specified error message as a JSON-formatted GraphQL response, with the specified HTTP status code.
///
diff --git a/src/GraphQL.AspNetCore3/GraphQLHttpMiddlewareOptions.cs b/src/GraphQL.AspNetCore3/GraphQLHttpMiddlewareOptions.cs
index ab07f17..4f90d05 100644
--- a/src/GraphQL.AspNetCore3/GraphQLHttpMiddlewareOptions.cs
+++ b/src/GraphQL.AspNetCore3/GraphQLHttpMiddlewareOptions.cs
@@ -61,6 +61,22 @@ public class GraphQLHttpMiddlewareOptions : IAuthorizationOptions
///
public bool ReadFormOnPost { get; set; } = true;
+ ///
+ /// Enables cross-site request forgery (CSRF) protection for both GET and POST requests.
+ /// Requires a non-empty header from the list to be
+ /// present, or a POST request with a Content-Type header that is not text/plain,
+ /// application/x-www-form-urlencoded, or multipart/form-data.
+ ///
+ public bool CsrfProtectionEnabled { get; set; }
+
+ ///
+ /// When is enabled, requests require a non-empty
+ /// header from this list or a POST request with a Content-Type header that is not
+ /// text/plain, application/x-www-form-urlencoded, or multipart/form-data.
+ /// Defaults to GraphQL-Require-Preflight.
+ ///
+ public List CsrfProtectionHeaders { get; set; } = new() { "GraphQL-Require-Preflight" }; // see https://github.com/graphql/graphql-over-http/pull/303
+
///
/// Enables reading variables from the query string.
/// Variables are interpreted as JSON and deserialized before being
diff --git a/src/Tests.ApiApprovals/GraphQL.AspNetCore3.approved.txt b/src/Tests.ApiApprovals/GraphQL.AspNetCore3.approved.txt
index d009bee..9db0348 100644
--- a/src/Tests.ApiApprovals/GraphQL.AspNetCore3.approved.txt
+++ b/src/Tests.ApiApprovals/GraphQL.AspNetCore3.approved.txt
@@ -132,6 +132,8 @@ namespace GraphQL.AspNetCore3
protected virtual System.Threading.Tasks.Task HandleBatchRequestAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next, System.Collections.Generic.IList gqlRequests) { }
protected virtual System.Threading.Tasks.Task HandleBatchedRequestsNotSupportedAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { }
protected virtual System.Threading.Tasks.Task HandleContentTypeCouldNotBeParsedErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { }
+ protected virtual System.Threading.Tasks.ValueTask HandleCsrfProtectionAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { }
+ protected virtual System.Threading.Tasks.Task HandleCsrfProtectionErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { }
protected virtual System.Threading.Tasks.ValueTask HandleDeserializationErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next, System.Exception exception) { }
protected virtual System.Threading.Tasks.Task HandleInvalidContentTypeErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { }
protected virtual System.Threading.Tasks.Task HandleInvalidHttpMethodErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { }
@@ -147,6 +149,7 @@ namespace GraphQL.AspNetCore3
"BatchRequest"})]
protected virtual System.Threading.Tasks.Task?>?> ReadPostContentAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next, string? mediaType, System.Text.Encoding? sourceEncoding) { }
protected virtual string SelectResponseContentType(Microsoft.AspNetCore.Http.HttpContext context) { }
+ protected virtual System.Threading.Tasks.Task WriteErrorResponseAsync(Microsoft.AspNetCore.Http.HttpContext context, GraphQL.ExecutionError executionError) { }
protected virtual System.Threading.Tasks.Task WriteErrorResponseAsync(Microsoft.AspNetCore.Http.HttpContext context, System.Net.HttpStatusCode httpStatusCode, GraphQL.ExecutionError executionError) { }
protected virtual System.Threading.Tasks.Task WriteErrorResponseAsync(Microsoft.AspNetCore.Http.HttpContext context, System.Net.HttpStatusCode httpStatusCode, string errorMessage) { }
protected virtual System.Threading.Tasks.Task WriteJsonResponseAsync(Microsoft.AspNetCore.Http.HttpContext context, System.Net.HttpStatusCode httpStatusCode, TResult result) { }
@@ -158,6 +161,8 @@ namespace GraphQL.AspNetCore3
public bool AuthorizationRequired { get; set; }
public string? AuthorizedPolicy { get; set; }
public System.Collections.Generic.List AuthorizedRoles { get; set; }
+ public bool CsrfProtectionEnabled { get; set; }
+ public System.Collections.Generic.List CsrfProtectionHeaders { get; set; }
public bool EnableBatchedRequests { get; set; }
public bool ExecuteBatchedRequestsInParallel { get; set; }
public bool HandleGet { get; set; }
@@ -224,6 +229,11 @@ namespace GraphQL.AspNetCore3.Errors
{
public BatchedRequestsNotSupportedError() { }
}
+ public class CsrfProtectionError : GraphQL.Execution.RequestError
+ {
+ public CsrfProtectionError(System.Collections.Generic.IEnumerable headersRequired) { }
+ public CsrfProtectionError(System.Collections.Generic.IEnumerable headersRequired, System.Exception innerException) { }
+ }
public class FileCountExceededError : GraphQL.Execution.RequestError, GraphQL.AspNetCore3.Errors.IHasPreferredStatusCode
{
public FileCountExceededError() { }