Skip to content

Commit

Permalink
Basic, token and login authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Mar 15, 2024
1 parent f6876c0 commit 2703d6e
Show file tree
Hide file tree
Showing 4 changed files with 343 additions and 1 deletion.
147 changes: 147 additions & 0 deletions src/microdot/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from microdot import abort, redirect
from microdot.microdot import urlencode, invoke_handler


class BaseAuth:
def __init__(self):
self.auth_callback = None
self.error_callback = lambda request: abort(401)

def __call__(self, f):
"""Decorator to protect a route with authentication.
Microdot will only call the route if the authentication callback
returns a valid user object, otherwise it will call the error
callback."""
async def wrapper(request, *args, **kwargs):
auth = self._get_auth(request)
if not auth:
return await invoke_handler(self.error_callback, request)
request.g.current_user = await invoke_handler(
self.auth_callback, request, *auth)
if not request.g.current_user:
return await invoke_handler(self.error_callback, request)
return await invoke_handler(f, request, *args, **kwargs)

return wrapper


class HTTPAuth(BaseAuth):
def authenticate(self, f):
"""Decorator to configure the authentication callback.
Microdot calls the authentication callback to allow the application to
check user credentials.
"""
self.auth_callback = f


class BasicAuth(HTTPAuth):
def __init__(self, realm='Please login', charset='UTF-8', scheme='Basic',
error_status=401):
super().__init__()
self.realm = realm
self.charset = charset
self.scheme = scheme
self.error_status = error_status
self.error_callback = self.authentication_error

def _get_auth(self, request):
auth = request.headers.get('Authorization')
if auth and auth.startswith('Basic '):
import binascii
try:
username, password = binascii.a2b_base64(
auth[6:]).decode().split(':', 1)
except Exception: # pragma: no cover
return None
return username, password

def authentication_error(self, request):
return '', self.error_status, {
'WWW-Authenticate': '{} realm="{}", charset="{}"'.format(
self.scheme, self.realm, self.charset)}


class TokenAuth(HTTPAuth):
def __init__(self, header='Authorization', scheme='Bearer'):
super().__init__()
self.header = header
self.scheme = scheme.lower()

def _get_auth(self, request):
auth = request.headers.get(self.header)
if auth:
if self.header == 'Authorization':
try:
scheme, token = auth.split(' ', 1)
except Exception:
return None
if scheme.lower() == self.scheme:
return (token.strip(),)
else:
return (auth,)

def errorhandler(self, f):
"""Decorator to configure the error callback.
Microdot calls the error callback to allow the application to generate
a custom error response. The default error response is to call
``abort(401)``.
"""
self.error_callback = f


class LoginAuth(BaseAuth):
def __init__(self, login_url='/login'):
super().__init__()
self.login_url = login_url
self.user_callback = None
self.user_id_callback = None
self.auth_callback = self._authenticate
self.error_callback = self._redirect_to_login

def id_to_user(self, f):
"""Decorator to configure the user callback.
Microdot calls the user callback to load the user object from the
user ID stored in the user session.
"""
self.user_callback = f

def user_to_id(self, f):
"""Decorator to configure the user ID callback.
Microdot calls the user ID callback to load the user ID from the
user session.
"""
self.user_id_callback = f

def _get_session(self, request):
return request.app._session.get(request)

def _get_auth(self, request):
session = self._get_session(request)
if session and 'user_id' in session:
return (session['user_id'],)

async def _authenticate(self, request, user_id):
return await invoke_handler(self.user_callback, user_id)

async def _redirect_to_login(self, request):
return '', 302, {'Location': self.login_url + '?next=' + urlencode(
request.url)}

async def login_user(self, request, user, redirect_url='/'):
session = self._get_session(request)
session['user_id'] = await invoke_handler(self.user_id_callback, user)
session.save()
next_url = request.args.get('next', redirect_url)
if not next_url.startswith('/'):
next_url = redirect_url
return redirect(next_url)

async def logout_user(self, request):
session = self._get_session(request)
session.pop('user_id', None)
session.save()
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from tests.test_cors import * # noqa: F401, F403
from tests.test_utemplate import * # noqa: F401, F403
from tests.test_session import * # noqa: F401, F403
from tests.test_auth import * # noqa: F401, F403
194 changes: 194 additions & 0 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import asyncio
import binascii
import unittest
from microdot import Microdot
from microdot.auth import BasicAuth, TokenAuth, LoginAuth
from microdot.session import Session, with_session
from microdot.test_client import TestClient


class TestAuth(unittest.TestCase):
@classmethod
def setUpClass(cls):
if hasattr(asyncio, 'set_event_loop'):
asyncio.set_event_loop(asyncio.new_event_loop())
cls.loop = asyncio.get_event_loop()

def _run(self, coro):
return self.loop.run_until_complete(coro)

def test_basic_auth(self):
app = Microdot()
basic_auth = BasicAuth()

@basic_auth.authenticate
def authenticate(request, username, password):
if username == 'foo' and password == 'bar':
return {'username': username}

