Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CSRF protection #66

Merged
merged 4 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,8 @@ endpoint.
| `AuthorizationRequired` | Requires `HttpContext.User` to represent an authenticated user. | False |
| `AuthorizedPolicy` | If set, requires `HttpContext.User` to pass authorization of the specified policy. | |
| `AuthorizedRoles` | If set, requires `HttpContext.User` to be a member of any one of a list of roles. | |
| `CsrfProtectionEnabled` | Enables cross-site request forgery (CSRF) protection for both GET and POST requests. | True |
| `CsrfProtectionHeaders` | Sets the headers used for CSRF protection when necessary. | `GraphQL-Require-Preflight` |
| `EnableBatchedRequests` | Enables handling of batched GraphQL requests for POST requests when formatted as JSON. | True |
| `ExecuteBatchedRequestsInParallel` | Enables parallel execution of batched GraphQL requests. | True |
| `HandleGet` | Enables handling of GET requests. | True |
Expand Down
16 changes: 16 additions & 0 deletions src/GraphQL.AspNetCore3/Errors/CsrfProtectionError.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
namespace GraphQL.AspNetCore3.Errors;

/// <summary>
/// Represents an error indicating that the request may not have triggered a CORS preflight request.
/// </summary>
public class CsrfProtectionError : RequestError
{
/// <inheritdoc cref="CsrfProtectionError"/>
public CsrfProtectionError(IEnumerable<string> headersRequired) : base($"This request requires a non-empty header from the following list: {FormatHeaders(headersRequired)}.") { }

/// <inheritdoc cref="CsrfProtectionError"/>
public CsrfProtectionError(IEnumerable<string> headersRequired, Exception innerException) : base($"This request requires a non-empty header from the following list: {FormatHeaders(headersRequired)}. {innerException.Message}") { }

private static string FormatHeaders(IEnumerable<string> headersRequired)
=> string.Join(", ", headersRequired.Select(x => $"'{x}'"));
}
51 changes: 48 additions & 3 deletions src/GraphQL.AspNetCore3/GraphQLHttpMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Security.Claims;
using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Authorization;
using static System.Net.Mime.MediaTypeNames;

namespace GraphQL.AspNetCore3;

Expand Down Expand Up @@ -125,6 +126,10 @@ public virtual async Task InvokeAsync(HttpContext context)
return;
}

// Perform CSRF protection if necessary
if (await HandleCsrfProtectionAsync(context, _next))
return;

// Authenticate request if necessary
if (await HandleAuthorizeAsync(context, _next))
return;
Expand Down Expand Up @@ -423,6 +428,32 @@ static void ApplyFileToRequest(IFormFile file, string target, GraphQLRequest? re
}
}

/// <summary>
/// Performs CSRF protection, if required, and returns <see langword="true"/> if the
/// request was handled (typically by returning an error message). If <see langword="false"/>
/// is returned, the request is processed normally.
/// </summary>
protected virtual async ValueTask<bool> HandleCsrfProtectionAsync(HttpContext context, RequestDelegate next)
{
if (!_options.CsrfProtectionEnabled)
return false;
if (context.Request.Headers.TryGetValue("Content-Type", out var contentTypes) && contentTypes.Count > 0 && contentTypes[0] != null) {
var contentType = contentTypes[0]!;
if (contentType.IndexOf(';') > 0) {
contentType = contentType.Substring(0, contentType.IndexOf(';'));
}
contentType = contentType.Trim().ToLowerInvariant();
if (!(contentType == "text/plain" || contentType == "application/x-www-form-urlencoded" || contentType == "multipart/form-data"))
return false;
}
foreach (var header in _options.CsrfProtectionHeaders) {
if (context.Request.Headers.TryGetValue(header, out var values) && values.Count > 0 && values[0]?.Length > 0)
return false;
}
await HandleCsrfProtectionErrorAsync(context, next);
return true;
}

