Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix misuse of class field for handling of base args #165

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions src/flask_classful/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ class FlaskView:
route_base = None
route_prefix = None
trailing_slash = True
base_args = []
excluded_methods = (
[]
) # specify the class methods to be explicitly excluded from routing creation
Expand Down Expand Up @@ -382,13 +381,13 @@ def build_rule(cls, rule, method=None):
if cls.route_prefix:
rule_parts.append(cls.route_prefix)

route_base = cls.get_route_base()
route_base, base_args = cls.get_route_base()
if route_base:
rule_parts.append(route_base)
if len(rule) > 0: # the case of rule='' empty string
rule_parts.append(rule)

ignored_rule_args = ["self"] + cls.base_args
ignored_rule_args = {"self"} | base_args

if method and getattr(cls, "inspect_args", True):
argspec = get_true_argspec(method)
Expand All @@ -413,19 +412,17 @@ def build_rule(cls, rule, method=None):
def get_route_base(cls):
"""Returns the route base to use for the current class."""

if cls.route_base is not None:
route_base = cls.route_base
if not route_base.startswith("/"):
route_base = "/" + route_base
base_rule = Rule(route_base)
# Add rule to a dummy map and bind that map so that
# the Rule's arguments field is populated
Map(rules=[base_rule]).bind("")
cls.base_args.extend(base_rule.arguments)
else:
route_base = cls.default_route_base()

return route_base.strip("/")
if cls.route_base is None:
return cls.default_route_base(), set()

route_base = cls.route_base
if not route_base.startswith("/"):
route_base = "/" + route_base
base_rule = Rule(route_base)
# Add rule to a dummy map and bind that map so that
# the Rule's arguments field is populated
Map(rules=[base_rule]).bind("")
return route_base.strip("/"), base_rule.arguments

@classmethod
def default_route_base(cls):
Expand Down
295 changes: 98 additions & 197 deletions tests/test_base_args.py
Original file line number Diff line number Diff line change
@@ -1,216 +1,117 @@
import json

import marshmallow as ma
from flask import Flask
from marshmallow import Schema
from webargs import fields
from webargs.flaskparser import use_args
from flask import jsonify
from pytest import raises

from flask_classful import FlaskView
from flask_classful import route

# we'll make a list to hold some quotes for our app
quotes = [
"A noble spirit embiggens the smallest man! ~ Jebediah Springfield",
"If there is a way to do it better... find it. ~ Thomas Edison",
"No one knows what he can do till he tries. ~ Publilius Syrus",
]

app = Flask(__name__)
app.config["DEBUG"] = True

put_args = {"text": fields.Str(required=True)}


class UserSchema(Schema):
email = ma.fields.Str()

class Meta:
strict = True


def make_user_schema(request):
# Filter based on 'fields' query parameter
only = request.args.get("fields", None)
# Respect partial updates for PATCH requests
partial = request.method == "PATCH"
# Add current request to the schema's context
return UserSchema(only=only, partial=partial, context={"request": request})


class UsersView(FlaskView):
base_args = ["args"]

@use_args(make_user_schema)
def post(self, args):
return args["email"]

@use_args(make_user_schema)
def put(self, args, id):
return args["email"]

@use_args(make_user_schema)
def patch(self, args, id):
return args["email"]


class QuoteSchema(ma.Schema):
id = ma.fields.Int()
text = ma.fields.Str()

class Meta:
strict = True


def make_quote_schema(request):
# Filter based on 'fields' query parameter
only = request.args.get("fields", None)
# Respect partial updates for PATCH requests
partial = request.method == "PATCH"
# Add current request to the schema's context
return QuoteSchema(only=only, partial=partial, context={"request": request})


class QuotesView(FlaskView):
base_args = ["args"]

def index(self):
return "<br>".join(quotes)

def get(self, id):
quote_id = int(id)
if quote_id < len(quotes) - 1:
return quotes[quote_id]
else:
return "Not Found", 404

@use_args(put_args)
def put(self, args, id):
quote_id = int(id)
if quote_id >= len(quotes) - 1:
return "Not Found", 404
quotes[quote_id] = args["text"]
return quotes[quote_id]

@route("<id>/", methods=["PATCH"])
@use_args(make_quote_schema)
def factory(self, args, id):
quote_id = int(id)
if quote_id >= len(quotes) - 1:
return "Not Found", 404
quotes[quote_id] = args["text"]
return quotes[quote_id]


class UglyNameView(FlaskView):
base_args = ["args"]
route_base = "quotes-2"

def index(self):
return "<br>".join(quotes)

def get(self, id):
quote_id = int(id)
if quote_id < len(quotes) - 1:
return quotes[quote_id]
else:
return "Not Found", 404

@use_args(put_args)
def put(self, args, id):
quote_id = int(id)
if quote_id >= len(quotes) - 1:
return "Not Found", 404
quotes[quote_id] = args["text"]
return quotes[quote_id]


QuotesView.register(app)
UglyNameView.register(app)
UsersView.register(app)

client = app.test_client()

