-
-
Notifications
You must be signed in to change notification settings - Fork 117
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Basic, token and login authentication
- Loading branch information
1 parent
f6876c0
commit 2703d6e
Showing
4 changed files
with
343 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
from microdot import abort, redirect | ||
from microdot.microdot import urlencode, invoke_handler | ||
|
||
|
||
class BaseAuth: | ||
def __init__(self): | ||
self.auth_callback = None | ||
self.error_callback = lambda request: abort(401) | ||
|
||
def __call__(self, f): | ||
"""Decorator to protect a route with authentication. | ||
Microdot will only call the route if the authentication callback | ||
returns a valid user object, otherwise it will call the error | ||
callback.""" | ||
async def wrapper(request, *args, **kwargs): | ||
auth = self._get_auth(request) | ||
if not auth: | ||
return await invoke_handler(self.error_callback, request) | ||
request.g.current_user = await invoke_handler( | ||
self.auth_callback, request, *auth) | ||
if not request.g.current_user: | ||
return await invoke_handler(self.error_callback, request) | ||
return await invoke_handler(f, request, *args, **kwargs) | ||
|
||
return wrapper | ||
|
||
|
||
class HTTPAuth(BaseAuth): | ||
def authenticate(self, f): | ||
"""Decorator to configure the authentication callback. | ||
Microdot calls the authentication callback to allow the application to | ||
check user credentials. | ||
""" | ||
self.auth_callback = f | ||
|
||
|
||
class BasicAuth(HTTPAuth): | ||
def __init__(self, realm='Please login', charset='UTF-8', scheme='Basic', | ||
error_status=401): | ||
super().__init__() | ||
self.realm = realm | ||
self.charset = charset | ||
self.scheme = scheme | ||
self.error_status = error_status | ||
self.error_callback = self.authentication_error | ||
|
||
def _get_auth(self, request): | ||
auth = request.headers.get('Authorization') | ||
if auth and auth.startswith('Basic '): | ||
import binascii | ||
try: | ||
username, password = binascii.a2b_base64( | ||
auth[6:]).decode().split(':', 1) | ||
except Exception: # pragma: no cover | ||
return None | ||
return username, password | ||
|
||
def authentication_error(self, request): | ||
return '', self.error_status, { | ||
'WWW-Authenticate': '{} realm="{}", charset="{}"'.format( | ||
self.scheme, self.realm, self.charset)} | ||
|
||
|
||
class TokenAuth(HTTPAuth): | ||
def __init__(self, header='Authorization', scheme='Bearer'): | ||
super().__init__() | ||
self.header = header | ||
self.scheme = scheme.lower() | ||
|
||
def _get_auth(self, request): | ||
auth = request.headers.get(self.header) | ||
if auth: | ||
if self.header == 'Authorization': | ||
try: | ||
scheme, token = auth.split(' ', 1) | ||
except Exception: | ||
return None | ||
if scheme.lower() == self.scheme: | ||
return (token.strip(),) | ||
else: | ||
return (auth,) | ||
|
||
def errorhandler(self, f): | ||
"""Decorator to configure the error callback. | ||
Microdot calls the error callback to allow the application to generate | ||
a custom error response. The default error response is to call | ||
``abort(401)``. | ||
""" | ||
self.error_callback = f | ||
|
||
|
||
class LoginAuth(BaseAuth): | ||
def __init__(self, login_url='/login'): | ||
super().__init__() | ||
self.login_url = login_url | ||
self.user_callback = None | ||
self.user_id_callback = None | ||
self.auth_callback = self._authenticate | ||
self.error_callback = self._redirect_to_login | ||
|
||
def id_to_user(self, f): | ||
"""Decorator to configure the user callback. | ||
Microdot calls the user callback to load the user object from the | ||
user ID stored in the user session. | ||
""" | ||
self.user_callback = f | ||
|
||
def user_to_id(self, f): | ||
"""Decorator to configure the user ID callback. | ||
Microdot calls the user ID callback to load the user ID from the | ||
user session. | ||
""" | ||
self.user_id_callback = f | ||
|
||
def _get_session(self, request): | ||
return request.app._session.get(request) | ||
|
||
def _get_auth(self, request): | ||
session = self._get_session(request) | ||
if session and 'user_id' in session: | ||
return (session['user_id'],) | ||
|
||
async def _authenticate(self, request, user_id): | ||
return await invoke_handler(self.user_callback, user_id) | ||
|
||
async def _redirect_to_login(self, request): | ||
return '', 302, {'Location': self.login_url + '?next=' + urlencode( | ||
request.url)} | ||
|
||
async def login_user(self, request, user, redirect_url='/'): | ||
session = self._get_session(request) | ||
session['user_id'] = await invoke_handler(self.user_id_callback, user) | ||
session.save() | ||
next_url = request.args.get('next', redirect_url) | ||
if not next_url.startswith('/'): | ||
next_url = redirect_url | ||
return redirect(next_url) | ||
|
||
async def logout_user(self, request): | ||
session = self._get_session(request) | ||
session.pop('user_id', None) | ||
session.save() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
import asyncio | ||
import binascii | ||
import unittest | ||
from microdot import Microdot | ||
from microdot.auth import BasicAuth, TokenAuth, LoginAuth | ||
from microdot.session import Session, with_session | ||
from microdot.test_client import TestClient | ||
|
||
|
||
class TestAuth(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
if hasattr(asyncio, 'set_event_loop'): | ||
asyncio.set_event_loop(asyncio.new_event_loop()) | ||
cls.loop = asyncio.get_event_loop() | ||
|
||
def _run(self, coro): | ||
return self.loop.run_until_complete(coro) | ||
|
||
def test_basic_auth(self): | ||
app = Microdot() | ||
basic_auth = BasicAuth() | ||
|
||
@basic_auth.authenticate | ||
def authenticate(request, username, password): | ||
if username == 'foo' and password == 'bar': | ||
return {'username': username} | ||
|
||
@app.route('/') | ||
@basic_auth | ||
def index(request): | ||
return request.g.current_user['username'] | ||
|
||
client = TestClient(app) | ||
res = self._run(client.get('/')) | ||
self.assertEqual(res.status_code, 401) | ||
|
||
res = self._run(client.get('/', headers={ | ||
'Authorization': 'Basic ' + binascii.b2a_base64( | ||
b'foo:bar').decode()})) | ||
self.assertEqual(res.status_code, 200) | ||
self.assertEqual(res.text, 'foo') | ||
|
||
res = self._run(client.get('/', headers={ | ||
'Authorization': 'Basic ' + binascii.b2a_base64( | ||
b'foo:baz').decode()})) | ||
self.assertEqual(res.status_code, 401) | ||
|
||
def test_token_auth(self): | ||
app = Microdot() | ||
token_auth = TokenAuth() | ||
|
||
@token_auth.authenticate | ||
def authenticate(request, token): | ||
if token == 'foo': | ||
return 'user' | ||
|
||
@app.route('/') | ||
@token_auth | ||
def index(request): | ||
return request.g.current_user | ||
|
||
client = TestClient(app) | ||
res = self._run(client.get('/')) | ||
self.assertEqual(res.status_code, 401) | ||
|
||
res = self._run(client.get('/', headers={'Authorization': 'Basic foo'})) | ||
self.assertEqual(res.status_code, 401) | ||
|
||
res = self._run(client.get('/', headers={'Authorization': 'foo'})) | ||
self.assertEqual(res.status_code, 401) | ||
|
||
res = self._run(client.get('/', headers={'Authorization': 'Bearer foo'})) | ||
self.assertEqual(res.status_code, 200) | ||
self.assertEqual(res.text, 'user') | ||
|
||
def test_token_auth_custom_header(self): | ||
app = Microdot() | ||
token_auth = TokenAuth(header='X-Auth-Token') | ||
|
||
@token_auth.authenticate | ||
def authenticate(request, token): | ||
if token == 'foo': | ||
return 'user' | ||
|
||
@app.route('/') | ||
@token_auth | ||
def index(request): | ||
return request.g.current_user | ||
|
||
client = TestClient(app) | ||
res = self._run(client.get('/')) | ||
self.assertEqual(res.status_code, 401) | ||
|
||
res = self._run(client.get('/', headers={'Authorization': 'Basic foo'})) | ||
self.assertEqual(res.status_code, 401) | ||
|
||
res = self._run(client.get('/', headers={'Authorization': 'foo'})) | ||
self.assertEqual(res.status_code, 401) | ||
|
||
res = self._run(client.get('/', headers={'Authorization': 'Bearer foo'})) | ||
self.assertEqual(res.status_code, 401) | ||
|
||
res = self._run(client.get('/', headers={'X-Token-Auth': 'Bearer foo'})) | ||
self.assertEqual(res.status_code, 401) | ||
|
||
res = self._run(client.get('/', headers={'X-Auth-Token': 'foo'})) | ||
self.assertEqual(res.status_code, 200) | ||
self.assertEqual(res.text, 'user') | ||
|
||
res = self._run(client.get('/', headers={'x-auth-token': 'foo'})) | ||
self.assertEqual(res.status_code, 200) | ||
self.assertEqual(res.text, 'user') | ||
|
||
@token_auth.errorhandler | ||
def error_handler(request): | ||
return {'status_code': 403}, 403 | ||
|
||
res = self._run(client.get('/')) | ||
self.assertEqual(res.status_code, 403) | ||
self.assertEqual(res.json, {'status_code': 403}) | ||
|
||
def test_login_auth(self): | ||
app = Microdot() | ||
Session(app, secret_key='secret') | ||
login_auth = LoginAuth() | ||
|
||
@login_auth.id_to_user | ||
def id_to_user(user_id): | ||
return user_id | ||
|
||
@login_auth.user_to_id | ||
def user_to_id(user): | ||
return user | ||
|
||
@app.get('/') | ||
@login_auth | ||
def index(request): | ||
return request.g.current_user | ||
|
||
@app.post('/login') | ||
async def login(request): | ||
return await login_auth.login_user(request, 'user') | ||
|
||
@app.post('/logout') | ||
async def logout(request): | ||
await login_auth.logout_user(request) | ||
return 'ok' | ||
|
||
client = TestClient(app) | ||
res = self._run(client.get('/?foo=bar')) | ||
self.assertEqual(res.status_code, 302) | ||
self.assertEqual(res.headers['Location'], '/login?next=/%3Ffoo%3Dbar') | ||
|
||
res = self._run(client.post('/login?next=/%3Ffoo=bar')) | ||
self.assertEqual(res.status_code, 302) | ||
self.assertEqual(res.headers['Location'], '/?foo=bar') | ||
|
||
res = self._run(client.get('/')) | ||
self.assertEqual(res.status_code, 200) | ||
self.assertEqual(res.text, 'user') | ||
|
||
res = self._run(client.post('/logout')) | ||
self.assertEqual(res.status_code, 200) | ||
|
||
res = self._run(client.get('/')) | ||
self.assertEqual(res.status_code, 302) | ||
|
||
def test_login_auth_bad_redirect(self): | ||
app = Microdot() | ||
Session(app, secret_key='secret') | ||
login_auth = LoginAuth() | ||
|
||
@login_auth.id_to_user | ||
def id_to_user(user_id): | ||
return user_id | ||
|
||
@login_auth.user_to_id | ||
def user_to_id(user): | ||
return user | ||
|
||
@app.get('/') | ||
@login_auth | ||
async def index(request): | ||
return 'ok' | ||
|
||
@app.post('/login') | ||
async def login(request): | ||
return await login_auth.login_user(request, 'user') | ||
|
||
client = TestClient(app) | ||
res = self._run(client.post('/login?next=http://example.com')) | ||
self.assertEqual(res.status_code, 302) | ||
self.assertEqual(res.headers['Location'], '/') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters