mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
200 lines
7.3 KiB
Python
200 lines
7.3 KiB
Python
from typing import List
|
||
|
||
import boto3
|
||
import json
|
||
import math
|
||
import asyncio
|
||
from contextlib import contextmanager
|
||
from sqlmodel import Session
|
||
|
||
from datetime import datetime
|
||
|
||
from fastapi import APIRouter, Depends
|
||
from backend.app.dependencies import validate_token
|
||
from backend.app.plan.schemas import PlanTriggerRequest
|
||
from backend.app.config import get_settings
|
||
from sqlalchemy.orm import sessionmaker
|
||
from backend.categorisation.categorisation_trigger_request import (
|
||
CategorisationTriggerRequest,
|
||
)
|
||
from utils.logger import setup_logger
|
||
from backend.app.db.connection import db_engine
|
||
|
||
from backend.app.db.functions.recommendations_functions import (
|
||
create_scenario,
|
||
get_property_ids,
|
||
get_scenarios_count_by_portfolio_id,
|
||
)
|
||
from backend.app.db.functions.tasks.Tasks import TasksInterface, SubTaskInterface
|
||
|
||
logger = setup_logger()
|
||
|
||
router = APIRouter(
|
||
prefix="/plan",
|
||
tags=["plan"],
|
||
dependencies=[Depends(validate_token)],
|
||
responses={404: {"description": "Not found"}},
|
||
)
|
||
|
||
sqs_client = boto3.client("sqs")
|
||
|
||
|
||
@router.post("/categorisation", status_code=202)
|
||
async def trigger_categorisation(
|
||
body: CategorisationTriggerRequest,
|
||
) -> dict[str, int]:
|
||
payload: CategorisationTriggerRequest = CategorisationTriggerRequest.model_validate(
|
||
body
|
||
)
|
||
|
||
logger.info("API triggered with body: %s", payload)
|
||
|
||
property_ids: List[int] = get_property_ids(payload.portfolio_id)
|
||
property_ids.sort()
|
||
|
||
num_scenarios: int = get_scenarios_count_by_portfolio_id(payload.portfolio_id)
|
||
batch_size: int = math.ceil(1000 / num_scenarios)
|
||
num_property_buckets: int = max(1, math.ceil(len(property_ids) / batch_size))
|
||
|
||
bucket_requests: List[CategorisationTriggerRequest] = []
|
||
|
||
for bucket_index in range(num_property_buckets):
|
||
bucket_property_ids: List[int] = [
|
||
pid for pid in property_ids if pid % num_property_buckets == bucket_index
|
||
]
|
||
bucket_request: CategorisationTriggerRequest = CategorisationTriggerRequest(
|
||
portfolio_id=payload.portfolio_id,
|
||
scenarios_to_consider=payload.scenarios_to_consider,
|
||
scenario_priority_order=payload.scenario_priority_order,
|
||
min_property_id=min(bucket_property_ids),
|
||
max_property_id=max(bucket_property_ids),
|
||
)
|
||
|
||
bucket_requests.append(bucket_request)
|
||
|
||
# Dispatch requests to lambdas
|
||
|
||
return {"num_buckets": len(bucket_requests)}
|
||
|
||
|
||
@router.post("/trigger", status_code=202)
|
||
async def trigger_plan_entrypoint(body: PlanTriggerRequest):
|
||
"""
|
||
Entry point for triggering the plan engine via SQS.
|
||
"""
|
||
logger.info("API triggered with body: %s", body)
|
||
|
||
settings = get_settings()
|
||
|
||
try:
|
||
data = body.model_dump()
|
||
except Exception as e:
|
||
logger.error("Failed to parse request body: %s", e)
|
||
return {"message": "Invalid request"}, 400
|
||
|
||
# If file_format is domna_asset_list and type is xlsx, read and chunk it
|
||
if (
|
||
data.get("file_format") == "domna_asset_list"
|
||
and data.get("file_type") == "xlsx"
|
||
):
|
||
try:
|
||
|
||
total_rows = data.get("sheet_count", 0)
|
||
chunk_size = 30
|
||
total_chunks = math.ceil(total_rows / chunk_size)
|
||
|
||
# We also need to create a new scenario and pass it to the SQS messages, if one doesn't
|
||
# exist
|
||
scenario_id = data.get("scenario_id")
|
||
if not scenario_id:
|
||
created_at = datetime.now().isoformat()
|
||
with db_session() as session:
|
||
# Create a new scenario
|
||
scenario_id = create_scenario(
|
||
session=session,
|
||
scenario={
|
||
"name": body.scenario_name,
|
||
"created_at": created_at,
|
||
"budget": body.budget,
|
||
"portfolio_id": body.portfolio_id,
|
||
"housing_type": body.housing_type,
|
||
"goal": body.goal,
|
||
"goal_value": body.goal_value,
|
||
"trigger_file_path": body.trigger_file_path,
|
||
"already_installed_file_path": body.already_installed_file_path,
|
||
"patches_file_path": body.patches_file_path,
|
||
"non_invasive_recommendations_file_path": body.non_invasive_recommendations_file_path,
|
||
"exclusions": body.exclusions,
|
||
"multi_plan": body.multi_plan,
|
||
},
|
||
)
|
||
# Insert the scenario ID into the data payload
|
||
data["scenario_id"] = scenario_id
|
||
|
||
# Create a main task
|
||
task_id, _ = TasksInterface.create_task(
|
||
task_source="backend/plan/router.py:trigger_plan_entrypoint",
|
||
service="plan_engine",
|
||
inputs=data,
|
||
task_only=True,
|
||
)
|
||
|
||
subtask_interface = SubTaskInterface()
|
||
for i in range(total_chunks):
|
||
# Create an entry in the request logs table
|
||
index_start = i * chunk_size
|
||
index_end = min((i + 1) * chunk_size, total_rows)
|
||
|
||
message_payload = {
|
||
**data,
|
||
"index_start": index_start,
|
||
"index_end": index_end,
|
||
}
|
||
|
||
# Create a subtask for this chunk
|
||
subtask_id = subtask_interface.create_subtask(
|
||
task_id=task_id, inputs=message_payload
|
||
)
|
||
|
||
# Add task and subtask to message
|
||
message_payload["task_id"] = str(task_id)
|
||
message_payload["subtask_id"] = str(subtask_id)
|
||
|
||
message_body = json.dumps(message_payload)
|
||
|
||
response = sqs_client.send_message(
|
||
QueueUrl=settings.ENGINE_SQS_URL, MessageBody=message_body
|
||
)
|
||
logger.info(
|
||
f"Chunk {i} sent to SQS. Rows {index_start}–{index_end}. Message ID: {response.get('MessageId')}"
|
||
)
|
||
|
||
await asyncio.sleep(0.05) # Small delay to avoid SQS throttling
|
||
|
||
# await asyncio.sleep(random.uniform(0.1, 0.5)) # Delay to reduce spike pressure
|
||
|
||
except Exception as e:
|
||
logger.error("Error during Excel file handling: %s", e)
|
||
return {"message": "Failed to process asset list"}, 500
|
||
|
||
else:
|
||
# Fallback: Just send a single message
|
||
try:
|
||
task_id, subtask_id = TasksInterface.create_task(
|
||
task_source="backend/plan/router.py:trigger_plan_entrypoint",
|
||
service="plan_engine",
|
||
inputs=data,
|
||
task_only=False,
|
||
)
|
||
data["task_id"] = str(task_id)
|
||
data["subtask_id"] = str(subtask_id)
|
||
message_body = json.dumps(data)
|
||
response = sqs_client.send_message(
|
||
QueueUrl=settings.ENGINE_SQS_URL, MessageBody=message_body
|
||
)
|
||
logger.info(f"SQS message sent. Message ID: {response.get('MessageId')}")
|
||
except Exception as e:
|
||
logger.error("Failed to send SQS message: %s", e)
|
||
return {"message": "Failed to trigger engine"}, 500
|
||
|
||
return {"message": "Plan job accepted"}
|