mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
103 lines
3.7 KiB
Python
103 lines
3.7 KiB
Python
from fastapi import Depends, HTTPException, status, Request
|
|
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
|
|
from jose import JWTError, jwe, jwt
|
|
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
|
from cryptography.hazmat.primitives import hashes
|
|
from cryptography.hazmat.backends import default_backend
|
|
from typing import Any
|
|
import json
|
|
from backend.app.config import get_settings
|
|
from backend.app.utils import setup_logger
|
|
|
|
logger = setup_logger()
|
|
|
|
api_key_header = APIKeyHeader(name=get_settings().API_KEY_NAME, auto_error=False)
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
|
|
|
|
async def validate_api_key(api_key_header: str = Depends(api_key_header)):
|
|
if api_key_header != get_settings().API_KEY:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN, detail="Could not validate credentials"
|
|
)
|
|
return api_key_header
|
|
|
|
|
|
def get_user(user_id: str):
|
|
# Define here how to fetch a user from your database
|
|
# using the user_id. Here's a simple placeholder implementation:
|
|
# TODO: Update this function to fetch a user from your actual database
|
|
if get_settings().ENVIRONMENT == "local":
|
|
return {"id": user_id, "name": "Dummy User"}
|
|
else:
|
|
if user_id == "known_id":
|
|
user = {"id": user_id, "name": "Known User"}
|
|
else:
|
|
print("IMPLEMENT ME! - fetch user from database")
|
|
user = {"id": user_id, "name": "Dummy User"}
|
|
|
|
return user
|
|
|
|
|
|
def get_derived_encryption_key(secret: str) -> Any:
|
|
context = str.encode("NextAuth.js Generated Encryption Key")
|
|
hkdf = HKDF(
|
|
algorithm=hashes.SHA256(),
|
|
length=32,
|
|
salt=b"",
|
|
info=context,
|
|
backend=default_backend()
|
|
)
|
|
return hkdf.derive(secret.encode())
|
|
|
|
|
|
def get_token_payload(token: str, secret: str) -> dict[str, Any]:
|
|
# This repo: https://github.com/jackrdye/Decrypt-NextAuth-JWE-getToken/tree/main
|
|
# Contains examples of how to decrypt the JWE token and extract the payload as has been implemented by
|
|
# next-auth
|
|
encryption_key = get_derived_encryption_key(secret)
|
|
payload_str = jwe.decrypt(token, encryption_key).decode()
|
|
payload: dict[str, Any] = json.loads(payload_str)
|
|
return payload
|
|
|
|
|
|
def validate_jwt_token(token: str = Depends(oauth2_scheme)):
|
|
credentials_exception = HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Could not validate credentials",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
try:
|
|
# The SECRET_KEY should match the NEXTAUTH_SECRET in the front end
|
|
try:
|
|
payload = get_token_payload(token, get_settings().SECRET_KEY)
|
|
except jwt.ExpiredSignatureError:
|
|
logger.error("JWT token has expired.")
|
|
raise credentials_exception
|
|
except Exception as e:
|
|
logger.error(f"An error occurred while validating the token: {e}")
|
|
raise credentials_exception
|
|
|
|
user_id: str = payload.get("dbId")
|
|
if user_id is None:
|
|
logger.error("No user ID found in the JWT token.")
|
|
raise credentials_exception
|
|
|
|
user = get_user(user_id=user_id)
|
|
if user is None:
|
|
logger.error(f"No user found for user ID: {user_id}.")
|
|
raise credentials_exception
|
|
|
|
return user
|
|
except JWTError:
|
|
logger.error("An error occurred while decoding the JWT token.")
|
|
raise credentials_exception
|
|
|
|
|
|
async def validate_token(token: str = Depends(oauth2_scheme), request: Request = None):
|
|
token_data = validate_jwt_token(token)
|
|
if not token_data:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN, detail="Could not validate credentials"
|
|
)
|
|
return token
|