implement trigger_categorisation API

This commit is contained in:
Daniel Roth 2026-02-24 14:33:29 +00:00
parent 5646376d1e
commit 76b648c861
4 changed files with 67 additions and 34 deletions

View file

@ -1,5 +1,6 @@
import os import os
from functools import lru_cache from functools import lru_cache
from pathlib import Path
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import Optional from typing import Optional
@ -8,12 +9,16 @@ def resolve_env_file() -> Optional[str]:
env = os.getenv("ENVIRONMENT", "local") env = os.getenv("ENVIRONMENT", "local")
if env == "local": if env == "local":
return "backend/.env" env_file: Path = Path("backend/.env").resolve() # resolve to full path
print("USING ENV FILE:", env_file)
return str(env_file)
if env == "test": if env == "test":
return "backend/.env.test" env_file: Path = Path("backend/.env.test").resolve()
print("USING ENV FILE:", env_file)
return str(env_file)
# prod = no env file print("NO ENV FILE")
return None return None

View file

@ -25,7 +25,12 @@ class SubTaskInterface:
# -------------------------------------------------------- # --------------------------------------------------------
# CREATE SUBTASK # CREATE SUBTASK
# -------------------------------------------------------- # --------------------------------------------------------
def create_subtask(self, task_id: UUID, inputs: Optional[Dict[str, Any]] = None, status=None): def create_subtask(
self,
task_id: UUID,
inputs: Optional[Dict[str, Any]] = None,
status: Optional[str] = None,
):
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
with get_db_session() as session: with get_db_session() as session:
@ -177,9 +182,7 @@ class SubTaskInterface:
if not task: if not task:
return return
subtasks = session.exec( subtasks = session.exec(select(SubTask).where(SubTask.task_id == task_id)).all()
select(SubTask).where(SubTask.task_id == task_id)
).all()
statuses = [s.status.lower() for s in subtasks] statuses = [s.status.lower() for s in subtasks]
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
@ -211,7 +214,7 @@ class SubTaskInterface:
subtask_id: UUID, subtask_id: UUID,
status: str, status: str,
outputs: Optional[Dict[str, Any]], outputs: Optional[Dict[str, Any]],
cloud_logs_url: Optional[str] cloud_logs_url: Optional[str],
): ):
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)

View file

@ -1,4 +1,5 @@
from typing import List from typing import List
from uuid import UUID
import boto3 import boto3
import json import json
@ -36,13 +37,14 @@ router = APIRouter(
responses={404: {"description": "Not found"}}, responses={404: {"description": "Not found"}},
) )
sqs_client = boto3.client("sqs") settings = get_settings()
sqs_client = boto3.client("sqs", settings.AWS_DEFAULT_REGION)
@router.post("/categorisation", status_code=202) @router.post("/categorisation", status_code=202)
async def trigger_categorisation( async def trigger_categorisation(
body: CategorisationTriggerRequest, body: CategorisationTriggerRequest,
) -> dict[str, int]: ) -> dict[str, str]:
payload: CategorisationTriggerRequest = CategorisationTriggerRequest.model_validate( payload: CategorisationTriggerRequest = CategorisationTriggerRequest.model_validate(
body body
) )
@ -56,7 +58,16 @@ async def trigger_categorisation(
batch_size: int = math.ceil(1000 / num_scenarios) batch_size: int = math.ceil(1000 / num_scenarios)
num_property_buckets: int = max(1, math.ceil(len(property_ids) / batch_size)) num_property_buckets: int = max(1, math.ceil(len(property_ids) / batch_size))
bucket_requests: List[CategorisationTriggerRequest] = [] # Create task
task_id, _ = TasksInterface.create_task(
task_source="backend/plan/router.py:trigger_categorisation",
service="plan_engine",
inputs=payload.model_dump(),
task_only=True,
)
# Dispatch requests to lambdas
subtask_interface = SubTaskInterface()
for bucket_index in range(num_property_buckets): for bucket_index in range(num_property_buckets):
bucket_property_ids: List[int] = [ bucket_property_ids: List[int] = [
@ -69,12 +80,23 @@ async def trigger_categorisation(
min_property_id=min(bucket_property_ids), min_property_id=min(bucket_property_ids),
max_property_id=max(bucket_property_ids), max_property_id=max(bucket_property_ids),
) )
# Create sub-task for each
subtask_id: UUID = subtask_interface.create_subtask(
task_id=task_id, inputs=bucket_request.model_dump()
)
bucket_requests.append(bucket_request) response = sqs_client.send_message(
QueueUrl="categorisation-queue-dev",
MessageBody=bucket_request.model_dump_json(),
)
# Dispatch requests to lambdas logger.info(
f"Chunk {bucket_index} sent to SQS. Property IDs {min(bucket_property_ids)}{max(bucket_property_ids)}. Message ID: {response.get('MessageId')}"
)
return {"num_buckets": len(bucket_requests)} await asyncio.sleep(0.05) # Small delay to avoid SQS throttling
return {"message": "Categorisation jobs distributed"}
@router.post("/trigger", status_code=202) @router.post("/trigger", status_code=202)
@ -84,8 +106,6 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest):
""" """
logger.info("API triggered with body: %s", body) logger.info("API triggered with body: %s", body)
settings = get_settings()
try: try:
data = body.model_dump() data = body.model_dump()
except Exception as e: except Exception as e:

