from typing import List from uuid import UUID import boto3 import json import math import asyncio from datetime import datetime from fastapi import APIRouter, Depends from backend.app.db.connection import db_session from backend.app.db.models.tasks import SourceEnum from backend.app.dependencies import validate_token from backend.app.plan.schemas import PlanTriggerRequest from backend.app.config import get_settings from backend.categorisation.categorisation_trigger_request import ( CategorisationTriggerRequest, ) from utils.logger import setup_logger from backend.app.db.functions.recommendations_functions import ( create_scenario, get_property_ids, get_scenarios_count_by_portfolio_id, ) from backend.app.db.functions.tasks.Tasks import TasksInterface, SubTaskInterface logger = setup_logger() router = APIRouter( prefix="/plan", tags=["plan"], dependencies=[Depends(validate_token)], responses={404: {"description": "Not found"}}, ) settings = get_settings() sqs_client = boto3.client("sqs", settings.AWS_DEFAULT_REGION) @router.post("/categorisation", status_code=202) async def trigger_categorisation( body: CategorisationTriggerRequest, ) -> dict[str, str]: payload: CategorisationTriggerRequest = CategorisationTriggerRequest.model_validate( body ) logger.info("API triggered with body: %s", payload) property_ids: list[int] = get_property_ids(payload.portfolio_id) property_ids.sort() num_scenarios: int = get_scenarios_count_by_portfolio_id(payload.portfolio_id) total_plans_to_update: int = len(property_ids) * num_scenarios max_writes_per_batch: int = 1000 properties_per_batch: int = max(1, max_writes_per_batch // num_scenarios) num_property_batches: int = math.ceil(len(property_ids) / properties_per_batch) logger.info("total_plans_to_update: %s", total_plans_to_update) logger.info("properties_per_batch: %s", properties_per_batch) logger.info("num_property_batchess: %s", num_property_batches) # Create task task_id, _ = TasksInterface.create_task( task_source="backend/plan/router.py:trigger_categorisation", service="plan_categorisation", inputs=payload.model_dump(), task_only=True, source=SourceEnum.PORTFOLIO, source_id=str(payload.portfolio_id), ) # Dispatch requests to lambdas subtask_interface = SubTaskInterface() for batch_index in range(num_property_batches): start: int = batch_index * properties_per_batch end: int = start + properties_per_batch batch_property_ids: List[int] = property_ids[start:end] if not batch_property_ids: continue batch_request: CategorisationTriggerRequest = CategorisationTriggerRequest( portfolio_id=payload.portfolio_id, scenarios_to_consider=payload.scenarios_to_consider, scenario_priority_order=payload.scenario_priority_order, min_property_id=min(batch_property_ids), max_property_id=max(batch_property_ids), ) # Create sub-task for each subtask_id: UUID = subtask_interface.create_subtask( task_id=task_id, inputs=batch_request.model_dump() ) batch_request.subtask_id = str(subtask_id) response = sqs_client.send_message( QueueUrl=settings.CATEGORISATION_SQS_URL, MessageBody=batch_request.model_dump_json(), ) logger.info( f"Chunk {batch_index} sent to SQS. {len(batch_property_ids)} Property IDs in batch (total " f"{len(property_ids)}). Property IDs {min(batch_property_ids)}–{max(batch_property_ids)}. Message ID: " f"{response.get('MessageId')}" ) await asyncio.sleep(0.05) # Small delay to avoid SQS throttling return {"message": "Categorisation jobs distributed"} @router.post("/trigger", status_code=202) async def trigger_plan_entrypoint(body: PlanTriggerRequest): """ Entry point for triggering the plan engine via SQS. """ logger.info("API triggered with body: %s", body) try: data = body.model_dump() except Exception as e: logger.error("Failed to parse request body: %s", e) return {"message": "Invalid request"}, 400 # If file_format is domna_asset_list and type is xlsx, read and chunk it if ( data.get("file_format") == "domna_asset_list" and data.get("file_type") == "xlsx" ): try: total_rows = data.get("sheet_count", 0) chunk_size = 30 total_chunks = math.ceil(total_rows / chunk_size) # We also need to create a new scenario and pass it to the SQS messages, if one doesn't # exist scenario_id = data.get("scenario_id") if not scenario_id: created_at = datetime.now().isoformat() with db_session() as session: # Create a new scenario scenario_id = create_scenario( session=session, scenario={ "name": body.scenario_name, "created_at": created_at, "budget": body.budget, "portfolio_id": body.portfolio_id, "housing_type": body.housing_type, "goal": body.goal, "goal_value": body.goal_value, "trigger_file_path": body.trigger_file_path, "already_installed_file_path": body.already_installed_file_path, "patches_file_path": body.patches_file_path, "non_invasive_recommendations_file_path": body.non_invasive_recommendations_file_path, "exclusions": body.exclusions, "multi_plan": body.multi_plan, }, ) # Insert the scenario ID into the data payload data["scenario_id"] = scenario_id # Create a main task task_id, _ = TasksInterface.create_task( task_source="backend/plan/router.py:trigger_plan_entrypoint", service="plan_engine", inputs=data, task_only=True, source=SourceEnum.PORTFOLIO, source_id=str(body.portfolio_id), ) subtask_interface = SubTaskInterface() for i in range(total_chunks): # Create an entry in the request logs table index_start = i * chunk_size index_end = min((i + 1) * chunk_size, total_rows) message_payload = { **data, "index_start": index_start, "index_end": index_end, } # Create a subtask for this chunk subtask_id = subtask_interface.create_subtask( task_id=task_id, inputs=message_payload ) # Add task and subtask to message message_payload["task_id"] = str(task_id) message_payload["subtask_id"] = str(subtask_id) message_body = json.dumps(message_payload) response = sqs_client.send_message( QueueUrl=settings.ENGINE_SQS_URL, MessageBody=message_body ) logger.info( f"Chunk {i} sent to SQS. Rows {index_start}–{index_end}. Message ID: {response.get('MessageId')}" ) await asyncio.sleep(0.05) # Small delay to avoid SQS throttling # await asyncio.sleep(random.uniform(0.1, 0.5)) # Delay to reduce spike pressure except Exception as e: logger.error("Error during Excel file handling: %s", e) return {"message": "Failed to process asset list"}, 500 else: # Fallback: Just send a single message try: task_id, subtask_id = TasksInterface.create_task( task_source="backend/plan/router.py:trigger_plan_entrypoint", service="plan_engine", inputs=data, task_only=False, source=SourceEnum.PORTFOLIO, source_id=str(body.portfolio_id), ) data["task_id"] = str(task_id) data["subtask_id"] = str(subtask_id) message_body = json.dumps(data) response = sqs_client.send_message( QueueUrl=settings.ENGINE_SQS_URL, MessageBody=message_body ) logger.info(f"SQS message sent. Message ID: {response.get('MessageId')}") except Exception as e: logger.error("Failed to send SQS message: %s", e) return {"message": "Failed to trigger engine"}, 500 return {"message": "Plan job accepted"}