Model/backend/app/plan/router.py

152 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import boto3
import json
import math
import asyncio
import random
from datetime import datetime
from fastapi import APIRouter, Depends
from backend.app.dependencies import validate_token
from backend.app.plan.schemas import PlanTriggerRequest
from backend.app.config import get_settings
from sqlalchemy.orm import sessionmaker
from utils.logger import setup_logger
from backend.app.db.connection import db_engine
from backend.app.db.functions.recommendations_functions import create_scenario
from backend.app.db.functions.tasks.Tasks import TasksInterface, SubTaskInterface
logger = setup_logger()
router = APIRouter(
prefix="/plan",
tags=["plan"],
dependencies=[Depends(validate_token)],
responses={404: {"description": "Not found"}}
)
sqs_client = boto3.client("sqs")
@router.post("/trigger", status_code=202)
async def trigger_plan_entrypoint(body: PlanTriggerRequest):
"""
Entry point for triggering the plan engine via SQS.
"""
logger.info("API triggered with body: %s", body)
settings = get_settings()
try:
data = body.model_dump()
except Exception as e:
logger.error("Failed to parse request body: %s", e)
return {"message": "Invalid request"}, 400
# If file_format is domna_asset_list and type is xlsx, read and chunk it
if data.get("file_format") == "domna_asset_list" and data.get("file_type") == "xlsx":
try:
total_rows = data.get("sheet_count", 0)
chunk_size = 30
total_chunks = math.ceil(total_rows / chunk_size)
# We also need to create a new scenario and pass it to the SQS messages, if one doesn't
# exist
scenario_id = data.get("scenario_id")
if not scenario_id:
created_at = datetime.now().isoformat()
session = sessionmaker(bind=db_engine)()
# Create a new scenario
new_scenario = create_scenario(
session=session,
scenario={
"name": body.scenario_name,
"created_at": created_at,
"budget": body.budget,
"portfolio_id": body.portfolio_id,
"housing_type": body.housing_type,
"goal": body.goal,
"goal_value": body.goal_value,
"trigger_file_path": body.trigger_file_path,
"already_installed_file_path": body.already_installed_file_path,
"patches_file_path": body.patches_file_path,
"non_invasive_recommendations_file_path": body.non_invasive_recommendations_file_path,
"exclusions": body.exclusions,
"multi_plan": body.multi_plan
}
)
scenario_id = new_scenario.id
# Insert the scenario ID into the data payload
data["scenario_id"] = scenario_id
# Create a main task
task_id, _ = TasksInterface.create_task(
task_source="backend/plan/router.py:trigger_plan_entrypoint",
service="plan_engine",
inputs=data,
task_only=True
)
subtask_interface = SubTaskInterface()
for i in range(total_chunks):
# Create an entry in the request logs table
index_start = i * chunk_size
index_end = min((i + 1) * chunk_size, total_rows)
message_payload = {
**data, "index_start": index_start, "index_end": index_end,
}
# Create a subtask for this chunk
subtask_id = subtask_interface.create_subtask(
task_id=task_id,
inputs=message_payload
)
# Add task and subtask to message
message_payload["task_id"] = str(task_id)
message_payload["subtask_id"] = str(subtask_id)
message_body = json.dumps(message_payload)
response = sqs_client.send_message(
QueueUrl=settings.ENGINE_SQS_URL,
MessageBody=message_body
)
logger.info(
f"Chunk {i} sent to SQS. Rows {index_start}{index_end}. Message ID: {response.get('MessageId')}"
)
await asyncio.sleep(0.1) # Small delay to avoid SQS throttling
# await asyncio.sleep(random.uniform(0.1, 0.5)) # Delay to reduce spike pressure
except Exception as e:
logger.error("Error during Excel file handling: %s", e)
return {"message": "Failed to process asset list"}, 500
else:
# Fallback: Just send a single message
try:
task_id, subtask_id = TasksInterface.create_task(
task_source="backend/plan/router.py:trigger_plan_entrypoint",
service="plan_engine",
inputs=data,
task_only=False,
)
data["task_id"] = str(task_id)
data["subtask_id"] = str(subtask_id)
message_body = json.dumps(data)
response = sqs_client.send_message(
QueueUrl=settings.ENGINE_SQS_URL,
MessageBody=message_body
)
logger.info(f"SQS message sent. Message ID: {response.get('MessageId')}")
except Exception as e:
logger.error("Failed to send SQS message: %s", e)
return {"message": "Failed to trigger engine"}, 500
return {"message": "Plan job accepted"}