from fastapi import Depends, HTTPException, status, Request from fastapi.security import APIKeyHeader, OAuth2PasswordBearer from jose import JWTError, jwe, jwt from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.hazmat.primitives import hashes from cryptography.hazmat.backends import default_backend from typing import Any import json from sqlalchemy.orm import sessionmaker from backend.app.config import get_settings from backend.app.utils import setup_logger from backend.app.db.connection import db_engine from backend.app.db.models.users import UserModel from backend.app import exceptions logger = setup_logger() api_key_header = APIKeyHeader(name=get_settings().API_KEY_NAME, auto_error=False) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") async def validate_api_key(request: Request, api_key_header: str = Depends(api_key_header)): if request.url.path == "/health": return None if api_key_header != get_settings().API_KEY: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Could not validate credentials" ) return api_key_header def validate_user(user_id) -> (bool, int): Session = sessionmaker(bind=db_engine) with Session() as session: user = session.query(UserModel).filter_by(id=user_id).first() if user is not None: # This methodology is temp but we'll just take the first batch return True, user return False, None def get_user(user_id: str): # Define here how to fetch a user from your database # using the user_id. Here's a simple placeholder implementation: # TODO: Update this function to fetch a user from your actual database if get_settings().ENVIRONMENT == "local": return {"id": user_id, "name": "Dummy User"} else: is_valid, user = validate_user(user_id) if not is_valid or user is None: exceptions.manage_exception(status_code=401) return None return user def get_derived_encryption_key(secret: str) -> Any: context = str.encode("NextAuth.js Generated Encryption Key") hkdf = HKDF( algorithm=hashes.SHA256(), length=32, salt=b"", info=context, backend=default_backend() ) return hkdf.derive(secret.encode()) def get_token_payload(token: str, secret: str) -> dict[str, Any]: # This repo: https://github.com/jackrdye/Decrypt-NextAuth-JWE-getToken/tree/main # Contains examples of how to decrypt the JWE token and extract the payload as has been implemented by # next-auth encryption_key = get_derived_encryption_key(secret) payload_str = jwe.decrypt(token, encryption_key).decode() payload: dict[str, Any] = json.loads(payload_str) return payload def validate_jwt_token(token: str = Depends(oauth2_scheme)): credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: # The SECRET_KEY should match the NEXTAUTH_SECRET in the front end try: payload = get_token_payload(token, get_settings().SECRET_KEY) except jwt.ExpiredSignatureError: logger.error("JWT token has expired.") raise credentials_exception except Exception as e: logger.error(f"An error occurred while validating the token: {e}") raise credentials_exception user_id: str = payload.get("dbId") if user_id is None: logger.error("No user ID found in the JWT token.") raise credentials_exception user = get_user(user_id=user_id) if user is None: logger.error(f"No user found for user ID: {user_id}.") raise credentials_exception return user except JWTError: logger.error("An error occurred while decoding the JWT token.") raise credentials_exception async def validate_token(token: str = Depends(oauth2_scheme), request: Request = None): token_data = validate_jwt_token(token) if not token_data: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Could not validate credentials" ) return token