from fastapi import Depends, HTTPException, status, Request from fastapi.security import APIKeyHeader, OAuth2PasswordBearer from jose import JWTError, jwe, jwt from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.hazmat.primitives import hashes from cryptography.hazmat.backends import default_backend from typing import Any import json from app.config import get_settings from app.utils import logger api_key_header = APIKeyHeader(name=get_settings().API_KEY_NAME, auto_error=False) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") async def validate_api_key(api_key_header: str = Depends(api_key_header)): if api_key_header != get_settings().API_KEY: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Could not validate credentials" ) return api_key_header def get_user(user_id: str): # Define here how to fetch a user from your database # using the user_id. Here's a simple placeholder implementation: # TODO: Update this function to fetch a user from your actual database if get_settings().ENVIRONMENT == "local": return {"id": user_id, "name": "Dummy User"} else: if user_id == "known_id": user = {"id": user_id, "name": "Known User"} else: print("IMPLEMENT ME! - fetch user from database") user = {"id": user_id, "name": "Dummy User"} return user def get_derived_encryption_key(secret: str) -> Any: context = str.encode("NextAuth.js Generated Encryption Key") hkdf = HKDF( algorithm=hashes.SHA256(), length=32, salt=b"", info=context, backend=default_backend() ) return hkdf.derive(secret.encode()) def get_token_payload(token: str, secret: str) -> dict[str, Any]: # This repo: https://github.com/jackrdye/Decrypt-NextAuth-JWE-getToken/tree/main # Contains examples of how to decrypt the JWE token and extract the payload as has been implemented by # next-auth encryption_key = get_derived_encryption_key(secret) payload_str = jwe.decrypt(token, encryption_key).decode() payload: dict[str, Any] = json.loads(payload_str) return payload def validate_jwt_token(token: str = Depends(oauth2_scheme)): credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: # The SECRET_KEY should match the NEXTAUTH_SECRET in the front end try: payload = get_token_payload(token, get_settings().SECRET_KEY) except jwt.ExpiredSignatureError: logger.error("JWT token has expired.") raise credentials_exception except Exception as e: logger.error(f"An error occurred while validating the token: {e}") raise credentials_exception user_id: str = payload.get("dbId") if user_id is None: logger.error("No user ID found in the JWT token.") raise credentials_exception user = get_user(user_id=user_id) if user is None: logger.error(f"No user found for user ID: {user_id}.") raise credentials_exception return user except JWTError: logger.error("An error occurred while decoding the JWT token.") raise credentials_exception async def validate_token(token: str = Depends(oauth2_scheme), request: Request = None): print("VALIDATING - PRINT") logger.info("Validating token") logger.info(token) logger.info("Secret") logger.info(get_settings().SECRET_KEY) token_data = validate_jwt_token(token) if not token_data: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Could not validate credentials" ) return token