From 9f8548fd9dc143b4a870626e2c24304a34d61246 Mon Sep 17 00:00:00 2001 From: Shane Krueger Date: Mon, 19 Aug 2024 00:05:33 -0400 Subject: [PATCH] Change CSRF protection to on by default (#67) --- README.md | 25 +++++++ .../GraphQLHttpMiddlewareOptions.cs | 4 +- src/Tests/AuthorizationTests.cs | 2 +- src/Tests/Middleware/AuthorizationTests.cs | 4 +- src/Tests/Middleware/Cors/EndpointTests.cs | 2 +- src/Tests/Middleware/FileUploadTests.cs | 5 +- src/Tests/Middleware/GetTests.cs | 48 +++++++++++- src/Tests/Middleware/PostTests.cs | 74 +++++++++++++++---- src/Tests/TestServerExtensions.cs | 4 +- src/Tests/UserContextBuilderTests.cs | 4 +- 10 files changed, 148 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 3b9f202..84fd213 100644 --- a/README.md +++ b/README.md @@ -489,6 +489,31 @@ app.UseEndpoints(endpoints => { await app.RunAsync(); ``` +In order to ensure that all requests trigger CORS preflight requests, by default the server +will reject requests that do not meet one of the following criteria: + +- The request is a POST request that includes a Content-Type header that is not + `application/x-www-form-urlencoded`, `multipart/form-data`, or `text/plain`. +- The request includes a non-empty `GraphQL-Require-Preflight` header. + +To disable this behavior, set the `CsrfProtectionEnabled` option to `false` in the `GraphQLServerOptions`. + +```csharp +app.UseGraphQL("/graphql", config => +{ + config.CsrfProtectionEnabled = false; +}); +``` + +You may also change the allowed headers by modifying the `CsrfProtectionHeaders` option. + +```csharp +app.UseGraphQL("/graphql", config => +{ + config.CsrfProtectionHeaders = ["MyCustomHeader"]; +}); +``` + ### Response compression ASP.NET Core supports response compression independently of GraphQL, with brotli and gzip diff --git a/src/GraphQL.AspNetCore3/GraphQLHttpMiddlewareOptions.cs b/src/GraphQL.AspNetCore3/GraphQLHttpMiddlewareOptions.cs index 4f90d05..cce98ba 100644 --- a/src/GraphQL.AspNetCore3/GraphQLHttpMiddlewareOptions.cs +++ b/src/GraphQL.AspNetCore3/GraphQLHttpMiddlewareOptions.cs @@ -67,7 +67,7 @@ public class GraphQLHttpMiddlewareOptions : IAuthorizationOptions /// present, or a POST request with a Content-Type header that is not text/plain, /// application/x-www-form-urlencoded, or multipart/form-data. /// - public bool CsrfProtectionEnabled { get; set; } + public bool CsrfProtectionEnabled { get; set; } = true; /// /// When is enabled, requests require a non-empty @@ -75,7 +75,7 @@ public class GraphQLHttpMiddlewareOptions : IAuthorizationOptions /// text/plain, application/x-www-form-urlencoded, or multipart/form-data. /// Defaults to GraphQL-Require-Preflight. /// - public List CsrfProtectionHeaders { get; set; } = new() { "GraphQL-Require-Preflight" }; // see https://github.com/graphql/graphql-over-http/pull/303 + public List CsrfProtectionHeaders { get; set; } = ["GraphQL-Require-Preflight"]; // see https://github.com/graphql/graphql-over-http/pull/303 /// /// Enables reading variables from the query string. diff --git a/src/Tests/AuthorizationTests.cs b/src/Tests/AuthorizationTests.cs index c6cad3f..cac6b0d 100644 --- a/src/Tests/AuthorizationTests.cs +++ b/src/Tests/AuthorizationTests.cs @@ -739,7 +739,7 @@ public async Task EndToEnd(bool authenticated) context.User = _principal; return next(context); }); - app.UseGraphQL(); + app.UseGraphQL(configureMiddleware: c => c.CsrfProtectionEnabled = false); }); using var server = new TestServer(hostBuilder); diff --git a/src/Tests/Middleware/AuthorizationTests.cs b/src/Tests/Middleware/AuthorizationTests.cs index b4906e0..6cece03 100644 --- a/src/Tests/Middleware/AuthorizationTests.cs +++ b/src/Tests/Middleware/AuthorizationTests.cs @@ -98,7 +98,9 @@ public async Task NotAuthorized_Get() { _options.AuthorizationRequired = true; var client = _server.CreateClient(); - using var response = await client.GetAsync("/graphql?query={ __typename }"); + using var request = new HttpRequestMessage(HttpMethod.Get, "/graphql?query={ __typename }"); + request.Headers.Add("GraphQL-Require-Preflight", "true"); + using var response = await client.SendAsync(request); response.StatusCode.ShouldBe(HttpStatusCode.Unauthorized); var actual = await response.Content.ReadAsStringAsync(); actual.ShouldBe(@"{""errors"":[{""message"":""Access denied for schema."",""extensions"":{""code"":""ACCESS_DENIED"",""codes"":[""ACCESS_DENIED""]}}]}"); diff --git a/src/Tests/Middleware/Cors/EndpointTests.cs b/src/Tests/Middleware/Cors/EndpointTests.cs index e2aca2e..9f76683 100644 --- a/src/Tests/Middleware/Cors/EndpointTests.cs +++ b/src/Tests/Middleware/Cors/EndpointTests.cs @@ -82,7 +82,7 @@ public async Task NoCorsConfig(string httpMethod, string url) httpMethod == "POST" ? HttpMethod.Post : httpMethod == "OPTIONS" ? HttpMethod.Options : httpMethod == "GET" ? HttpMethod.Get : throw new ArgumentOutOfRangeException(nameof(httpMethod)), configureCors: _ => { }, configureCorsPolicy: _ => { }, - configureGraphQl: _ => { }, + configureGraphQl: o => o.CsrfProtectionEnabled = false, configureGraphQlEndpoint: _ => { }, configureHeaders: headers => { headers.Add("Origin", "http://www.example.com"); diff --git a/src/Tests/Middleware/FileUploadTests.cs b/src/Tests/Middleware/FileUploadTests.cs index ebafdda..51941c9 100644 --- a/src/Tests/Middleware/FileUploadTests.cs +++ b/src/Tests/Middleware/FileUploadTests.cs @@ -48,7 +48,10 @@ public async Task Basic(bool withOtherVariables) var fileContent = new ByteArrayContent(fileData); fileContent.Headers.ContentType = new("application/octet-stream"); content.Add(fileContent, "file", "filename.bin"); - using var response = await client.PostAsync("/graphql", content); + using var request = new HttpRequestMessage(HttpMethod.Post, "/graphql"); + request.Content = content; + request.Headers.Add("GraphQL-Require-Preflight", "true"); + using var response = await client.SendAsync(request); if (withOtherVariables) { await response.ShouldBeAsync(@"{""data"":{""convertToBase64"":""pre-filename.bin-YWJjZA==""}}"); } else { diff --git a/src/Tests/Middleware/GetTests.cs b/src/Tests/Middleware/GetTests.cs index 175d324..37ccbb8 100644 --- a/src/Tests/Middleware/GetTests.cs +++ b/src/Tests/Middleware/GetTests.cs @@ -28,9 +28,11 @@ public GetTests() hostBuilder.Configure(app => { app.UseWebSockets(); app.UseGraphQL("/graphql", opts => { + opts.CsrfProtectionEnabled = false; _options = opts; }); app.UseGraphQL("/graphql2", opts => { + opts.CsrfProtectionEnabled = false; _options2 = opts; }); }); @@ -61,6 +63,48 @@ public async Task BasicTest() await response.ShouldBeAsync(@"{""data"":{""count"":0}}"); } + [Theory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task CsrfBasicTests(bool requireCsrf, bool sendCsrf) + { + _options.CsrfProtectionEnabled = requireCsrf; + var client = _server.CreateClient(); + using var request = new HttpRequestMessage(HttpMethod.Get, "/graphql?query={count}"); + if (sendCsrf) + request.Headers.Add("GraphQL-Require-Preflight", "true"); + using var response = await client.SendAsync(request); + if (requireCsrf && !sendCsrf) + await response.ShouldBeAsync(true, """{"errors":[{"message":"This request requires a non-empty header from the following list: \u0027GraphQL-Require-Preflight\u0027.","extensions":{"code":"CSRF_PROTECTION","codes":["CSRF_PROTECTION"]}}]}"""); + else + await response.ShouldBeAsync("""{"data":{"count":0}}"""); + } + + [Theory] + [InlineData(null, null, false)] + [InlineData("Header1", "true", true)] + [InlineData("Header1", "", false)] + [InlineData("Header1", null, false)] + [InlineData("Header2", "true", true)] + [InlineData("Header3", "true", false)] + [InlineData("GraphQL-Require-Preflight", "true", false)] + public async Task CsrfCustomTests(string? header, string? value, bool success) + { + _options.CsrfProtectionEnabled = true; + _options.CsrfProtectionHeaders = ["Header1", "Header2"]; + var client = _server.CreateClient(); + using var request = new HttpRequestMessage(HttpMethod.Get, "/graphql?query={count}"); + if (header != null) + request.Headers.Add(header, value); + using var response = await client.SendAsync(request); + if (!success) + await response.ShouldBeAsync(true, """{"errors":[{"message":"This request requires a non-empty header from the following list: \u0027Header1\u0027, \u0027Header2\u0027.","extensions":{"code":"CSRF_PROTECTION","codes":["CSRF_PROTECTION"]}}]}"""); + else + await response.ShouldBeAsync("""{"data":{"count":0}}"""); + } + [Fact] public async Task NoUseWebSockets() { @@ -78,7 +122,9 @@ public async Task NoUseWebSockets() using var server = new TestServer(hostBuilder); var client = server.CreateClient(); - using var response = await client.GetAsync("/graphql?query={count}"); + using var request = new HttpRequestMessage(HttpMethod.Get, "/graphql?query={count}"); + request.Headers.Add("GraphQL-Require-Preflight", "true"); + using var response = await client.SendAsync(request); await response.ShouldBeAsync(@"{""data"":{""count"":0}}"); } diff --git a/src/Tests/Middleware/PostTests.cs b/src/Tests/Middleware/PostTests.cs index d19fa8f..a576f5f 100644 --- a/src/Tests/Middleware/PostTests.cs +++ b/src/Tests/Middleware/PostTests.cs @@ -143,9 +143,15 @@ public async Task AltCharset_Invalid() } #endif - [Fact] - public async Task FormMultipart_Legacy() + [Theory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task FormMultipart_Legacy(bool requireCsrf, bool supplyCsrf) { + if (!requireCsrf) + _options2.CsrfProtectionEnabled = false; var client = _server.CreateClient(); var content = new MultipartFormDataContent(); var queryContent = new StringContent(@"query op1{ext} query op2($test:String!){ext var(test:$test)}"); @@ -160,13 +166,25 @@ public async Task FormMultipart_Legacy() content.Add(variablesContent, "variables"); content.Add(extensionsContent, "extensions"); content.Add(operationNameContent, "operationName"); - using var response = await client.PostAsync("/graphql2", content); - await response.ShouldBeAsync(@"{""data"":{""ext"":""2"",""var"":""1""}}"); + using var request = new HttpRequestMessage(HttpMethod.Post, "/graphql2") { Content = content }; + if (supplyCsrf) + request.Headers.Add("GraphQL-Require-Preflight", "true"); + using var response = await client.SendAsync(request); + if (!requireCsrf || supplyCsrf) + await response.ShouldBeAsync("""{"data":{"ext":"2","var":"1"}}"""); + else + await response.ShouldBeAsync(true, """{"errors":[{"message":"This request requires a non-empty header from the following list: \u0027GraphQL-Require-Preflight\u0027.","extensions":{"code":"CSRF_PROTECTION","codes":["CSRF_PROTECTION"]}}]}"""); } - [Fact] - public async Task FormMultipart_Upload() + [Theory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task FormMultipart_Upload(bool requireCsrf, bool supplyCsrf) { + if (!requireCsrf) + _options2.CsrfProtectionEnabled = false; var client = _server.CreateClient(); using var content = new MultipartFormDataContent(); var jsonContent = new StringContent(""" @@ -178,8 +196,14 @@ public async Task FormMultipart_Upload() } """, Encoding.UTF8, "application/json"); content.Add(jsonContent, "operations"); - using var response = await client.PostAsync("/graphql2", content); - await response.ShouldBeAsync("""{"data":{"ext":"2","var":"1"}}"""); + using var request = new HttpRequestMessage(HttpMethod.Post, "/graphql2") { Content = content }; + if (supplyCsrf) + request.Headers.Add("GraphQL-Require-Preflight", "true"); + using var response = await client.SendAsync(request); + if (!requireCsrf || supplyCsrf) + await response.ShouldBeAsync("""{"data":{"ext":"2","var":"1"}}"""); + else + await response.ShouldBeAsync(true, """{"errors":[{"message":"This request requires a non-empty header from the following list: \u0027GraphQL-Require-Preflight\u0027.","extensions":{"code":"CSRF_PROTECTION","codes":["CSRF_PROTECTION"]}}]}"""); } // successful queries @@ -341,7 +365,9 @@ public async Task FormMultipart_Upload_Matrix(int testIndex, string? operations, content.Add(new StringContent("test1", Encoding.UTF8, "text/text"), "file0", "example1.txt"); if (file1) content.Add(new StringContent("test2", Encoding.UTF8, "text/html"), "file1", "example2.html"); - using var response = await client.PostAsync("/graphql2", content); + using var request = new HttpRequestMessage(HttpMethod.Post, "/graphql2") { Content = content }; + request.Headers.Add("GraphQL-Require-Preflight", "true"); + using var response = await client.SendAsync(request); await response.ShouldBeAsync((HttpStatusCode)expectedStatusCode, expectedResponse); } @@ -362,13 +388,21 @@ public async Task FormMultipart_Upload_Validation(int? maxFileCount, int? maxFil { new StringContent("test1", Encoding.UTF8, "text/text"), "file0", "example1.txt" }, { new StringContent("test2", Encoding.UTF8, "text/html"), "file1", "example2.html" } }; - using var response = await client.PostAsync("/graphql2", content); + using var request = new HttpRequestMessage(HttpMethod.Post, "/graphql2") { Content = content }; + request.Headers.Add("GraphQL-Require-Preflight", "true"); + using var response = await client.SendAsync(request); await response.ShouldBeAsync(expectedStatusCode, expectedResponse); } - [Fact] - public async Task FormUrlEncoded() + [Theory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task FormUrlEncoded(bool requireCsrf, bool supplyCsrf) { + if (!requireCsrf) + _options2.CsrfProtectionEnabled = false; var client = _server.CreateClient(); var content = new FormUrlEncodedContent(new[] { new KeyValuePair("query", @"query op1{ext} query op2($test:String!){ext var(test:$test)}"), @@ -376,8 +410,14 @@ public async Task FormUrlEncoded() new KeyValuePair("extensions", @"{""test"":""2""}"), new KeyValuePair("operationName", @"op2"), }); - using var response = await client.PostAsync("/graphql2", content); - await response.ShouldBeAsync(@"{""data"":{""ext"":""2"",""var"":""1""}}"); + using var request = new HttpRequestMessage(HttpMethod.Post, "/graphql2") { Content = content }; + if (supplyCsrf) + request.Headers.Add("GraphQL-Require-Preflight", "true"); + using var response = await client.SendAsync(request); + if (requireCsrf && !supplyCsrf) + await response.ShouldBeAsync(true, """{"errors":[{"message":"This request requires a non-empty header from the following list: \u0027GraphQL-Require-Preflight\u0027.","extensions":{"code":"CSRF_PROTECTION","codes":["CSRF_PROTECTION"]}}]}"""); + else + await response.ShouldBeAsync("""{"data":{"ext":"2","var":"1"}}"""); } [Theory] @@ -391,7 +431,9 @@ public async Task FormUrlEncoded_DeserializationError(bool badRequest) new KeyValuePair("query", @"{ext}"), new KeyValuePair("variables", @"{"), }); - using var response = await client.PostAsync("/graphql2", content); + using var request = new HttpRequestMessage(HttpMethod.Post, "/graphql2") { Content = content }; + request.Headers.Add("GraphQL-Require-Preflight", "true"); + using var response = await client.SendAsync(request); // always returns BadRequest here await response.ShouldBeAsync(true, @"{""errors"":[{""message"":""JSON body text could not be parsed. Expected depth to be zero at the end of the JSON payload. There is an open JSON object or array that should be closed. Path: $ | LineNumber: 0 | BytePositionInLine: 1."",""extensions"":{""code"":""JSON_INVALID"",""codes"":[""JSON_INVALID""]}}]}"); } @@ -431,6 +473,7 @@ public async Task ContentType_GraphQLJson(string contentType) [InlineData(true, false, "application/x-www-form-urlencoded")] public async Task UnknownContentType(bool badRequest, bool allowFormBody, string contentType) { + _options.CsrfProtectionEnabled = false; _options.ValidationErrorsReturnBadRequest = badRequest; _options.ReadFormOnPost = allowFormBody; var client = _server.CreateClient(); @@ -459,6 +502,7 @@ public async Task CannotParseContentType(bool badRequest) var client = _server.CreateClient(); var content = new StringContent(""); content.Headers.ContentType = null; + content.Headers.Add("GraphQL-Require-Preflight", "true"); var response = await client.PostAsync("/graphql2", content); // always returns unsupported media type response.StatusCode.ShouldBe(HttpStatusCode.UnsupportedMediaType); diff --git a/src/Tests/TestServerExtensions.cs b/src/Tests/TestServerExtensions.cs index f767d48..16eb8c8 100644 --- a/src/Tests/TestServerExtensions.cs +++ b/src/Tests/TestServerExtensions.cs @@ -5,7 +5,9 @@ internal static class TestServerExtensions public static async Task ExecuteGet(this TestServer server, string url) { var client = server.CreateClient(); - using var response = await client.GetAsync(url); + using var request = new HttpRequestMessage(HttpMethod.Get, url); + request.Headers.Add("GraphQL-Require-Preflight", "true"); + using var response = await client.SendAsync(request); response.EnsureSuccessStatusCode(); var str = await response.Content.ReadAsStringAsync(); return str; diff --git a/src/Tests/UserContextBuilderTests.cs b/src/Tests/UserContextBuilderTests.cs index 33c3200..756d910 100644 --- a/src/Tests/UserContextBuilderTests.cs +++ b/src/Tests/UserContextBuilderTests.cs @@ -100,7 +100,9 @@ public async Task Async_Payload_Works() private async Task Test(string name) { - using var response = await _client.GetAsync("/graphql?query={test}"); + using var request = new HttpRequestMessage(HttpMethod.Get, "/graphql?query={test}"); + request.Headers.Add("GraphQL-Require-Preflight", "true"); + using var response = await _client.SendAsync(request); response.EnsureSuccessStatusCode(); var actual = await response.Content.ReadAsStringAsync(); actual.ShouldBe(@"{""data"":{""test"":""" + name + @"""}}");