mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
implement trigger_categorisation API
This commit is contained in:
parent
5646376d1e
commit
76b648c861
4 changed files with 67 additions and 34 deletions
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -8,12 +9,16 @@ def resolve_env_file() -> Optional[str]:
|
|||
env = os.getenv("ENVIRONMENT", "local")
|
||||
|
||||
if env == "local":
|
||||
return "backend/.env"
|
||||
env_file: Path = Path("backend/.env").resolve() # resolve to full path
|
||||
print("USING ENV FILE:", env_file)
|
||||
return str(env_file)
|
||||
|
||||
if env == "test":
|
||||
return "backend/.env.test"
|
||||
env_file: Path = Path("backend/.env.test").resolve()
|
||||
print("USING ENV FILE:", env_file)
|
||||
return str(env_file)
|
||||
|
||||
# prod = no env file
|
||||
print("NO ENV FILE")
|
||||
return None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,12 @@ class SubTaskInterface:
|
|||
# --------------------------------------------------------
|
||||
# CREATE SUBTASK
|
||||
# --------------------------------------------------------
|
||||
def create_subtask(self, task_id: UUID, inputs: Optional[Dict[str, Any]] = None, status=None):
|
||||
def create_subtask(
|
||||
self,
|
||||
task_id: UUID,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
status: Optional[str] = None,
|
||||
):
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
with get_db_session() as session:
|
||||
|
|
@ -177,9 +182,7 @@ class SubTaskInterface:
|
|||
if not task:
|
||||
return
|
||||
|
||||
subtasks = session.exec(
|
||||
select(SubTask).where(SubTask.task_id == task_id)
|
||||
).all()
|
||||
subtasks = session.exec(select(SubTask).where(SubTask.task_id == task_id)).all()
|
||||
|
||||
statuses = [s.status.lower() for s in subtasks]
|
||||
now = datetime.now(timezone.utc)
|
||||
|
|
@ -211,7 +214,7 @@ class SubTaskInterface:
|
|||
subtask_id: UUID,
|
||||
status: str,
|
||||
outputs: Optional[Dict[str, Any]],
|
||||
cloud_logs_url: Optional[str]
|
||||
cloud_logs_url: Optional[str],
|
||||
):
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
import boto3
|
||||
import json
|
||||
|
|
@ -36,13 +37,14 @@ router = APIRouter(
|
|||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
sqs_client = boto3.client("sqs")
|
||||
settings = get_settings()
|
||||
sqs_client = boto3.client("sqs", settings.AWS_DEFAULT_REGION)
|
||||
|
||||
|
||||
@router.post("/categorisation", status_code=202)
|
||||
async def trigger_categorisation(
|
||||
body: CategorisationTriggerRequest,
|
||||
) -> dict[str, int]:
|
||||
) -> dict[str, str]:
|
||||
payload: CategorisationTriggerRequest = CategorisationTriggerRequest.model_validate(
|
||||
body
|
||||
)
|
||||
|
|
@ -56,7 +58,16 @@ async def trigger_categorisation(
|
|||
batch_size: int = math.ceil(1000 / num_scenarios)
|
||||
num_property_buckets: int = max(1, math.ceil(len(property_ids) / batch_size))
|
||||
|
||||
bucket_requests: List[CategorisationTriggerRequest] = []
|
||||
# Create task
|
||||
task_id, _ = TasksInterface.create_task(
|
||||
task_source="backend/plan/router.py:trigger_categorisation",
|
||||
service="plan_engine",
|
||||
inputs=payload.model_dump(),
|
||||
task_only=True,
|
||||
)
|
||||
|
||||
# Dispatch requests to lambdas
|
||||
subtask_interface = SubTaskInterface()
|
||||
|
||||
for bucket_index in range(num_property_buckets):
|
||||
bucket_property_ids: List[int] = [
|
||||
|
|
@ -69,12 +80,23 @@ async def trigger_categorisation(
|
|||
min_property_id=min(bucket_property_ids),
|
||||
max_property_id=max(bucket_property_ids),
|
||||
)
|
||||
# Create sub-task for each
|
||||
subtask_id: UUID = subtask_interface.create_subtask(
|
||||
task_id=task_id, inputs=bucket_request.model_dump()
|
||||
)
|
||||
|
||||
bucket_requests.append(bucket_request)
|
||||
response = sqs_client.send_message(
|
||||
QueueUrl="categorisation-queue-dev",
|
||||
MessageBody=bucket_request.model_dump_json(),
|
||||
)
|
||||
|
||||
# Dispatch requests to lambdas
|
||||
logger.info(
|
||||
f"Chunk {bucket_index} sent to SQS. Property IDs {min(bucket_property_ids)}–{max(bucket_property_ids)}. Message ID: {response.get('MessageId')}"
|
||||
)
|
||||
|
||||
return {"num_buckets": len(bucket_requests)}
|
||||
await asyncio.sleep(0.05) # Small delay to avoid SQS throttling
|
||||
|
||||
return {"message": "Categorisation jobs distributed"}
|
||||
|
||||
|
||||
@router.post("/trigger", status_code=202)
|
||||
|
|
@ -84,8 +106,6 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest):
|
|||
"""
|
||||
logger.info("API triggered with body: %s", body)
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
try:
|
||||
data = body.model_dump()
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from backend.app.tasks.schema import (
|
|||
CreateSubTaskRequest,
|
||||
UpdateSubTaskStatusRequest,
|
||||
FinalizeSubTaskRequest,
|
||||
TaskSqsTriggerRequest
|
||||
TaskSqsTriggerRequest,
|
||||
)
|
||||
|
||||
# Correct location of interfaces
|
||||
|
|
@ -51,18 +51,18 @@ async def get_task(task_id: UUID):
|
|||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
subtasks = session.exec(
|
||||
select(SubTask).where(SubTask.taskId == task_id)
|
||||
).all()
|
||||
subtasks = session.exec(select(SubTask).where(SubTask.taskId == task_id)).all()
|
||||
|
||||
formatted = []
|
||||
for st in subtasks:
|
||||
formatted.append({
|
||||
**st.dict(),
|
||||
"inputs": json.loads(st.inputs) if st.inputs else None,
|
||||
"outputs": json.loads(st.outputs) if st.outputs else None,
|
||||
"cloud_logs_url": st.cloudLogsURL,
|
||||
})
|
||||
formatted.append(
|
||||
{
|
||||
**st.dict(),
|
||||
"inputs": json.loads(st.inputs) if st.inputs else None,
|
||||
"outputs": json.loads(st.outputs) if st.outputs else None,
|
||||
"cloud_logs_url": st.cloudLogsURL,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"task": task,
|
||||
|
|
@ -111,7 +111,10 @@ async def update_subtask_status(subtask_id: UUID, req: UpdateSubTaskStatusReques
|
|||
|
||||
# ===
|
||||
# Sub task is complete
|
||||
@router.post("/subtask/{subtask_id}/finalize", summary="Finalize a subtask with status, outputs, logs")
|
||||
@router.post(
|
||||
"/subtask/{subtask_id}/finalize",
|
||||
summary="Finalize a subtask with status, outputs, logs",
|
||||
)
|
||||
async def finalize_subtask(subtask_id: UUID, req: FinalizeSubTaskRequest):
|
||||
subtasks = SubTaskInterface()
|
||||
|
||||
|
|
@ -120,7 +123,7 @@ async def finalize_subtask(subtask_id: UUID, req: FinalizeSubTaskRequest):
|
|||
subtask_id=subtask_id,
|
||||
status=req.status,
|
||||
outputs=req.outputs,
|
||||
cloud_logs_url=req.cloud_logs_url
|
||||
cloud_logs_url=req.cloud_logs_url,
|
||||
)
|
||||
|
||||
return {
|
||||
|
|
@ -142,9 +145,10 @@ from backend.app.tasks.schema import TaskSqsTriggerRequest
|
|||
from backend.app.db.functions.tasks.Tasks import TasksInterface, SubTaskInterface
|
||||
from backend.app.config import get_settings
|
||||
|
||||
sqs = boto3.client("sqs")
|
||||
|
||||
@router.post("/trigger", summary="Create task + subtask and publish to SQS", status_code=202)
|
||||
@router.post(
|
||||
"/trigger", summary="Create task + subtask and publish to SQS", status_code=202
|
||||
)
|
||||
async def trigger_task(req: TaskSqsTriggerRequest):
|
||||
"""
|
||||
Creates a Task + SubTask, then pushes the SubTask into SQS so a Lambda can process it.
|
||||
|
|
@ -152,11 +156,12 @@ async def trigger_task(req: TaskSqsTriggerRequest):
|
|||
"""
|
||||
|
||||
settings = get_settings()
|
||||
sqs = boto3.client("sqs", settings.AWS_DEFAULT_REGION)
|
||||
|
||||
tasks = TasksInterface()
|
||||
|
||||
# ---- Normalize empty inputs ----
|
||||
inputs = req.inputs or {} # ensures {} even if null
|
||||
inputs = req.inputs or {} # ensures {} even if null
|
||||
|
||||
# ---- 1. Create Task + SubTask ----
|
||||
task_id, subtask_id = tasks.create_task(
|
||||
|
|
@ -174,8 +179,8 @@ async def trigger_task(req: TaskSqsTriggerRequest):
|
|||
try:
|
||||
response = sqs.send_message(
|
||||
QueueUrl=f"https://sqs.{settings.AWS_REGION}.amazonaws.com/"
|
||||
f"{settings.AWS_ACCOUNT_ID}/lambda-example-queue",
|
||||
MessageBody=json.dumps(sqs_payload)
|
||||
f"{settings.AWS_ACCOUNT_ID}/lambda-example-queue",
|
||||
MessageBody=json.dumps(sqs_payload),
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"SQS error: {e}")
|
||||
|
|
@ -186,4 +191,4 @@ async def trigger_task(req: TaskSqsTriggerRequest):
|
|||
"subtask_id": subtask_id,
|
||||
"sqs_message_id": response.get("MessageId"),
|
||||
"inputs_sent": inputs,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue