diff --git a/src/flask_classful/__init__.py b/src/flask_classful/__init__.py index 9deb320..823b548 100644 --- a/src/flask_classful/__init__.py +++ b/src/flask_classful/__init__.py @@ -56,7 +56,6 @@ class FlaskView: route_base = None route_prefix = None trailing_slash = True - base_args = [] excluded_methods = ( [] ) # specify the class methods to be explicitly excluded from routing creation @@ -382,13 +381,13 @@ def build_rule(cls, rule, method=None): if cls.route_prefix: rule_parts.append(cls.route_prefix) - route_base = cls.get_route_base() + route_base, base_args = cls.get_route_base() if route_base: rule_parts.append(route_base) if len(rule) > 0: # the case of rule='' empty string rule_parts.append(rule) - ignored_rule_args = ["self"] + cls.base_args + ignored_rule_args = {"self"} | base_args if method and getattr(cls, "inspect_args", True): argspec = get_true_argspec(method) @@ -413,19 +412,17 @@ def build_rule(cls, rule, method=None): def get_route_base(cls): """Returns the route base to use for the current class.""" - if cls.route_base is not None: - route_base = cls.route_base - if not route_base.startswith("/"): - route_base = "/" + route_base - base_rule = Rule(route_base) - # Add rule to a dummy map and bind that map so that - # the Rule's arguments field is populated - Map(rules=[base_rule]).bind("") - cls.base_args.extend(base_rule.arguments) - else: - route_base = cls.default_route_base() - - return route_base.strip("/") + if cls.route_base is None: + return cls.default_route_base(), set() + + route_base = cls.route_base + if not route_base.startswith("/"): + route_base = "/" + route_base + base_rule = Rule(route_base) + # Add rule to a dummy map and bind that map so that + # the Rule's arguments field is populated + Map(rules=[base_rule]).bind("") + return route_base.strip("/"), base_rule.arguments @classmethod def default_route_base(cls): diff --git a/tests/test_base_args.py b/tests/test_base_args.py index 2dca314..04d37b9 100644 --- a/tests/test_base_args.py +++ b/tests/test_base_args.py @@ -1,216 +1,117 @@ -import json - -import marshmallow as ma from flask import Flask -from marshmallow import Schema -from webargs import fields -from webargs.flaskparser import use_args +from flask import jsonify +from pytest import raises from flask_classful import FlaskView -from flask_classful import route - -# we'll make a list to hold some quotes for our app -quotes = [ - "A noble spirit embiggens the smallest man! ~ Jebediah Springfield", - "If there is a way to do it better... find it. ~ Thomas Edison", - "No one knows what he can do till he tries. ~ Publilius Syrus", -] app = Flask(__name__) app.config["DEBUG"] = True -put_args = {"text": fields.Str(required=True)} - - -class UserSchema(Schema): - email = ma.fields.Str() - - class Meta: - strict = True - - -def make_user_schema(request): - # Filter based on 'fields' query parameter - only = request.args.get("fields", None) - # Respect partial updates for PATCH requests - partial = request.method == "PATCH" - # Add current request to the schema's context - return UserSchema(only=only, partial=partial, context={"request": request}) - - -class UsersView(FlaskView): - base_args = ["args"] - - @use_args(make_user_schema) - def post(self, args): - return args["email"] - - @use_args(make_user_schema) - def put(self, args, id): - return args["email"] - - @use_args(make_user_schema) - def patch(self, args, id): - return args["email"] - - -class QuoteSchema(ma.Schema): - id = ma.fields.Int() - text = ma.fields.Str() - - class Meta: - strict = True - - -def make_quote_schema(request): - # Filter based on 'fields' query parameter - only = request.args.get("fields", None) - # Respect partial updates for PATCH requests - partial = request.method == "PATCH" - # Add current request to the schema's context - return QuoteSchema(only=only, partial=partial, context={"request": request}) - - -class QuotesView(FlaskView): - base_args = ["args"] - - def index(self): - return "
".join(quotes) - - def get(self, id): - quote_id = int(id) - if quote_id < len(quotes) - 1: - return quotes[quote_id] - else: - return "Not Found", 404 - - @use_args(put_args) - def put(self, args, id): - quote_id = int(id) - if quote_id >= len(quotes) - 1: - return "Not Found", 404 - quotes[quote_id] = args["text"] - return quotes[quote_id] - - @route("/", methods=["PATCH"]) - @use_args(make_quote_schema) - def factory(self, args, id): - quote_id = int(id) - if quote_id >= len(quotes) - 1: - return "Not Found", 404 - quotes[quote_id] = args["text"] - return quotes[quote_id] - - -class UglyNameView(FlaskView): - base_args = ["args"] - route_base = "quotes-2" - def index(self): - return "
".join(quotes) - - def get(self, id): - quote_id = int(id) - if quote_id < len(quotes) - 1: - return quotes[quote_id] - else: - return "Not Found", 404 - - @use_args(put_args) - def put(self, args, id): - quote_id = int(id) - if quote_id >= len(quotes) - 1: - return "Not Found", 404 - quotes[quote_id] = args["text"] - return quotes[quote_id] - - -QuotesView.register(app) -UglyNameView.register(app) -UsersView.register(app) - -client = app.test_client() - -input_headers = [("Content-Type", "application/json")] -input_data = {"text": "My quote"} - - -def test_users_post(): - resp = client.post( - "users/", headers=input_headers, data=json.dumps({"email": "test@example.com"}) - ) +class NoRouteBaseArgsView(FlaskView): + route_base = "/route/without/args" + + def get(self, arg_1): + return ( + jsonify( + { + "arg_1": arg_1, + } + ), + 200, + ) + + +class MultiRouteBaseArgsView(FlaskView): + route_base = "/route//with//some_args" + + def get(self, arg_1, arg_2, arg_3): + return ( + jsonify( + { + "arg_1": arg_1, + "arg_2": arg_2, + "arg_3": arg_3, + } + ), + 200, + ) + + +class OtherRouteBaseArgsView(FlaskView): + route_base = "/route//other" + + def get(self, arg_1, arg_2): + return ( + jsonify( + { + "arg_1": arg_1, + "arg_2": arg_2, + } + ), + 200, + ) + + +class ErroneousRouteBaseArgsView(FlaskView): + route_base = "/route//error" + + def get(self, arg_2): + return ( + jsonify( + { + "arg_2": arg_2, + } + ), + 200, + ) + + +NoRouteBaseArgsView.register(app) +MultiRouteBaseArgsView.register(app) +OtherRouteBaseArgsView.register(app) +ErroneousRouteBaseArgsView.register(app) + + +def test_no_route_args(): + _, base_args = NoRouteBaseArgsView.get_route_base() + # No route base with args == no base args + assert base_args == set() + client = app.test_client() + resp = client.get("/route/without/args/foo/") assert resp.status_code == 200 - assert "test@example.com" == resp.data.decode("ascii") + assert resp.json == {"arg_1": "foo"} -def test_users_put(): - resp = client.put( - "users/1/", - headers=input_headers, - data=json.dumps({"email": "test@example.com"}), - ) - assert resp.status_code == 200 - assert "test@example.com" == resp.data.decode("ascii") +def test_route_args_are_detected(): + _, base_args = MultiRouteBaseArgsView.get_route_base() + assert base_args == {"arg_1", "arg_2"} -def test_users_patch(): - resp = client.patch( - "users/1/", - headers=input_headers, - data=json.dumps({"email": "test@example.com"}), - ) +def test_multi_route_args_values(): + client = app.test_client() + resp = client.get("/route/foo/with/bar/some_args/baz/") assert resp.status_code == 200 - assert "test@example.com" == resp.data.decode("ascii") - + assert resp.json == {"arg_1": "foo", "arg_2": "bar", "arg_3": "baz"} -def test_quotes_index(): - resp = client.get("/quotes/") - num = len(str(resp.data).split("
")) - assert 3 == num - resp = client.get("/quotes") - assert resp.status_code == 308 +def test_route_args_are_independent_across_views(): + _, base_args = OtherRouteBaseArgsView.get_route_base() + # arg_2 does not leak from evaluating the previous view + assert base_args == {"arg_1"} -def test_quotes_get(): - resp = client.get("/quotes/0/") - assert quotes[0] == resp.data.decode("ascii") - -def test_quotes_put(): - resp = client.put("/quotes/1/", headers=input_headers, data=json.dumps(input_data)) - assert input_data["text"] == resp.data.decode("ascii") - - -def test_quotes_factory(): - resp = client.patch( - "/quotes/1/", headers=input_headers, data=json.dumps(input_data) - ) - assert input_data["text"] == resp.data.decode("ascii") - - -def test_quotes2_index(): - resp = client.get("/quotes-2/") - num = len(str(resp.data).split("
")) - assert 3 == num - resp = client.get("/quotes-2") - assert resp.status_code == 308 - - -def test_quotes2_get(): - resp = client.get("/quotes-2/0/") - assert quotes[0] == resp.data.decode("ascii") - assert UglyNameView.base_args.count(UglyNameView.route_base) == 0 - - -def test_quotes2_put(): - resp = client.put( - "/quotes-2/1/", headers=input_headers, data=json.dumps(input_data) +def test_missing_base_arg_in_method(): + _, base_args = ErroneousRouteBaseArgsView.get_route_base() + # Base arg is recognized + assert base_args == {"arg_1"} + # Rule is correctly generated + assert ( + ErroneousRouteBaseArgsView.build_rule("/", ErroneousRouteBaseArgsView.get) + == ErroneousRouteBaseArgsView.route_base + "/" ) - assert input_data["text"] == resp.data.decode("ascii") - assert UglyNameView.base_args.count(UglyNameView.route_base) == 0 - - -# see: https://github.com/pallets-eco/flask-classful/pull/56#issuecomment-328985183 -def test_unique_elements(): - client.put("/quotes-2/1/", headers=input_headers, data=json.dumps(input_data)) - assert UglyNameView.base_args.count(UglyNameView.route_base) == 0 + client = app.test_client() + # But calling the method fails because ErroneousRouteBaseArgsView.get is + # supplied with an unexpected "arg_1" argument + with raises(TypeError): + client.get("/route/foo/error/baz/")