Skip to content

Commit

Permalink
update is_valid_jwt function
Browse files Browse the repository at this point in the history
  • Loading branch information
silentworks committed Nov 22, 2024
1 parent 3e0934f commit 1314595
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 17 deletions.
8 changes: 5 additions & 3 deletions supabase/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from storage3.constants import DEFAULT_TIMEOUT as DEFAULT_STORAGE_CLIENT_TIMEOUT
from supafunc import AsyncFunctionsClient

from supabase.lib.helpers import is_jwt
from supabase.lib.helpers import is_valid_jwt

from ..lib.client_options import AsyncClientOptions as ClientOptions
from .auth_client import AsyncSupabaseAuthClient
Expand Down Expand Up @@ -280,7 +280,7 @@ def _create_auth_header(self, token: str):

def _get_auth_headers(self, authorization: Optional[str] = None) -> Dict[str, str]:
if authorization is None:
if is_jwt(self.supabase_key):
if is_valid_jwt(self.supabase_key):
authorization = self.options.headers.get(
"Authorization", self._create_auth_header(self.supabase_key)
)
Expand All @@ -294,7 +294,9 @@ def _get_auth_headers(self, authorization: Optional[str] = None) -> Dict[str, st
def _listen_to_auth_events(
self, event: AuthChangeEvent, session: Optional[Session]
):
default_access_token = self.supabase_key if is_jwt(self.supabase_key) else None
default_access_token = (
self.supabase_key if is_valid_jwt(self.supabase_key) else None
)
access_token = default_access_token
if event in ["SIGNED_IN", "TOKEN_REFRESHED", "SIGNED_OUT"]:
# reset postgrest and storage instance on event change
Expand Down
8 changes: 5 additions & 3 deletions supabase/_sync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from storage3.constants import DEFAULT_TIMEOUT as DEFAULT_STORAGE_CLIENT_TIMEOUT
from supafunc import SyncFunctionsClient

from supabase.lib.helpers import is_jwt
from supabase.lib.helpers import is_valid_jwt

from ..lib.client_options import SyncClientOptions as ClientOptions
from .auth_client import SyncSupabaseAuthClient
Expand Down Expand Up @@ -279,7 +279,7 @@ def _create_auth_header(self, token: str):

def _get_auth_headers(self, authorization: Optional[str] = None) -> Dict[str, str]:
if authorization is None:
if is_jwt(self.supabase_key):
if is_valid_jwt(self.supabase_key):
authorization = self.options.headers.get(
"Authorization", self._create_auth_header(self.supabase_key)
)
Expand All @@ -293,7 +293,9 @@ def _get_auth_headers(self, authorization: Optional[str] = None) -> Dict[str, st
def _listen_to_auth_events(
self, event: AuthChangeEvent, session: Optional[Session]
):
default_access_token = self.supabase_key if is_jwt(self.supabase_key) else None
default_access_token = (
self.supabase_key if is_valid_jwt(self.supabase_key) else None
)
access_token = default_access_token
if event in ["SIGNED_IN", "TOKEN_REFRESHED", "SIGNED_OUT"]:
# reset postgrest and storage instance on event change
Expand Down
36 changes: 25 additions & 11 deletions supabase/lib/helpers.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,41 @@
import re
from typing import Dict

BASE64URL_REGEX = r"^([a-z0-9_-]{4})*($|[a-z0-9_-]{3}$|[a-z0-9_-]{2}$)$"


def is_jwt(value: str) -> bool:
if value.startswith("Bearer "):
value = value.replace("Bearer ", "")
def is_valid_jwt(value: str) -> bool:
"""Checks if value looks like a JWT, does not do any extra parsing."""
if not isinstance(value, str):
return False

# Remove trailing whitespaces if any.
value = value.strip()
if not value:
return False

parts = value.split(".")
if len(parts) != 3:
# Remove "Bearer " prefix if any.
if value.startswith("Bearer "):
value = value[7:]

# Valid JWT must have 2 dots (Header.Paylod.Signature)
if value.count(".") != 2:
return False

# loop through the parts and test against regex
for part in parts:
if len(part) < 4 or not re.search(BASE64URL_REGEX, part, re.IGNORECASE):
for part in value.split("."):
if not re.search(BASE64URL_REGEX, part, re.IGNORECASE):
return False

return True


def check_authorization_header(headers):
def check_authorization_header(headers: Dict[str, str]):
authorization = headers.get("Authorization")
if not authorization:
return

if authorization.startswith("Bearer "):
if not is_valid_jwt(authorization):
raise ValueError(
"create_client called with global Authorization header that does not contain a JWT"
)

return True

0 comments on commit 1314595

Please sign in to comment.