added actual jwe decrpytion to fastapi

This commit is contained in:
Khalim Conn-Kowlessar 2023-07-17 17:01:07 +01:00
parent 748e87e74c
commit dbe0a27869
6 changed files with 60 additions and 16 deletions

View file

@ -6,7 +6,6 @@ class Settings(BaseSettings):
API_KEY: str
API_KEY_NAME: str = "X-API-KEY"
SECRET_KEY: str
ALGORITHM: str
ENVIRONMENT: str
PLAN_TRIGGER_BUCKET: str

View file

@ -1,6 +1,10 @@
from fastapi import Depends, HTTPException, status, Request
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
from jose import jwt, JWTError
from jose import JWTError, jwe
from Crypto.Protocol.KDF import HKDF
from Crypto.Hash import SHA256
from typing import Any
import json
from app.config import get_settings
@ -29,6 +33,28 @@ def get_user(user_id: str):
return user
def get_derived_encryption_key(secret: str) -> Any:
context = str.encode("NextAuth.js Generated Encryption Key")
return HKDF(
master=secret.encode(),
key_len=32,
salt="".encode(),
hashmod=SHA256,
num_keys=1,
context=context,
)
def get_token_payload(token: str, secret: str) -> dict[str, Any]:
# This repo: https://github.com/jackrdye/Decrypt-NextAuth-JWE-getToken/tree/main
# Contains examples of how to decrypt the JWE token and extract the payload as has been implemented by
# next-auth
encryption_key = get_derived_encryption_key(secret)
payload_str = jwe.decrypt(token, encryption_key).decode()
payload: dict[str, Any] = json.loads(payload_str)
return payload
def validate_jwt_token(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -37,8 +63,12 @@ def validate_jwt_token(token: str = Depends(oauth2_scheme)):
)
try:
# The SECRET_KEY should match the NEXTAUTH_SECRET in the front end
payload = jwt.decode(token, get_settings().SECRET_KEY, algorithms=[get_settings().ALGORITHM])
user_id: str = payload.get("sub")
try:
payload = get_token_payload(token, get_settings().SECRET_KEY)
except Exception as e:
print(e)
raise credentials_exception
user_id: str = payload.get("dbId")
if user_id is None:
raise credentials_exception
user = get_user(user_id=user_id)

View file

@ -1,7 +1,9 @@
from fastapi import APIRouter, HTTPException, status
from jose import jwt
from jose import jwt, jwe
import json
import datetime
from app.config import get_settings
from app.dependencies import get_derived_encryption_key
router = APIRouter(
prefix="/local",
@ -9,14 +11,24 @@ router = APIRouter(
)
def create_dummy_token(secret: str, algorithm: str):
data = {
"sub": "known_id",
"name": "Test User",
"iat": datetime.datetime.utcnow(),
"exp": datetime.datetime.utcnow() + datetime.timedelta(minutes=30)
def create_dummy_token(secret: str) -> str:
"""
Create a JWE token using NextAuth.js encryption method
Arguments:
sub -- The subject or identifier for who the token is for (usually a user id)
secret -- The secret key to encrypt the token. Should be the same as the key used in NextAuth.js
exp -- Optional expiry time for the token. If not provided, token does not expire
Returns:
A string containing the JWE token
"""
claims = {
"dbId": "known_id",
}
return jwt.encode(data, secret, algorithm=algorithm)
token = jwe.encrypt(json.dumps(claims), get_derived_encryption_key(secret), algorithm="dir", encryption="A256GCM")
return token
@router.get("/dummy-token")
@ -25,4 +37,4 @@ async def dummy_token():
if settings.ENVIRONMENT != "local":
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN,
detail="Dummy token can only be generated in local environment")
return {"dummy_token": create_dummy_token(settings.SECRET_KEY, settings.ALGORITHM)}
return {"dummy_token": create_dummy_token(settings.SECRET_KEY)}

View file

@ -1,6 +1,7 @@
from fastapi import FastAPI, Depends
from mangum import Mangum
from app.portfolio import router as portfolio_router
from app.plan import router as plan_router
from app.dependencies import validate_api_key
from app.config import get_settings
@ -9,6 +10,7 @@ app = FastAPI(dependencies=[Depends(validate_api_key)])
app.include_router(portfolio_router.router, prefix="/v1")
app.include_router(plan_router.router, prefix="/v1")
if get_settings().ENVIRONMENT == "local":
from app.local import router as local_router

View file

@ -4,7 +4,7 @@ from pydantic import BaseModel
class PlanTriggerRequest(BaseModel):
budget: float | None = None
goal: str
housting_type: str
goal_value: float
housing_type: str
goal_value: str
portfolio_id: int
trigger_file_path: str

View file

@ -25,4 +25,5 @@ uvicorn==0.22.0
uvloop==0.17.0
watchfiles==0.19.0
websockets==11.0.3
boto3
boto3
pycryptodome