Skip to content

Commit

Permalink
Configurable session cookie options (Fixes #242)
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Jun 17, 2024
1 parent 4204db6 commit 0151611
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 13 deletions.
17 changes: 12 additions & 5 deletions src/microdot/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,21 @@ class Session:
"""
secret_key = None

def __init__(self, app=None, secret_key=None):
def __init__(self, app=None, secret_key=None, cookie_options=None):
self.secret_key = secret_key
self.cookie_options = cookie_options or {}
if app is not None:
self.initialize(app)

def initialize(self, app, secret_key=None):
def initialize(self, app, secret_key=None, cookie_options=None):
if secret_key is not None:
self.secret_key = secret_key
if cookie_options is not None:
self.cookie_options = cookie_options
if 'path' not in self.cookie_options:
self.cookie_options['path'] = '/'
if 'http_only' not in self.cookie_options:
self.cookie_options['http_only'] = True
app._session = self

def get(self, request):
Expand Down Expand Up @@ -86,7 +93,8 @@ def index(request, session):

@request.after_request
def _update_session(request, response):
response.set_cookie('session', encoded_session, http_only=True)
response.set_cookie('session', encoded_session,
**self.cookie_options)
return response

def delete(self, request):
Expand All @@ -109,8 +117,7 @@ def index(request, session):
"""
@request.after_request
def _delete_session(request, response):
response.set_cookie('session', '', http_only=True,
expires='Thu, 01 Jan 1970 00:00:01 GMT')
response.delete_cookie('session')
return response

def encode(self, payload, secret_key=None):
Expand Down
27 changes: 19 additions & 8 deletions src/microdot/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,13 @@ def _process_body(self, body, headers):
headers['Host'] = 'example.com:1234'
return body, headers

def _process_cookies(self, headers):
def _process_cookies(self, path, headers):
cookies = ''
for name, value in self.cookies.items():
if isinstance(value, tuple):
value, cookie_path = value
if not path.startswith(cookie_path):
continue
if cookies:
cookies += '; '
cookies += name + '=' + value
Expand All @@ -123,7 +127,7 @@ def _process_cookies(self, headers):
headers['Cookie'] += '; ' + cookies
else:
headers['Cookie'] = cookies
return cookies, headers
return headers

def _render_request(self, method, path, headers, body):
request_bytes = '{method} {path} HTTP/1.0\n'.format(
Expand All @@ -139,36 +143,43 @@ def _update_cookies(self, res):
for cookie in cookies:
cookie_name, cookie_value = cookie.split('=', 1)
cookie_options = cookie_value.split(';')
path = '/'
delete = False
for option in cookie_options[1:]:
if option.strip().lower().startswith(
option = option.strip().lower()
if option.startswith(
'max-age='): # pragma: no cover
_, age = option.strip().split('=', 1)
_, age = option.split('=', 1)
try:
age = int(age)
except ValueError: # pragma: no cover
age = 0
if age <= 0:
delete = True
break
elif option.strip().lower().startswith('expires='):
_, e = option.strip().split('=', 1)
elif option.startswith('expires='):
_, e = option.split('=', 1)
# this is a very limited parser for cookie expiry
# that only detects a cookie deletion request when
# the date is 1/1/1970
if '1 jan 1970' in e.lower(): # pragma: no branch
delete = True
break
elif option.startswith('path='):
_, path = option.split('=', 1)
if delete:
if cookie_name in self.cookies: # pragma: no branch
del self.cookies[cookie_name]
else:
self.cookies[cookie_name] = cookie_options[0]
if path == '/':
self.cookies[cookie_name] = cookie_options[0]
else:
self.cookies[cookie_name] = (cookie_options[0], path)

async def request(self, method, path, headers=None, body=None, sock=None):
headers = headers or {}
body, headers = self._process_body(body, headers)
cookies, headers = self._process_cookies(headers)
headers = self._process_cookies(path, headers)
request_bytes = self._render_request(method, path, headers, body)
if sock:
reader = sock[0]
Expand Down
52 changes: 52 additions & 0 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,55 @@ def index(req):

res = self._run(client.get('/'))
self.assertEqual(res.status_code, 200)

def test_session_default_path(self):
app = Microdot()
session_ext.initialize(app, secret_key='some-other-secret')
client = TestClient(app)

@app.get('/')
@with_session
def index(req, session):
session['foo'] = 'bar'
session.save()
return ''

@app.get('/child')
@with_session
def child(req, session):
return str(session.get('foo'))

res = self._run(client.get('/'))
self.assertEqual(res.status_code, 200)
res = self._run(client.get('/child'))
self.assertEqual(res.text, 'bar')

def test_session_custom_path(self):
app = Microdot()
session_ext.initialize(app, secret_key='some-other-secret',
cookie_options={'path': '/child'})
client = TestClient(app)

@app.get('/')
@with_session
def index(req, session):
return str(session.get('foo'))

@app.get('/child')
@with_session
def child(req, session):
session['foo'] = 'bar'
session.save()
return ''

@app.get('/child/foo')
@with_session
def foo(req, session):
return str(session.get('foo'))

res = self._run(client.get('/child'))
self.assertEqual(res.status_code, 200)
res = self._run(client.get('/'))
self.assertEqual(res.text, 'None')
res = self._run(client.get('/child/foo'))
self.assertEqual(res.text, 'bar')

0 comments on commit 0151611

Please sign in to comment.