from fastapi import Depends, HTTPException, status, Request from fastapi.security import APIKeyHeader, OAuth2PasswordBearer from jose import JWTError, jwe, jwt from Crypto.Protocol.KDF import HKDF from Crypto.Hash import SHA256 from typing import Any import json import logging from app.config import get_settings logger = logging.getLogger(__name__) api_key_header = APIKeyHeader(name=get_settings().API_KEY_NAME, auto_error=False) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") async def validate_api_key(api_key_header: str = Depends(api_key_header)): if api_key_header != get_settings().API_KEY: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Could not validate credentials" ) return api_key_header def get_user(user_id: str): # Define here how to fetch a user from your database # using the user_id. Here's a simple placeholder implementation: # TODO: Update this function to fetch a user from your actual database if get_settings().ENVIRONMENT == "local": return {"id": user_id, "name": "Dummy User"} else: user = None if user_id == "known_id": user = {"id": user_id, "name": "Known User"} else: print("IMPLEMENT ME! - fetch user from database") user = {"id": user_id, "name": "Dummy User"} return user def get_derived_encryption_key(secret: str) -> Any: context = str.encode("NextAuth.js Generated Encryption Key") return HKDF( master=secret.encode(), key_len=32, salt="".encode(), hashmod=SHA256, num_keys=1, context=context, ) def get_token_payload(token: str, secret: str) -> dict[str, Any]: # This repo: https://github.com/jackrdye/Decrypt-NextAuth-JWE-getToken/tree/main # Contains examples of how to decrypt the JWE token and extract the payload as has been implemented by # next-auth encryption_key = get_derived_encryption_key(secret) payload_str = jwe.decrypt(token, encryption_key).decode() payload: dict[str, Any] = json.loads(payload_str) return payload def validate_jwt_token(token: str = Depends(oauth2_scheme)): credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: # The SECRET_KEY should match the NEXTAUTH_SECRET in the front end try: payload = get_token_payload(token, get_settings().SECRET_KEY) except jwt.ExpiredSignatureError: logger.error("JWT token has expired.") raise credentials_exception except Exception as e: logger.error(f"An error occurred while validating the token: {e}") raise credentials_exception user_id: str = payload.get("dbId") if user_id is None: logger.error("No user ID found in the JWT token.") raise credentials_exception user = get_user(user_id=user_id) if user is None: logger.error(f"No user found for user ID: {user_id}.") raise credentials_exception return user except JWTError: logger.error("An error occurred while decoding the JWT token.") raise credentials_exception async def validate_token(token: str = Depends(oauth2_scheme), request: Request = None): token_data = validate_jwt_token(token) if not token_data: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Could not validate credentials" ) return token