From 5e717b73b2061cb8a3cc14d620fc39e8f3d0ecfb Mon Sep 17 00:00:00 2001 From: Khalim Conn-Kowlessar Date: Tue, 22 Jul 2025 09:36:20 +0100 Subject: [PATCH] added validation of indexes to PlanTriggerRequest --- backend/app/plan/router.py | 99 ++++++++++++++++++++++++++++++++----- backend/app/plan/schemas.py | 13 ++++- utils/s3.py | 5 +- 3 files changed, 101 insertions(+), 16 deletions(-) diff --git a/backend/app/plan/router.py b/backend/app/plan/router.py index a9979e31..cd73cce3 100644 --- a/backend/app/plan/router.py +++ b/backend/app/plan/router.py @@ -1,9 +1,19 @@ import boto3 +import json +import math +from datetime import datetime + +import pandas as pd from fastapi import APIRouter, Depends from backend.app.dependencies import validate_token from backend.app.plan.schemas import PlanTriggerRequest from backend.app.config import get_settings +from sqlalchemy.orm import sessionmaker from utils.logger import setup_logger +from utils.s3 import read_excel_from_s3 +from backend.app.db.connection import db_engine + +from backend.app.db.functions.recommendations_functions import create_scenario logger = setup_logger() @@ -26,21 +36,86 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest): settings = get_settings() - # Serialize the PlanTriggerRequest into JSON try: - message_body = body.model_dump_json() + data = body.model_dump() except Exception as e: - logger.error("Failed to serialize request body: %s", e) + logger.error("Failed to parse request body: %s", e) return {"message": "Invalid request"}, 400 - try: - response = sqs_client.send_message( - QueueUrl=settings.ENGINE_SQS_URL, - MessageBody=message_body - ) - logger.info(f"SQS message sent. Message ID: {response.get('MessageId')}") - except Exception as e: - logger.error("Failed to send SQS message: %s", e) - return {"message": "Failed to trigger engine"}, 500 + # If file_format is domna_asset_list and type is xlsx, read and chunk it + if data.get("file_format") == "domna_asset_list" and data.get("file_type") == "xlsx": + try: + input_data: pd.DataFrame = read_excel_from_s3( + bucket_name=settings.PLAN_TRIGGER_BUCKET, + file_key=data.get("trigger_file_path"), + sheet_name=data.get("sheet_name"), + header_row=0 + ) + + total_rows = len(input_data) + chunk_size = 30 + total_chunks = math.ceil(total_rows / chunk_size) + + # We also need to create a new scenario and pass it to the SQS messages, if one doesn't + # exist + scenario_id = data.get("scenario_id") + if not scenario_id: + created_at = datetime.now().isoformat() + session = sessionmaker(bind=db_engine)() + + # Create a new scenario + new_scenario = create_scenario( + session=session, + scenario={ + "name": body.scenario_name, + "created_at": created_at, + "budget": body.budget, + "portfolio_id": body.portfolio_id, + "housing_type": body.housing_type, + "goal": body.goal, + "goal_value": body.goal_value, + "trigger_file_path": body.trigger_file_path, + "already_installed_file_path": body.already_installed_file_path, + "patches_file_path": body.patches_file_path, + "non_invasive_recommendations_file_path": body.non_invasive_recommendations_file_path, + "exclusions": body.exclusions, + "multi_plan": body.multi_plan + } + ) + scenario_id = new_scenario.id + # Insert the scenario ID into the data payload + data["scenario_id"] = scenario_id + + for i in range(total_chunks): + index_start = i * chunk_size + index_end = min((i + 1) * chunk_size, total_rows) + + message_payload = {**data, "index_start": index_start, "index_end": index_end} + + message_body = json.dumps(message_payload) + + response = sqs_client.send_message( + QueueUrl=settings.ENGINE_SQS_URL, + MessageBody=message_body + ) + logger.info( + f"Chunk {i} sent to SQS. Rows {index_start}–{index_end}. Message ID: {response.get('MessageId')}") + + except Exception as e: + logger.error("Error during Excel file handling: %s", e) + return {"message": "Failed to process asset list"}, 500 + + else: + # Fallback: Just send a single message + try: + message_body = json.dumps(data) + response = sqs_client.send_message( + QueueUrl=settings.ENGINE_SQS_URL, + MessageBody=message_body + ) + logger.info(f"SQS message sent. Message ID: {response.get('MessageId')}") + except Exception as e: + logger.error("Failed to send SQS message: %s", e) + return {"message": "Failed to trigger engine"}, 500 return {"message": "Plan job accepted"} diff --git a/backend/app/plan/schemas.py b/backend/app/plan/schemas.py index 2a388b2f..85a48a6f 100644 --- a/backend/app/plan/schemas.py +++ b/backend/app/plan/schemas.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field, BeforeValidator +from pydantic import BaseModel, Field, BeforeValidator, model_validator from typing import Annotated, List, Optional, Literal # Example constants for validation @@ -104,7 +104,16 @@ class PlanTriggerRequest(BaseModel): simulate_sap_10: Optional[bool] = False # Add in optional fields which describe the format of the asset list being used - + file_type: Optional[Literal["csv", "xlsx"]] = None, file_format: Optional[Literal["domna_asset_list"]] = None, sheet_name: Optional[str] = None + # If one of index_start or index_end is set, the other must be set too + index_start: Optional[int] = None + index_end: Optional[int] = None + + @model_validator(mode="after") + def check_indexes(self): + if (self.index_start is None) != (self.index_end is None): + raise ValueError("Both index_start and index_end must be set or both must be None") + return self diff --git a/utils/s3.py b/utils/s3.py index 1a686b55..e70669d0 100644 --- a/utils/s3.py +++ b/utils/s3.py @@ -198,7 +198,7 @@ def read_pickle_from_s3(bucket_name, s3_file_name): return data -def read_excel_from_s3(bucket_name, file_key, header_row, drop_all_na=True): +def read_excel_from_s3(bucket_name, file_key, header_row, drop_all_na=True, sheet_name=None): """ Read an Excel file from an S3 bucket and return it as a pandas DataFrame. @@ -206,6 +206,7 @@ def read_excel_from_s3(bucket_name, file_key, header_row, drop_all_na=True): :param file_key: Key of the file (including directory path within the bucket). :param header_row: The row number to use as the header (0-indexed). :param drop_all_na: Whether to drop columns where all values are NaN. + :param sheet_name: The name of the sheet to read from the Excel file. If None, reads the first sheet. :return: A pandas DataFrame containing the data from the Excel file. """ @@ -217,7 +218,7 @@ def read_excel_from_s3(bucket_name, file_key, header_row, drop_all_na=True): excel_buffer = read_io_from_s3(bucket_name, file_key) # Read the Excel file into a pandas DataFrame - df = pd.read_excel(excel_buffer, header=header_row) + df = pd.read_excel(excel_buffer, header=header_row, sheet_name=sheet_name) # Drop columns where all values are NaN if drop_all_na: