added validation of indexes to PlanTriggerRequest

This commit is contained in:
Khalim Conn-Kowlessar 2025-07-22 09:36:20 +01:00
parent ac139174b9
commit 5e717b73b2
3 changed files with 101 additions and 16 deletions

View file

@ -1,9 +1,19 @@
import boto3
import json
import math
from datetime import datetime
import pandas as pd
from fastapi import APIRouter, Depends
from backend.app.dependencies import validate_token
from backend.app.plan.schemas import PlanTriggerRequest
from backend.app.config import get_settings
from sqlalchemy.orm import sessionmaker
from utils.logger import setup_logger
from utils.s3 import read_excel_from_s3
from backend.app.db.connection import db_engine
from backend.app.db.functions.recommendations_functions import create_scenario
logger = setup_logger()
@ -26,21 +36,86 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest):
settings = get_settings()
# Serialize the PlanTriggerRequest into JSON
try:
message_body = body.model_dump_json()
data = body.model_dump()
except Exception as e:
logger.error("Failed to serialize request body: %s", e)
logger.error("Failed to parse request body: %s", e)
return {"message": "Invalid request"}, 400
try:
response = sqs_client.send_message(
QueueUrl=settings.ENGINE_SQS_URL,
MessageBody=message_body
)
logger.info(f"SQS message sent. Message ID: {response.get('MessageId')}")
except Exception as e:
logger.error("Failed to send SQS message: %s", e)
return {"message": "Failed to trigger engine"}, 500
# If file_format is domna_asset_list and type is xlsx, read and chunk it
if data.get("file_format") == "domna_asset_list" and data.get("file_type") == "xlsx":
try:
input_data: pd.DataFrame = read_excel_from_s3(
bucket_name=settings.PLAN_TRIGGER_BUCKET,
file_key=data.get("trigger_file_path"),
sheet_name=data.get("sheet_name"),
header_row=0
)
total_rows = len(input_data)
chunk_size = 30
total_chunks = math.ceil(total_rows / chunk_size)
# We also need to create a new scenario and pass it to the SQS messages, if one doesn't
# exist
scenario_id = data.get("scenario_id")
if not scenario_id:
created_at = datetime.now().isoformat()
session = sessionmaker(bind=db_engine)()
# Create a new scenario
new_scenario = create_scenario(
session=session,
scenario={
"name": body.scenario_name,
"created_at": created_at,
"budget": body.budget,
"portfolio_id": body.portfolio_id,
"housing_type": body.housing_type,
"goal": body.goal,
"goal_value": body.goal_value,
"trigger_file_path": body.trigger_file_path,
"already_installed_file_path": body.already_installed_file_path,
"patches_file_path": body.patches_file_path,
"non_invasive_recommendations_file_path": body.non_invasive_recommendations_file_path,
"exclusions": body.exclusions,
"multi_plan": body.multi_plan
}
)
scenario_id = new_scenario.id
# Insert the scenario ID into the data payload
data["scenario_id"] = scenario_id
for i in range(total_chunks):
index_start = i * chunk_size
index_end = min((i + 1) * chunk_size, total_rows)
message_payload = {**data, "index_start": index_start, "index_end": index_end}
message_body = json.dumps(message_payload)
response = sqs_client.send_message(
QueueUrl=settings.ENGINE_SQS_URL,
MessageBody=message_body
)
logger.info(
f"Chunk {i} sent to SQS. Rows {index_start}{index_end}. Message ID: {response.get('MessageId')}")
except Exception as e:
logger.error("Error during Excel file handling: %s", e)
return {"message": "Failed to process asset list"}, 500
else:
# Fallback: Just send a single message
try:
message_body = json.dumps(data)
response = sqs_client.send_message(
QueueUrl=settings.ENGINE_SQS_URL,
MessageBody=message_body
)
logger.info(f"SQS message sent. Message ID: {response.get('MessageId')}")
except Exception as e:
logger.error("Failed to send SQS message: %s", e)
return {"message": "Failed to trigger engine"}, 500
return {"message": "Plan job accepted"}

View file

@ -1,4 +1,4 @@
from pydantic import BaseModel, Field, BeforeValidator
from pydantic import BaseModel, Field, BeforeValidator, model_validator
from typing import Annotated, List, Optional, Literal
# Example constants for validation
@ -104,7 +104,16 @@ class PlanTriggerRequest(BaseModel):
simulate_sap_10: Optional[bool] = False
# Add in optional fields which describe the format of the asset list being used
file_type: Optional[Literal["csv", "xlsx"]] = None,
file_format: Optional[Literal["domna_asset_list"]] = None,
sheet_name: Optional[str] = None
# If one of index_start or index_end is set, the other must be set too
index_start: Optional[int] = None
index_end: Optional[int] = None
@model_validator(mode="after")
def check_indexes(self):
if (self.index_start is None) != (self.index_end is None):
raise ValueError("Both index_start and index_end must be set or both must be None")
return self

View file

@ -198,7 +198,7 @@ def read_pickle_from_s3(bucket_name, s3_file_name):
return data
def read_excel_from_s3(bucket_name, file_key, header_row, drop_all_na=True):
def read_excel_from_s3(bucket_name, file_key, header_row, drop_all_na=True, sheet_name=None):
"""
Read an Excel file from an S3 bucket and return it as a pandas DataFrame.
@ -206,6 +206,7 @@ def read_excel_from_s3(bucket_name, file_key, header_row, drop_all_na=True):
:param file_key: Key of the file (including directory path within the bucket).
:param header_row: The row number to use as the header (0-indexed).
:param drop_all_na: Whether to drop columns where all values are NaN.
:param sheet_name: The name of the sheet to read from the Excel file. If None, reads the first sheet.
:return: A pandas DataFrame containing the data from the Excel file.
"""
@ -217,7 +218,7 @@ def read_excel_from_s3(bucket_name, file_key, header_row, drop_all_na=True):
excel_buffer = read_io_from_s3(bucket_name, file_key)
# Read the Excel file into a pandas DataFrame
df = pd.read_excel(excel_buffer, header=header_row)
df = pd.read_excel(excel_buffer, header=header_row, sheet_name=sheet_name)
# Drop columns where all values are NaN
if drop_all_na: