fix(airflow): fix JWT decode and verify

This commit is contained in:
Masaki Yatsu
2025-09-18 15:01:25 +09:00
parent dc30a37a42
commit 0106e22c84
3 changed files with 131 additions and 64 deletions

View File

@@ -1,8 +1,14 @@
import os
import logging
import json
import base64
import requests
from typing import Dict, Any, Optional
from urllib.parse import urljoin
from flask_appbuilder.security.manager import AUTH_OAUTH
from airflow.providers.fab.auth_manager.security_manager.override import FabAirflowSecurityManagerOverride
log = logging.getLogger(__name__)
AUTH_TYPE = AUTH_OAUTH
@@ -56,8 +62,7 @@ class KeycloakSecurityManager(FabAirflowSecurityManagerOverride):
"""Extract user info and roles from Keycloak token"""
if provider == "keycloak":
import jwt
import base64
import json
import requests
# Get access token
token = response.get("access_token")
@@ -66,69 +71,84 @@ class KeycloakSecurityManager(FabAirflowSecurityManagerOverride):
return None
try:
# Decode token without verification for debugging
# In production, you should verify the signature
parts = token.split('.')
if len(parts) == 3:
# Decode payload
payload_b64 = parts[1]
# Add padding if needed
payload_b64 += '=' * (4 - len(payload_b64) % 4)
payload = json.loads(base64.b64decode(payload_b64))
# Get JWKS URL from OpenID configuration
jwks_url = f"{OIDC_ISSUER}/.well-known/openid-configuration"
oidc_config = requests.get(jwks_url).json()
jwks_uri = oidc_config["jwks_uri"]
log.info(f"Decoded token payload: {payload}")
# Use PyJWT to decode and verify the token
from jwt import PyJWKClient
jwks_client = PyJWKClient(jwks_uri)
signing_key = jwks_client.get_signing_key_from_jwt(token)
# Extract user information
userinfo = {
"username": payload.get("preferred_username"),
"email": payload.get("email"),
"first_name": payload.get("given_name"),
"last_name": payload.get("family_name"),
payload = jwt.decode(
token,
signing_key.key,
algorithms=["RS256"],
audience=["airflow", "account"], # Keycloak uses both
issuer=OIDC_ISSUER,
options={"verify_signature": True, "verify_aud": False} # Relax audience check
)
log.info(f"JWT signature verified successfully")
log.debug(f"Decoded token payload keys: {list(payload.keys())}")
log.debug(f"Token has preferred_username: {bool(payload.get('preferred_username'))}")
log.debug(f"Token has email: {bool(payload.get('email'))}")
# Extract user information
userinfo = {
"username": payload.get("preferred_username"),
"email": payload.get("email"),
"first_name": payload.get("given_name"),
"last_name": payload.get("family_name"),
}
log.debug(f"Extracted userinfo keys: {list(userinfo.keys())}")
# Extract roles from different possible locations
roles = []
# Check realm access roles
realm_access = payload.get("realm_access", {})
realm_roles = realm_access.get("roles", [])
# Check resource access (client roles)
resource_access = payload.get("resource_access", {})
client_access = resource_access.get("airflow", {})
client_roles = client_access.get("roles", [])
# Check airflow_roles claim directly
direct_roles = payload.get("airflow_roles", [])
log.info(f"Realm roles: {realm_roles}")
log.info(f"Client roles: {client_roles}")
log.info(f"Direct airflow roles: {direct_roles}")
# Prefer client roles, then direct roles, then realm roles
if client_roles:
roles = client_roles
log.info(f"Using client roles: {roles}")
elif direct_roles:
roles = direct_roles
log.info(f"Using direct airflow roles: {roles}")
elif realm_roles:
# Map common realm roles to Airflow roles
role_mapping = {
'admin': 'Admin',
'user': 'User',
'viewer': 'Viewer'
}
roles = [role_mapping.get(role.lower(), 'Viewer') for role in realm_roles]
log.info(f"Using mapped realm roles: {roles}")
else:
roles = ['Viewer']
log.info("No roles found, defaulting to Viewer")
# Extract roles from different possible locations
roles = []
userinfo["role_keys"] = roles
log.info(f"User authentication successful for: {userinfo.get('username', 'unknown')}")
log.debug(f"Final userinfo keys: {list(userinfo.keys())}")
# Check realm access roles
realm_access = payload.get("realm_access", {})
realm_roles = realm_access.get("roles", [])
# Check resource access (client roles)
resource_access = payload.get("resource_access", {})
client_access = resource_access.get("airflow", {})
client_roles = client_access.get("roles", [])
# Check airflow_roles claim directly
direct_roles = payload.get("airflow_roles", [])
log.info(f"Realm roles: {realm_roles}")
log.info(f"Client roles: {client_roles}")
log.info(f"Direct airflow roles: {direct_roles}")
# Prefer client roles, then direct roles, then realm roles
if client_roles:
roles = client_roles
log.info(f"Using client roles: {roles}")
elif direct_roles:
roles = direct_roles
log.info(f"Using direct airflow roles: {roles}")
elif realm_roles:
# Map common realm roles to Airflow roles
role_mapping = {
'admin': 'Admin',
'user': 'User',
'viewer': 'Viewer'
}
roles = [role_mapping.get(role.lower(), 'Viewer') for role in realm_roles]
log.info(f"Using mapped realm roles: {roles}")
else:
roles = ['Viewer']
log.info("No roles found, defaulting to Viewer")
userinfo["role_keys"] = roles
log.info(f"Final userinfo: {userinfo}")
return userinfo
return userinfo
except Exception as e:
log.error(f"Error decoding JWT token: {e}")