Merge pull request #750 from Hestia-Homes/feature/categorisation-work-distributor

Categorisation work distributor
This commit is contained in:
Daniel Roth 2026-03-02 10:20:52 +00:00 committed by GitHub
commit 46da1a71fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 464 additions and 193 deletions

View file

@ -6,7 +6,8 @@
"workspaceFolder": "/workspaces/model",
"postStartCommand": "bash .devcontainer/backend/post-install.sh",
"mounts": [
"source=${localEnv:HOME},target=/home/vscode,type=bind"
// "source=${localEnv:HOME},target=/home/vscode,type=bind",
"source=${localEnv:HOME}/.aws,target=/home/vscode/.aws,type=bind,consistency=cached"
],
"customizations": {
"vscode": {

View file

@ -19,4 +19,5 @@ pytest==9.0.2
pytest-cov==7.0.0
ipykernel>=6.25,<7
# Formatting
black==26.1.0
black==26.1.0
boto3-stubs

View file

@ -19,4 +19,5 @@ PLAN_TRIGGER_BUCKET=test
DATA_BUCKET=test
EPC_AUTH_TOKEN=test
ENGINE_SQS_URL=test
CATEGORISATION_SQS_URL=test
ENERGY_ASSESSMENTS_BUCKET=test

View file

@ -45,12 +45,14 @@ cp .env.example .env
## Running the Application
from within the application you can run with the following command:
from `model/backend/` you can run with the following command:
```commandline
uvicorn app.main:app --reload
```
Or run `sh run_local.sh`, which runs that same uvicorn command.
You application will be available at the designated url
## API Documentation
@ -172,7 +174,7 @@ For instance, if your server is running locally on port 8000, you can use curl
to get a dummy token:
```commandline
curl http://localhost:8000/dummy-token
curl http://localhost:8000/local/dummy-token
```
You will receive a response containing the dummy JWT

View file

@ -1,17 +1,29 @@
import os
from functools import lru_cache
from pathlib import Path
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import Optional
from utils.logger import setup_logger
logger = setup_logger()
def resolve_env_file() -> Optional[str]:
env = os.getenv("ENVIRONMENT", "local")
backend_dir = Path(__file__).resolve().parents[1]
if env == "local":
return "backend/.env"
env_file = backend_dir / ".env"
print("USING ENV FILE:", env_file)
logger.debug("USING ENV FILE:", env_file)
return str(env_file)
if env == "test":
return "backend/.env.test"
env_file = backend_dir / ".env.test"
logger.debug("USING ENV FILE:", env_file)
return str(env_file)
# prod = no env file
return None
@ -25,6 +37,7 @@ class Settings(BaseSettings):
DATA_BUCKET: str = "changeme"
PLAN_TRIGGER_BUCKET: str
ENGINE_SQS_URL: str = "changeme"
CATEGORISATION_SQS_URL: str = "changeme"
# Third parties
EPC_AUTH_TOKEN: str = "changeme"

View file

@ -1,5 +1,14 @@
from typing import Any, Dict, List, Tuple
from sqlalchemy import inspect, text, insert, delete, select
from typing import Any, Dict, List, Optional
from sqlalchemy import (
ColumnElement,
and_,
func,
inspect,
text,
insert,
delete,
select,
)
from sqlalchemy.orm import Session, Mapper
from sqlalchemy.exc import SQLAlchemyError
from sqlmodel import Session
@ -625,11 +634,22 @@ def get_plans_by_scenario_ids(ids: List[int]) -> List[PlanModel]:
return session_any.exec(stmt).scalars().all()
def get_most_recent_plans_by_portfolio_id(portfolio_id: int) -> List[PlanModel]:
def get_most_recent_plans_by_portfolio_id(
portfolio_id: int,
min_property_id: Optional[int] = None,
max_property_id: Optional[int] = None,
) -> List[PlanModel]:
filters = [PlanModel.portfolio_id == portfolio_id]
if min_property_id is not None:
filters.append(PlanModel.property_id >= min_property_id)
if max_property_id is not None:
filters.append(PlanModel.property_id <= max_property_id)
# NOTE: This statement works for Postgres only, because of the Distinct
stmt = (
select(PlanModel)
.where(PlanModel.portfolio_id == portfolio_id)
.where(and_(*filters))
.distinct(
PlanModel.property_id, PlanModel.scenario_id
) # one plan per property per scenario
@ -645,11 +665,27 @@ def get_most_recent_plans_by_portfolio_id(portfolio_id: int) -> List[PlanModel]:
return session_any.exec(stmt).scalars().all()
def get_most_recent_plans_by_scenario_ids(scenario_ids: List[int]) -> List[PlanModel]:
def get_most_recent_plans_by_scenario_ids(
scenario_ids: List[int],
min_property_id: Optional[int] = None,
max_property_id: Optional[int] = None,
) -> List[PlanModel]:
if not scenario_ids:
return []
# Base filter: scenario_id in provided list
filters: List[ColumnElement[bool]] = [PlanModel.scenario_id.in_(scenario_ids)]
# Add optional property ID range filters
if min_property_id is not None:
filters.append(PlanModel.property_id >= min_property_id)
if max_property_id is not None:
filters.append(PlanModel.property_id <= max_property_id)
# NOTE: This statement works for Postgres only, because of the Distinct
stmt = (
select(PlanModel)
.where(PlanModel.scenario_id.in_(scenario_ids))
.where(and_(*filters))
.distinct(
PlanModel.property_id, PlanModel.scenario_id
) # one plan per property per scenario
@ -673,16 +709,37 @@ def get_scenarios_by_portfolio_id(portfolio_id: int) -> List[ScenarioModel]:
return session_any.exec(stmt).scalars().all()
def get_scenarios_count_by_portfolio_id(portfolio_id: int) -> int:
stmt = (
select(func.count())
.select_from(ScenarioModel)
.where(ScenarioModel.portfolio_id == portfolio_id)
)
with db_read_session() as session:
session_any: Any = session # Typehint as Any to satisfy Pylance...
return session_any.exec(stmt).scalar_one()
def get_default_plans(
portfolio_id: int,
min_property_id: Optional[int] = None,
max_property_id: Optional[int] = None,
) -> List[PlanModel]:
plan_stmt = select(PlanModel).where(
(PlanModel.portfolio_id == portfolio_id) & (PlanModel.is_default == True)
)
filters: List[ColumnElement[bool]] = [
PlanModel.portfolio_id == portfolio_id,
PlanModel.is_default.is_(True),
]
if min_property_id is not None:
filters.append(PlanModel.property_id >= min_property_id)
if max_property_id is not None:
filters.append(PlanModel.property_id <= max_property_id)
stmt = select(PlanModel).where(and_(*filters))
with db_read_session() as session:
session_any: Any = session # Typehint as Any to satisfy Pylance...
plans: List[PlanModel] = session_any.exec(plan_stmt).scalars().all()
plans: List[PlanModel] = session_any.exec(stmt).scalars().all()
return plans

View file

@ -25,7 +25,12 @@ class SubTaskInterface:
# --------------------------------------------------------
# CREATE SUBTASK
# --------------------------------------------------------
def create_subtask(self, task_id: UUID, inputs: Optional[Dict[str, Any]] = None, status=None):
def create_subtask(
self,
task_id: UUID,
inputs: Optional[Dict[str, Any]] = None,
status: Optional[str] = None,
):
now = datetime.now(timezone.utc)
with get_db_session() as session:
@ -56,8 +61,12 @@ class SubTaskInterface:
# UPDATE STATUS (in progress, complete, failed)
# --------------------------------------------------------
def update_subtask_status(
self, subtask_id: UUID, status: str, outputs=None, cloud_logs_url=None
):
self,
subtask_id: UUID,
status: str,
outputs: Optional[Dict[str, str]] = None,
cloud_logs_url: Optional[str] = None,
) -> SubTask:
"""
Update the status of a subtask, and recalculate the parent task progress.
:param subtask_id: UUID of the subtask to update
@ -177,9 +186,7 @@ class SubTaskInterface:
if not task:
return
subtasks = session.exec(
select(SubTask).where(SubTask.task_id == task_id)
).all()
subtasks = session.exec(select(SubTask).where(SubTask.task_id == task_id)).all()
statuses = [s.status.lower() for s in subtasks]
now = datetime.now(timezone.utc)
@ -211,7 +218,7 @@ class SubTaskInterface:
subtask_id: UUID,
status: str,
outputs: Optional[Dict[str, Any]],
cloud_logs_url: Optional[str]
cloud_logs_url: Optional[str],
):
now = datetime.now(timezone.utc)

View file

@ -1,9 +1,10 @@
from typing import List
from uuid import UUID
import boto3
import json
import math
import asyncio
from contextlib import contextmanager
from sqlmodel import Session
from datetime import datetime
@ -11,11 +12,16 @@ from fastapi import APIRouter, Depends
from backend.app.dependencies import validate_token
from backend.app.plan.schemas import PlanTriggerRequest
from backend.app.config import get_settings
from sqlalchemy.orm import sessionmaker
from backend.categorisation.categorisation_trigger_request import (
CategorisationTriggerRequest,
)
from utils.logger import setup_logger
from backend.app.db.connection import db_engine
from backend.app.db.functions.recommendations_functions import create_scenario
from backend.app.db.functions.recommendations_functions import (
create_scenario,
get_property_ids,
get_scenarios_count_by_portfolio_id,
)
from backend.app.db.functions.tasks.Tasks import TasksInterface, SubTaskInterface
logger = setup_logger()
@ -24,23 +30,84 @@ router = APIRouter(
prefix="/plan",
tags=["plan"],
dependencies=[Depends(validate_token)],
responses={404: {"description": "Not found"}}
responses={404: {"description": "Not found"}},
)
sqs_client = boto3.client("sqs")
settings = get_settings()
sqs_client = boto3.client("sqs", settings.AWS_DEFAULT_REGION)
@contextmanager
def db_session():
session = Session(db_engine)
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
@router.post("/categorisation", status_code=202)
async def trigger_categorisation(
body: CategorisationTriggerRequest,
) -> dict[str, str]:
payload: CategorisationTriggerRequest = CategorisationTriggerRequest.model_validate(
body
)
logger.info("API triggered with body: %s", payload)
property_ids: list[int] = get_property_ids(payload.portfolio_id)
property_ids.sort()
num_scenarios: int = get_scenarios_count_by_portfolio_id(payload.portfolio_id)
total_plans_to_update: int = len(property_ids) * num_scenarios
max_writes_per_batch: int = 1000
properties_per_batch: int = max(1, max_writes_per_batch // num_scenarios)
num_property_batches: int = math.ceil(len(property_ids) / properties_per_batch)
logger.info("total_plans_to_update", total_plans_to_update)
logger.info("properties_per_batch", properties_per_batch)
logger.info("num_property_buckets", num_property_batches)
# Create task
task_id, _ = TasksInterface.create_task(
task_source="backend/plan/router.py:trigger_categorisation",
service="plan_categorisation",
inputs=payload.model_dump(),
task_only=True,
)
# Dispatch requests to lambdas
subtask_interface = SubTaskInterface()
for batch_index in range(num_property_batches):
start: int = batch_index * properties_per_batch
end: int = start + properties_per_batch
batch_property_ids: List[int] = property_ids[start:end]
if not batch_property_ids:
continue
batch_request: CategorisationTriggerRequest = CategorisationTriggerRequest(
portfolio_id=payload.portfolio_id,
scenarios_to_consider=payload.scenarios_to_consider,
scenario_priority_order=payload.scenario_priority_order,
min_property_id=min(batch_property_ids),
max_property_id=max(batch_property_ids),
)
# Create sub-task for each
subtask_id: UUID = subtask_interface.create_subtask(
task_id=task_id, inputs=batch_request.model_dump()
)
batch_request.subtask_id = str(subtask_id)
response = sqs_client.send_message(
QueueUrl=settings.CATEGORISATION_SQS_URL,
MessageBody=batch_request.model_dump_json(),
)
logger.info(
f"Chunk {batch_index} sent to SQS. {len(batch_property_ids)} Property IDs in batch (total {len(property_ids)}). Property IDs {min(batch_property_ids)}{max(batch_property_ids)}. Message ID: {response.get('MessageId')}"
)
await asyncio.sleep(0.05) # Small delay to avoid SQS throttling
return {"message": "Categorisation jobs distributed"}
@router.post("/trigger", status_code=202)
@ -50,8 +117,6 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest):
"""
logger.info("API triggered with body: %s", body)
settings = get_settings()
try:
data = body.model_dump()
except Exception as e:
@ -59,7 +124,10 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest):
return {"message": "Invalid request"}, 400
# If file_format is domna_asset_list and type is xlsx, read and chunk it
if data.get("file_format") == "domna_asset_list" and data.get("file_type") == "xlsx":
if (
data.get("file_format") == "domna_asset_list"
and data.get("file_type") == "xlsx"
):
try:
total_rows = data.get("sheet_count", 0)
@ -88,8 +156,8 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest):
"patches_file_path": body.patches_file_path,
"non_invasive_recommendations_file_path": body.non_invasive_recommendations_file_path,
"exclusions": body.exclusions,
"multi_plan": body.multi_plan
}
"multi_plan": body.multi_plan,
},
)
# Insert the scenario ID into the data payload
data["scenario_id"] = scenario_id
@ -99,7 +167,7 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest):
task_source="backend/plan/router.py:trigger_plan_entrypoint",
service="plan_engine",
inputs=data,
task_only=True
task_only=True,
)
subtask_interface = SubTaskInterface()
@ -109,13 +177,14 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest):
index_end = min((i + 1) * chunk_size, total_rows)
message_payload = {
**data, "index_start": index_start, "index_end": index_end,
**data,
"index_start": index_start,
"index_end": index_end,
}
# Create a subtask for this chunk
subtask_id = subtask_interface.create_subtask(
task_id=task_id,
inputs=message_payload
task_id=task_id, inputs=message_payload
)
# Add task and subtask to message
@ -125,8 +194,7 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest):
message_body = json.dumps(message_payload)
response = sqs_client.send_message(
QueueUrl=settings.ENGINE_SQS_URL,
MessageBody=message_body
QueueUrl=settings.ENGINE_SQS_URL, MessageBody=message_body
)
logger.info(
f"Chunk {i} sent to SQS. Rows {index_start}{index_end}. Message ID: {response.get('MessageId')}"
@ -153,8 +221,7 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest):
data["subtask_id"] = str(subtask_id)
message_body = json.dumps(data)
response = sqs_client.send_message(
QueueUrl=settings.ENGINE_SQS_URL,
MessageBody=message_body
QueueUrl=settings.ENGINE_SQS_URL, MessageBody=message_body
)
logger.info(f"SQS message sent. Message ID: {response.get('MessageId')}")
except Exception as e:

View file

@ -1,5 +1,6 @@
import ast
import os
from typing import Optional
import msgpack
from uuid import UUID
from utils.s3 import read_from_s3
@ -24,7 +25,7 @@ def get_cleaned():
cleaned = read_from_s3(
s3_file_name="cleaned_epc_data/cleaned.bson",
bucket_name=get_settings().DATA_BUCKET
bucket_name=get_settings().DATA_BUCKET,
)
cleaned = msgpack.unpackb(cleaned, raw=False)
@ -56,32 +57,45 @@ def extract_property_request_data(
):
patch_has_uprn = "uprn" in patches[0] if patches else True
if patch_has_uprn:
patch = next((
x for x in patches if str(x["uprn"]) == str(address.uprn)
), {})
patch = next((x for x in patches if str(x["uprn"]) == str(address.uprn)), {})
else:
patch = next((
x for x in patches if (x["address"] == address.address) and (x["postcode"] == address.postcode)
), {})
patch = next(
(
x
for x in patches
if (x["address"] == address.address)
and (x["postcode"] == address.postcode)
),
{},
)
# Because we have some non-invasive recommendations that match on address and postcode, but not UPRN
# we need to check existence of uprn
has_uprn = "uprn" in non_invasive_recommendations[0] if non_invasive_recommendations else False
has_uprn = (
"uprn" in non_invasive_recommendations[0]
if non_invasive_recommendations
else False
)
if has_uprn:
has_uprn = non_invasive_recommendations[0]["uprn"] not in ["", None]
if has_uprn:
property_non_invasive_recommendations = next((
x for x in non_invasive_recommendations if
(str(x["uprn"]) == str(uprn))
), {})
property_non_invasive_recommendations = next(
(x for x in non_invasive_recommendations if (str(x["uprn"]) == str(uprn))),
{},
)
# We patch the non-invasive recs that are ['cavity_extract_and_refill']
else:
property_non_invasive_recommendations = next((
x for x in non_invasive_recommendations if
(x["address"] == address.address) and (x["postcode"] == address.postcode)
), {})
property_non_invasive_recommendations = next(
(
x
for x in non_invasive_recommendations
if (x["address"] == address.address)
and (x["postcode"] == address.postcode)
),
{},
)
if isinstance(property_non_invasive_recommendations.get("recommendations"), str):
property_non_invasive_recommendations["recommendations"] = ast.literal_eval(
@ -90,7 +104,11 @@ def extract_property_request_data(
transformed = []
for rec in property_non_invasive_recommendations["recommendations"]:
if isinstance(rec, str):
transformed.append({"type": rec, })
transformed.append(
{
"type": rec,
}
)
else:
transformed.append(rec)
@ -102,26 +120,36 @@ def extract_property_request_data(
valuation_has_uprn = valuation_data[0]["uprn"] not in ["", None]
if valuation_has_uprn:
property_valuation = next((
float(x["valuation"]) for x in valuation_data if
(str(x["uprn"]) == str(uprn))
), None)
property_valuation = next(
(
float(x["valuation"])
for x in valuation_data
if (str(x["uprn"]) == str(uprn))
),
None,
)
else:
property_valuation = next((
float(x["valuation"]) for x in valuation_data if
(x["address"] == address.address) and (x["postcode"] == address.postcode)
), None)
property_valuation = next(
(
float(x["valuation"])
for x in valuation_data
if (x["address"] == address.address)
and (x["postcode"] == address.postcode)
),
None,
)
# Return data class to give a structured format
return PropertyRequestData(
patch=patch,
non_invasive_recommendations=property_non_invasive_recommendations,
valuation=property_valuation
valuation=property_valuation,
)
def parse_eco_packages(addr: Address, prepared_epc) -> tuple[list[str], int, str, list[str]] | tuple[
None, None, None, list]:
def parse_eco_packages(
addr: Address, prepared_epc
) -> tuple[list[str], int, str, list[str]] | tuple[None, None, None, list]:
solar_identification = addr.solar_reason
cavity_identification = addr.cavity_reason
if not solar_identification and not cavity_identification:
@ -140,47 +168,51 @@ def parse_eco_packages(addr: Address, prepared_epc) -> tuple[list[str], int, str
"Solar Eligible": {
"measures": ["solar_pv", "loft_insulation", "mechanical_ventilation"],
"target_sap": 86, # High B
"plan_type": "solar_eco4"
"plan_type": "solar_eco4",
},
"Solar Eligible, Solid Wall Uninsulated, EPC E or Below": {
"measures": ["solar_pv", "loft_insulation", "mechanical_ventilation"],
"target_sap": 86, # High B
"plan_type": "solar_eco4"
"plan_type": "solar_eco4",
},
"Solar Eligible, Needs Heating Upgrade": {
"measures": ["solar_pv", "loft_insulation", "high_heat_retention_storage_heaters",
"mechanical_ventilation"],
"measures": [
"solar_pv",
"loft_insulation",
"high_heat_retention_storage_heaters",
"mechanical_ventilation",
],
"target_sap": 86, # High B
"plan_type": "solar_hhrsh_eco4"
"plan_type": "solar_hhrsh_eco4",
},
"Non-Intrusive Data Shows Empty Cavity": {
"measures": ["cavity_wall_insulation", "mechanical_ventilation"],
"target_sap": 69, # Low C
"plan_type": "empty_cavity_eco"
"plan_type": "empty_cavity_eco",
},
'Non-Intrusive Data Shows Empty Cavity, built after 2002': {
"Non-Intrusive Data Shows Empty Cavity, built after 2002": {
"measures": ["cavity_wall_insulation", "mechanical_ventilation"],
"target_sap": 69, # Low C
"plan_type": "empty_cavity_eco"
"plan_type": "empty_cavity_eco",
},
"EPC Shows Empty Cavity, inspections show retro drilled": {
# EPC Indicates it's empty, so we simulate a fill
"measures": ["cavity_wall_insulation", "mechanical_ventilation"],
"target_sap": 69, # Low C
"plan_type": "extraction_eco"
"plan_type": "extraction_eco",
},
"EPC Shows Empty Cavity, inspections show filled at build": {
# EPC Indicates it's empty, so we simulate a fill
"measures": ["cavity_wall_insulation", "mechanical_ventilation"],
"target_sap": 69, # Low C
"plan_type": "extraction_eco"
"plan_type": "extraction_eco",
},
"EPC Shows Empty Cavity": {
# EPC Indicates it's empty, so we simulate a fill
"measures": ["cavity_wall_insulation", "mechanical_ventilation"],
"target_sap": 69, # Low C
"plan_type": "empty_cavity_eco"
}
"plan_type": "empty_cavity_eco",
},
}
# Always prioritise solar
@ -214,9 +246,13 @@ def build_cloudwatch_log_url(start_ms: int) -> str:
Build a CloudWatch Logs URL for the current Lambda invocation,
including timestamp window from start_ms to end_ms (epoch ms).
"""
logger.info("Building cloudwatch logs URL")
region = os.environ["AWS_REGION"]
logger.info("Building cloudwatch logs URL: Got AWS region")
log_group = os.environ["AWS_LAMBDA_LOG_GROUP_NAME"]
logger.info("Building cloudwatch logs URL: Got lambda log group name")
log_stream = os.environ["AWS_LAMBDA_LOG_STREAM_NAME"]
logger.info("Building cloudwatch logs URL: Got lambda log stream name")
# CloudWatch console requires / encoded as $252F
encoded_group = log_group.replace("/", "$252F")
@ -232,15 +268,21 @@ def build_cloudwatch_log_url(start_ms: int) -> str:
)
def handle_error(msg, e, subtask_id, status=500, start_ms=None):
def handle_error(
msg: str,
exception: Exception,
subtask_id: str,
status_code: int = 500,
start_ms: Optional[int] = None,
):
# When the pipeline fails, handles error process
cloud_logs_url = build_cloudwatch_log_url(start_ms)
SubTaskInterface().update_subtask_status(
subtask_id=UUID(subtask_id),
status="failed",
outputs=str(e),
cloud_logs_url=cloud_logs_url
outputs=str(exception),
cloud_logs_url=cloud_logs_url,
)
logger.error(msg, exc_info=True)
return Response(status_code=status, content=msg)
return Response(status_code=status_code, content=msg)

View file

@ -9,7 +9,7 @@ from backend.app.tasks.schema import (
CreateSubTaskRequest,
UpdateSubTaskStatusRequest,
FinalizeSubTaskRequest,
TaskSqsTriggerRequest
TaskSqsTriggerRequest,
)
# Correct location of interfaces
@ -51,18 +51,18 @@ async def get_task(task_id: UUID):
if not task:
raise HTTPException(status_code=404, detail="Task not found")
subtasks = session.exec(
select(SubTask).where(SubTask.taskId == task_id)
).all()
subtasks = session.exec(select(SubTask).where(SubTask.taskId == task_id)).all()
formatted = []
for st in subtasks:
formatted.append({
**st.dict(),
"inputs": json.loads(st.inputs) if st.inputs else None,
"outputs": json.loads(st.outputs) if st.outputs else None,
"cloud_logs_url": st.cloudLogsURL,
})
formatted.append(
{
**st.dict(),
"inputs": json.loads(st.inputs) if st.inputs else None,
"outputs": json.loads(st.outputs) if st.outputs else None,
"cloud_logs_url": st.cloudLogsURL,
}
)
return {
"task": task,
@ -111,7 +111,10 @@ async def update_subtask_status(subtask_id: UUID, req: UpdateSubTaskStatusReques
# ===
# Sub task is complete
@router.post("/subtask/{subtask_id}/finalize", summary="Finalize a subtask with status, outputs, logs")
@router.post(
"/subtask/{subtask_id}/finalize",
summary="Finalize a subtask with status, outputs, logs",
)
async def finalize_subtask(subtask_id: UUID, req: FinalizeSubTaskRequest):
subtasks = SubTaskInterface()
@ -120,7 +123,7 @@ async def finalize_subtask(subtask_id: UUID, req: FinalizeSubTaskRequest):
subtask_id=subtask_id,
status=req.status,
outputs=req.outputs,
cloud_logs_url=req.cloud_logs_url
cloud_logs_url=req.cloud_logs_url,
)
return {
@ -142,9 +145,10 @@ from backend.app.tasks.schema import TaskSqsTriggerRequest
from backend.app.db.functions.tasks.Tasks import TasksInterface, SubTaskInterface
from backend.app.config import get_settings
sqs = boto3.client("sqs")
@router.post("/trigger", summary="Create task + subtask and publish to SQS", status_code=202)
@router.post(
"/trigger", summary="Create task + subtask and publish to SQS", status_code=202
)
async def trigger_task(req: TaskSqsTriggerRequest):
"""
Creates a Task + SubTask, then pushes the SubTask into SQS so a Lambda can process it.
@ -152,11 +156,12 @@ async def trigger_task(req: TaskSqsTriggerRequest):
"""
settings = get_settings()
sqs = boto3.client("sqs", settings.AWS_DEFAULT_REGION)
tasks = TasksInterface()
# ---- Normalize empty inputs ----
inputs = req.inputs or {} # ensures {} even if null
inputs = req.inputs or {} # ensures {} even if null
# ---- 1. Create Task + SubTask ----
task_id, subtask_id = tasks.create_task(
@ -174,8 +179,8 @@ async def trigger_task(req: TaskSqsTriggerRequest):
try:
response = sqs.send_message(
QueueUrl=f"https://sqs.{settings.AWS_REGION}.amazonaws.com/"
f"{settings.AWS_ACCOUNT_ID}/lambda-example-queue",
MessageBody=json.dumps(sqs_payload)
f"{settings.AWS_ACCOUNT_ID}/lambda-example-queue",
MessageBody=json.dumps(sqs_payload),
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"SQS error: {e}")
@ -186,4 +191,4 @@ async def trigger_task(req: TaskSqsTriggerRequest):
"subtask_id": subtask_id,
"sqs_message_id": response.get("MessageId"),
"inputs_sent": inputs,
}
}

View file

@ -8,5 +8,10 @@ class CategorisationTriggerRequest(BaseModel):
scenarios_to_consider: Optional[List[int]] = None
scenario_priority_order: Optional[List[int]] = None
min_property_id: Optional[int] = None
max_property_id: Optional[int] = None
subtask_id: Optional[str] = None
# {"portfolio_id": 556, "scenarios_to_consider": [1039,1041], "scenario_priority_order": [1041,1039]}

View file

@ -29,20 +29,10 @@ RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
# -----------------------------
COPY utils/ utils/
COPY backend/categorisation/ backend/categorisation/
COPY backend/app/db/ backend/app/db/
COPY backend/app/domain/ backend/app/domain/
COPY backend/addresses/ backend/addresses/
# NOTE: if build is ever slow we can be more specific with which files are copied
COPY backend/ backend/
COPY datatypes/ datatypes/
COPY backend/app/db/connection.py backend/app/db/connection.py
COPY backend/app/config.py backend/app/config.py
COPY backend/app/utils.py backend/app/utils.py
COPY backend/__init__.py backend/__init__.py
COPY backend/app/__init__.py backend/app/__init__.py
# -----------------------------
# Lambda handler

View file

@ -1,6 +1,9 @@
import json
import time
from typing import Any, Mapping
from backend.app.db.functions.tasks.Tasks import SubTaskInterface
from backend.app.plan.utils import build_cloudwatch_log_url
from backend.categorisation.categorisation_trigger_request import (
CategorisationTriggerRequest,
)
@ -25,12 +28,7 @@ def handler(event: Mapping[str, Any], context: Any) -> None:
logger.debug("Successfully validated request body")
process_portfolio(
payload.portfolio_id,
payload.scenarios_to_consider,
payload.scenario_priority_order,
)
process_portfolio(payload)
except Exception as e:
logger.info("Handler exception")
logger.error(f"Failed to process record: {e}")

View file

@ -1,6 +1,10 @@
sqlmodel
pydantic-settings
psycopg2-binary==2.9.10
starlette
# Not used but needed to satisfy imports
pytz==2024.2
pytz==2024.2
msgpack==1.1.0
numpy<2
pandas==2.2.3

View file

@ -2,16 +2,22 @@
import json
import requests
LAMBDA_URL = "http://localhost:9000/2015-03-31/functions/function/invocations"
HOST = "localhost"
PORT = "9000"
LAMBDA_URL = f"http://{HOST}:{PORT}/2015-03-31/functions/function/invocations"
payload = {
"Records": [
{
"body": json.dumps(
{
"portfolio_id": 556,
"portfolio_id": 569,
"scenarios_to_consider": [],
"scenario_priority_order": [],
"min_property_id": 660418,
"max_property_id": 660917,
"subtask_id": "6a0bcbac-ddab-435f-8708-8acd4662b067",
}
)
}

View file

@ -1,5 +1,8 @@
from typing import List
from backend.categorisation.categorisation_trigger_request import (
CategorisationTriggerRequest,
)
from backend.categorisation.processor import process_portfolio
@ -9,9 +12,11 @@ def main() -> None:
scenario_priority_order: List[int] = []
process_portfolio(
portfolio_id=portfolio_id,
scenarios_to_consider=scenarios_to_consider,
scenario_priority_order=scenario_priority_order,
CategorisationTriggerRequest(
portfolio_id=portfolio_id,
scenarios_to_consider=scenarios_to_consider,
scenario_priority_order=scenario_priority_order,
)
)

View file

@ -1,5 +1,8 @@
import time
from collections import defaultdict
from typing import Dict, List, Optional
from uuid import UUID
from starlette.responses import Response
from backend.app.db.functions.recommendations_functions import (
bulk_update_plans,
@ -8,74 +11,124 @@ from backend.app.db.functions.recommendations_functions import (
get_most_recent_plans_by_scenario_ids,
get_scenarios_by_portfolio_id,
)
from backend.app.db.functions.tasks.Tasks import SubTaskInterface
from backend.app.db.models.recommendations import PlanModel, ScenarioModel
from backend.app.domain.classes.plan import Plan
from backend.app.domain.classes.scenario import Scenario
from backend.app.plan.utils import build_cloudwatch_log_url, handle_error
from backend.categorisation.categorisation_trigger_request import (
CategorisationTriggerRequest,
)
from utils.logger import setup_logger
logger = setup_logger()
def process_portfolio(
portfolio_id: int,
scenarios_to_consider: Optional[List[int]] = None,
scenario_priority_order: Optional[List[int]] = None,
) -> None: # TODO: make this a class
body: CategorisationTriggerRequest,
) -> Response: # TODO: make this a class
portfolio_id: int = body.portfolio_id
scenarios_to_consider: Optional[List[int]] = body.scenarios_to_consider
scenario_priority_order: Optional[List[int]] = body.scenario_priority_order
min_property_id: Optional[int] = body.min_property_id
max_property_id: Optional[int] = body.max_property_id
subtask_id: Optional[str] = body.subtask_id
logger.info(f"Processing portfolio {portfolio_id}")
start_ms = int(time.time() * 1000)
cloud_logs_url = build_cloudwatch_log_url(start_ms)
all_scenarios: List[Scenario] = _load_scenarios_for_portfolio(portfolio_id)
plans_by_id: Dict[int, Plan] = {} # TODO: make this an in-memory repository class
if body.subtask_id:
SubTaskInterface().update_subtask_status(
subtask_id=UUID(subtask_id),
status="in progress",
cloud_logs_url=cloud_logs_url,
)
if scenarios_to_consider:
if len(scenarios_to_consider) < 2:
raise ValueError(
"Cannot run auto categorisation for fewer than 2 scenarios"
)
try:
# first get all plans that we're interested in
plans_for_consideration: List[Plan] = _load_plans_for_portfolio(
portfolio_id, all_scenarios, scenarios_to_consider
)
for plan in plans_for_consideration:
if plan.id is not None: # just in case
plans_by_id[plan.id] = plan
all_scenarios: List[Scenario] = _load_scenarios_for_portfolio(portfolio_id)
plans_by_id: Dict[int, Plan] = (
{}
) # TODO: make this an in-memory repository class
# then unset existing defaults on domain objects regardless of whether they're under consideration or not
default_plans: List[Plan] = _get_default_plans(portfolio_id, all_scenarios)
for plan in default_plans:
plan.set_default(False)
if plan.id is not None: # just in case
plans_by_id[plan.id] = plan
if scenarios_to_consider:
if len(scenarios_to_consider) < 2:
raise ValueError(
"Cannot run auto categorisation for fewer than 2 scenarios"
)
logger.info(f"Successfully unset {len(default_plans)} default plan(s)")
# then set new defaults on domain objects under consideration
plans_for_consideration_by_property: Dict[int, List[Plan]] = (
_group_plans_by_property(plans_for_consideration)
)
for property_id, property_plans in plans_for_consideration_by_property.items():
if not property_plans:
raise ValueError(f"No plans for property {property_id}")
try:
cheapest_plan = choose_cheapest_relevant_plan(
property_plans, scenario_priority_order
)
except Exception:
logger.error(f"Failed to find cheapest plan for property {property_id}")
raise
property_plans = _update_plan_objects(property_plans, cheapest_plan)
for plan in property_plans:
# first get all plans that we're interested in
plans_for_consideration: List[Plan] = _load_plans_for_portfolio(
portfolio_id,
all_scenarios,
scenarios_to_consider,
min_property_id,
max_property_id,
)
for plan in plans_for_consideration:
if plan.id is not None: # just in case
plans_by_id[plan.id] = plan
logger.info("Successfully set defaults on Plan objects in memory")
# then unset existing defaults on domain objects regardless of whether they're under consideration or not
default_plans: List[Plan] = _get_default_plans(
portfolio_id, all_scenarios, min_property_id, max_property_id
)
for plan in default_plans:
plan.set_default(False)
if plan.id is not None: # just in case
plans_by_id[plan.id] = plan
# then pass all domain objects to database to update (regardless of whether they've changed)
_update_plans_in_db(list(plans_by_id.values()))
logger.info(f"Successfully updated {len(plans_by_id)} Plans in database")
logger.info(f"Successfully unset {len(default_plans)} default plan(s)")
# then set new defaults on domain objects under consideration
plans_for_consideration_by_property: Dict[int, List[Plan]] = (
_group_plans_by_property(plans_for_consideration)
)
for property_id, property_plans in plans_for_consideration_by_property.items():
if not property_plans:
raise ValueError(f"No plans for property {property_id}")
try:
cheapest_plan = choose_cheapest_relevant_plan(
property_plans, scenario_priority_order
)
except Exception:
logger.error(f"Failed to find cheapest plan for property {property_id}")
raise
property_plans = _update_plan_objects(property_plans, cheapest_plan)
for plan in property_plans:
if plan.id is not None: # just in case
plans_by_id[plan.id] = plan
logger.info("Successfully set defaults on Plan objects in memory")
# then pass all domain objects to database to update (regardless of whether they've changed)
_update_plans_in_db(list(plans_by_id.values()))
# Mark the subtask as successful
logger.info(f"Successfully updated {len(plans_by_id)} Plans in database")
if body.subtask_id:
SubTaskInterface().update_subtask_status(
subtask_id=UUID(subtask_id),
status="complete",
cloud_logs_url=cloud_logs_url,
)
return Response(status_code=200)
except Exception as e:
if subtask_id:
return handle_error(
"Exception during Categorisation processing.",
e,
subtask_id,
500,
start_ms,
)
raise
def choose_cheapest_relevant_plan(
@ -109,8 +162,15 @@ def choose_cheapest_relevant_plan(
return cheapest_plans[0]
def _get_default_plans(portfolio_id: int, scenarios: List[Scenario]) -> List[Plan]:
default_plan_models = get_default_plans(portfolio_id)
def _get_default_plans(
portfolio_id: int,
scenarios: List[Scenario],
min_property_id: Optional[int] = None,
max_property_id: Optional[int] = None,
) -> List[Plan]:
default_plan_models = get_default_plans(
portfolio_id, min_property_id, max_property_id
)
scenario_map = {s.id: s for s in scenarios}
@ -131,12 +191,14 @@ def _load_plans_for_portfolio(
portfolio_id: int,
all_scenarios: List[Scenario],
scenarios_to_consider: Optional[List[int]] = None,
min_property_id: Optional[int] = None,
max_property_id: Optional[int] = None,
) -> List[Plan]:
if scenarios_to_consider:
logger.info(f"Getting plans for {len(scenarios_to_consider)} scenarios")
plan_models: List[PlanModel] = get_most_recent_plans_by_scenario_ids(
scenarios_to_consider
scenarios_to_consider, min_property_id, max_property_id
)
logger.info(f"Got {len(plan_models)} plan models from database")
else:
@ -144,7 +206,7 @@ def _load_plans_for_portfolio(
f"No list of Plans to consider provided. Getting all Plans for portfolio {portfolio_id}"
)
plan_models: List[PlanModel] = get_most_recent_plans_by_portfolio_id(
portfolio_id
portfolio_id, min_property_id, max_property_id
)
plans: List[Plan] = []

View file

@ -17,6 +17,11 @@ variable "image_digest" {
description = "Image digest (sha256:...)"
}
variable "maximum_concurrency" {
type = number
default = 10 # null if you don't want to set it for this handler
description = "Maximum number of concurrent Lambda invocations from SQS (2-1000). null = no limit."
}
locals {
image_uri = "${var.ecr_repo_url}@${var.image_digest}"