input_headers = [("Content-Type", "application/json")]
input_data = {"text": "My quote"}


def test_users_post():
resp = client.post(
"users/", headers=input_headers, data=json.dumps({"email": "[email protected]"})
)
class NoRouteBaseArgsView(FlaskView):
route_base = "/route/without/args"

def get(self, arg_1):
return (
jsonify(
{
"arg_1": arg_1,
}
),
200,
)


class MultiRouteBaseArgsView(FlaskView):
route_base = "/route/<arg_1>/with/<arg_2>/some_args"

def get(self, arg_1, arg_2, arg_3):
return (
jsonify(
{
"arg_1": arg_1,
"arg_2": arg_2,
"arg_3": arg_3,
}
),
200,
)


class OtherRouteBaseArgsView(FlaskView):
route_base = "/route/<arg_1>/other"

def get(self, arg_1, arg_2):
return (
jsonify(
{
"arg_1": arg_1,
"arg_2": arg_2,
}
),
200,
)


class ErroneousRouteBaseArgsView(FlaskView):
route_base = "/route/<arg_1>/error"

def get(self, arg_2):
return (
jsonify(
{
"arg_2": arg_2,
}
),
200,
)


NoRouteBaseArgsView.register(app)
MultiRouteBaseArgsView.register(app)
OtherRouteBaseArgsView.register(app)
ErroneousRouteBaseArgsView.register(app)


def test_no_route_args():
_, base_args = NoRouteBaseArgsView.get_route_base()
# No route base with args == no base args
assert base_args == set()
client = app.test_client()
resp = client.get("/route/without/args/foo/")
assert resp.status_code == 200
assert "[email protected]" == resp.data.decode("ascii")
assert resp.json == {"arg_1": "foo"}


def test_users_put():
resp = client.put(
"users/1/",
headers=input_headers,
data=json.dumps({"email": "[email protected]"}),
)
assert resp.status_code == 200
assert "[email protected]" == resp.data.decode("ascii")
def test_route_args_are_detected():
_, base_args = MultiRouteBaseArgsView.get_route_base()
assert base_args == {"arg_1", "arg_2"}


def test_users_patch():
resp = client.patch(
"users/1/",
headers=input_headers,
data=json.dumps({"email": "[email protected]"}),
)
def test_multi_route_args_values():
client = app.test_client()
resp = client.get("/route/foo/with/bar/some_args/baz/")
assert resp.status_code == 200
assert "[email protected]" == resp.data.decode("ascii")

assert resp.json == {"arg_1": "foo", "arg_2": "bar", "arg_3": "baz"}

def test_quotes_index():
resp = client.get("/quotes/")
num = len(str(resp.data).split("<br>"))
assert 3 == num
resp = client.get("/quotes")
assert resp.status_code == 308

def test_route_args_are_independent_across_views():
_, base_args = OtherRouteBaseArgsView.get_route_base()
# arg_2 does not leak from evaluating the previous view
assert base_args == {"arg_1"}

def test_quotes_get():
resp = client.get("/quotes/0/")
assert quotes[0] == resp.data.decode("ascii")


def test_quotes_put():
resp = client.put("/quotes/1/", headers=input_headers, data=json.dumps(input_data))
assert input_data["text"] == resp.data.decode("ascii")


def test_quotes_factory():
resp = client.patch(
"/quotes/1/", headers=input_headers, data=json.dumps(input_data)
)
assert input_data["text"] == resp.data.decode("ascii")


def test_quotes2_index():
resp = client.get("/quotes-2/")
num = len(str(resp.data).split("<br>"))
assert 3 == num
resp = client.get("/quotes-2")
assert resp.status_code == 308


def test_quotes2_get():
resp = client.get("/quotes-2/0/")
assert quotes[0] == resp.data.decode("ascii")
assert UglyNameView.base_args.count(UglyNameView.route_base) == 0


def test_quotes2_put():
resp = client.put(
"/quotes-2/1/", headers=input_headers, data=json.dumps(input_data)
def test_missing_base_arg_in_method():
_, base_args = ErroneousRouteBaseArgsView.get_route_base()
# Base arg is recognized
assert base_args == {"arg_1"}
# Rule is correctly generated
assert (
ErroneousRouteBaseArgsView.build_rule("/", ErroneousRouteBaseArgsView.get)
== ErroneousRouteBaseArgsView.route_base + "/<arg_2>"
)
assert input_data["text"] == resp.data.decode("ascii")
assert UglyNameView.base_args.count(UglyNameView.route_base) == 0


# see: https://github.com/pallets-eco/flask-classful/pull/56#issuecomment-328985183
def test_unique_elements():
client.put("/quotes-2/1/", headers=input_headers, data=json.dumps(input_data))
assert UglyNameView.base_args.count(UglyNameView.route_base) == 0
client = app.test_client()
# But calling the method fails because ErroneousRouteBaseArgsView.get is
# supplied with an unexpected "arg_1" argument
with raises(TypeError):
client.get("/route/foo/error/baz/")