mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
152 lines
5.6 KiB
Python
152 lines
5.6 KiB
Python
import boto3
|
||
import json
|
||
import math
|
||
import asyncio
|
||
import random
|
||
|
||
from datetime import datetime
|
||
|
||
from fastapi import APIRouter, Depends
|
||
from backend.app.dependencies import validate_token
|
||
from backend.app.plan.schemas import PlanTriggerRequest
|
||
from backend.app.config import get_settings
|
||
from sqlalchemy.orm import sessionmaker
|
||
from utils.logger import setup_logger
|
||
from backend.app.db.connection import db_engine
|
||
|
||
from backend.app.db.functions.recommendations_functions import create_scenario
|
||
from backend.app.db.functions.tasks.Tasks import TasksInterface, SubTaskInterface
|
||
|
||
logger = setup_logger()
|
||
|
||
router = APIRouter(
|
||
prefix="/plan",
|
||
tags=["plan"],
|
||
dependencies=[Depends(validate_token)],
|
||
responses={404: {"description": "Not found"}}
|
||
)
|
||
|
||
sqs_client = boto3.client("sqs")
|
||
|
||
|
||
@router.post("/trigger", status_code=202)
|
||
async def trigger_plan_entrypoint(body: PlanTriggerRequest):
|
||
"""
|
||
Entry point for triggering the plan engine via SQS.
|
||
"""
|
||
logger.info("API triggered with body: %s", body)
|
||
|
||
settings = get_settings()
|
||
|
||
try:
|
||
data = body.model_dump()
|
||
except Exception as e:
|
||
logger.error("Failed to parse request body: %s", e)
|
||
return {"message": "Invalid request"}, 400
|
||
|
||
# If file_format is domna_asset_list and type is xlsx, read and chunk it
|
||
if data.get("file_format") == "domna_asset_list" and data.get("file_type") == "xlsx":
|
||
try:
|
||
|
||
total_rows = data.get("sheet_count", 0)
|
||
chunk_size = 30
|
||
total_chunks = math.ceil(total_rows / chunk_size)
|
||
|
||
# We also need to create a new scenario and pass it to the SQS messages, if one doesn't
|
||
# exist
|
||
scenario_id = data.get("scenario_id")
|
||
if not scenario_id:
|
||
created_at = datetime.now().isoformat()
|
||
session = sessionmaker(bind=db_engine)()
|
||
|
||
# Create a new scenario
|
||
new_scenario = create_scenario(
|
||
session=session,
|
||
scenario={
|
||
"name": body.scenario_name,
|
||
"created_at": created_at,
|
||
"budget": body.budget,
|
||
"portfolio_id": body.portfolio_id,
|
||
"housing_type": body.housing_type,
|
||
"goal": body.goal,
|
||
"goal_value": body.goal_value,
|
||
"trigger_file_path": body.trigger_file_path,
|
||
"already_installed_file_path": body.already_installed_file_path,
|
||
"patches_file_path": body.patches_file_path,
|
||
"non_invasive_recommendations_file_path": body.non_invasive_recommendations_file_path,
|
||
"exclusions": body.exclusions,
|
||
"multi_plan": body.multi_plan
|
||
}
|
||
)
|
||
scenario_id = new_scenario.id
|
||
# Insert the scenario ID into the data payload
|
||
data["scenario_id"] = scenario_id
|
||
|
||
# Create a main task
|
||
task_id, _ = TasksInterface.create_task(
|
||
task_source="backend/plan/router.py:trigger_plan_entrypoint",
|
||
service="plan_engine",
|
||
inputs=data,
|
||
task_only=True
|
||
)
|
||
|
||
subtask_interface = SubTaskInterface()
|
||
for i in range(total_chunks):
|
||
# Create an entry in the request logs table
|
||
index_start = i * chunk_size
|
||
index_end = min((i + 1) * chunk_size, total_rows)
|
||
|
||
message_payload = {
|
||
**data, "index_start": index_start, "index_end": index_end,
|
||
}
|
||
|
||
# Create a subtask for this chunk
|
||
subtask_id = subtask_interface.create_subtask(
|
||
task_id=task_id,
|
||
inputs=message_payload
|
||
)
|
||
|
||
# Add task and subtask to message
|
||
message_payload["task_id"] = str(task_id)
|
||
message_payload["subtask_id"] = str(subtask_id)
|
||
|
||
message_body = json.dumps(message_payload)
|
||
|
||
response = sqs_client.send_message(
|
||
QueueUrl=settings.ENGINE_SQS_URL,
|
||
MessageBody=message_body
|
||
)
|
||
logger.info(
|
||
f"Chunk {i} sent to SQS. Rows {index_start}–{index_end}. Message ID: {response.get('MessageId')}"
|
||
)
|
||
|
||
await asyncio.sleep(0.1) # Small delay to avoid SQS throttling
|
||
|
||
# await asyncio.sleep(random.uniform(0.1, 0.5)) # Delay to reduce spike pressure
|
||
|
||
except Exception as e:
|
||
logger.error("Error during Excel file handling: %s", e)
|
||
return {"message": "Failed to process asset list"}, 500
|
||
|
||
else:
|
||
# Fallback: Just send a single message
|
||
try:
|
||
task_id, subtask_id = TasksInterface.create_task(
|
||
task_source="backend/plan/router.py:trigger_plan_entrypoint",
|
||
service="plan_engine",
|
||
inputs=data,
|
||
task_only=False,
|
||
)
|
||
data["task_id"] = str(task_id)
|
||
data["subtask_id"] = str(subtask_id)
|
||
message_body = json.dumps(data)
|
||
response = sqs_client.send_message(
|
||
QueueUrl=settings.ENGINE_SQS_URL,
|
||
MessageBody=message_body
|
||
)
|
||
logger.info(f"SQS message sent. Message ID: {response.get('MessageId')}")
|
||
except Exception as e:
|
||
logger.error("Failed to send SQS message: %s", e)
|
||
return {"message": "Failed to trigger engine"}, 500
|
||
|
||
return {"message": "Plan job accepted"}
|