mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
added validation of indexes to PlanTriggerRequest
This commit is contained in:
parent
ac139174b9
commit
5e717b73b2
3 changed files with 101 additions and 16 deletions
|
|
@ -1,9 +1,19 @@
|
|||
import boto3
|
||||
import json
|
||||
import math
|
||||
from datetime import datetime
|
||||
|
||||
import pandas as pd
|
||||
from fastapi import APIRouter, Depends
|
||||
from backend.app.dependencies import validate_token
|
||||
from backend.app.plan.schemas import PlanTriggerRequest
|
||||
from backend.app.config import get_settings
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from utils.logger import setup_logger
|
||||
from utils.s3 import read_excel_from_s3
|
||||
from backend.app.db.connection import db_engine
|
||||
|
||||
from backend.app.db.functions.recommendations_functions import create_scenario
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
|
@ -26,21 +36,86 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest):
|
|||
|
||||
settings = get_settings()
|
||||
|
||||
# Serialize the PlanTriggerRequest into JSON
|
||||
try:
|
||||
message_body = body.model_dump_json()
|
||||
data = body.model_dump()
|
||||
except Exception as e:
|
||||
logger.error("Failed to serialize request body: %s", e)
|
||||
logger.error("Failed to parse request body: %s", e)
|
||||
return {"message": "Invalid request"}, 400
|
||||
|
||||
try:
|
||||
response = sqs_client.send_message(
|
||||
QueueUrl=settings.ENGINE_SQS_URL,
|
||||
MessageBody=message_body
|
||||
)
|
||||
logger.info(f"SQS message sent. Message ID: {response.get('MessageId')}")
|
||||
except Exception as e:
|
||||
logger.error("Failed to send SQS message: %s", e)
|
||||
return {"message": "Failed to trigger engine"}, 500
|
||||
# If file_format is domna_asset_list and type is xlsx, read and chunk it
|
||||
if data.get("file_format") == "domna_asset_list" and data.get("file_type") == "xlsx":
|
||||
try:
|
||||
input_data: pd.DataFrame = read_excel_from_s3(
|
||||
bucket_name=settings.PLAN_TRIGGER_BUCKET,
|
||||
file_key=data.get("trigger_file_path"),
|
||||
sheet_name=data.get("sheet_name"),
|
||||
header_row=0
|
||||
)
|
||||
|
||||
total_rows = len(input_data)
|
||||
chunk_size = 30
|
||||
total_chunks = math.ceil(total_rows / chunk_size)
|
||||
|
||||
# We also need to create a new scenario and pass it to the SQS messages, if one doesn't
|
||||
# exist
|
||||
scenario_id = data.get("scenario_id")
|
||||
if not scenario_id:
|
||||
created_at = datetime.now().isoformat()
|
||||
session = sessionmaker(bind=db_engine)()
|
||||
|
||||
# Create a new scenario
|
||||
new_scenario = create_scenario(
|
||||
session=session,
|
||||
scenario={
|
||||
"name": body.scenario_name,
|
||||
"created_at": created_at,
|
||||
"budget": body.budget,
|
||||
"portfolio_id": body.portfolio_id,
|
||||
"housing_type": body.housing_type,
|
||||
"goal": body.goal,
|
||||
"goal_value": body.goal_value,
|
||||
"trigger_file_path": body.trigger_file_path,
|
||||
"already_installed_file_path": body.already_installed_file_path,
|
||||
"patches_file_path": body.patches_file_path,
|
||||
"non_invasive_recommendations_file_path": body.non_invasive_recommendations_file_path,
|
||||
"exclusions": body.exclusions,
|
||||
"multi_plan": body.multi_plan
|
||||
}
|
||||
)
|
||||
scenario_id = new_scenario.id
|
||||
# Insert the scenario ID into the data payload
|
||||
data["scenario_id"] = scenario_id
|
||||
|
||||
for i in range(total_chunks):
|
||||
index_start = i * chunk_size
|
||||
index_end = min((i + 1) * chunk_size, total_rows)
|
||||
|
||||
message_payload = {**data, "index_start": index_start, "index_end": index_end}
|
||||
|
||||
message_body = json.dumps(message_payload)
|
||||
|
||||
response = sqs_client.send_message(
|
||||
QueueUrl=settings.ENGINE_SQS_URL,
|
||||
MessageBody=message_body
|
||||
)
|
||||
logger.info(
|
||||
f"Chunk {i} sent to SQS. Rows {index_start}–{index_end}. Message ID: {response.get('MessageId')}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error during Excel file handling: %s", e)
|
||||
return {"message": "Failed to process asset list"}, 500
|
||||
|
||||
else:
|
||||
# Fallback: Just send a single message
|
||||
try:
|
||||
message_body = json.dumps(data)
|
||||
response = sqs_client.send_message(
|
||||
QueueUrl=settings.ENGINE_SQS_URL,
|
||||
MessageBody=message_body
|
||||
)
|
||||
logger.info(f"SQS message sent. Message ID: {response.get('MessageId')}")
|
||||
except Exception as e:
|
||||
logger.error("Failed to send SQS message: %s", e)
|
||||
return {"message": "Failed to trigger engine"}, 500
|
||||
|
||||
return {"message": "Plan job accepted"}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from pydantic import BaseModel, Field, BeforeValidator
|
||||
from pydantic import BaseModel, Field, BeforeValidator, model_validator
|
||||
from typing import Annotated, List, Optional, Literal
|
||||
|
||||
# Example constants for validation
|
||||
|
|
@ -104,7 +104,16 @@ class PlanTriggerRequest(BaseModel):
|
|||
simulate_sap_10: Optional[bool] = False
|
||||
|
||||
# Add in optional fields which describe the format of the asset list being used
|
||||
|
||||
|
||||
file_type: Optional[Literal["csv", "xlsx"]] = None,
|
||||
file_format: Optional[Literal["domna_asset_list"]] = None,
|
||||
sheet_name: Optional[str] = None
|
||||
# If one of index_start or index_end is set, the other must be set too
|
||||
index_start: Optional[int] = None
|
||||
index_end: Optional[int] = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_indexes(self):
|
||||
if (self.index_start is None) != (self.index_end is None):
|
||||
raise ValueError("Both index_start and index_end must be set or both must be None")
|
||||
return self
|
||||
|
|
|
|||
|
|
@ -198,7 +198,7 @@ def read_pickle_from_s3(bucket_name, s3_file_name):
|
|||
return data
|
||||
|
||||
|
||||
def read_excel_from_s3(bucket_name, file_key, header_row, drop_all_na=True):
|
||||
def read_excel_from_s3(bucket_name, file_key, header_row, drop_all_na=True, sheet_name=None):
|
||||
"""
|
||||
Read an Excel file from an S3 bucket and return it as a pandas DataFrame.
|
||||
|
||||
|
|
@ -206,6 +206,7 @@ def read_excel_from_s3(bucket_name, file_key, header_row, drop_all_na=True):
|
|||
:param file_key: Key of the file (including directory path within the bucket).
|
||||
:param header_row: The row number to use as the header (0-indexed).
|
||||
:param drop_all_na: Whether to drop columns where all values are NaN.
|
||||
:param sheet_name: The name of the sheet to read from the Excel file. If None, reads the first sheet.
|
||||
:return: A pandas DataFrame containing the data from the Excel file.
|
||||
"""
|
||||
|
||||
|
|
@ -217,7 +218,7 @@ def read_excel_from_s3(bucket_name, file_key, header_row, drop_all_na=True):
|
|||
excel_buffer = read_io_from_s3(bucket_name, file_key)
|
||||
|
||||
# Read the Excel file into a pandas DataFrame
|
||||
df = pd.read_excel(excel_buffer, header=header_row)
|
||||
df = pd.read_excel(excel_buffer, header=header_row, sheet_name=sheet_name)
|
||||
|
||||
# Drop columns where all values are NaN
|
||||
if drop_all_na:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue