import base64 import secrets import uuid from collections import defaultdict from dataclasses import dataclass from datetime import ( datetime, timedelta, ) from types import SimpleNamespace from typing import Optional from unittest.mock import ( MagicMock, patch, ) import jwt import pytest # Tools from hazmat should only be used for testing! from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric.rsa import ( RSAPrivateKey, RSAPublicKey, ) from jwt import ( InvalidAudienceError, InvalidIssuerError, InvalidSignatureError, ) from social_core.backends.open_id_connect import OpenIdConnectAuth from sqlalchemy import create_engine from sqlalchemy.orm import Session from galaxy import model from galaxy.authnz.managers import AuthnzManager from galaxy.authnz.oidc_utils import decode_access_token as decode_access_token_oidc from galaxy.authnz.psa_authnz import ( AUTH_PIPELINE, decode_access_token, PSAAuthnz, sync_user_profile, ) @pytest.fixture(scope="module") def engine(): return create_engine("sqlite:///:memory:") @pytest.fixture(scope="module") def db_session(engine): model.mapper_registry.metadata.create_all(engine) session = Session(bind=engine) yield session session.rollback() session.close() @pytest.fixture def mock_oidc_backend_config_file(tmp_path): config = """ login.example.com gxyclient dummyclientsecret $galaxy_url/authnz/$provider_name/callback true gxyclient """ filename = tmp_path / "oidc_backends_config.xml" filename.write_text(config) return filename @pytest.fixture def mock_oidc_config_file(tmp_path): config = """ """ filename = tmp_path / "oidc_config.xml" filename.write_text(config) return filename @dataclass class AuthTokenData: """ Stores all the information needed to generate an access token and test that it can be decoded. """ private_key: RSAPrivateKey public_key: RSAPublicKey access_token_str: str access_token_data: dict key_id: str def create_access_token( email: str = "user@example.com", roles: Optional[list[str]] = None, iss: str = "https://issuer.example.com", sub: Optional[str] = None, iat: Optional[int] = None, exp: Optional[int] = None, aud: str = "https://audience.example.com", scope: Optional[list[str]] = None, azp: Optional[str] = None, permissions: Optional[list[str]] = None, algorithm: str = "RS256", public_key_id: str = "example-key", ) -> AuthTokenData: """ Create an OIDC access token along with a dummy private and public key for signing it. Each field of the payload can be set, but otherwise will get a sensible default (e.g. expiry time in the future). """ if roles is None: roles = [] # Generate a random alphanumeric ID if sub is None: sub = uuid.uuid4().hex if iat is None: iat = int(datetime.now().strftime("%s")) if exp is None: exp = int((datetime.now() + timedelta(hours=1)).strftime("%s")) if azp is None: azp = uuid.uuid4().hex if permissions is None: permissions = [] payload = { "email": email, "biocommons.org.au/roles": roles, "iss": iss, "sub": sub, "aud": [aud], "iat": iat, "exp": exp, "scope": scope, "azp": azp, "permissions": permissions, } public_key, private_key = generate_public_private_key_pair() access_token_encoded = jwt.encode( payload, key=private_key, algorithm=algorithm, headers={"kid": public_key_id}, ) return AuthTokenData( private_key=private_key, public_key=public_key, access_token_str=access_token_encoded, access_token_data=payload, key_id=public_key_id, ) def generate_public_private_key_pair(): # Code from https://fmpm.dev/mocking-auth0-tokens private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) public_key = private_key.public_key() return public_key, private_key def get_jwk_data(public_key: RSAPublicKey): """ Format an RSAPublicKey into the structure PyJWK expects. """ def base64url_uint(val: int) -> str: """Base64url encode a big integer.""" b = val.to_bytes((val.bit_length() + 7) // 8, "big") return base64.urlsafe_b64encode(b).rstrip(b"=").decode("ascii") numbers = public_key.public_numbers() return {"kty": "RSA", "n": base64url_uint(numbers.n), "e": base64url_uint(numbers.e)} def test_decode_access_token(): """ Test we can decode a valid access token. """ # Set up dummy data/mocks dummy_access_token = create_access_token() mock_social = MagicMock() mock_social.extra_data.get.return_value = dummy_access_token.access_token_str # Create a mock backend that's recognized as an OIDC backend mock_backend = MagicMock() public_key_data = get_jwk_data(dummy_access_token.public_key) mock_backend.find_valid_key.return_value = public_key_data mock_backend.strategy.config = {"accepted_audiences": dummy_access_token.access_token_data["aud"]} mock_backend.id_token_issuer.return_value = dummy_access_token.access_token_data["iss"] # Make isinstance() checks pass by setting __class__ after configuring the mock mock_backend.__class__ = OpenIdConnectAuth # type: ignore[assignment] # Check that access token is decoded successfully to return the original data data = decode_access_token(social=mock_social, backend=mock_backend) assert data["access_token"] == dummy_access_token.access_token_data def test_decode_access_token_invalid_key(): """ Test that decoding fails when an invalid key is provided. """ # Set up example data and mocks dummy_access_token = create_access_token() incorrect_public_key, incorrect_private_key = generate_public_private_key_pair() mock_social = MagicMock() mock_social.extra_data.get.return_value = dummy_access_token.access_token_str # Create a mock backend that's recognized as an OIDC backend mock_backend = MagicMock() incorrect_public_key_data = get_jwk_data(incorrect_public_key) mock_backend.find_valid_key.return_value = incorrect_public_key_data mock_backend.strategy.config = {"accepted_audiences": dummy_access_token.access_token_data["aud"]} mock_backend.id_token_issuer.return_value = dummy_access_token.access_token_data["iss"] # Make isinstance() checks pass by setting __class__ after configuring the mock mock_backend.__class__ = OpenIdConnectAuth # type: ignore[assignment] # Test that the decode function returns None for the access token result = decode_access_token(social=mock_social, backend=mock_backend) assert result["access_token"] is None # Test the actual decoding raises expected error with pytest.raises(InvalidSignatureError): decode_access_token_oidc(token_str=dummy_access_token.access_token_str, backend=mock_backend) def test_decode_access_token_invalid_issuer(): """ Test that a token with an invalid issuer (doesn't match what we expect) is not decoded/returned """ # Set up example data and mocks dummy_access_token = create_access_token(iss="https://invalid.url") mock_social = MagicMock() mock_social.extra_data.get.return_value = dummy_access_token.access_token_str # Create a mock backend that's recognized as an OIDC backend mock_backend = MagicMock() public_key_data = get_jwk_data(dummy_access_token.public_key) mock_backend.find_valid_key.return_value = public_key_data mock_backend.strategy.config = {"accepted_audiences": dummy_access_token.access_token_data["aud"]} mock_backend.id_token_issuer.return_value = "https://validissuer.com" # Make isinstance() checks pass by setting __class__ after configuring the mock mock_backend.__class__ = OpenIdConnectAuth # type: ignore[assignment] # Test that the decode function returns None for the access token result = decode_access_token(social=mock_social, backend=mock_backend) assert result["access_token"] is None # Test the actual decoding raises expected error with pytest.raises(InvalidIssuerError): decode_access_token_oidc(token_str=dummy_access_token.access_token_str, backend=mock_backend) def test_decode_access_token_invalid_audience(): """ Test that a token with an invalid audience (doesn't match what we expect) is not decoded/returned """ # Set up example data and mocks dummy_access_token = create_access_token(aud="https://invalidaudience.url") mock_social = MagicMock() mock_social.extra_data.get.return_value = dummy_access_token.access_token_str # Create a mock backend that's recognized as an OIDC backend mock_backend = MagicMock() public_key_data = get_jwk_data(dummy_access_token.public_key) mock_backend.find_valid_key.return_value = public_key_data mock_backend.strategy.config = {"accepted_audiences": ["https://validaudience.url"]} mock_backend.id_token_issuer.return_value = dummy_access_token.access_token_data["iss"] # Make isinstance() checks pass by setting __class__ after configuring the mock mock_backend.__class__ = OpenIdConnectAuth # type: ignore[assignment] # Test that the decode function returns None for the access token result = decode_access_token(social=mock_social, backend=mock_backend) assert result["access_token"] is None # Test the actual decoding raises expected error with pytest.raises(InvalidAudienceError): decode_access_token_oidc(token_str=dummy_access_token.access_token_str, backend=mock_backend) def test_decode_access_token_opaque_token(): """ Test that when the access token is opaque (e.g. those returned by Google Auth), we don't decode and just return None """ def generate_google_style_token(): prefix = "ya29" part1 = secrets.token_urlsafe(32) part2 = secrets.token_urlsafe(64) return f"{prefix}.{part1}{part2}" opaque_token = generate_google_style_token() mock_social = MagicMock() mock_social.extra_data.get.return_value = opaque_token result = decode_access_token(social=mock_social, backend=MagicMock()) assert result["access_token"] is None def test_oidc_config_custom_auth_pipeline(mock_oidc_config_file, mock_oidc_backend_config_file): custom_auth_pipeline = ("custom", "auth", "steps") mock_app = MagicMock() mock_app.config = SimpleNamespace( oidc_auth_pipeline=custom_auth_pipeline, oidc_auth_pipeline_extra=None, oidc=defaultdict(dict), fixed_delegated_auth=False, ) manager = AuthnzManager( app=mock_app, oidc_config_file=mock_oidc_config_file, oidc_backends_config_file=mock_oidc_backend_config_file ) psa_authnz = PSAAuthnz( provider="oidc", oidc_config=manager.oidc_config, oidc_backend_config=manager.oidc_backends_config, app_config=mock_app.config, ) assert psa_authnz.config["SOCIAL_AUTH_PIPELINE"] == custom_auth_pipeline def test_oidc_config_auth_pipeline_extra(mock_oidc_config_file, mock_oidc_backend_config_file): """ Test that the oidc_auth_pipeline_extra config option is used to extend the auth pipeline. """ custom_auth_pipeline_extra = ["extra", "auth", "steps"] mock_app = MagicMock() mock_app.config = SimpleNamespace( oidc_auth_pipeline=None, oidc_auth_pipeline_extra=custom_auth_pipeline_extra, oidc=defaultdict(dict), fixed_delegated_auth=False, ) manager = AuthnzManager( app=mock_app, oidc_config_file=mock_oidc_config_file, oidc_backends_config_file=mock_oidc_backend_config_file ) psa_authnz = PSAAuthnz( provider="oidc", oidc_config=manager.oidc_config, oidc_backend_config=manager.oidc_backends_config, app_config=mock_app.config, ) assert psa_authnz.config["SOCIAL_AUTH_PIPELINE"] == AUTH_PIPELINE + tuple(custom_auth_pipeline_extra) def test_oidc_config_custom_auth_pipeline_and_extra(mock_oidc_config_file, mock_oidc_backend_config_file): """ Test that the oidc_auth_pipeline_extra config option is used to extend the auth pipeline, when a custom auth pipeline is also specified in the config file. """ custom_auth_pipeline = ("custom", "auth", "steps") custom_auth_pipeline_extra = ["extra", "auth", "steps"] mock_app = MagicMock() mock_app.config = SimpleNamespace( oidc_auth_pipeline=custom_auth_pipeline, oidc_auth_pipeline_extra=custom_auth_pipeline_extra, oidc=defaultdict(dict), fixed_delegated_auth=False, ) manager = AuthnzManager( app=mock_app, oidc_config_file=mock_oidc_config_file, oidc_backends_config_file=mock_oidc_backend_config_file ) psa_authnz = PSAAuthnz( provider="oidc", oidc_config=manager.oidc_config, oidc_backend_config=manager.oidc_backends_config, app_config=mock_app.config, ) assert psa_authnz.config["SOCIAL_AUTH_PIPELINE"] == custom_auth_pipeline + tuple(custom_auth_pipeline_extra) def test_sync_user_profile_skips_when_account_interface_enabled(): manager = MagicMock() session = MagicMock() app_config = SimpleNamespace(enable_account_interface=True, enable_notification_system=True) app = SimpleNamespace(config=app_config, user_manager=manager, notification_manager=SimpleNamespace()) trans = SimpleNamespace(app=app, sa_session=session) strategy = SimpleNamespace(config={"GALAXY_TRANS": trans, "FIXED_DELEGATED_AUTH": True}) user = SimpleNamespace(id=1, preferences={}) details = {"email": "new@example.com", "username": "newname"} with patch("galaxy.webapps.galaxy.services.notifications.NotificationService.send_notification_internal") as notify: sync_user_profile(strategy=strategy, details=details, user=user) manager.update_email.assert_not_called() manager.update_username.assert_not_called() session.commit.assert_not_called() notify.assert_not_called() def test_sync_user_profile_skips_when_fixed_delegated_auth_disabled(): manager = MagicMock() session = MagicMock() app_config = SimpleNamespace(enable_account_interface=False, enable_notification_system=True) app = SimpleNamespace(config=app_config, user_manager=manager, notification_manager=SimpleNamespace()) trans = SimpleNamespace(app=app, sa_session=session) strategy = SimpleNamespace(config={"GALAXY_TRANS": trans, "FIXED_DELEGATED_AUTH": False}) user = SimpleNamespace(id=2, email="old@example.com", username="oldname", preferences={}) details = {"email": "new@example.com", "username": "newname"} with patch("galaxy.webapps.galaxy.services.notifications.NotificationService.send_notification_internal") as notify: sync_user_profile(strategy=strategy, details=details, user=user) manager.update_email.assert_not_called() manager.update_username.assert_not_called() session.commit.assert_not_called() notify.assert_not_called() def test_sync_user_profile_updates_when_account_interface_disabled(): manager = MagicMock() session = MagicMock() app_config = SimpleNamespace(enable_account_interface=False, enable_notification_system=True) notification_manager = SimpleNamespace(notifications_enabled=True) app = SimpleNamespace(config=app_config, user_manager=manager, notification_manager=notification_manager) trans = SimpleNamespace(app=app, sa_session=session) strategy = SimpleNamespace(config={"GALAXY_TRANS": trans, "FIXED_DELEGATED_AUTH": True}) user = SimpleNamespace(id=2, email="old@example.com", username="oldname", preferences={}) details = {"email": "new@example.com", "username": "newname"} with patch("galaxy.webapps.galaxy.services.notifications.NotificationService.send_notification_internal") as notify: sync_user_profile(strategy=strategy, details=details, user=user) manager.update_email.assert_called_once_with( trans, user, "new@example.com", commit=False, send_activation_email=False ) manager.update_username.assert_called_once_with(trans, user, "newname", commit=False) assert session.commit.call_count == 1 notify.assert_called_once()