diff --git a/src/flask_classful/__init__.py b/src/flask_classful/__init__.py
index 9deb320..823b548 100644
--- a/src/flask_classful/__init__.py
+++ b/src/flask_classful/__init__.py
@@ -56,7 +56,6 @@ class FlaskView:
route_base = None
route_prefix = None
trailing_slash = True
- base_args = []
excluded_methods = (
[]
) # specify the class methods to be explicitly excluded from routing creation
@@ -382,13 +381,13 @@ def build_rule(cls, rule, method=None):
if cls.route_prefix:
rule_parts.append(cls.route_prefix)
- route_base = cls.get_route_base()
+ route_base, base_args = cls.get_route_base()
if route_base:
rule_parts.append(route_base)
if len(rule) > 0: # the case of rule='' empty string
rule_parts.append(rule)
- ignored_rule_args = ["self"] + cls.base_args
+ ignored_rule_args = {"self"} | base_args
if method and getattr(cls, "inspect_args", True):
argspec = get_true_argspec(method)
@@ -413,19 +412,17 @@ def build_rule(cls, rule, method=None):
def get_route_base(cls):
"""Returns the route base to use for the current class."""
- if cls.route_base is not None:
- route_base = cls.route_base
- if not route_base.startswith("/"):
- route_base = "/" + route_base
- base_rule = Rule(route_base)
- # Add rule to a dummy map and bind that map so that
- # the Rule's arguments field is populated
- Map(rules=[base_rule]).bind("")
- cls.base_args.extend(base_rule.arguments)
- else:
- route_base = cls.default_route_base()
-
- return route_base.strip("/")
+ if cls.route_base is None:
+ return cls.default_route_base(), set()
+
+ route_base = cls.route_base
+ if not route_base.startswith("/"):
+ route_base = "/" + route_base
+ base_rule = Rule(route_base)
+ # Add rule to a dummy map and bind that map so that
+ # the Rule's arguments field is populated
+ Map(rules=[base_rule]).bind("")
+ return route_base.strip("/"), base_rule.arguments
@classmethod
def default_route_base(cls):
diff --git a/tests/test_base_args.py b/tests/test_base_args.py
index 2dca314..04d37b9 100644
--- a/tests/test_base_args.py
+++ b/tests/test_base_args.py
@@ -1,216 +1,117 @@
-import json
-
-import marshmallow as ma
from flask import Flask
-from marshmallow import Schema
-from webargs import fields
-from webargs.flaskparser import use_args
+from flask import jsonify
+from pytest import raises
from flask_classful import FlaskView
-from flask_classful import route
-
-# we'll make a list to hold some quotes for our app
-quotes = [
- "A noble spirit embiggens the smallest man! ~ Jebediah Springfield",
- "If there is a way to do it better... find it. ~ Thomas Edison",
- "No one knows what he can do till he tries. ~ Publilius Syrus",
-]
app = Flask(__name__)
app.config["DEBUG"] = True
-put_args = {"text": fields.Str(required=True)}
-
-
-class UserSchema(Schema):
- email = ma.fields.Str()
-
- class Meta:
- strict = True
-
-
-def make_user_schema(request):
- # Filter based on 'fields' query parameter
- only = request.args.get("fields", None)
- # Respect partial updates for PATCH requests
- partial = request.method == "PATCH"
- # Add current request to the schema's context
- return UserSchema(only=only, partial=partial, context={"request": request})
-
-
-class UsersView(FlaskView):
- base_args = ["args"]
-
- @use_args(make_user_schema)
- def post(self, args):
- return args["email"]
-
- @use_args(make_user_schema)
- def put(self, args, id):
- return args["email"]
-
- @use_args(make_user_schema)
- def patch(self, args, id):
- return args["email"]
-
-
-class QuoteSchema(ma.Schema):
- id = ma.fields.Int()
- text = ma.fields.Str()
-
- class Meta:
- strict = True
-
-
-def make_quote_schema(request):
- # Filter based on 'fields' query parameter
- only = request.args.get("fields", None)
- # Respect partial updates for PATCH requests
- partial = request.method == "PATCH"
- # Add current request to the schema's context
- return QuoteSchema(only=only, partial=partial, context={"request": request})
-
-
-class QuotesView(FlaskView):
- base_args = ["args"]
-
- def index(self):
- return "
".join(quotes)
-
- def get(self, id):
- quote_id = int(id)
- if quote_id < len(quotes) - 1:
- return quotes[quote_id]
- else:
- return "Not Found", 404
-
- @use_args(put_args)
- def put(self, args, id):
- quote_id = int(id)
- if quote_id >= len(quotes) - 1:
- return "Not Found", 404
- quotes[quote_id] = args["text"]
- return quotes[quote_id]
-
- @route("/", methods=["PATCH"])
- @use_args(make_quote_schema)
- def factory(self, args, id):
- quote_id = int(id)
- if quote_id >= len(quotes) - 1:
- return "Not Found", 404
- quotes[quote_id] = args["text"]
- return quotes[quote_id]
-
-
-class UglyNameView(FlaskView):
- base_args = ["args"]
- route_base = "quotes-2"
- def index(self):
- return "
".join(quotes)
-
- def get(self, id):
- quote_id = int(id)
- if quote_id < len(quotes) - 1:
- return quotes[quote_id]
- else:
- return "Not Found", 404
-
- @use_args(put_args)
- def put(self, args, id):
- quote_id = int(id)
- if quote_id >= len(quotes) - 1:
- return "Not Found", 404
- quotes[quote_id] = args["text"]
- return quotes[quote_id]
-
-
-QuotesView.register(app)
-UglyNameView.register(app)
-UsersView.register(app)
-
-client = app.test_client()
-
-input_headers = [("Content-Type", "application/json")]
-input_data = {"text": "My quote"}
-
-
-def test_users_post():
- resp = client.post(
- "users/", headers=input_headers, data=json.dumps({"email": "test@example.com"})
- )
+class NoRouteBaseArgsView(FlaskView):
+ route_base = "/route/without/args"
+
+ def get(self, arg_1):
+ return (
+ jsonify(
+ {
+ "arg_1": arg_1,
+ }
+ ),
+ 200,
+ )
+
+
+class MultiRouteBaseArgsView(FlaskView):
+ route_base = "/route//with//some_args"
+
+ def get(self, arg_1, arg_2, arg_3):
+ return (
+ jsonify(
+ {
+ "arg_1": arg_1,
+ "arg_2": arg_2,
+ "arg_3": arg_3,
+ }
+ ),
+ 200,
+ )
+
+
+class OtherRouteBaseArgsView(FlaskView):
+ route_base = "/route//other"
+
+ def get(self, arg_1, arg_2):
+ return (
+ jsonify(
+ {
+ "arg_1": arg_1,
+ "arg_2": arg_2,
+ }
+ ),
+ 200,
+ )
+
+
+class ErroneousRouteBaseArgsView(FlaskView):
+ route_base = "/route//error"
+
+ def get(self, arg_2):
+ return (
+ jsonify(
+ {
+ "arg_2": arg_2,
+ }
+ ),
+ 200,
+ )
+
+
+NoRouteBaseArgsView.register(app)
+MultiRouteBaseArgsView.register(app)
+OtherRouteBaseArgsView.register(app)
+ErroneousRouteBaseArgsView.register(app)
+
+
+def test_no_route_args():
+ _, base_args = NoRouteBaseArgsView.get_route_base()
+ # No route base with args == no base args
+ assert base_args == set()
+ client = app.test_client()
+ resp = client.get("/route/without/args/foo/")
assert resp.status_code == 200
- assert "test@example.com" == resp.data.decode("ascii")
+ assert resp.json == {"arg_1": "foo"}
-def test_users_put():
- resp = client.put(
- "users/1/",
- headers=input_headers,
- data=json.dumps({"email": "test@example.com"}),
- )
- assert resp.status_code == 200
- assert "test@example.com" == resp.data.decode("ascii")
+def test_route_args_are_detected():
+ _, base_args = MultiRouteBaseArgsView.get_route_base()
+ assert base_args == {"arg_1", "arg_2"}
-def test_users_patch():
- resp = client.patch(
- "users/1/",
- headers=input_headers,
- data=json.dumps({"email": "test@example.com"}),
- )
+def test_multi_route_args_values():
+ client = app.test_client()
+ resp = client.get("/route/foo/with/bar/some_args/baz/")
assert resp.status_code == 200
- assert "test@example.com" == resp.data.decode("ascii")
-
+ assert resp.json == {"arg_1": "foo", "arg_2": "bar", "arg_3": "baz"}
-def test_quotes_index():
- resp = client.get("/quotes/")
- num = len(str(resp.data).split("
"))
- assert 3 == num
- resp = client.get("/quotes")
- assert resp.status_code == 308
+def test_route_args_are_independent_across_views():
+ _, base_args = OtherRouteBaseArgsView.get_route_base()
+ # arg_2 does not leak from evaluating the previous view
+ assert base_args == {"arg_1"}
-def test_quotes_get():
- resp = client.get("/quotes/0/")
- assert quotes[0] == resp.data.decode("ascii")
-
-def test_quotes_put():
- resp = client.put("/quotes/1/", headers=input_headers, data=json.dumps(input_data))
- assert input_data["text"] == resp.data.decode("ascii")
-
-
-def test_quotes_factory():
- resp = client.patch(
- "/quotes/1/", headers=input_headers, data=json.dumps(input_data)
- )
- assert input_data["text"] == resp.data.decode("ascii")
-
-
-def test_quotes2_index():
- resp = client.get("/quotes-2/")
- num = len(str(resp.data).split("
"))
- assert 3 == num
- resp = client.get("/quotes-2")
- assert resp.status_code == 308
-
-
-def test_quotes2_get():
- resp = client.get("/quotes-2/0/")
- assert quotes[0] == resp.data.decode("ascii")
- assert UglyNameView.base_args.count(UglyNameView.route_base) == 0
-
-
-def test_quotes2_put():
- resp = client.put(
- "/quotes-2/1/", headers=input_headers, data=json.dumps(input_data)
+def test_missing_base_arg_in_method():
+ _, base_args = ErroneousRouteBaseArgsView.get_route_base()
+ # Base arg is recognized
+ assert base_args == {"arg_1"}
+ # Rule is correctly generated
+ assert (
+ ErroneousRouteBaseArgsView.build_rule("/", ErroneousRouteBaseArgsView.get)
+ == ErroneousRouteBaseArgsView.route_base + "/"
)
- assert input_data["text"] == resp.data.decode("ascii")
- assert UglyNameView.base_args.count(UglyNameView.route_base) == 0
-
-
-# see: https://github.com/pallets-eco/flask-classful/pull/56#issuecomment-328985183
-def test_unique_elements():
- client.put("/quotes-2/1/", headers=input_headers, data=json.dumps(input_data))
- assert UglyNameView.base_args.count(UglyNameView.route_base) == 0
+ client = app.test_client()
+ # But calling the method fails because ErroneousRouteBaseArgsView.get is
+ # supplied with an unexpected "arg_1" argument
+ with raises(TypeError):
+ client.get("/route/foo/error/baz/")