diff --git a/backend/app/plan/router.py b/backend/app/plan/router.py index cd73cce3..a0eca27a 100644 --- a/backend/app/plan/router.py +++ b/backend/app/plan/router.py @@ -2,19 +2,66 @@ import boto3 import json import math from datetime import datetime +from openpyxl import load_workbook +from io import BytesIO -import pandas as pd from fastapi import APIRouter, Depends from backend.app.dependencies import validate_token from backend.app.plan.schemas import PlanTriggerRequest from backend.app.config import get_settings from sqlalchemy.orm import sessionmaker from utils.logger import setup_logger -from utils.s3 import read_excel_from_s3 from backend.app.db.connection import db_engine from backend.app.db.functions.recommendations_functions import create_scenario + +def read_excel_from_s3(bucket_name, file_key, header_row=0, drop_all_na=True, sheet_name=None): + """ + Reads an Excel file from S3 and returns it as a list of dictionaries. + + :param bucket_name: Name of the S3 bucket. + :param file_key: S3 key/path to the file. + :param header_row: Row number (0-indexed) to use as header. + :param drop_all_na: If True, drop columns where all values are None. + :param sheet_name: Name of the worksheet to read. Defaults to the first. + :return: List of dicts, one per row. + """ + s3 = boto3.client("s3") + response = s3.get_object(Bucket=bucket_name, Key=file_key) + excel_buffer = BytesIO(response["Body"].read()) + + wb = load_workbook(filename=excel_buffer, data_only=True) + ws = wb[sheet_name] if sheet_name else wb.active + + rows = list(ws.iter_rows(values_only=True)) + if len(rows) <= header_row: + raise ValueError("Header row index is out of range.") + + headers = [str(h).strip() if h is not None else f"__col_{i}" for i, h in enumerate(rows[header_row])] + data_rows = rows[header_row + 1:] + + # Drop columns where all values are None if required + if drop_all_na: + # Transpose rows to get columns + col_data = list(zip(*data_rows)) + keep_indices = [i for i, col in enumerate(col_data) if not all(v is None for v in col)] + headers = [h for i, h in enumerate(headers) if i in keep_indices] + data_rows = [ + [row[i] for i in keep_indices] + for row in data_rows + ] + + # Create list of dicts + result = [ + {headers[i]: cell for i, cell in enumerate(row)} + for row in data_rows + if any(cell is not None for cell in row) # skip fully empty rows + ] + + return result + + logger = setup_logger() router = APIRouter( @@ -45,11 +92,12 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest): # If file_format is domna_asset_list and type is xlsx, read and chunk it if data.get("file_format") == "domna_asset_list" and data.get("file_type") == "xlsx": try: - input_data: pd.DataFrame = read_excel_from_s3( + + input_data = read_excel_from_s3( bucket_name=settings.PLAN_TRIGGER_BUCKET, file_key=data.get("trigger_file_path"), sheet_name=data.get("sheet_name"), - header_row=0 + header_row=0, ) total_rows = len(input_data) diff --git a/backend/app/plan/schemas.py b/backend/app/plan/schemas.py index a6d21ae7..d5b92256 100644 --- a/backend/app/plan/schemas.py +++ b/backend/app/plan/schemas.py @@ -108,6 +108,7 @@ class PlanTriggerRequest(BaseModel): file_type: Optional[Literal["csv", "xlsx"]] = None file_format: Optional[Literal["domna_asset_list"]] = None sheet_name: Optional[str] = None + sheet_count: Optional[int] = None # If one of index_start or index_end is set, the other must be set too index_start: Optional[int] = None index_end: Optional[int] = None diff --git a/backend/app/requirements/requirements.txt b/backend/app/requirements/requirements.txt index ca9d0f32..14ec525f 100644 --- a/backend/app/requirements/requirements.txt +++ b/backend/app/requirements/requirements.txt @@ -8,4 +8,5 @@ cryptography==43.0.3 mangum==0.19.0 # AWS boto3==1.35.44 - +# Data +openpyxl==3.1.2 diff --git a/backend/engine/handler.py b/backend/engine/handler.py index fdf48db3..8fce7f16 100644 --- a/backend/engine/handler.py +++ b/backend/engine/handler.py @@ -11,6 +11,7 @@ def handler(event, context): """ Lambda handler that triggers the model engine for each SQS message. """ + logger.info("Received event: %s", json.dumps(event, indent=2)) for record in event.get("Records", []): try: body_dict = json.loads(record["body"])