Skip to content

Commit

Permalink
Change CSRF protection to on by default (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
Shane32 authored Aug 19, 2024
1 parent c429beb commit 9f8548f
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 24 deletions.
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,31 @@ app.UseEndpoints(endpoints => {
await app.RunAsync();
```

In order to ensure that all requests trigger CORS preflight requests, by default the server
will reject requests that do not meet one of the following criteria:

- The request is a POST request that includes a Content-Type header that is not
`application/x-www-form-urlencoded`, `multipart/form-data`, or `text/plain`.
- The request includes a non-empty `GraphQL-Require-Preflight` header.

To disable this behavior, set the `CsrfProtectionEnabled` option to `false` in the `GraphQLServerOptions`.

```csharp
app.UseGraphQL("/graphql", config =>
{
config.CsrfProtectionEnabled = false;
});
```

You may also change the allowed headers by modifying the `CsrfProtectionHeaders` option.

```csharp
app.UseGraphQL("/graphql", config =>
{
config.CsrfProtectionHeaders = ["MyCustomHeader"];
});
```

### Response compression

ASP.NET Core supports response compression independently of GraphQL, with brotli and gzip
Expand Down
4 changes: 2 additions & 2 deletions src/GraphQL.AspNetCore3/GraphQLHttpMiddlewareOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,15 @@ public class GraphQLHttpMiddlewareOptions : IAuthorizationOptions
/// present, or a POST request with a Content-Type header that is not <c>text/plain</c>,
/// <c>application/x-www-form-urlencoded</c>, or <c>multipart/form-data</c>.
/// </summary>
public bool CsrfProtectionEnabled { get; set; }
public bool CsrfProtectionEnabled { get; set; } = true;

/// <summary>
/// When <see cref="CsrfProtectionEnabled"/> is enabled, requests require a non-empty
/// header from this list or a POST request with a Content-Type header that is not
/// <c>text/plain</c>, <c>application/x-www-form-urlencoded</c>, or <c>multipart/form-data</c>.
/// Defaults to <c>GraphQL-Require-Preflight</c>.
/// </summary>
public List<string> CsrfProtectionHeaders { get; set; } = new() { "GraphQL-Require-Preflight" }; // see https://github.com/graphql/graphql-over-http/pull/303
public List<string> CsrfProtectionHeaders { get; set; } = ["GraphQL-Require-Preflight"]; // see https://github.com/graphql/graphql-over-http/pull/303

/// <summary>
/// Enables reading variables from the query string.
Expand Down
2 changes: 1 addition & 1 deletion src/Tests/AuthorizationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ public async Task EndToEnd(bool authenticated)
context.User = _principal;
return next(context);
});
app.UseGraphQL();
app.UseGraphQL(configureMiddleware: c => c.CsrfProtectionEnabled = false);
});
using var server = new TestServer(hostBuilder);

Expand Down
4 changes: 3 additions & 1 deletion src/Tests/Middleware/AuthorizationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ public async Task NotAuthorized_Get()
{
_options.AuthorizationRequired = true;
var client = _server.CreateClient();
using var response = await client.GetAsync("/graphql?query={ __typename }");
using var request = new HttpRequestMessage(HttpMethod.Get, "/graphql?query={ __typename }");
request.Headers.Add("GraphQL-Require-Preflight", "true");
using var response = await client.SendAsync(request);
response.StatusCode.ShouldBe(HttpStatusCode.Unauthorized);
var actual = await response.Content.ReadAsStringAsync();
actual.ShouldBe(@"{""errors"":[{""message"":""Access denied for schema."",""extensions"":{""code"":""ACCESS_DENIED"",""codes"":[""ACCESS_DENIED""]}}]}");
Expand Down
2 changes: 1 addition & 1 deletion src/Tests/Middleware/Cors/EndpointTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public async Task NoCorsConfig(string httpMethod, string url)
httpMethod == "POST" ? HttpMethod.Post : httpMethod == "OPTIONS" ? HttpMethod.Options : httpMethod == "GET" ? HttpMethod.Get : throw new ArgumentOutOfRangeException(nameof(httpMethod)),
configureCors: _ => { },
configureCorsPolicy: _ => { },
configureGraphQl: _ => { },
configureGraphQl: o => o.CsrfProtectionEnabled = false,
configureGraphQlEndpoint: _ => { },
configureHeaders: headers => {
headers.Add("Origin", "http://www.example.com");
Expand Down
5 changes: 4 additions & 1 deletion src/Tests/Middleware/FileUploadTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ public async Task Basic(bool withOtherVariables)
var fileContent = new ByteArrayContent(fileData);
fileContent.Headers.ContentType = new("application/octet-stream");
content.Add(fileContent, "file", "filename.bin");
using var response = await client.PostAsync("/graphql", content);
using var request = new HttpRequestMessage(HttpMethod.Post, "/graphql");
request.Content = content;
request.Headers.Add("GraphQL-Require-Preflight", "true");
using var response = await client.SendAsync(request);
if (withOtherVariables) {
await response.ShouldBeAsync(@"{""data"":{""convertToBase64"":""pre-filename.bin-YWJjZA==""}}");
} else {
Expand Down
48 changes: 47 additions & 1 deletion src/Tests/Middleware/GetTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ public GetTests()
hostBuilder.Configure(app => {
app.UseWebSockets();
app.UseGraphQL("/graphql", opts => {
opts.CsrfProtectionEnabled = false;
_options = opts;
});
app.UseGraphQL<Schema2>("/graphql2", opts => {
opts.CsrfProtectionEnabled = false;
_options2 = opts;
});
});
Expand Down Expand Up @@ -61,6 +63,48 @@ public async Task BasicTest()
await response.ShouldBeAsync(@"{""data"":{""count"":0}}");
}

[Theory]
[InlineData(true, true)]
[InlineData(true, false)]
[InlineData(false, true)]
[InlineData(false, false)]
public async Task CsrfBasicTests(bool requireCsrf, bool sendCsrf)
{
_options.CsrfProtectionEnabled = requireCsrf;
var client = _server.CreateClient();
using var request = new HttpRequestMessage(HttpMethod.Get, "/graphql?query={count}");
if (sendCsrf)
request.Headers.Add("GraphQL-Require-Preflight", "true");
using var response = await client.SendAsync(request);
if (requireCsrf && !sendCsrf)
await response.ShouldBeAsync(true, """{"errors":[{"message":"This request requires a non-empty header from the following list: \u0027GraphQL-Require-Preflight\u0027.","extensions":{"code":"CSRF_PROTECTION","codes":["CSRF_PROTECTION"]}}]}""");
else
await response.ShouldBeAsync("""{"data":{"count":0}}""");
}

[Theory]
[InlineData(null, null, false)]
[InlineData("Header1", "true", true)]
[InlineData("Header1", "", false)]
[InlineData("Header1", null, false)]
[InlineData("Header2", "true", true)]
[InlineData("Header3", "true", false)]
[InlineData("GraphQL-Require-Preflight", "true", false)]
public async Task CsrfCustomTests(string? header, string? value, bool success)
{
_options.CsrfProtectionEnabled = true;
_options.CsrfProtectionHeaders = ["Header1", "Header2"];
var client = _server.CreateClient();
using var request = new HttpRequestMessage(HttpMethod.Get, "/graphql?query={count}");
if (header != null)
request.Headers.Add(header, value);
using var response = await client.SendAsync(request);
if (!success)
await response.ShouldBeAsync(true, """{"errors":[{"message":"This request requires a non-empty header from the following list: \u0027Header1\u0027, \u0027Header2\u0027.","extensions":{"code":"CSRF_PROTECTION","codes":["CSRF_PROTECTION"]}}]}""");
else
await response.ShouldBeAsync("""{"data":{"count":0}}""");
}

[Fact]
public async Task NoUseWebSockets()
{
Expand All @@ -78,7 +122,9 @@ public async Task NoUseWebSockets()
using var server = new TestServer(hostBuilder);

var client = server.CreateClient();
using var response = await client.GetAsync("/graphql?query={count}");
using var request = new HttpRequestMessage(HttpMethod.Get, "/graphql?query={count}");
request.Headers.Add("GraphQL-Require-Preflight", "true");
using var response = await client.SendAsync(request);
await response.ShouldBeAsync(@"{""data"":{""count"":0}}");
}

Expand Down
74 changes: 59 additions & 15 deletions src/Tests/Middleware/PostTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,15 @@ public async Task AltCharset_Invalid()
}
#endif

[Fact]
public async Task FormMultipart_Legacy()
[Theory]
[InlineData(true, true)]
[InlineData(true, false)]
[InlineData(false, true)]
[InlineData(false, false)]
public async Task FormMultipart_Legacy(bool requireCsrf, bool supplyCsrf)
{
if (!requireCsrf)
_options2.CsrfProtectionEnabled = false;
var client = _server.CreateClient();
var content = new MultipartFormDataContent();
var queryContent = new StringContent(@"query op1{ext} query op2($test:String!){ext var(test:$test)}");
Expand All @@ -160,13 +166,25 @@ public async Task FormMultipart_Legacy()
content.Add(variablesContent, "variables");
content.Add(extensionsContent, "extensions");
content.Add(operationNameContent, "operationName");
using var response = await client.PostAsync("/graphql2", content);
await response.ShouldBeAsync(@"{""data"":{""ext"":""2"",""var"":""1""}}");
using var request = new HttpRequestMessage(HttpMethod.Post, "/graphql2") { Content = content };
if (supplyCsrf)
request.Headers.Add("GraphQL-Require-Preflight", "true");
using var response = await client.SendAsync(request);
if (!requireCsrf || supplyCsrf)
await response.ShouldBeAsync("""{"data":{"ext":"2","var":"1"}}""");
else
await response.ShouldBeAsync(true, """{"errors":[{"message":"This request requires a non-empty header from the following list: \u0027GraphQL-Require-Preflight\u0027.","extensions":{"code":"CSRF_PROTECTION","codes":["CSRF_PROTECTION"]}}]}""");
}

[Fact]
public async Task FormMultipart_Upload()
[Theory]
[InlineData(true, true)]
[InlineData(true, false)]
[InlineData(false, true)]
[InlineData(false, false)]
public async Task FormMultipart_Upload(bool requireCsrf, bool supplyCsrf)
{
if (!requireCsrf)
_options2.CsrfProtectionEnabled = false;
var client = _server.CreateClient();
using var content = new MultipartFormDataContent();
var jsonContent = new StringContent("""
Expand All @@ -178,8 +196,14 @@ public async Task FormMultipart_Upload()
}
""", Encoding.UTF8, "application/json");
content.Add(jsonContent, "operations");
using var response = await client.PostAsync("/graphql2", content);
await response.ShouldBeAsync("""{"data":{"ext":"2","var":"1"}}""");
using var request = new HttpRequestMessage(HttpMethod.Post, "/graphql2") { Content = content };
if (supplyCsrf)
request.Headers.Add("GraphQL-Require-Preflight", "true");
using var response = await client.SendAsync(request);
if (!requireCsrf || supplyCsrf)
await response.ShouldBeAsync("""{"data":{"ext":"2","var":"1"}}""");
else
await response.ShouldBeAsync(true, """{"errors":[{"message":"This request requires a non-empty header from the following list: \u0027GraphQL-Require-Preflight\u0027.","extensions":{"code":"CSRF_PROTECTION","codes":["CSRF_PROTECTION"]}}]}""");
}

// successful queries
Expand Down Expand Up @@ -341,7 +365,9 @@ public async Task FormMultipart_Upload_Matrix(int testIndex, string? operations,
content.Add(new StringContent("test1", Encoding.UTF8, "text/text"), "file0", "example1.txt");
if (file1)
content.Add(new StringContent("test2", Encoding.UTF8, "text/html"), "file1", "example2.html");
using var response = await client.PostAsync("/graphql2", content);
using var request = new HttpRequestMessage(HttpMethod.Post, "/graphql2") { Content = content };
request.Headers.Add("GraphQL-Require-Preflight", "true");
using var response = await client.SendAsync(request);
await response.ShouldBeAsync((HttpStatusCode)expectedStatusCode, expectedResponse);
}

Expand All @@ -362,22 +388,36 @@ public async Task FormMultipart_Upload_Validation(int? maxFileCount, int? maxFil
{ new StringContent("test1", Encoding.UTF8, "text/text"), "file0", "example1.txt" },
{ new StringContent("test2", Encoding.UTF8, "text/html"), "file1", "example2.html" }
};
using var response = await client.PostAsync("/graphql2", content);
using var request = new HttpRequestMessage(HttpMethod.Post, "/graphql2") { Content = content };
request.Headers.Add("GraphQL-Require-Preflight", "true");
using var response = await client.SendAsync(request);
await response.ShouldBeAsync(expectedStatusCode, expectedResponse);
}

[Fact]
public async Task FormUrlEncoded()
[Theory]
[InlineData(true, true)]
[InlineData(true, false)]
[InlineData(false, true)]
[InlineData(false, false)]
public async Task FormUrlEncoded(bool requireCsrf, bool supplyCsrf)
{
if (!requireCsrf)
_options2.CsrfProtectionEnabled = false;
var client = _server.CreateClient();
var content = new FormUrlEncodedContent(new[] {
new KeyValuePair<string?, string?>("query", @"query op1{ext} query op2($test:String!){ext var(test:$test)}"),
new KeyValuePair<string?, string?>("variables", @"{""test"":""1""}"),
new KeyValuePair<string?, string?>("extensions", @"{""test"":""2""}"),
new KeyValuePair<string?, string?>("operationName", @"op2"),
});
using var response = await client.PostAsync("/graphql2", content);
await response.ShouldBeAsync(@"{""data"":{""ext"":""2"",""var"":""1""}}");
using var request = new HttpRequestMessage(HttpMethod.Post, "/graphql2") { Content = content };
if (supplyCsrf)
request.Headers.Add("GraphQL-Require-Preflight", "true");
using var response = await client.SendAsync(request);
if (requireCsrf && !supplyCsrf)
await response.ShouldBeAsync(true, """{"errors":[{"message":"This request requires a non-empty header from the following list: \u0027GraphQL-Require-Preflight\u0027.","extensions":{"code":"CSRF_PROTECTION","codes":["CSRF_PROTECTION"]}}]}""");
else
await response.ShouldBeAsync("""{"data":{"ext":"2","var":"1"}}""");
}

[Theory]
Expand All @@ -391,7 +431,9 @@ public async Task FormUrlEncoded_DeserializationError(bool badRequest)
new KeyValuePair<string?, string?>("query", @"{ext}"),
new KeyValuePair<string?, string?>("variables", @"{"),
});
using var response = await client.PostAsync("/graphql2", content);
using var request = new HttpRequestMessage(HttpMethod.Post, "/graphql2") { Content = content };
request.Headers.Add("GraphQL-Require-Preflight", "true");
using var response = await client.SendAsync(request);
// always returns BadRequest here
await response.ShouldBeAsync(true, @"{""errors"":[{""message"":""JSON body text could not be parsed. Expected depth to be zero at the end of the JSON payload. There is an open JSON object or array that should be closed. Path: $ | LineNumber: 0 | BytePositionInLine: 1."",""extensions"":{""code"":""JSON_INVALID"",""codes"":[""JSON_INVALID""]}}]}");
}
Expand Down Expand Up @@ -431,6 +473,7 @@ public async Task ContentType_GraphQLJson(string contentType)
[InlineData(true, false, "application/x-www-form-urlencoded")]
public async Task UnknownContentType(bool badRequest, bool allowFormBody, string contentType)
{
_options.CsrfProtectionEnabled = false;
_options.ValidationErrorsReturnBadRequest = badRequest;
_options.ReadFormOnPost = allowFormBody;
var client = _server.CreateClient();
Expand Down Expand Up @@ -459,6 +502,7 @@ public async Task CannotParseContentType(bool badRequest)
var client = _server.CreateClient();
var content = new StringContent("");
content.Headers.ContentType = null;
content.Headers.Add("GraphQL-Require-Preflight", "true");
var response = await client.PostAsync("/graphql2", content);
// always returns unsupported media type
response.StatusCode.ShouldBe(HttpStatusCode.UnsupportedMediaType);
Expand Down
4 changes: 3 additions & 1 deletion src/Tests/TestServerExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ internal static class TestServerExtensions
public static async Task<string> ExecuteGet(this TestServer server, string url)
{
var client = server.CreateClient();
using var response = await client.GetAsync(url);
using var request = new HttpRequestMessage(HttpMethod.Get, url);
request.Headers.Add("GraphQL-Require-Preflight", "true");
using var response = await client.SendAsync(request);
response.EnsureSuccessStatusCode();
var str = await response.Content.ReadAsStringAsync();
return str;
Expand Down
4 changes: 3 additions & 1 deletion src/Tests/UserContextBuilderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ public async Task Async_Payload_Works()

private async Task Test(string name)
{
using var response = await _client.GetAsync("/graphql?query={test}");
using var request = new HttpRequestMessage(HttpMethod.Get, "/graphql?query={test}");
request.Headers.Add("GraphQL-Require-Preflight", "true");
using var response = await _client.SendAsync(request);
response.EnsureSuccessStatusCode();
var actual = await response.Content.ReadAsStringAsync();
actual.ShouldBe(@"{""data"":{""test"":""" + name + @"""}}");
Expand Down

0 comments on commit 9f8548f

Please sign in to comment.