Model/backend/app/dependencies.py
2026-05-10 21:07:16 +00:00

120 lines
4.2 KiB
Python

from fastapi import Depends, HTTPException, status, Request
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
from jose import JWTError, jwe, jwt
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.backends import default_backend
from typing import Any
import json
from sqlalchemy.orm import sessionmaker
from backend.app.config import get_settings
from backend.app.utils import setup_logger
from backend.app.db.connection import db_engine
from backend.app.db.models.users import UserModel
from backend.app import exceptions
logger = setup_logger()
api_key_header = APIKeyHeader(name=get_settings().API_KEY_NAME, auto_error=False)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
async def validate_api_key(request: Request, api_key_header: str = Depends(api_key_header)):
if request.url.path == "/health":
return None
if api_key_header != get_settings().API_KEY:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Could not validate credentials"
)
return api_key_header
def validate_user(user_id) -> (bool, int):
Session = sessionmaker(bind=db_engine)
with Session() as session:
user = session.query(UserModel).filter_by(id=user_id).first()
if user is not None:
# This methodology is temp but we'll just take the first batch
return True, user
return False, None
def get_user(user_id: str):
# Define here how to fetch a user from your database
# using the user_id. Here's a simple placeholder implementation:
# TODO: Update this function to fetch a user from your actual database
if get_settings().ENVIRONMENT == "local":
return {"id": user_id, "name": "Dummy User"}
else:
is_valid, user = validate_user(user_id)
if not is_valid or user is None:
exceptions.manage_exception(status_code=401)
return None
return user
def get_derived_encryption_key(secret: str) -> Any:
context = str.encode("NextAuth.js Generated Encryption Key")
hkdf = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=b"",
info=context,
backend=default_backend()
)
return hkdf.derive(secret.encode())
def get_token_payload(token: str, secret: str) -> dict[str, Any]:
# This repo: https://github.com/jackrdye/Decrypt-NextAuth-JWE-getToken/tree/main
# Contains examples of how to decrypt the JWE token and extract the payload as has been implemented by
# next-auth
encryption_key = get_derived_encryption_key(secret)
payload_str = jwe.decrypt(token, encryption_key).decode()
payload: dict[str, Any] = json.loads(payload_str)
return payload
def validate_jwt_token(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
# The SECRET_KEY should match the NEXTAUTH_SECRET in the front end
try:
payload = get_token_payload(token, get_settings().SECRET_KEY)
except jwt.ExpiredSignatureError:
logger.error("JWT token has expired.")
raise credentials_exception
except Exception as e:
logger.error(f"An error occurred while validating the token: {e}")
raise credentials_exception
user_id: str = payload.get("dbId")
if user_id is None:
logger.error("No user ID found in the JWT token.")
raise credentials_exception
user = get_user(user_id=user_id)
if user is None:
logger.error(f"No user found for user ID: {user_id}.")
raise credentials_exception
return user
except JWTError:
logger.error("An error occurred while decoding the JWT token.")
raise credentials_exception
async def validate_token(token: str = Depends(oauth2_scheme), request: Request = None):
token_data = validate_jwt_token(token)
if not token_data:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Could not validate credentials"
)
return token