@app.route('/')
@basic_auth
def index(request):
return request.g.current_user['username']

client = TestClient(app)
res = self._run(client.get('/'))
self.assertEqual(res.status_code, 401)

res = self._run(client.get('/', headers={
'Authorization': 'Basic ' + binascii.b2a_base64(
b'foo:bar').decode()}))
self.assertEqual(res.status_code, 200)
self.assertEqual(res.text, 'foo')

res = self._run(client.get('/', headers={
'Authorization': 'Basic ' + binascii.b2a_base64(
b'foo:baz').decode()}))
self.assertEqual(res.status_code, 401)

def test_token_auth(self):
app = Microdot()
token_auth = TokenAuth()

@token_auth.authenticate
def authenticate(request, token):
if token == 'foo':
return 'user'

@app.route('/')
@token_auth
def index(request):
return request.g.current_user

client = TestClient(app)
res = self._run(client.get('/'))
self.assertEqual(res.status_code, 401)

res = self._run(client.get('/', headers={'Authorization': 'Basic foo'}))
self.assertEqual(res.status_code, 401)

res = self._run(client.get('/', headers={'Authorization': 'foo'}))
self.assertEqual(res.status_code, 401)

res = self._run(client.get('/', headers={'Authorization': 'Bearer foo'}))
self.assertEqual(res.status_code, 200)
self.assertEqual(res.text, 'user')

def test_token_auth_custom_header(self):
app = Microdot()
token_auth = TokenAuth(header='X-Auth-Token')

@token_auth.authenticate
def authenticate(request, token):
if token == 'foo':
return 'user'

@app.route('/')
@token_auth
def index(request):
return request.g.current_user

client = TestClient(app)
res = self._run(client.get('/'))
self.assertEqual(res.status_code, 401)

res = self._run(client.get('/', headers={'Authorization': 'Basic foo'}))
self.assertEqual(res.status_code, 401)

res = self._run(client.get('/', headers={'Authorization': 'foo'}))
self.assertEqual(res.status_code, 401)

res = self._run(client.get('/', headers={'Authorization': 'Bearer foo'}))
self.assertEqual(res.status_code, 401)

res = self._run(client.get('/', headers={'X-Token-Auth': 'Bearer foo'}))
self.assertEqual(res.status_code, 401)

res = self._run(client.get('/', headers={'X-Auth-Token': 'foo'}))
self.assertEqual(res.status_code, 200)
self.assertEqual(res.text, 'user')

res = self._run(client.get('/', headers={'x-auth-token': 'foo'}))
self.assertEqual(res.status_code, 200)
self.assertEqual(res.text, 'user')

@token_auth.errorhandler
def error_handler(request):
return {'status_code': 403}, 403

res = self._run(client.get('/'))
self.assertEqual(res.status_code, 403)
self.assertEqual(res.json, {'status_code': 403})

def test_login_auth(self):
app = Microdot()
Session(app, secret_key='secret')
login_auth = LoginAuth()

@login_auth.id_to_user
def id_to_user(user_id):
return user_id

@login_auth.user_to_id
def user_to_id(user):
return user

@app.get('/')
@login_auth
def index(request):
return request.g.current_user

@app.post('/login')
async def login(request):
return await login_auth.login_user(request, 'user')

@app.post('/logout')
async def logout(request):
await login_auth.logout_user(request)
return 'ok'

client = TestClient(app)
res = self._run(client.get('/?foo=bar'))
self.assertEqual(res.status_code, 302)
self.assertEqual(res.headers['Location'], '/login?next=/%3Ffoo%3Dbar')

res = self._run(client.post('/login?next=/%3Ffoo=bar'))
self.assertEqual(res.status_code, 302)
self.assertEqual(res.headers['Location'], '/?foo=bar')

res = self._run(client.get('/'))
self.assertEqual(res.status_code, 200)
self.assertEqual(res.text, 'user')

res = self._run(client.post('/logout'))
self.assertEqual(res.status_code, 200)

res = self._run(client.get('/'))
self.assertEqual(res.status_code, 302)

def test_login_auth_bad_redirect(self):
app = Microdot()
Session(app, secret_key='secret')
login_auth = LoginAuth()

@login_auth.id_to_user
def id_to_user(user_id):
return user_id

@login_auth.user_to_id
def user_to_id(user):
return user

@app.get('/')
@login_auth
async def index(request):
return 'ok'

@app.post('/login')
async def login(request):
return await login_auth.login_user(request, 'user')

client = TestClient(app)
res = self._run(client.post('/login?next=http://example.com'))
self.assertEqual(res.status_code, 302)
self.assertEqual(res.headers['Location'], '/')
2 changes: 1 addition & 1 deletion tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async def session_context_manager(req, session):

@app.post('/set')
@with_session
async def save_session(req, session):
def save_session(req, session):
session['name'] = 'joe'
session.save()
return 'OK'
Expand Down

0 comments on commit 2703d6e

Please sign in to comment.