/// <summary>
/// Perform authentication, if required, and return <see langword="true"/> if the
/// request was handled (typically by returning an error message). If <see langword="false"/>
Expand Down Expand Up @@ -769,21 +800,29 @@ protected virtual Task HandleNotAuthorizedPolicyAsync(HttpContext context, Reque
/// </summary>
protected virtual async ValueTask<bool> HandleDeserializationErrorAsync(HttpContext context, RequestDelegate next, Exception exception)
{
await WriteErrorResponseAsync(context, HttpStatusCode.BadRequest, new JsonInvalidError(exception));
await WriteErrorResponseAsync(context, new JsonInvalidError(exception));
return true;
}

/// <summary>
/// Writes a '.' message to the output.
/// </summary>
protected virtual async Task HandleCsrfProtectionErrorAsync(HttpContext context, RequestDelegate next)
{
await WriteErrorResponseAsync(context, new CsrfProtectionError(_options.CsrfProtectionHeaders));
}

/// <summary>
/// Writes a '400 Batched requests are not supported.' message to the output.
/// </summary>
protected virtual Task HandleBatchedRequestsNotSupportedAsync(HttpContext context, RequestDelegate next)
=> WriteErrorResponseAsync(context, HttpStatusCode.BadRequest, new BatchedRequestsNotSupportedError());
=> WriteErrorResponseAsync(context, new BatchedRequestsNotSupportedError());

/// <summary>
/// Writes a '400 Invalid requested WebSocket sub-protocol(s).' message to the output.
/// </summary>
protected virtual Task HandleWebSocketSubProtocolNotSupportedAsync(HttpContext context, RequestDelegate next)
=> WriteErrorResponseAsync(context, HttpStatusCode.BadRequest, new WebSocketSubProtocolNotSupportedError(context.WebSockets.WebSocketRequestedProtocols));
=> WriteErrorResponseAsync(context, new WebSocketSubProtocolNotSupportedError(context.WebSockets.WebSocketRequestedProtocols));

/// <summary>
/// Writes a '415 Invalid Content-Type header: could not be parsed.' message to the output.
Expand Down Expand Up @@ -814,6 +853,12 @@ protected virtual Task HandleInvalidHttpMethodErrorAsync(HttpContext context, Re
return next(context);
}

/// <summary>
/// Writes the specified error as a JSON-formatted GraphQL response.
/// </summary>
protected virtual Task WriteErrorResponseAsync(HttpContext context, ExecutionError executionError)
=> WriteErrorResponseAsync(context, executionError is IHasPreferredStatusCode withCode ? withCode.PreferredStatusCode : HttpStatusCode.BadRequest, executionError);

/// <summary>
/// Writes the specified error message as a JSON-formatted GraphQL response, with the specified HTTP status code.
/// </summary>
Expand Down
16 changes: 16 additions & 0 deletions src/GraphQL.AspNetCore3/GraphQLHttpMiddlewareOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,22 @@ public class GraphQLHttpMiddlewareOptions : IAuthorizationOptions
/// </remarks>
public bool ReadFormOnPost { get; set; } = true;

/// <summary>
/// Enables cross-site request forgery (CSRF) protection for both GET and POST requests.
/// Requires a non-empty header from the <see cref="CsrfProtectionHeaders"/> list to be
/// present, or a POST request with a Content-Type header that is not <c>text/plain</c>,
/// <c>application/x-www-form-urlencoded</c>, or <c>multipart/form-data</c>.
/// </summary>
public bool CsrfProtectionEnabled { get; set; }

/// <summary>
/// When <see cref="CsrfProtectionEnabled"/> is enabled, requests require a non-empty
/// header from this list or a POST request with a Content-Type header that is not
/// <c>text/plain</c>, <c>application/x-www-form-urlencoded</c>, or <c>multipart/form-data</c>.
/// Defaults to <c>GraphQL-Require-Preflight</c>.
/// </summary>
public List<string> CsrfProtectionHeaders { get; set; } = new() { "GraphQL-Require-Preflight" }; // see https://github.com/graphql/graphql-over-http/pull/303

/// <summary>
/// Enables reading variables from the query string.
/// Variables are interpreted as JSON and deserialized before being
Expand Down
10 changes: 10 additions & 0 deletions src/Tests.ApiApprovals/GraphQL.AspNetCore3.approved.txt
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ namespace GraphQL.AspNetCore3
protected virtual System.Threading.Tasks.Task HandleBatchRequestAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next, System.Collections.Generic.IList<GraphQL.Transport.GraphQLRequest?> gqlRequests) { }
protected virtual System.Threading.Tasks.Task HandleBatchedRequestsNotSupportedAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { }
protected virtual System.Threading.Tasks.Task HandleContentTypeCouldNotBeParsedErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { }
protected virtual System.Threading.Tasks.ValueTask<bool> HandleCsrfProtectionAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { }
protected virtual System.Threading.Tasks.Task HandleCsrfProtectionErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { }
protected virtual System.Threading.Tasks.ValueTask<bool> HandleDeserializationErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next, System.Exception exception) { }
protected virtual System.Threading.Tasks.Task HandleInvalidContentTypeErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { }
protected virtual System.Threading.Tasks.Task HandleInvalidHttpMethodErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { }
Expand All @@ -147,6 +149,7 @@ namespace GraphQL.AspNetCore3
"BatchRequest"})]
protected virtual System.Threading.Tasks.Task<System.ValueTuple<GraphQL.Transport.GraphQLRequest?, System.Collections.Generic.IList<GraphQL.Transport.GraphQLRequest?>?>?> ReadPostContentAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next, string? mediaType, System.Text.Encoding? sourceEncoding) { }
protected virtual string SelectResponseContentType(Microsoft.AspNetCore.Http.HttpContext context) { }
protected virtual System.Threading.Tasks.Task WriteErrorResponseAsync(Microsoft.AspNetCore.Http.HttpContext context, GraphQL.ExecutionError executionError) { }
protected virtual System.Threading.Tasks.Task WriteErrorResponseAsync(Microsoft.AspNetCore.Http.HttpContext context, System.Net.HttpStatusCode httpStatusCode, GraphQL.ExecutionError executionError) { }
protected virtual System.Threading.Tasks.Task WriteErrorResponseAsync(Microsoft.AspNetCore.Http.HttpContext context, System.Net.HttpStatusCode httpStatusCode, string errorMessage) { }
protected virtual System.Threading.Tasks.Task WriteJsonResponseAsync<TResult>(Microsoft.AspNetCore.Http.HttpContext context, System.Net.HttpStatusCode httpStatusCode, TResult result) { }
Expand All @@ -158,6 +161,8 @@ namespace GraphQL.AspNetCore3
public bool AuthorizationRequired { get; set; }
public string? AuthorizedPolicy { get; set; }
public System.Collections.Generic.List<string> AuthorizedRoles { get; set; }
public bool CsrfProtectionEnabled { get; set; }
public System.Collections.Generic.List<string> CsrfProtectionHeaders { get; set; }
public bool EnableBatchedRequests { get; set; }
public bool ExecuteBatchedRequestsInParallel { get; set; }
public bool HandleGet { get; set; }
Expand Down Expand Up @@ -224,6 +229,11 @@ namespace GraphQL.AspNetCore3.Errors
{
public BatchedRequestsNotSupportedError() { }
}
public class CsrfProtectionError : GraphQL.Execution.RequestError
{
public CsrfProtectionError(System.Collections.Generic.IEnumerable<string> headersRequired) { }
public CsrfProtectionError(System.Collections.Generic.IEnumerable<string> headersRequired, System.Exception innerException) { }
}
public class FileCountExceededError : GraphQL.Execution.RequestError, GraphQL.AspNetCore3.Errors.IHasPreferredStatusCode
{
public FileCountExceededError() { }
Expand Down
Loading