Model/backend/app/plan/router.py
2026-02-26 16:38:44 +00:00

236 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import List
from uuid import UUID
import boto3
import json
import math
import asyncio
from contextlib import contextmanager
from sqlmodel import Session
from datetime import datetime
from fastapi import APIRouter, Depends
from backend.app.dependencies import validate_token
from backend.app.plan.schemas import PlanTriggerRequest
from backend.app.config import get_settings
from sqlalchemy.orm import sessionmaker
from backend.categorisation.categorisation_trigger_request import (
CategorisationTriggerRequest,
)
from utils.logger import setup_logger
from backend.app.db.connection import db_engine
from backend.app.db.functions.recommendations_functions import (
create_scenario,
get_property_ids,
get_scenarios_count_by_portfolio_id,
)
from backend.app.db.functions.tasks.Tasks import TasksInterface, SubTaskInterface
logger = setup_logger()
router = APIRouter(
prefix="/plan",
tags=["plan"],
dependencies=[Depends(validate_token)],
responses={404: {"description": "Not found"}},
)
settings = get_settings()
print("CONNECTION TO SQS IN REGION", settings.AWS_DEFAULT_REGION)
sqs_client = boto3.client("sqs", settings.AWS_DEFAULT_REGION)
@router.post("/categorisation", status_code=202)
async def trigger_categorisation(
body: CategorisationTriggerRequest,
) -> dict[str, str]:
payload: CategorisationTriggerRequest = CategorisationTriggerRequest.model_validate(
body
)
logger.info("API triggered with body: %s", payload)
property_ids: list[int] = get_property_ids(payload.portfolio_id)
property_ids.sort()
num_scenarios: int = get_scenarios_count_by_portfolio_id(payload.portfolio_id)
total_plans_to_update: int = len(property_ids) * num_scenarios
max_writes_per_batch: int = 1000
properties_per_batch: int = max(1, max_writes_per_batch // num_scenarios)
num_property_batches: int = math.ceil(len(property_ids) / properties_per_batch)
print("total_plans_to_update", total_plans_to_update)
print("properties_per_batch", properties_per_batch)
print("num_property_buckets", num_property_batches)
# Create task
task_id, _ = TasksInterface.create_task(
task_source="backend/plan/router.py:trigger_categorisation",
service="plan_engine",
inputs=payload.model_dump(),
task_only=True,
)
# Dispatch requests to lambdas
subtask_interface = SubTaskInterface()
for batch_index in range(num_property_batches):
start: int = batch_index * properties_per_batch
end: int = start + properties_per_batch
batch_property_ids: List[int] = property_ids[start:end]
if not batch_property_ids:
continue
batch_request: CategorisationTriggerRequest = CategorisationTriggerRequest(
portfolio_id=payload.portfolio_id,
scenarios_to_consider=payload.scenarios_to_consider,
scenario_priority_order=payload.scenario_priority_order,
min_property_id=min(batch_property_ids),
max_property_id=max(batch_property_ids),
)
# Create sub-task for each
subtask_id: UUID = subtask_interface.create_subtask(
task_id=task_id, inputs=batch_request.model_dump()
)
batch_request.subtask_id = str(subtask_id)
response = sqs_client.send_message(
QueueUrl="categorisation-queue-dev",
MessageBody=batch_request.model_dump_json(),
)
logger.info(
f"Chunk {batch_index} sent to SQS. {len(batch_property_ids)} Property IDs in batch (total {len(property_ids)}). Property IDs {min(batch_property_ids)}{max(batch_property_ids)}. Message ID: {response.get('MessageId')}"
)
await asyncio.sleep(0.05) # Small delay to avoid SQS throttling
return {"message": "Categorisation jobs distributed"}
@router.post("/trigger", status_code=202)
async def trigger_plan_entrypoint(body: PlanTriggerRequest):
"""
Entry point for triggering the plan engine via SQS.
"""
logger.info("API triggered with body: %s", body)
try:
data = body.model_dump()
except Exception as e:
logger.error("Failed to parse request body: %s", e)
return {"message": "Invalid request"}, 400
# If file_format is domna_asset_list and type is xlsx, read and chunk it
if (
data.get("file_format") == "domna_asset_list"
and data.get("file_type") == "xlsx"
):
try:
total_rows = data.get("sheet_count", 0)
chunk_size = 30
total_chunks = math.ceil(total_rows / chunk_size)
# We also need to create a new scenario and pass it to the SQS messages, if one doesn't
# exist
scenario_id = data.get("scenario_id")
if not scenario_id:
created_at = datetime.now().isoformat()
with db_session() as session:
# Create a new scenario
scenario_id = create_scenario(
session=session,
scenario={
"name": body.scenario_name,
"created_at": created_at,
"budget": body.budget,
"portfolio_id": body.portfolio_id,
"housing_type": body.housing_type,
"goal": body.goal,
"goal_value": body.goal_value,
"trigger_file_path": body.trigger_file_path,
"already_installed_file_path": body.already_installed_file_path,
"patches_file_path": body.patches_file_path,
"non_invasive_recommendations_file_path": body.non_invasive_recommendations_file_path,
"exclusions": body.exclusions,
"multi_plan": body.multi_plan,
},
)
# Insert the scenario ID into the data payload
data["scenario_id"] = scenario_id
# Create a main task
task_id, _ = TasksInterface.create_task(
task_source="backend/plan/router.py:trigger_plan_entrypoint",
service="plan_engine",
inputs=data,
task_only=True,
)
subtask_interface = SubTaskInterface()
for i in range(total_chunks):
# Create an entry in the request logs table
index_start = i * chunk_size
index_end = min((i + 1) * chunk_size, total_rows)
message_payload = {
**data,
"index_start": index_start,
"index_end": index_end,
}
# Create a subtask for this chunk
subtask_id = subtask_interface.create_subtask(
task_id=task_id, inputs=message_payload
)
# Add task and subtask to message
message_payload["task_id"] = str(task_id)
message_payload["subtask_id"] = str(subtask_id)
message_body = json.dumps(message_payload)
response = sqs_client.send_message(
QueueUrl=settings.ENGINE_SQS_URL, MessageBody=message_body
)
logger.info(
f"Chunk {i} sent to SQS. Rows {index_start}{index_end}. Message ID: {response.get('MessageId')}"
)
await asyncio.sleep(0.05) # Small delay to avoid SQS throttling
# await asyncio.sleep(random.uniform(0.1, 0.5)) # Delay to reduce spike pressure
except Exception as e:
logger.error("Error during Excel file handling: %s", e)
return {"message": "Failed to process asset list"}, 500
else:
# Fallback: Just send a single message
try:
task_id, subtask_id = TasksInterface.create_task(
task_source="backend/plan/router.py:trigger_plan_entrypoint",
service="plan_engine",
inputs=data,
task_only=False,
)
data["task_id"] = str(task_id)
data["subtask_id"] = str(subtask_id)
message_body = json.dumps(data)
response = sqs_client.send_message(
QueueUrl=settings.ENGINE_SQS_URL, MessageBody=message_body
)
logger.info(f"SQS message sent. Message ID: {response.get('MessageId')}")
except Exception as e:
logger.error("Failed to send SQS message: %s", e)
return {"message": "Failed to trigger engine"}, 500
return {"message": "Plan job accepted"}