mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
implement trigger_categorisation API
This commit is contained in:
parent
5646376d1e
commit
76b648c861
4 changed files with 67 additions and 34 deletions
|
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
@ -8,12 +9,16 @@ def resolve_env_file() -> Optional[str]:
|
||||||
env = os.getenv("ENVIRONMENT", "local")
|
env = os.getenv("ENVIRONMENT", "local")
|
||||||
|
|
||||||
if env == "local":
|
if env == "local":
|
||||||
return "backend/.env"
|
env_file: Path = Path("backend/.env").resolve() # resolve to full path
|
||||||
|
print("USING ENV FILE:", env_file)
|
||||||
|
return str(env_file)
|
||||||
|
|
||||||
if env == "test":
|
if env == "test":
|
||||||
return "backend/.env.test"
|
env_file: Path = Path("backend/.env.test").resolve()
|
||||||
|
print("USING ENV FILE:", env_file)
|
||||||
|
return str(env_file)
|
||||||
|
|
||||||
# prod = no env file
|
print("NO ENV FILE")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,12 @@ class SubTaskInterface:
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
# CREATE SUBTASK
|
# CREATE SUBTASK
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
def create_subtask(self, task_id: UUID, inputs: Optional[Dict[str, Any]] = None, status=None):
|
def create_subtask(
|
||||||
|
self,
|
||||||
|
task_id: UUID,
|
||||||
|
inputs: Optional[Dict[str, Any]] = None,
|
||||||
|
status: Optional[str] = None,
|
||||||
|
):
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
|
|
@ -177,9 +182,7 @@ class SubTaskInterface:
|
||||||
if not task:
|
if not task:
|
||||||
return
|
return
|
||||||
|
|
||||||
subtasks = session.exec(
|
subtasks = session.exec(select(SubTask).where(SubTask.task_id == task_id)).all()
|
||||||
select(SubTask).where(SubTask.task_id == task_id)
|
|
||||||
).all()
|
|
||||||
|
|
||||||
statuses = [s.status.lower() for s in subtasks]
|
statuses = [s.status.lower() for s in subtasks]
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
|
|
@ -211,7 +214,7 @@ class SubTaskInterface:
|
||||||
subtask_id: UUID,
|
subtask_id: UUID,
|
||||||
status: str,
|
status: str,
|
||||||
outputs: Optional[Dict[str, Any]],
|
outputs: Optional[Dict[str, Any]],
|
||||||
cloud_logs_url: Optional[str]
|
cloud_logs_url: Optional[str],
|
||||||
):
|
):
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
import json
|
import json
|
||||||
|
|
@ -36,13 +37,14 @@ router = APIRouter(
|
||||||
responses={404: {"description": "Not found"}},
|
responses={404: {"description": "Not found"}},
|
||||||
)
|
)
|
||||||
|
|
||||||
sqs_client = boto3.client("sqs")
|
settings = get_settings()
|
||||||
|
sqs_client = boto3.client("sqs", settings.AWS_DEFAULT_REGION)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/categorisation", status_code=202)
|
@router.post("/categorisation", status_code=202)
|
||||||
async def trigger_categorisation(
|
async def trigger_categorisation(
|
||||||
body: CategorisationTriggerRequest,
|
body: CategorisationTriggerRequest,
|
||||||
) -> dict[str, int]:
|
) -> dict[str, str]:
|
||||||
payload: CategorisationTriggerRequest = CategorisationTriggerRequest.model_validate(
|
payload: CategorisationTriggerRequest = CategorisationTriggerRequest.model_validate(
|
||||||
body
|
body
|
||||||
)
|
)
|
||||||
|
|
@ -56,7 +58,16 @@ async def trigger_categorisation(
|
||||||
batch_size: int = math.ceil(1000 / num_scenarios)
|
batch_size: int = math.ceil(1000 / num_scenarios)
|
||||||
num_property_buckets: int = max(1, math.ceil(len(property_ids) / batch_size))
|
num_property_buckets: int = max(1, math.ceil(len(property_ids) / batch_size))
|
||||||
|
|
||||||
bucket_requests: List[CategorisationTriggerRequest] = []
|
# Create task
|
||||||
|
task_id, _ = TasksInterface.create_task(
|
||||||
|
task_source="backend/plan/router.py:trigger_categorisation",
|
||||||
|
service="plan_engine",
|
||||||
|
inputs=payload.model_dump(),
|
||||||
|
task_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dispatch requests to lambdas
|
||||||
|
subtask_interface = SubTaskInterface()
|
||||||
|
|
||||||
for bucket_index in range(num_property_buckets):
|
for bucket_index in range(num_property_buckets):
|
||||||
bucket_property_ids: List[int] = [
|
bucket_property_ids: List[int] = [
|
||||||
|
|
@ -69,12 +80,23 @@ async def trigger_categorisation(
|
||||||
min_property_id=min(bucket_property_ids),
|
min_property_id=min(bucket_property_ids),
|
||||||
max_property_id=max(bucket_property_ids),
|
max_property_id=max(bucket_property_ids),
|
||||||
)
|
)
|
||||||
|
# Create sub-task for each
|
||||||
|
subtask_id: UUID = subtask_interface.create_subtask(
|
||||||
|
task_id=task_id, inputs=bucket_request.model_dump()
|
||||||
|
)
|
||||||
|
|
||||||
bucket_requests.append(bucket_request)
|
response = sqs_client.send_message(
|
||||||
|
QueueUrl="categorisation-queue-dev",
|
||||||
|
MessageBody=bucket_request.model_dump_json(),
|
||||||
|
)
|
||||||
|
|
||||||
# Dispatch requests to lambdas
|
logger.info(
|
||||||
|
f"Chunk {bucket_index} sent to SQS. Property IDs {min(bucket_property_ids)}–{max(bucket_property_ids)}. Message ID: {response.get('MessageId')}"
|
||||||
|
)
|
||||||
|
|
||||||
return {"num_buckets": len(bucket_requests)}
|
await asyncio.sleep(0.05) # Small delay to avoid SQS throttling
|
||||||
|
|
||||||
|
return {"message": "Categorisation jobs distributed"}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/trigger", status_code=202)
|
@router.post("/trigger", status_code=202)
|
||||||
|
|
@ -84,8 +106,6 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest):
|
||||||
"""
|
"""
|
||||||
logger.info("API triggered with body: %s", body)
|
logger.info("API triggered with body: %s", body)
|
||||||
|
|
||||||
settings = get_settings()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = body.model_dump()
|
data = body.model_dump()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from backend.app.tasks.schema import (
|
||||||
CreateSubTaskRequest,
|
CreateSubTaskRequest,
|
||||||
UpdateSubTaskStatusRequest,
|
UpdateSubTaskStatusRequest,
|
||||||
FinalizeSubTaskRequest,
|
FinalizeSubTaskRequest,
|
||||||
TaskSqsTriggerRequest
|
TaskSqsTriggerRequest,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Correct location of interfaces
|
# Correct location of interfaces
|
||||||
|
|
@ -51,18 +51,18 @@ async def get_task(task_id: UUID):
|
||||||
if not task:
|
if not task:
|
||||||
raise HTTPException(status_code=404, detail="Task not found")
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
|
||||||
subtasks = session.exec(
|
subtasks = session.exec(select(SubTask).where(SubTask.taskId == task_id)).all()
|
||||||
select(SubTask).where(SubTask.taskId == task_id)
|
|
||||||
).all()
|
|
||||||
|
|
||||||
formatted = []
|
formatted = []
|
||||||
for st in subtasks:
|
for st in subtasks:
|
||||||
formatted.append({
|
formatted.append(
|
||||||
**st.dict(),
|
{
|
||||||
"inputs": json.loads(st.inputs) if st.inputs else None,
|
**st.dict(),
|
||||||
"outputs": json.loads(st.outputs) if st.outputs else None,
|
"inputs": json.loads(st.inputs) if st.inputs else None,
|
||||||
"cloud_logs_url": st.cloudLogsURL,
|
"outputs": json.loads(st.outputs) if st.outputs else None,
|
||||||
})
|
"cloud_logs_url": st.cloudLogsURL,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"task": task,
|
"task": task,
|
||||||
|
|
@ -111,7 +111,10 @@ async def update_subtask_status(subtask_id: UUID, req: UpdateSubTaskStatusReques
|
||||||
|
|
||||||
# ===
|
# ===
|
||||||
# Sub task is complete
|
# Sub task is complete
|
||||||
@router.post("/subtask/{subtask_id}/finalize", summary="Finalize a subtask with status, outputs, logs")
|
@router.post(
|
||||||
|
"/subtask/{subtask_id}/finalize",
|
||||||
|
summary="Finalize a subtask with status, outputs, logs",
|
||||||
|
)
|
||||||
async def finalize_subtask(subtask_id: UUID, req: FinalizeSubTaskRequest):
|
async def finalize_subtask(subtask_id: UUID, req: FinalizeSubTaskRequest):
|
||||||
subtasks = SubTaskInterface()
|
subtasks = SubTaskInterface()
|
||||||
|
|
||||||
|
|
@ -120,7 +123,7 @@ async def finalize_subtask(subtask_id: UUID, req: FinalizeSubTaskRequest):
|
||||||
subtask_id=subtask_id,
|
subtask_id=subtask_id,
|
||||||
status=req.status,
|
status=req.status,
|
||||||
outputs=req.outputs,
|
outputs=req.outputs,
|
||||||
cloud_logs_url=req.cloud_logs_url
|
cloud_logs_url=req.cloud_logs_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -142,9 +145,10 @@ from backend.app.tasks.schema import TaskSqsTriggerRequest
|
||||||
from backend.app.db.functions.tasks.Tasks import TasksInterface, SubTaskInterface
|
from backend.app.db.functions.tasks.Tasks import TasksInterface, SubTaskInterface
|
||||||
from backend.app.config import get_settings
|
from backend.app.config import get_settings
|
||||||
|
|
||||||
sqs = boto3.client("sqs")
|
|
||||||
|
|
||||||
@router.post("/trigger", summary="Create task + subtask and publish to SQS", status_code=202)
|
@router.post(
|
||||||
|
"/trigger", summary="Create task + subtask and publish to SQS", status_code=202
|
||||||
|
)
|
||||||
async def trigger_task(req: TaskSqsTriggerRequest):
|
async def trigger_task(req: TaskSqsTriggerRequest):
|
||||||
"""
|
"""
|
||||||
Creates a Task + SubTask, then pushes the SubTask into SQS so a Lambda can process it.
|
Creates a Task + SubTask, then pushes the SubTask into SQS so a Lambda can process it.
|
||||||
|
|
@ -152,11 +156,12 @@ async def trigger_task(req: TaskSqsTriggerRequest):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
|
sqs = boto3.client("sqs", settings.AWS_DEFAULT_REGION)
|
||||||
|
|
||||||
tasks = TasksInterface()
|
tasks = TasksInterface()
|
||||||
|
|
||||||
# ---- Normalize empty inputs ----
|
# ---- Normalize empty inputs ----
|
||||||
inputs = req.inputs or {} # ensures {} even if null
|
inputs = req.inputs or {} # ensures {} even if null
|
||||||
|
|
||||||
# ---- 1. Create Task + SubTask ----
|
# ---- 1. Create Task + SubTask ----
|
||||||
task_id, subtask_id = tasks.create_task(
|
task_id, subtask_id = tasks.create_task(
|
||||||
|
|
@ -174,8 +179,8 @@ async def trigger_task(req: TaskSqsTriggerRequest):
|
||||||
try:
|
try:
|
||||||
response = sqs.send_message(
|
response = sqs.send_message(
|
||||||
QueueUrl=f"https://sqs.{settings.AWS_REGION}.amazonaws.com/"
|
QueueUrl=f"https://sqs.{settings.AWS_REGION}.amazonaws.com/"
|
||||||
f"{settings.AWS_ACCOUNT_ID}/lambda-example-queue",
|
f"{settings.AWS_ACCOUNT_ID}/lambda-example-queue",
|
||||||
MessageBody=json.dumps(sqs_payload)
|
MessageBody=json.dumps(sqs_payload),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"SQS error: {e}")
|
raise HTTPException(status_code=500, detail=f"SQS error: {e}")
|
||||||
|
|
@ -186,4 +191,4 @@ async def trigger_task(req: TaskSqsTriggerRequest):
|
||||||
"subtask_id": subtask_id,
|
"subtask_id": subtask_id,
|
||||||
"sqs_message_id": response.get("MessageId"),
|
"sqs_message_id": response.get("MessageId"),
|
||||||
"inputs_sent": inputs,
|
"inputs_sent": inputs,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue