diff --git a/.devcontainer/backend/devcontainer.json b/.devcontainer/backend/devcontainer.json index 3727d8a3..ac654ac1 100644 --- a/.devcontainer/backend/devcontainer.json +++ b/.devcontainer/backend/devcontainer.json @@ -6,7 +6,8 @@ "workspaceFolder": "/workspaces/model", "postStartCommand": "bash .devcontainer/backend/post-install.sh", "mounts": [ - "source=${localEnv:HOME},target=/home/vscode,type=bind" + // "source=${localEnv:HOME},target=/home/vscode,type=bind", + "source=${localEnv:HOME}/.aws,target=/home/vscode/.aws,type=bind,consistency=cached" ], "customizations": { "vscode": { diff --git a/.devcontainer/backend/requirements.txt b/.devcontainer/backend/requirements.txt index 9814c8d4..c84332dd 100644 --- a/.devcontainer/backend/requirements.txt +++ b/.devcontainer/backend/requirements.txt @@ -19,4 +19,5 @@ pytest==9.0.2 pytest-cov==7.0.0 ipykernel>=6.25,<7 # Formatting -black==26.1.0 \ No newline at end of file +black==26.1.0 +boto3-stubs \ No newline at end of file diff --git a/backend/.env.test b/backend/.env.test index 1679f10f..34a1803d 100644 --- a/backend/.env.test +++ b/backend/.env.test @@ -19,4 +19,5 @@ PLAN_TRIGGER_BUCKET=test DATA_BUCKET=test EPC_AUTH_TOKEN=test ENGINE_SQS_URL=test +CATEGORISATION_SQS_URL=test ENERGY_ASSESSMENTS_BUCKET=test diff --git a/backend/README.md b/backend/README.md index 005d6fc4..b8e859c2 100644 --- a/backend/README.md +++ b/backend/README.md @@ -45,12 +45,14 @@ cp .env.example .env ## Running the Application -from within the application you can run with the following command: +from `model/backend/` you can run with the following command: ```commandline uvicorn app.main:app --reload ``` +Or run `sh run_local.sh`, which runs that same uvicorn command. + You application will be available at the designated url ## API Documentation @@ -172,7 +174,7 @@ For instance, if your server is running locally on port 8000, you can use curl to get a dummy token: ```commandline -curl http://localhost:8000/dummy-token +curl http://localhost:8000/local/dummy-token ``` You will receive a response containing the dummy JWT diff --git a/backend/app/config.py b/backend/app/config.py index feb312b4..26fb6b8b 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -1,17 +1,29 @@ import os from functools import lru_cache +from pathlib import Path from pydantic_settings import BaseSettings, SettingsConfigDict from typing import Optional +from utils.logger import setup_logger + +logger = setup_logger() + def resolve_env_file() -> Optional[str]: env = os.getenv("ENVIRONMENT", "local") + backend_dir = Path(__file__).resolve().parents[1] + if env == "local": - return "backend/.env" + env_file = backend_dir / ".env" + print("USING ENV FILE:", env_file) + logger.debug("USING ENV FILE:", env_file) + return str(env_file) if env == "test": - return "backend/.env.test" + env_file = backend_dir / ".env.test" + logger.debug("USING ENV FILE:", env_file) + return str(env_file) # prod = no env file return None @@ -25,6 +37,7 @@ class Settings(BaseSettings): DATA_BUCKET: str = "changeme" PLAN_TRIGGER_BUCKET: str ENGINE_SQS_URL: str = "changeme" + CATEGORISATION_SQS_URL: str = "changeme" # Third parties EPC_AUTH_TOKEN: str = "changeme" diff --git a/backend/app/db/functions/recommendations_functions.py b/backend/app/db/functions/recommendations_functions.py index 09d6da83..ed3fb435 100644 --- a/backend/app/db/functions/recommendations_functions.py +++ b/backend/app/db/functions/recommendations_functions.py @@ -1,5 +1,14 @@ -from typing import Any, Dict, List, Tuple -from sqlalchemy import inspect, text, insert, delete, select +from typing import Any, Dict, List, Optional +from sqlalchemy import ( + ColumnElement, + and_, + func, + inspect, + text, + insert, + delete, + select, +) from sqlalchemy.orm import Session, Mapper from sqlalchemy.exc import SQLAlchemyError from sqlmodel import Session @@ -625,11 +634,22 @@ def get_plans_by_scenario_ids(ids: List[int]) -> List[PlanModel]: return session_any.exec(stmt).scalars().all() -def get_most_recent_plans_by_portfolio_id(portfolio_id: int) -> List[PlanModel]: +def get_most_recent_plans_by_portfolio_id( + portfolio_id: int, + min_property_id: Optional[int] = None, + max_property_id: Optional[int] = None, +) -> List[PlanModel]: + filters = [PlanModel.portfolio_id == portfolio_id] + + if min_property_id is not None: + filters.append(PlanModel.property_id >= min_property_id) + if max_property_id is not None: + filters.append(PlanModel.property_id <= max_property_id) + # NOTE: This statement works for Postgres only, because of the Distinct stmt = ( select(PlanModel) - .where(PlanModel.portfolio_id == portfolio_id) + .where(and_(*filters)) .distinct( PlanModel.property_id, PlanModel.scenario_id ) # one plan per property per scenario @@ -645,11 +665,27 @@ def get_most_recent_plans_by_portfolio_id(portfolio_id: int) -> List[PlanModel]: return session_any.exec(stmt).scalars().all() -def get_most_recent_plans_by_scenario_ids(scenario_ids: List[int]) -> List[PlanModel]: +def get_most_recent_plans_by_scenario_ids( + scenario_ids: List[int], + min_property_id: Optional[int] = None, + max_property_id: Optional[int] = None, +) -> List[PlanModel]: + if not scenario_ids: + return [] + + # Base filter: scenario_id in provided list + filters: List[ColumnElement[bool]] = [PlanModel.scenario_id.in_(scenario_ids)] + + # Add optional property ID range filters + if min_property_id is not None: + filters.append(PlanModel.property_id >= min_property_id) + if max_property_id is not None: + filters.append(PlanModel.property_id <= max_property_id) + # NOTE: This statement works for Postgres only, because of the Distinct stmt = ( select(PlanModel) - .where(PlanModel.scenario_id.in_(scenario_ids)) + .where(and_(*filters)) .distinct( PlanModel.property_id, PlanModel.scenario_id ) # one plan per property per scenario @@ -673,16 +709,37 @@ def get_scenarios_by_portfolio_id(portfolio_id: int) -> List[ScenarioModel]: return session_any.exec(stmt).scalars().all() +def get_scenarios_count_by_portfolio_id(portfolio_id: int) -> int: + stmt = ( + select(func.count()) + .select_from(ScenarioModel) + .where(ScenarioModel.portfolio_id == portfolio_id) + ) + with db_read_session() as session: + session_any: Any = session # Typehint as Any to satisfy Pylance... + return session_any.exec(stmt).scalar_one() + + def get_default_plans( portfolio_id: int, + min_property_id: Optional[int] = None, + max_property_id: Optional[int] = None, ) -> List[PlanModel]: - plan_stmt = select(PlanModel).where( - (PlanModel.portfolio_id == portfolio_id) & (PlanModel.is_default == True) - ) + filters: List[ColumnElement[bool]] = [ + PlanModel.portfolio_id == portfolio_id, + PlanModel.is_default.is_(True), + ] + + if min_property_id is not None: + filters.append(PlanModel.property_id >= min_property_id) + if max_property_id is not None: + filters.append(PlanModel.property_id <= max_property_id) + + stmt = select(PlanModel).where(and_(*filters)) with db_read_session() as session: session_any: Any = session # Typehint as Any to satisfy Pylance... - plans: List[PlanModel] = session_any.exec(plan_stmt).scalars().all() + plans: List[PlanModel] = session_any.exec(stmt).scalars().all() return plans diff --git a/backend/app/db/functions/tasks/Tasks.py b/backend/app/db/functions/tasks/Tasks.py index d1ab9536..0f987f3b 100644 --- a/backend/app/db/functions/tasks/Tasks.py +++ b/backend/app/db/functions/tasks/Tasks.py @@ -25,7 +25,12 @@ class SubTaskInterface: # -------------------------------------------------------- # CREATE SUBTASK # -------------------------------------------------------- - def create_subtask(self, task_id: UUID, inputs: Optional[Dict[str, Any]] = None, status=None): + def create_subtask( + self, + task_id: UUID, + inputs: Optional[Dict[str, Any]] = None, + status: Optional[str] = None, + ): now = datetime.now(timezone.utc) with get_db_session() as session: @@ -56,8 +61,12 @@ class SubTaskInterface: # UPDATE STATUS (in progress, complete, failed) # -------------------------------------------------------- def update_subtask_status( - self, subtask_id: UUID, status: str, outputs=None, cloud_logs_url=None - ): + self, + subtask_id: UUID, + status: str, + outputs: Optional[Dict[str, str]] = None, + cloud_logs_url: Optional[str] = None, + ) -> SubTask: """ Update the status of a subtask, and recalculate the parent task progress. :param subtask_id: UUID of the subtask to update @@ -177,9 +186,7 @@ class SubTaskInterface: if not task: return - subtasks = session.exec( - select(SubTask).where(SubTask.task_id == task_id) - ).all() + subtasks = session.exec(select(SubTask).where(SubTask.task_id == task_id)).all() statuses = [s.status.lower() for s in subtasks] now = datetime.now(timezone.utc) @@ -211,7 +218,7 @@ class SubTaskInterface: subtask_id: UUID, status: str, outputs: Optional[Dict[str, Any]], - cloud_logs_url: Optional[str] + cloud_logs_url: Optional[str], ): now = datetime.now(timezone.utc) diff --git a/backend/app/plan/router.py b/backend/app/plan/router.py index ea41162f..2b534679 100644 --- a/backend/app/plan/router.py +++ b/backend/app/plan/router.py @@ -1,9 +1,10 @@ +from typing import List +from uuid import UUID + import boto3 import json import math import asyncio -from contextlib import contextmanager -from sqlmodel import Session from datetime import datetime @@ -11,11 +12,16 @@ from fastapi import APIRouter, Depends from backend.app.dependencies import validate_token from backend.app.plan.schemas import PlanTriggerRequest from backend.app.config import get_settings -from sqlalchemy.orm import sessionmaker +from backend.categorisation.categorisation_trigger_request import ( + CategorisationTriggerRequest, +) from utils.logger import setup_logger -from backend.app.db.connection import db_engine -from backend.app.db.functions.recommendations_functions import create_scenario +from backend.app.db.functions.recommendations_functions import ( + create_scenario, + get_property_ids, + get_scenarios_count_by_portfolio_id, +) from backend.app.db.functions.tasks.Tasks import TasksInterface, SubTaskInterface logger = setup_logger() @@ -24,23 +30,84 @@ router = APIRouter( prefix="/plan", tags=["plan"], dependencies=[Depends(validate_token)], - responses={404: {"description": "Not found"}} + responses={404: {"description": "Not found"}}, ) -sqs_client = boto3.client("sqs") +settings = get_settings() +sqs_client = boto3.client("sqs", settings.AWS_DEFAULT_REGION) -@contextmanager -def db_session(): - session = Session(db_engine) - try: - yield session - session.commit() - except Exception: - session.rollback() - raise - finally: - session.close() +@router.post("/categorisation", status_code=202) +async def trigger_categorisation( + body: CategorisationTriggerRequest, +) -> dict[str, str]: + payload: CategorisationTriggerRequest = CategorisationTriggerRequest.model_validate( + body + ) + + logger.info("API triggered with body: %s", payload) + + property_ids: list[int] = get_property_ids(payload.portfolio_id) + property_ids.sort() + + num_scenarios: int = get_scenarios_count_by_portfolio_id(payload.portfolio_id) + total_plans_to_update: int = len(property_ids) * num_scenarios + + max_writes_per_batch: int = 1000 + properties_per_batch: int = max(1, max_writes_per_batch // num_scenarios) + + num_property_batches: int = math.ceil(len(property_ids) / properties_per_batch) + + logger.info("total_plans_to_update", total_plans_to_update) + logger.info("properties_per_batch", properties_per_batch) + logger.info("num_property_buckets", num_property_batches) + + # Create task + task_id, _ = TasksInterface.create_task( + task_source="backend/plan/router.py:trigger_categorisation", + service="plan_categorisation", + inputs=payload.model_dump(), + task_only=True, + ) + + # Dispatch requests to lambdas + subtask_interface = SubTaskInterface() + + for batch_index in range(num_property_batches): + + start: int = batch_index * properties_per_batch + end: int = start + properties_per_batch + + batch_property_ids: List[int] = property_ids[start:end] + + if not batch_property_ids: + continue + + batch_request: CategorisationTriggerRequest = CategorisationTriggerRequest( + portfolio_id=payload.portfolio_id, + scenarios_to_consider=payload.scenarios_to_consider, + scenario_priority_order=payload.scenario_priority_order, + min_property_id=min(batch_property_ids), + max_property_id=max(batch_property_ids), + ) + # Create sub-task for each + subtask_id: UUID = subtask_interface.create_subtask( + task_id=task_id, inputs=batch_request.model_dump() + ) + batch_request.subtask_id = str(subtask_id) + + response = sqs_client.send_message( + QueueUrl=settings.CATEGORISATION_SQS_URL, + MessageBody=batch_request.model_dump_json(), + ) + + logger.info( + f"Chunk {batch_index} sent to SQS. {len(batch_property_ids)} Property IDs in batch (total {len(property_ids)}). Property IDs {min(batch_property_ids)}–{max(batch_property_ids)}. Message ID: {response.get('MessageId')}" + ) + + await asyncio.sleep(0.05) # Small delay to avoid SQS throttling + + return {"message": "Categorisation jobs distributed"} @router.post("/trigger", status_code=202) @@ -50,8 +117,6 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest): """ logger.info("API triggered with body: %s", body) - settings = get_settings() - try: data = body.model_dump() except Exception as e: @@ -59,7 +124,10 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest): return {"message": "Invalid request"}, 400 # If file_format is domna_asset_list and type is xlsx, read and chunk it - if data.get("file_format") == "domna_asset_list" and data.get("file_type") == "xlsx": + if ( + data.get("file_format") == "domna_asset_list" + and data.get("file_type") == "xlsx" + ): try: total_rows = data.get("sheet_count", 0) @@ -88,8 +156,8 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest): "patches_file_path": body.patches_file_path, "non_invasive_recommendations_file_path": body.non_invasive_recommendations_file_path, "exclusions": body.exclusions, - "multi_plan": body.multi_plan - } + "multi_plan": body.multi_plan, + }, ) # Insert the scenario ID into the data payload data["scenario_id"] = scenario_id @@ -99,7 +167,7 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest): task_source="backend/plan/router.py:trigger_plan_entrypoint", service="plan_engine", inputs=data, - task_only=True + task_only=True, ) subtask_interface = SubTaskInterface() @@ -109,13 +177,14 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest): index_end = min((i + 1) * chunk_size, total_rows) message_payload = { - **data, "index_start": index_start, "index_end": index_end, + **data, + "index_start": index_start, + "index_end": index_end, } # Create a subtask for this chunk subtask_id = subtask_interface.create_subtask( - task_id=task_id, - inputs=message_payload + task_id=task_id, inputs=message_payload ) # Add task and subtask to message @@ -125,8 +194,7 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest): message_body = json.dumps(message_payload) response = sqs_client.send_message( - QueueUrl=settings.ENGINE_SQS_URL, - MessageBody=message_body + QueueUrl=settings.ENGINE_SQS_URL, MessageBody=message_body ) logger.info( f"Chunk {i} sent to SQS. Rows {index_start}–{index_end}. Message ID: {response.get('MessageId')}" @@ -153,8 +221,7 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest): data["subtask_id"] = str(subtask_id) message_body = json.dumps(data) response = sqs_client.send_message( - QueueUrl=settings.ENGINE_SQS_URL, - MessageBody=message_body + QueueUrl=settings.ENGINE_SQS_URL, MessageBody=message_body ) logger.info(f"SQS message sent. Message ID: {response.get('MessageId')}") except Exception as e: diff --git a/backend/app/plan/utils.py b/backend/app/plan/utils.py index 10d7fb06..7dfe5538 100644 --- a/backend/app/plan/utils.py +++ b/backend/app/plan/utils.py @@ -1,5 +1,6 @@ import ast import os +from typing import Optional import msgpack from uuid import UUID from utils.s3 import read_from_s3 @@ -24,7 +25,7 @@ def get_cleaned(): cleaned = read_from_s3( s3_file_name="cleaned_epc_data/cleaned.bson", - bucket_name=get_settings().DATA_BUCKET + bucket_name=get_settings().DATA_BUCKET, ) cleaned = msgpack.unpackb(cleaned, raw=False) @@ -56,32 +57,45 @@ def extract_property_request_data( ): patch_has_uprn = "uprn" in patches[0] if patches else True if patch_has_uprn: - patch = next(( - x for x in patches if str(x["uprn"]) == str(address.uprn) - ), {}) + patch = next((x for x in patches if str(x["uprn"]) == str(address.uprn)), {}) else: - patch = next(( - x for x in patches if (x["address"] == address.address) and (x["postcode"] == address.postcode) - ), {}) + patch = next( + ( + x + for x in patches + if (x["address"] == address.address) + and (x["postcode"] == address.postcode) + ), + {}, + ) # Because we have some non-invasive recommendations that match on address and postcode, but not UPRN # we need to check existence of uprn - has_uprn = "uprn" in non_invasive_recommendations[0] if non_invasive_recommendations else False + has_uprn = ( + "uprn" in non_invasive_recommendations[0] + if non_invasive_recommendations + else False + ) if has_uprn: has_uprn = non_invasive_recommendations[0]["uprn"] not in ["", None] if has_uprn: - property_non_invasive_recommendations = next(( - x for x in non_invasive_recommendations if - (str(x["uprn"]) == str(uprn)) - ), {}) + property_non_invasive_recommendations = next( + (x for x in non_invasive_recommendations if (str(x["uprn"]) == str(uprn))), + {}, + ) # We patch the non-invasive recs that are ['cavity_extract_and_refill'] else: - property_non_invasive_recommendations = next(( - x for x in non_invasive_recommendations if - (x["address"] == address.address) and (x["postcode"] == address.postcode) - ), {}) + property_non_invasive_recommendations = next( + ( + x + for x in non_invasive_recommendations + if (x["address"] == address.address) + and (x["postcode"] == address.postcode) + ), + {}, + ) if isinstance(property_non_invasive_recommendations.get("recommendations"), str): property_non_invasive_recommendations["recommendations"] = ast.literal_eval( @@ -90,7 +104,11 @@ def extract_property_request_data( transformed = [] for rec in property_non_invasive_recommendations["recommendations"]: if isinstance(rec, str): - transformed.append({"type": rec, }) + transformed.append( + { + "type": rec, + } + ) else: transformed.append(rec) @@ -102,26 +120,36 @@ def extract_property_request_data( valuation_has_uprn = valuation_data[0]["uprn"] not in ["", None] if valuation_has_uprn: - property_valuation = next(( - float(x["valuation"]) for x in valuation_data if - (str(x["uprn"]) == str(uprn)) - ), None) + property_valuation = next( + ( + float(x["valuation"]) + for x in valuation_data + if (str(x["uprn"]) == str(uprn)) + ), + None, + ) else: - property_valuation = next(( - float(x["valuation"]) for x in valuation_data if - (x["address"] == address.address) and (x["postcode"] == address.postcode) - ), None) + property_valuation = next( + ( + float(x["valuation"]) + for x in valuation_data + if (x["address"] == address.address) + and (x["postcode"] == address.postcode) + ), + None, + ) # Return data class to give a structured format return PropertyRequestData( patch=patch, non_invasive_recommendations=property_non_invasive_recommendations, - valuation=property_valuation + valuation=property_valuation, ) -def parse_eco_packages(addr: Address, prepared_epc) -> tuple[list[str], int, str, list[str]] | tuple[ - None, None, None, list]: +def parse_eco_packages( + addr: Address, prepared_epc +) -> tuple[list[str], int, str, list[str]] | tuple[None, None, None, list]: solar_identification = addr.solar_reason cavity_identification = addr.cavity_reason if not solar_identification and not cavity_identification: @@ -140,47 +168,51 @@ def parse_eco_packages(addr: Address, prepared_epc) -> tuple[list[str], int, str "Solar Eligible": { "measures": ["solar_pv", "loft_insulation", "mechanical_ventilation"], "target_sap": 86, # High B - "plan_type": "solar_eco4" + "plan_type": "solar_eco4", }, "Solar Eligible, Solid Wall Uninsulated, EPC E or Below": { "measures": ["solar_pv", "loft_insulation", "mechanical_ventilation"], "target_sap": 86, # High B - "plan_type": "solar_eco4" + "plan_type": "solar_eco4", }, "Solar Eligible, Needs Heating Upgrade": { - "measures": ["solar_pv", "loft_insulation", "high_heat_retention_storage_heaters", - "mechanical_ventilation"], + "measures": [ + "solar_pv", + "loft_insulation", + "high_heat_retention_storage_heaters", + "mechanical_ventilation", + ], "target_sap": 86, # High B - "plan_type": "solar_hhrsh_eco4" + "plan_type": "solar_hhrsh_eco4", }, "Non-Intrusive Data Shows Empty Cavity": { "measures": ["cavity_wall_insulation", "mechanical_ventilation"], "target_sap": 69, # Low C - "plan_type": "empty_cavity_eco" + "plan_type": "empty_cavity_eco", }, - 'Non-Intrusive Data Shows Empty Cavity, built after 2002': { + "Non-Intrusive Data Shows Empty Cavity, built after 2002": { "measures": ["cavity_wall_insulation", "mechanical_ventilation"], "target_sap": 69, # Low C - "plan_type": "empty_cavity_eco" + "plan_type": "empty_cavity_eco", }, "EPC Shows Empty Cavity, inspections show retro drilled": { # EPC Indicates it's empty, so we simulate a fill "measures": ["cavity_wall_insulation", "mechanical_ventilation"], "target_sap": 69, # Low C - "plan_type": "extraction_eco" + "plan_type": "extraction_eco", }, "EPC Shows Empty Cavity, inspections show filled at build": { # EPC Indicates it's empty, so we simulate a fill "measures": ["cavity_wall_insulation", "mechanical_ventilation"], "target_sap": 69, # Low C - "plan_type": "extraction_eco" + "plan_type": "extraction_eco", }, "EPC Shows Empty Cavity": { # EPC Indicates it's empty, so we simulate a fill "measures": ["cavity_wall_insulation", "mechanical_ventilation"], "target_sap": 69, # Low C - "plan_type": "empty_cavity_eco" - } + "plan_type": "empty_cavity_eco", + }, } # Always prioritise solar @@ -214,9 +246,13 @@ def build_cloudwatch_log_url(start_ms: int) -> str: Build a CloudWatch Logs URL for the current Lambda invocation, including timestamp window from start_ms to end_ms (epoch ms). """ + logger.info("Building cloudwatch logs URL") region = os.environ["AWS_REGION"] + logger.info("Building cloudwatch logs URL: Got AWS region") log_group = os.environ["AWS_LAMBDA_LOG_GROUP_NAME"] + logger.info("Building cloudwatch logs URL: Got lambda log group name") log_stream = os.environ["AWS_LAMBDA_LOG_STREAM_NAME"] + logger.info("Building cloudwatch logs URL: Got lambda log stream name") # CloudWatch console requires / encoded as $252F encoded_group = log_group.replace("/", "$252F") @@ -232,15 +268,21 @@ def build_cloudwatch_log_url(start_ms: int) -> str: ) -def handle_error(msg, e, subtask_id, status=500, start_ms=None): +def handle_error( + msg: str, + exception: Exception, + subtask_id: str, + status_code: int = 500, + start_ms: Optional[int] = None, +): # When the pipeline fails, handles error process cloud_logs_url = build_cloudwatch_log_url(start_ms) SubTaskInterface().update_subtask_status( subtask_id=UUID(subtask_id), status="failed", - outputs=str(e), - cloud_logs_url=cloud_logs_url + outputs=str(exception), + cloud_logs_url=cloud_logs_url, ) logger.error(msg, exc_info=True) - return Response(status_code=status, content=msg) + return Response(status_code=status_code, content=msg) diff --git a/backend/app/tasks/router.py b/backend/app/tasks/router.py index 90b62dd1..1c266f2c 100644 --- a/backend/app/tasks/router.py +++ b/backend/app/tasks/router.py @@ -9,7 +9,7 @@ from backend.app.tasks.schema import ( CreateSubTaskRequest, UpdateSubTaskStatusRequest, FinalizeSubTaskRequest, - TaskSqsTriggerRequest + TaskSqsTriggerRequest, ) # Correct location of interfaces @@ -51,18 +51,18 @@ async def get_task(task_id: UUID): if not task: raise HTTPException(status_code=404, detail="Task not found") - subtasks = session.exec( - select(SubTask).where(SubTask.taskId == task_id) - ).all() + subtasks = session.exec(select(SubTask).where(SubTask.taskId == task_id)).all() formatted = [] for st in subtasks: - formatted.append({ - **st.dict(), - "inputs": json.loads(st.inputs) if st.inputs else None, - "outputs": json.loads(st.outputs) if st.outputs else None, - "cloud_logs_url": st.cloudLogsURL, - }) + formatted.append( + { + **st.dict(), + "inputs": json.loads(st.inputs) if st.inputs else None, + "outputs": json.loads(st.outputs) if st.outputs else None, + "cloud_logs_url": st.cloudLogsURL, + } + ) return { "task": task, @@ -111,7 +111,10 @@ async def update_subtask_status(subtask_id: UUID, req: UpdateSubTaskStatusReques # === # Sub task is complete -@router.post("/subtask/{subtask_id}/finalize", summary="Finalize a subtask with status, outputs, logs") +@router.post( + "/subtask/{subtask_id}/finalize", + summary="Finalize a subtask with status, outputs, logs", +) async def finalize_subtask(subtask_id: UUID, req: FinalizeSubTaskRequest): subtasks = SubTaskInterface() @@ -120,7 +123,7 @@ async def finalize_subtask(subtask_id: UUID, req: FinalizeSubTaskRequest): subtask_id=subtask_id, status=req.status, outputs=req.outputs, - cloud_logs_url=req.cloud_logs_url + cloud_logs_url=req.cloud_logs_url, ) return { @@ -142,9 +145,10 @@ from backend.app.tasks.schema import TaskSqsTriggerRequest from backend.app.db.functions.tasks.Tasks import TasksInterface, SubTaskInterface from backend.app.config import get_settings -sqs = boto3.client("sqs") -@router.post("/trigger", summary="Create task + subtask and publish to SQS", status_code=202) +@router.post( + "/trigger", summary="Create task + subtask and publish to SQS", status_code=202 +) async def trigger_task(req: TaskSqsTriggerRequest): """ Creates a Task + SubTask, then pushes the SubTask into SQS so a Lambda can process it. @@ -152,11 +156,12 @@ async def trigger_task(req: TaskSqsTriggerRequest): """ settings = get_settings() + sqs = boto3.client("sqs", settings.AWS_DEFAULT_REGION) tasks = TasksInterface() # ---- Normalize empty inputs ---- - inputs = req.inputs or {} # ensures {} even if null + inputs = req.inputs or {} # ensures {} even if null # ---- 1. Create Task + SubTask ---- task_id, subtask_id = tasks.create_task( @@ -174,8 +179,8 @@ async def trigger_task(req: TaskSqsTriggerRequest): try: response = sqs.send_message( QueueUrl=f"https://sqs.{settings.AWS_REGION}.amazonaws.com/" - f"{settings.AWS_ACCOUNT_ID}/lambda-example-queue", - MessageBody=json.dumps(sqs_payload) + f"{settings.AWS_ACCOUNT_ID}/lambda-example-queue", + MessageBody=json.dumps(sqs_payload), ) except Exception as e: raise HTTPException(status_code=500, detail=f"SQS error: {e}") @@ -186,4 +191,4 @@ async def trigger_task(req: TaskSqsTriggerRequest): "subtask_id": subtask_id, "sqs_message_id": response.get("MessageId"), "inputs_sent": inputs, - } \ No newline at end of file + } diff --git a/backend/categorisation/categorisation_trigger_request.py b/backend/categorisation/categorisation_trigger_request.py index 44ac0ff1..62879b5d 100644 --- a/backend/categorisation/categorisation_trigger_request.py +++ b/backend/categorisation/categorisation_trigger_request.py @@ -8,5 +8,10 @@ class CategorisationTriggerRequest(BaseModel): scenarios_to_consider: Optional[List[int]] = None scenario_priority_order: Optional[List[int]] = None + min_property_id: Optional[int] = None + max_property_id: Optional[int] = None + + subtask_id: Optional[str] = None + # {"portfolio_id": 556, "scenarios_to_consider": [1039,1041], "scenario_priority_order": [1041,1039]} diff --git a/backend/categorisation/handler/Dockerfile b/backend/categorisation/handler/Dockerfile index 7811ee4a..0a92eaba 100644 --- a/backend/categorisation/handler/Dockerfile +++ b/backend/categorisation/handler/Dockerfile @@ -29,20 +29,10 @@ RUN pip install --no-cache-dir -r requirements.txt # Copy application code # ----------------------------- COPY utils/ utils/ -COPY backend/categorisation/ backend/categorisation/ -COPY backend/app/db/ backend/app/db/ -COPY backend/app/domain/ backend/app/domain/ -COPY backend/addresses/ backend/addresses/ +# NOTE: if build is ever slow we can be more specific with which files are copied +COPY backend/ backend/ COPY datatypes/ datatypes/ -COPY backend/app/db/connection.py backend/app/db/connection.py - -COPY backend/app/config.py backend/app/config.py -COPY backend/app/utils.py backend/app/utils.py - -COPY backend/__init__.py backend/__init__.py -COPY backend/app/__init__.py backend/app/__init__.py - # ----------------------------- # Lambda handler diff --git a/backend/categorisation/handler/handler.py b/backend/categorisation/handler/handler.py index 9fb235d5..a1f69ea6 100644 --- a/backend/categorisation/handler/handler.py +++ b/backend/categorisation/handler/handler.py @@ -1,6 +1,9 @@ import json +import time from typing import Any, Mapping +from backend.app.db.functions.tasks.Tasks import SubTaskInterface +from backend.app.plan.utils import build_cloudwatch_log_url from backend.categorisation.categorisation_trigger_request import ( CategorisationTriggerRequest, ) @@ -25,12 +28,7 @@ def handler(event: Mapping[str, Any], context: Any) -> None: logger.debug("Successfully validated request body") - process_portfolio( - payload.portfolio_id, - payload.scenarios_to_consider, - payload.scenario_priority_order, - ) - + process_portfolio(payload) except Exception as e: logger.info("Handler exception") logger.error(f"Failed to process record: {e}") diff --git a/backend/categorisation/handler/requirements.txt b/backend/categorisation/handler/requirements.txt index e277b094..6e737772 100644 --- a/backend/categorisation/handler/requirements.txt +++ b/backend/categorisation/handler/requirements.txt @@ -1,6 +1,10 @@ sqlmodel pydantic-settings psycopg2-binary==2.9.10 +starlette # Not used but needed to satisfy imports -pytz==2024.2 \ No newline at end of file +pytz==2024.2 +msgpack==1.1.0 +numpy<2 +pandas==2.2.3 \ No newline at end of file diff --git a/backend/categorisation/local_handler/invoke_local_lambda.py b/backend/categorisation/local_handler/invoke_local_lambda.py index 5ed23c2d..5aa82846 100644 --- a/backend/categorisation/local_handler/invoke_local_lambda.py +++ b/backend/categorisation/local_handler/invoke_local_lambda.py @@ -2,16 +2,22 @@ import json import requests -LAMBDA_URL = "http://localhost:9000/2015-03-31/functions/function/invocations" +HOST = "localhost" +PORT = "9000" + +LAMBDA_URL = f"http://{HOST}:{PORT}/2015-03-31/functions/function/invocations" payload = { "Records": [ { "body": json.dumps( { - "portfolio_id": 556, + "portfolio_id": 569, "scenarios_to_consider": [], "scenario_priority_order": [], + "min_property_id": 660418, + "max_property_id": 660917, + "subtask_id": "6a0bcbac-ddab-435f-8708-8acd4662b067", } ) } diff --git a/backend/categorisation/local_runner.py b/backend/categorisation/local_runner.py index 7de55bc0..384ce5ef 100644 --- a/backend/categorisation/local_runner.py +++ b/backend/categorisation/local_runner.py @@ -1,5 +1,8 @@ from typing import List +from backend.categorisation.categorisation_trigger_request import ( + CategorisationTriggerRequest, +) from backend.categorisation.processor import process_portfolio @@ -9,9 +12,11 @@ def main() -> None: scenario_priority_order: List[int] = [] process_portfolio( - portfolio_id=portfolio_id, - scenarios_to_consider=scenarios_to_consider, - scenario_priority_order=scenario_priority_order, + CategorisationTriggerRequest( + portfolio_id=portfolio_id, + scenarios_to_consider=scenarios_to_consider, + scenario_priority_order=scenario_priority_order, + ) ) diff --git a/backend/categorisation/processor.py b/backend/categorisation/processor.py index 09db2983..88bc121e 100644 --- a/backend/categorisation/processor.py +++ b/backend/categorisation/processor.py @@ -1,5 +1,8 @@ +import time from collections import defaultdict from typing import Dict, List, Optional +from uuid import UUID +from starlette.responses import Response from backend.app.db.functions.recommendations_functions import ( bulk_update_plans, @@ -8,74 +11,124 @@ from backend.app.db.functions.recommendations_functions import ( get_most_recent_plans_by_scenario_ids, get_scenarios_by_portfolio_id, ) +from backend.app.db.functions.tasks.Tasks import SubTaskInterface from backend.app.db.models.recommendations import PlanModel, ScenarioModel from backend.app.domain.classes.plan import Plan from backend.app.domain.classes.scenario import Scenario +from backend.app.plan.utils import build_cloudwatch_log_url, handle_error +from backend.categorisation.categorisation_trigger_request import ( + CategorisationTriggerRequest, +) from utils.logger import setup_logger logger = setup_logger() def process_portfolio( - portfolio_id: int, - scenarios_to_consider: Optional[List[int]] = None, - scenario_priority_order: Optional[List[int]] = None, -) -> None: # TODO: make this a class + body: CategorisationTriggerRequest, +) -> Response: # TODO: make this a class + portfolio_id: int = body.portfolio_id + scenarios_to_consider: Optional[List[int]] = body.scenarios_to_consider + scenario_priority_order: Optional[List[int]] = body.scenario_priority_order + min_property_id: Optional[int] = body.min_property_id + max_property_id: Optional[int] = body.max_property_id + subtask_id: Optional[str] = body.subtask_id + logger.info(f"Processing portfolio {portfolio_id}") + start_ms = int(time.time() * 1000) + cloud_logs_url = build_cloudwatch_log_url(start_ms) - all_scenarios: List[Scenario] = _load_scenarios_for_portfolio(portfolio_id) - plans_by_id: Dict[int, Plan] = {} # TODO: make this an in-memory repository class + if body.subtask_id: + SubTaskInterface().update_subtask_status( + subtask_id=UUID(subtask_id), + status="in progress", + cloud_logs_url=cloud_logs_url, + ) - if scenarios_to_consider: - if len(scenarios_to_consider) < 2: - raise ValueError( - "Cannot run auto categorisation for fewer than 2 scenarios" - ) + try: - # first get all plans that we're interested in - plans_for_consideration: List[Plan] = _load_plans_for_portfolio( - portfolio_id, all_scenarios, scenarios_to_consider - ) - for plan in plans_for_consideration: - if plan.id is not None: # just in case - plans_by_id[plan.id] = plan + all_scenarios: List[Scenario] = _load_scenarios_for_portfolio(portfolio_id) + plans_by_id: Dict[int, Plan] = ( + {} + ) # TODO: make this an in-memory repository class - # then unset existing defaults on domain objects regardless of whether they're under consideration or not - default_plans: List[Plan] = _get_default_plans(portfolio_id, all_scenarios) - for plan in default_plans: - plan.set_default(False) - if plan.id is not None: # just in case - plans_by_id[plan.id] = plan + if scenarios_to_consider: + if len(scenarios_to_consider) < 2: + raise ValueError( + "Cannot run auto categorisation for fewer than 2 scenarios" + ) - logger.info(f"Successfully unset {len(default_plans)} default plan(s)") - - # then set new defaults on domain objects under consideration - plans_for_consideration_by_property: Dict[int, List[Plan]] = ( - _group_plans_by_property(plans_for_consideration) - ) - - for property_id, property_plans in plans_for_consideration_by_property.items(): - if not property_plans: - raise ValueError(f"No plans for property {property_id}") - - try: - cheapest_plan = choose_cheapest_relevant_plan( - property_plans, scenario_priority_order - ) - except Exception: - logger.error(f"Failed to find cheapest plan for property {property_id}") - raise - - property_plans = _update_plan_objects(property_plans, cheapest_plan) - for plan in property_plans: + # first get all plans that we're interested in + plans_for_consideration: List[Plan] = _load_plans_for_portfolio( + portfolio_id, + all_scenarios, + scenarios_to_consider, + min_property_id, + max_property_id, + ) + for plan in plans_for_consideration: if plan.id is not None: # just in case plans_by_id[plan.id] = plan - logger.info("Successfully set defaults on Plan objects in memory") + # then unset existing defaults on domain objects regardless of whether they're under consideration or not + default_plans: List[Plan] = _get_default_plans( + portfolio_id, all_scenarios, min_property_id, max_property_id + ) + for plan in default_plans: + plan.set_default(False) + if plan.id is not None: # just in case + plans_by_id[plan.id] = plan - # then pass all domain objects to database to update (regardless of whether they've changed) - _update_plans_in_db(list(plans_by_id.values())) - logger.info(f"Successfully updated {len(plans_by_id)} Plans in database") + logger.info(f"Successfully unset {len(default_plans)} default plan(s)") + + # then set new defaults on domain objects under consideration + plans_for_consideration_by_property: Dict[int, List[Plan]] = ( + _group_plans_by_property(plans_for_consideration) + ) + + for property_id, property_plans in plans_for_consideration_by_property.items(): + if not property_plans: + raise ValueError(f"No plans for property {property_id}") + + try: + cheapest_plan = choose_cheapest_relevant_plan( + property_plans, scenario_priority_order + ) + except Exception: + logger.error(f"Failed to find cheapest plan for property {property_id}") + raise + + property_plans = _update_plan_objects(property_plans, cheapest_plan) + for plan in property_plans: + if plan.id is not None: # just in case + plans_by_id[plan.id] = plan + + logger.info("Successfully set defaults on Plan objects in memory") + + # then pass all domain objects to database to update (regardless of whether they've changed) + _update_plans_in_db(list(plans_by_id.values())) + + # Mark the subtask as successful + logger.info(f"Successfully updated {len(plans_by_id)} Plans in database") + if body.subtask_id: + SubTaskInterface().update_subtask_status( + subtask_id=UUID(subtask_id), + status="complete", + cloud_logs_url=cloud_logs_url, + ) + + return Response(status_code=200) + except Exception as e: + if subtask_id: + return handle_error( + "Exception during Categorisation processing.", + e, + subtask_id, + 500, + start_ms, + ) + + raise def choose_cheapest_relevant_plan( @@ -109,8 +162,15 @@ def choose_cheapest_relevant_plan( return cheapest_plans[0] -def _get_default_plans(portfolio_id: int, scenarios: List[Scenario]) -> List[Plan]: - default_plan_models = get_default_plans(portfolio_id) +def _get_default_plans( + portfolio_id: int, + scenarios: List[Scenario], + min_property_id: Optional[int] = None, + max_property_id: Optional[int] = None, +) -> List[Plan]: + default_plan_models = get_default_plans( + portfolio_id, min_property_id, max_property_id + ) scenario_map = {s.id: s for s in scenarios} @@ -131,12 +191,14 @@ def _load_plans_for_portfolio( portfolio_id: int, all_scenarios: List[Scenario], scenarios_to_consider: Optional[List[int]] = None, + min_property_id: Optional[int] = None, + max_property_id: Optional[int] = None, ) -> List[Plan]: if scenarios_to_consider: logger.info(f"Getting plans for {len(scenarios_to_consider)} scenarios") plan_models: List[PlanModel] = get_most_recent_plans_by_scenario_ids( - scenarios_to_consider + scenarios_to_consider, min_property_id, max_property_id ) logger.info(f"Got {len(plan_models)} plan models from database") else: @@ -144,7 +206,7 @@ def _load_plans_for_portfolio( f"No list of Plans to consider provided. Getting all Plans for portfolio {portfolio_id}" ) plan_models: List[PlanModel] = get_most_recent_plans_by_portfolio_id( - portfolio_id + portfolio_id, min_property_id, max_property_id ) plans: List[Plan] = [] diff --git a/infrastructure/terraform/lambda/categorisation/variables.tf b/infrastructure/terraform/lambda/categorisation/variables.tf index e4bab243..347964de 100644 --- a/infrastructure/terraform/lambda/categorisation/variables.tf +++ b/infrastructure/terraform/lambda/categorisation/variables.tf @@ -17,6 +17,11 @@ variable "image_digest" { description = "Image digest (sha256:...)" } +variable "maximum_concurrency" { + type = number + default = 10 # null if you don't want to set it for this handler + description = "Maximum number of concurrent Lambda invocations from SQS (2-1000). null = no limit." +} locals { image_uri = "${var.ecr_repo_url}@${var.image_digest}"