From 76b648c861b7f26ab8af8c7d411015cdc09934b3 Mon Sep 17 00:00:00 2001 From: Daniel Roth Date: Tue, 24 Feb 2026 14:33:29 +0000 Subject: [PATCH] implement trigger_categorisation API --- backend/app/config.py | 11 +++++-- backend/app/db/functions/tasks/Tasks.py | 13 +++++--- backend/app/plan/router.py | 36 +++++++++++++++++----- backend/app/tasks/router.py | 41 ++++++++++++++----------- 4 files changed, 67 insertions(+), 34 deletions(-) diff --git a/backend/app/config.py b/backend/app/config.py index feb312b4..22e4e302 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -1,5 +1,6 @@ import os from functools import lru_cache +from pathlib import Path from pydantic_settings import BaseSettings, SettingsConfigDict from typing import Optional @@ -8,12 +9,16 @@ def resolve_env_file() -> Optional[str]: env = os.getenv("ENVIRONMENT", "local") if env == "local": - return "backend/.env" + env_file: Path = Path("backend/.env").resolve() # resolve to full path + print("USING ENV FILE:", env_file) + return str(env_file) if env == "test": - return "backend/.env.test" + env_file: Path = Path("backend/.env.test").resolve() + print("USING ENV FILE:", env_file) + return str(env_file) - # prod = no env file + print("NO ENV FILE") return None diff --git a/backend/app/db/functions/tasks/Tasks.py b/backend/app/db/functions/tasks/Tasks.py index d1ab9536..13229447 100644 --- a/backend/app/db/functions/tasks/Tasks.py +++ b/backend/app/db/functions/tasks/Tasks.py @@ -25,7 +25,12 @@ class SubTaskInterface: # -------------------------------------------------------- # CREATE SUBTASK # -------------------------------------------------------- - def create_subtask(self, task_id: UUID, inputs: Optional[Dict[str, Any]] = None, status=None): + def create_subtask( + self, + task_id: UUID, + inputs: Optional[Dict[str, Any]] = None, + status: Optional[str] = None, + ): now = datetime.now(timezone.utc) with get_db_session() as session: @@ -177,9 +182,7 @@ class SubTaskInterface: if not task: return - subtasks = session.exec( - select(SubTask).where(SubTask.task_id == task_id) - ).all() + subtasks = session.exec(select(SubTask).where(SubTask.task_id == task_id)).all() statuses = [s.status.lower() for s in subtasks] now = datetime.now(timezone.utc) @@ -211,7 +214,7 @@ class SubTaskInterface: subtask_id: UUID, status: str, outputs: Optional[Dict[str, Any]], - cloud_logs_url: Optional[str] + cloud_logs_url: Optional[str], ): now = datetime.now(timezone.utc) diff --git a/backend/app/plan/router.py b/backend/app/plan/router.py index e9c06e40..cdf2873d 100644 --- a/backend/app/plan/router.py +++ b/backend/app/plan/router.py @@ -1,4 +1,5 @@ from typing import List +from uuid import UUID import boto3 import json @@ -36,13 +37,14 @@ router = APIRouter( responses={404: {"description": "Not found"}}, ) -sqs_client = boto3.client("sqs") +settings = get_settings() +sqs_client = boto3.client("sqs", settings.AWS_DEFAULT_REGION) @router.post("/categorisation", status_code=202) async def trigger_categorisation( body: CategorisationTriggerRequest, -) -> dict[str, int]: +) -> dict[str, str]: payload: CategorisationTriggerRequest = CategorisationTriggerRequest.model_validate( body ) @@ -56,7 +58,16 @@ async def trigger_categorisation( batch_size: int = math.ceil(1000 / num_scenarios) num_property_buckets: int = max(1, math.ceil(len(property_ids) / batch_size)) - bucket_requests: List[CategorisationTriggerRequest] = [] + # Create task + task_id, _ = TasksInterface.create_task( + task_source="backend/plan/router.py:trigger_categorisation", + service="plan_engine", + inputs=payload.model_dump(), + task_only=True, + ) + + # Dispatch requests to lambdas + subtask_interface = SubTaskInterface() for bucket_index in range(num_property_buckets): bucket_property_ids: List[int] = [ @@ -69,12 +80,23 @@ async def trigger_categorisation( min_property_id=min(bucket_property_ids), max_property_id=max(bucket_property_ids), ) + # Create sub-task for each + subtask_id: UUID = subtask_interface.create_subtask( + task_id=task_id, inputs=bucket_request.model_dump() + ) - bucket_requests.append(bucket_request) + response = sqs_client.send_message( + QueueUrl="categorisation-queue-dev", + MessageBody=bucket_request.model_dump_json(), + ) - # Dispatch requests to lambdas + logger.info( + f"Chunk {bucket_index} sent to SQS. Property IDs {min(bucket_property_ids)}–{max(bucket_property_ids)}. Message ID: {response.get('MessageId')}" + ) - return {"num_buckets": len(bucket_requests)} + await asyncio.sleep(0.05) # Small delay to avoid SQS throttling + + return {"message": "Categorisation jobs distributed"} @router.post("/trigger", status_code=202) @@ -84,8 +106,6 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest): """ logger.info("API triggered with body: %s", body) - settings = get_settings() - try: data = body.model_dump() except Exception as e: diff --git a/backend/app/tasks/router.py b/backend/app/tasks/router.py index 90b62dd1..1c266f2c 100644 --- a/backend/app/tasks/router.py +++ b/backend/app/tasks/router.py @@ -9,7 +9,7 @@ from backend.app.tasks.schema import ( CreateSubTaskRequest, UpdateSubTaskStatusRequest, FinalizeSubTaskRequest, - TaskSqsTriggerRequest + TaskSqsTriggerRequest, ) # Correct location of interfaces @@ -51,18 +51,18 @@ async def get_task(task_id: UUID): if not task: raise HTTPException(status_code=404, detail="Task not found") - subtasks = session.exec( - select(SubTask).where(SubTask.taskId == task_id) - ).all() + subtasks = session.exec(select(SubTask).where(SubTask.taskId == task_id)).all() formatted = [] for st in subtasks: - formatted.append({ - **st.dict(), - "inputs": json.loads(st.inputs) if st.inputs else None, - "outputs": json.loads(st.outputs) if st.outputs else None, - "cloud_logs_url": st.cloudLogsURL, - }) + formatted.append( + { + **st.dict(), + "inputs": json.loads(st.inputs) if st.inputs else None, + "outputs": json.loads(st.outputs) if st.outputs else None, + "cloud_logs_url": st.cloudLogsURL, + } + ) return { "task": task, @@ -111,7 +111,10 @@ async def update_subtask_status(subtask_id: UUID, req: UpdateSubTaskStatusReques # === # Sub task is complete -@router.post("/subtask/{subtask_id}/finalize", summary="Finalize a subtask with status, outputs, logs") +@router.post( + "/subtask/{subtask_id}/finalize", + summary="Finalize a subtask with status, outputs, logs", +) async def finalize_subtask(subtask_id: UUID, req: FinalizeSubTaskRequest): subtasks = SubTaskInterface() @@ -120,7 +123,7 @@ async def finalize_subtask(subtask_id: UUID, req: FinalizeSubTaskRequest): subtask_id=subtask_id, status=req.status, outputs=req.outputs, - cloud_logs_url=req.cloud_logs_url + cloud_logs_url=req.cloud_logs_url, ) return { @@ -142,9 +145,10 @@ from backend.app.tasks.schema import TaskSqsTriggerRequest from backend.app.db.functions.tasks.Tasks import TasksInterface, SubTaskInterface from backend.app.config import get_settings -sqs = boto3.client("sqs") -@router.post("/trigger", summary="Create task + subtask and publish to SQS", status_code=202) +@router.post( + "/trigger", summary="Create task + subtask and publish to SQS", status_code=202 +) async def trigger_task(req: TaskSqsTriggerRequest): """ Creates a Task + SubTask, then pushes the SubTask into SQS so a Lambda can process it. @@ -152,11 +156,12 @@ async def trigger_task(req: TaskSqsTriggerRequest): """ settings = get_settings() + sqs = boto3.client("sqs", settings.AWS_DEFAULT_REGION) tasks = TasksInterface() # ---- Normalize empty inputs ---- - inputs = req.inputs or {} # ensures {} even if null + inputs = req.inputs or {} # ensures {} even if null # ---- 1. Create Task + SubTask ---- task_id, subtask_id = tasks.create_task( @@ -174,8 +179,8 @@ async def trigger_task(req: TaskSqsTriggerRequest): try: response = sqs.send_message( QueueUrl=f"https://sqs.{settings.AWS_REGION}.amazonaws.com/" - f"{settings.AWS_ACCOUNT_ID}/lambda-example-queue", - MessageBody=json.dumps(sqs_payload) + f"{settings.AWS_ACCOUNT_ID}/lambda-example-queue", + MessageBody=json.dumps(sqs_payload), ) except Exception as e: raise HTTPException(status_code=500, detail=f"SQS error: {e}") @@ -186,4 +191,4 @@ async def trigger_task(req: TaskSqsTriggerRequest): "subtask_id": subtask_id, "sqs_message_id": response.get("MessageId"), "inputs_sent": inputs, - } \ No newline at end of file + }