View file

@ -9,7 +9,7 @@ from backend.app.tasks.schema import (
CreateSubTaskRequest, CreateSubTaskRequest,
UpdateSubTaskStatusRequest, UpdateSubTaskStatusRequest,
FinalizeSubTaskRequest, FinalizeSubTaskRequest,
TaskSqsTriggerRequest TaskSqsTriggerRequest,
) )
# Correct location of interfaces # Correct location of interfaces
@ -51,18 +51,18 @@ async def get_task(task_id: UUID):
if not task: if not task:
raise HTTPException(status_code=404, detail="Task not found") raise HTTPException(status_code=404, detail="Task not found")
subtasks = session.exec( subtasks = session.exec(select(SubTask).where(SubTask.taskId == task_id)).all()
select(SubTask).where(SubTask.taskId == task_id)
).all()
formatted = [] formatted = []
for st in subtasks: for st in subtasks:
formatted.append({ formatted.append(
**st.dict(), {
"inputs": json.loads(st.inputs) if st.inputs else None, **st.dict(),
"outputs": json.loads(st.outputs) if st.outputs else None, "inputs": json.loads(st.inputs) if st.inputs else None,
"cloud_logs_url": st.cloudLogsURL, "outputs": json.loads(st.outputs) if st.outputs else None,
}) "cloud_logs_url": st.cloudLogsURL,
}
)
return { return {
"task": task, "task": task,
@ -111,7 +111,10 @@ async def update_subtask_status(subtask_id: UUID, req: UpdateSubTaskStatusReques
# === # ===
# Sub task is complete # Sub task is complete
@router.post("/subtask/{subtask_id}/finalize", summary="Finalize a subtask with status, outputs, logs") @router.post(
"/subtask/{subtask_id}/finalize",
summary="Finalize a subtask with status, outputs, logs",
)
async def finalize_subtask(subtask_id: UUID, req: FinalizeSubTaskRequest): async def finalize_subtask(subtask_id: UUID, req: FinalizeSubTaskRequest):
subtasks = SubTaskInterface() subtasks = SubTaskInterface()
@ -120,7 +123,7 @@ async def finalize_subtask(subtask_id: UUID, req: FinalizeSubTaskRequest):
subtask_id=subtask_id, subtask_id=subtask_id,
status=req.status, status=req.status,
outputs=req.outputs, outputs=req.outputs,
cloud_logs_url=req.cloud_logs_url cloud_logs_url=req.cloud_logs_url,
) )
return { return {
@ -142,9 +145,10 @@ from backend.app.tasks.schema import TaskSqsTriggerRequest
from backend.app.db.functions.tasks.Tasks import TasksInterface, SubTaskInterface from backend.app.db.functions.tasks.Tasks import TasksInterface, SubTaskInterface
from backend.app.config import get_settings from backend.app.config import get_settings
sqs = boto3.client("sqs")
@router.post("/trigger", summary="Create task + subtask and publish to SQS", status_code=202) @router.post(
"/trigger", summary="Create task + subtask and publish to SQS", status_code=202
)
async def trigger_task(req: TaskSqsTriggerRequest): async def trigger_task(req: TaskSqsTriggerRequest):
""" """
Creates a Task + SubTask, then pushes the SubTask into SQS so a Lambda can process it. Creates a Task + SubTask, then pushes the SubTask into SQS so a Lambda can process it.
@ -152,11 +156,12 @@ async def trigger_task(req: TaskSqsTriggerRequest):
""" """
settings = get_settings() settings = get_settings()
sqs = boto3.client("sqs", settings.AWS_DEFAULT_REGION)
tasks = TasksInterface() tasks = TasksInterface()
# ---- Normalize empty inputs ---- # ---- Normalize empty inputs ----
inputs = req.inputs or {} # ensures {} even if null inputs = req.inputs or {} # ensures {} even if null
# ---- 1. Create Task + SubTask ---- # ---- 1. Create Task + SubTask ----
task_id, subtask_id = tasks.create_task( task_id, subtask_id = tasks.create_task(
@ -174,8 +179,8 @@ async def trigger_task(req: TaskSqsTriggerRequest):
try: try:
response = sqs.send_message( response = sqs.send_message(
QueueUrl=f"https://sqs.{settings.AWS_REGION}.amazonaws.com/" QueueUrl=f"https://sqs.{settings.AWS_REGION}.amazonaws.com/"
f"{settings.AWS_ACCOUNT_ID}/lambda-example-queue", f"{settings.AWS_ACCOUNT_ID}/lambda-example-queue",
MessageBody=json.dumps(sqs_payload) MessageBody=json.dumps(sqs_payload),
) )
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"SQS error: {e}") raise HTTPException(status_code=500, detail=f"SQS error: {e}")
@ -186,4 +191,4 @@ async def trigger_task(req: TaskSqsTriggerRequest):
"subtask_id": subtask_id, "subtask_id": subtask_id,
"sqs_message_id": response.get("MessageId"), "sqs_message_id": response.get("MessageId"),
"inputs_sent": inputs, "inputs_sent": inputs,
} }