refactoring tasks to get working with sqlalchemy 2

This commit is contained in:
Khalim Conn-Kowlessar 2025-11-26 16:40:16 +00:00
parent caeab2bf82
commit 5c8c9251c4
7 changed files with 310 additions and 1423 deletions

File diff suppressed because it is too large Load diff

View file

@ -1,5 +1,3 @@
from __future__ import annotations
# ---- Standard Library ----
from typing import Optional, Dict, Any
from datetime import datetime, timezone
@ -28,7 +26,6 @@ class SubTaskInterface:
# CREATE SUBTASK
# --------------------------------------------------------
def create_subtask(self, task_id: UUID, inputs: Optional[Dict[str, Any]] = None):
now = datetime.now(timezone.utc)
with get_db_session() as session:
task = session.get(Task, task_id)
@ -36,11 +33,11 @@ class SubTaskInterface:
raise ValueError(f"Task {task_id} not found")
subtask = SubTask(
taskId=task_id,
task_id=task_id,
inputs=json.dumps(inputs) if inputs else None,
status="waiting",
jobStarted=None,
jobCompleted=None,
job_started=None,
job_completed=None,
)
session.add(subtask)
@ -49,7 +46,7 @@ class SubTaskInterface:
# Recalculate parent task progress
self._update_task_progress(session, task_id)
return subtask
return subtask.id
# --------------------------------------------------------
# UPDATE STATUS (in progress, complete, failed)
@ -65,21 +62,21 @@ class SubTaskInterface:
normalized = status.lower()
# When job really starts
if normalized == "in progress" and subtask.jobStarted is None:
subtask.jobStarted = now
if normalized == "in progress" and subtask.job_started is None:
subtask.job_started = now
# Completed or failed
if normalized in ("complete", "failed"):
subtask.jobCompleted = now
subtask.job_completed = now
subtask.status = normalized
subtask.updatedAt = now
subtask.updated_at = now
session.add(subtask)
session.commit()
# Recalculate task status
self._update_task_progress(session, subtask.taskId)
self._update_task_progress(session, subtask.task_id)
session.refresh(subtask)
return subtask
@ -87,7 +84,8 @@ class SubTaskInterface:
# --------------------------------------------------------
# UPDATE OUTPUTS
# --------------------------------------------------------
def update_subtask_output(self, subtask_id: UUID, outputs: Dict[str, Any]):
@staticmethod
def update_subtask_output(subtask_id: UUID, outputs: Dict[str, Any]):
now = datetime.now(timezone.utc)
with get_db_session() as session:
@ -96,7 +94,7 @@ class SubTaskInterface:
raise ValueError(f"SubTask {subtask_id} not found")
subtask.outputs = json.dumps(outputs)
subtask.updatedAt = now
subtask.updated_at = now
session.add(subtask)
session.commit()
@ -106,7 +104,8 @@ class SubTaskInterface:
# --------------------------------------------------------
# UPDATE CLOUD LOGS URL
# --------------------------------------------------------
def update_subtask_logs(self, subtask_id: UUID, cloud_logs_url: str):
@staticmethod
def update_subtask_logs(subtask_id: UUID, cloud_logs_url: str):
now = datetime.now(timezone.utc)
with get_db_session() as session:
@ -114,8 +113,8 @@ class SubTaskInterface:
if not subtask:
raise ValueError(f"SubTask {subtask_id} not found")
subtask.cloudLogsURL = cloud_logs_url
subtask.updatedAt = now
subtask.cloud_logs_url = cloud_logs_url
subtask.updated_at = now
session.add(subtask)
session.commit()
@ -125,8 +124,8 @@ class SubTaskInterface:
# --------------------------------------------------------
# SET BOTH OUTPUT + LOGS
# --------------------------------------------------------
@staticmethod
def set_subtask_result(
self,
subtask_id: UUID,
outputs: Optional[Dict[str, Any]] = None,
cloud_logs_url: Optional[str] = None,
@ -142,9 +141,9 @@ class SubTaskInterface:
subtask.outputs = json.dumps(outputs)
if cloud_logs_url is not None:
subtask.cloudLogsURL = cloud_logs_url
subtask.cloud_logs_url = cloud_logs_url
subtask.updatedAt = now
subtask.updated_at = now
session.add(subtask)
session.commit()
session.refresh(subtask)
@ -153,13 +152,14 @@ class SubTaskInterface:
# --------------------------------------------------------
# TASK PROGRESS CALCULATION
# --------------------------------------------------------
def _update_task_progress(self, session: Session, task_id: UUID):
@staticmethod
def _update_task_progress(session: Session, task_id: UUID):
task = session.get(Task, task_id)
if not task:
return
subtasks = session.exec(
select(SubTask).where(SubTask.taskId == task_id)
select(SubTask).where(SubTask.task_id == task_id)
).all()
statuses = [s.status.lower() for s in subtasks]
@ -167,24 +167,24 @@ class SubTaskInterface:
if "failed" in statuses:
task.status = "failed"
task.jobCompleted = now
task.job_completed = now
elif all(s == "complete" for s in statuses):
task.status = "complete"
task.jobCompleted = now
task.job_completed = now
elif "in progress" in statuses:
task.status = "in progress"
if task.jobStarted is None:
task.jobStarted = now
if task.job_started is None:
task.job_started = now
else:
# All waiting
task.status = "waiting"
task.jobStarted = None
task.jobCompleted = None
task.job_started = None
task.job_completed = None
task.updatedAt = now
task.updated_at = now
session.add(task)
session.commit()
@ -212,18 +212,18 @@ class SubTaskInterface:
# Set logs
if cloud_logs_url is not None:
subtask.cloudLogsURL = cloud_logs_url
subtask.cloud_logs_url = cloud_logs_url
# Status + timestamps
subtask.status = normalized
subtask.jobCompleted = now
subtask.updatedAt = now
subtask.job_completed = now
subtask.updated_at = now
session.add(subtask)
session.commit()
# Update parent task (complete/failed)
self._update_task_progress(session, subtask.taskId)
self._update_task_progress(session, subtask.task_id)
session.refresh(subtask)
return subtask
@ -237,38 +237,49 @@ class TasksInterface:
High-level operations for Task records.
"""
@staticmethod
def create_task(
self,
*,
task_source: str,
service: Optional[str] = None,
inputs: Optional[Dict[str, Any]] = None,
task_only: bool = False,
):
now = datetime.now(timezone.utc)
"""
Create a new Task record, and an initial SubTask in waiting state. Can also be used to create just
a task, without a subtask
:param task_source: Text indicating source of task creation (e.g. file path + function name)
:param service: Optional service name
:param inputs: Inputs of the job being run
:param task_only: If True, only create the Task record, without a SubTask
:return:
"""
with get_db_session() as session:
task = Task(
taskSource=task_source,
task_source=task_source,
service=service,
status="waiting",
jobStarted=None,
jobCompleted=None,
job_started=None,
job_completed=None,
)
session.add(task)
session.commit()
session.refresh(task)
if task_only:
return task.id, None
# Create first subtask in waiting state
subtask_interface = SubTaskInterface()
subtask = subtask_interface.create_subtask(
subtask_id = subtask_interface.create_subtask(
task_id=task.id,
inputs=inputs,
)
return task.id, subtask.id
return task.id, subtask_id
def update_task_status(self, task_id: UUID, status: str):
@staticmethod
def update_task_status(task_id: UUID, status: str):
now = datetime.now(timezone.utc)
with get_db_session() as session:
@ -278,14 +289,14 @@ class TasksInterface:
normalized = status.lower()
if normalized == "in progress" and task.jobStarted is None:
task.jobStarted = now
if normalized == "in progress" and task.job_started is None:
task.job_started = now
if normalized == "complete":
task.jobCompleted = now
task.job_completed = now
task.status = normalized
task.updatedAt = now
task.updated_at = now
session.add(task)
session.commit()

View file

@ -1,6 +1,4 @@
from __future__ import annotations
from typing import Optional, List
from typing import Optional
from datetime import datetime
from uuid import UUID, uuid4
@ -10,64 +8,29 @@ from sqlmodel import SQLModel, Field, Relationship
class Task(SQLModel, table=True):
__tablename__ = "tasks"
id: UUID = Field(
default_factory=uuid4,
primary_key=True,
index=True,
)
taskSource: str = Field(alias="task_source")
jobStarted: Optional[datetime] = Field(
default=None, alias="job_started"
)
jobCompleted: Optional[datetime] = Field(
default=None, alias="job_completed"
)
id: UUID = Field(default_factory=uuid4, primary_key=True, index=True, )
task_source: str
job_started: Optional[datetime] = None
job_completed: Optional[datetime] = None
status: str = Field(default="In Progress")
service: Optional[str] = None
updated_at: datetime = Field(default_factory=datetime.utcnow)
updatedAt: datetime = Field(
default_factory=datetime.utcnow,
alias="updated_at",
)
# Relationship
subTasks: List["SubTask"] = Relationship(back_populates="task")
sub_tasks: list["SubTask"] = Relationship(back_populates="task")
class SubTask(SQLModel, table=True):
__tablename__ = "sub_task"
id: UUID = Field(
default_factory=uuid4,
primary_key=True,
index=True,
)
taskId: UUID = Field(
foreign_key="tasks.id",
alias="task_id",
)
jobStarted: Optional[datetime] = Field(
default=None, alias="job_started"
)
jobCompleted: Optional[datetime] = Field(
default=None, alias="job_completed"
)
id: UUID = Field(default_factory=uuid4, primary_key=True, index=True, )
task_id: UUID = Field(foreign_key="tasks.id")
job_started: Optional[datetime] = None
job_completed: Optional[datetime] = None
status: str = Field(default="In Progress")
inputs: Optional[str] = None
outputs: Optional[str] = None
cloudLogsURL: Optional[str] = Field(alias="cloud_logs_url")
cloud_logs_url: Optional[str] = None
updated_at: datetime = Field(default_factory=datetime.utcnow)
updatedAt: datetime = Field(
default_factory=datetime.utcnow,
alias="updated_at",
)
# Relationship
task: Optional[Task] = Relationship(back_populates="subTasks")
task: Optional["Task"] = Relationship(back_populates="sub_tasks")

View file

@ -81,14 +81,38 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest):
# Insert the scenario ID into the data payload
data["scenario_id"] = scenario_id
# Create a task, and associated sub-tasks
from backend.app.db.functions.tasks.Tasks import TasksInterface, SubTaskInterface
# Create a main task
task_id = TasksInterface.create_task(
task_source="backend/plan/router.py:trigger_plan_entrypoint",
service="plan_engine",
inputs=data,
task_only=True
)
subtask_interface = SubTaskInterface()
for i in range(total_chunks):
# Create an entry in the request logs table
index_start = i * chunk_size
index_end = min((i + 1) * chunk_size, total_rows)
message_payload = {**data, "index_start": index_start, "index_end": index_end}
message_payload = {
**data, "index_start": index_start, "index_end": index_end,
}
message_body = json.dumps(message_payload)
# Create a subtask for this chunk
subtask_id = subtask_interface.create_subtask(
task_id=task_id,
inputs=message_payload
)
# Add task and subtask to message
message_payload["task_id"] = str(task_id)
message_payload["subtask_id"] = str(subtask_id)
response = sqs_client.send_message(
QueueUrl=settings.ENGINE_SQS_URL,
MessageBody=message_body

View file

@ -129,6 +129,10 @@ class PlanTriggerRequest(BaseModel):
index_start: Optional[int] = None
index_end: Optional[int] = None
# Task and subtask IDs
task_id: Optional[str] = None
subtask_id: Optional[str] = None
@model_validator(mode="after")
def check_indexes(self):
if (self.index_start is None) != (self.index_end is None):

View file

@ -10,6 +10,7 @@ import json
import time
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
from dotenv import load_dotenv
from asset_list.utils import get_data_for_property
@ -52,8 +53,6 @@ n_postcodes = property_list["Post Code"].nunique()
postcode_summary = property_list.groupby("Post Code")["UPRN"].count().reset_index()
postcode_summary["UPRN"].mean()
test_match = property_list.merge(sustainability_data, left_on="UPRN", right_on="Org Ref")
def classify_floor_area(x):
if x <= 72:
@ -70,20 +69,187 @@ sustainability_data["Floor Area Band"] = sustainability_data["Total Floor Area (
lambda x: classify_floor_area(x)
)
archetypes = sustainability_data[
["Type", "Attachment", "Construction Years", "Wall Construction", "Wall Insulation",
"Roof Construction", "Roof Insulation", "Floor Construction", "Floor Insulation",
"Glazing", "Heating", "Boiler Efficiency", "Main Fuel", "Controls Adequacy",
"Floor Area Band"]
].drop_duplicates()
# Archetype reductions
# Potential reductions:
# Roof insulation category
# 1) Split roof insulation into > 100mm loft and <= 100mm loft
sustainability_data["Roof Insulation Category"] = sustainability_data["Roof Insulation"].copy()
sustainability_data["Roof Insulation Category"] = np.where(
sustainability_data["Roof Insulation Category"].isin(
['mm200', 'mm300', 'mm250', 'mm150', 'mm270', 'mm400', 'mm350'],
),
"LI > 100mm",
sustainability_data["Roof Insulation Category"],
)
sustainability_data["Roof Insulation Category"] = np.where(
sustainability_data["Roof Insulation Category"].isin(
['mm100', 'mm50', 'mm75', 'mm25'],
),
"LI <= 100mm",
sustainability_data["Roof Insulation Category"],
)
# 2) Group all of the glazed together (e.g. double glazed, secondary glazed, triple glazed)
# 3) Group up boiler efficiency A-C, D - F, G? or someting like this
sustainability_data["Glazing Type"] = sustainability_data["Glazing"].copy()
sustainability_data["Glazing Type"] = np.where(
sustainability_data["Glazing Type"].isin(
['Double 2002 or later', 'Double before 2002', 'Double but age unknown', 'DoubleKnownData']
),
"Double Glazed",
sustainability_data["Glazing Type"],
)
sustainability_data["Glazing Type"] = np.where(
sustainability_data["Glazing Type"].isin(['Triple', 'TripleKnownData']),
"Triple Glazed",
sustainability_data["Glazing Type"],
)
# 3) Group up boiler efficiency A, B-D, E - G? or someting like this
sustainability_data["Boiler Efficiency Group"] = sustainability_data["Boiler Efficiency"].copy()
sustainability_data["Boiler Efficiency Group"] = np.where(
sustainability_data["Boiler Efficiency Group"].isin(['B', 'C', 'D']),
"B-D",
sustainability_data["Boiler Efficiency Group"],
)
sustainability_data["Boiler Efficiency Group"] = np.where(
sustainability_data["Boiler Efficiency Group"].isin(['E', 'F', 'G']),
"E-G",
sustainability_data["Boiler Efficiency Group"],
)
# 4) Group up main fuel into gas, electric, oil, other?
sustainability_data["Main Fuel Group"] = sustainability_data["Main Fuel"].copy()
sustainability_data["Main Fuel Group"] = np.where(
sustainability_data["Main Fuel Group"].isin(
["SmokelessCoal", "BiomassCommunity", "B30DCommunity"]
),
"Other Fuel",
sustainability_data["Main Fuel Group"],
)
# 5) Wall Construction - group up Sandstone and Granite into one category
sustainability_data["Wall Construction"] = np.where(
sustainability_data["Wall Construction"].isin(["Sandstone", "Granite"]),
"Sandstone/Granite",
sustainability_data["Wall Construction"]
)
sustainability_data["Wall Construction"] = np.where(
sustainability_data["Wall Construction"].isin(["Timber Frame", "System", "Solid Brick"]),
"Solid",
sustainability_data["Wall Construction"]
)
# 6) Reduce or remove floor construction
sustainability_data["Floor Construction"] = np.where(
sustainability_data["Floor Construction"].isin(["SuspendedTimber", "SuspendedNotTimber"]),
"Suspended Floor",
sustainability_data["Floor Construction"]
)
# 7) Reduce wall insulation
sustainability_data["Wall Insulation"] = np.where(
sustainability_data["Wall Insulation"].isin(
["FilledCavityPlusInternal", "FilledCavityPlusExternal", "FilledCavity", "External", "Internal"]
),
"Insulated",
sustainability_data["Wall Insulation"]
)
# 8) Fill floor insulation
sustainability_data["Floor Insulation"] = sustainability_data["Floor Insulation"].fillna("Unknown")
# 9) Reduce Age bands
sustainability_data["Construction Years"] = np.where(
sustainability_data["Construction Years"].isin(["2003-2006", "2007-2011", "2012 onwards"]),
"2003 onwards",
sustainability_data["Construction Years"],
)
sustainability_data["Construction Years"] = np.where(
sustainability_data["Construction Years"].isin(["Before 1900", "1900-1929"]),
"Before 1929",
sustainability_data["Construction Years"],
)
sustainability_data["Construction Years"] = np.where(
sustainability_data["Construction Years"].isin(["1983-1990", "1991-1995"]),
"1983-1995",
sustainability_data["Construction Years"],
)
sustainability_data["Construction Years"] = np.where(
sustainability_data["Construction Years"].isin(["1950-1966", "1967-1975", "1976-1982"]),
"1950-1982",
sustainability_data["Construction Years"],
)
# Roof
sustainability_data["Roof Construction"] = np.where(
sustainability_data["Roof Construction"].isin(
["PitchedNormalLoftAccess", "PitchedThatched", "PitchedNormalNoLoftAccess", "PitchedWithSlopingCeiling"]
),
"Pitched Roof",
sustainability_data["Roof Construction"]
)
archetype_variables = [
"Type", "Attachment", "Construction Years", "Wall Construction", "Wall Insulation",
"Roof Construction", "Roof Insulation Category", "Floor Construction", "Floor Insulation",
"Glazing Type", "Heating", "Boiler Efficiency Group", "Main Fuel Group", "Controls Adequacy",
"Floor Area Band"
]
archetypes = sustainability_data[archetype_variables + ["UPRN"]].dropna().groupby(archetype_variables)[
"UPRN"].nunique().reset_index().rename(columns={"UPRN": "Count"}).sort_values(by="Count",
ascending=False).reset_index(
drop=True)
# We take a sample that represents 95% of the properties
archetypes["Cumulative Count"] = archetypes["Count"].cumsum()
archetypes["Cumulative Proportion"] = archetypes["Cumulative Count"] / archetypes["Count"].sum()
archetypes_85 = archetypes[archetypes["Cumulative Proportion"] <= 0.80]
archetypes_85["Archetypes_85_reference"] = archetypes_85.index + 1
archetypes_85["Archetypes_85_reference"] = "Archetype_Sample_" + archetypes_85["Archetypes_85_reference"].astype(str)
# We now take a sample of the properties that represent 85% of the total properties
sustainability_data = sustainability_data.merge(
archetypes_85,
on=archetype_variables,
how="inner"
)
# We take 1 random property, by archetype 85 reference
modelling_sample = sustainability_data.groupby("Archetypes_85_reference").apply(
lambda x: x.sample(1, random_state=42)
).reset_index(drop=True)
# Checking distributions
def compare_distributions(full_df, sample_df, column):
full_dist = full_df[column].value_counts(normalize=True)
sample_dist = sample_df[column].value_counts(normalize=True)
comparison = pd.concat([full_dist, sample_dist], axis=1, keys=['Full', 'Sample']).fillna(0)
return comparison
for col in archetype_variables:
print(f"--- {col} ---")
print(compare_distributions(sustainability_data, modelling_sample, col))
# Save this CSV as input
modelling_sample.to_excel(
"/Users/khalimconn-kowlessar/Documents/hestia/Customers/Peabody/Nov 2025 Consulting Project/modelling_sample.xlsx",
)
# Save the archetype definitions
archetypes_85.to_excel(
"/Users/khalimconn-kowlessar/Documents/hestia/Customers/Peabody/Nov 2025 Consulting Project/archetypes_85.xlsx",
)
# Save the full archetypes
archetypes.to_excel(
"/Users/khalimconn-kowlessar/Documents/hestia/Customers/Peabody/Nov 2025 Consulting Project/full_archetypes.xlsx",
)
# Maps the property types to the format recognised by the EPC api
property_type_map = {}

View file

@ -21,14 +21,16 @@ class RetrieveFindMyEpc:
'Chrome/111.0.0.0 Safari/537.36'
}
def __init__(self, address: str, postcode: str):
def __init__(self, address: str, postcode: str, rrn: str = None):
"""
This class is tasked with retrieving the latest EPC data from the find my epc website
:param address: The address of the property
:param postcode: The postcode of the property
:param rrn: The RRN of the EPC (if known)
"""
self.address = address
self.postcode = postcode
self.rrn = rrn
self.address_cleaned = self.address.replace(",", "").replace(" ", "").lower()
self.walls = []
@ -286,54 +288,12 @@ class RetrieveFindMyEpc:
:return:
"""
postcode_input = self.postcode.replace(" ", "+")
postcode_search = self.SEARCH_POSTCODE_URL.format(postcode_input=postcode_input)
postcode_response = requests.get(postcode_search, headers=self.HEADERS)
postcode_res = BeautifulSoup(postcode_response.text, features="html.parser")
rows = postcode_res.find_all('tr', class_='govuk-table__row')
extracted_table = []
for row in rows:
# Extract the address and URL
address_tag = row.find('a', class_='govuk-link')
if address_tag is None:
continue
extracted_address = None
extracted_address_url = None
if address_tag:
extracted_address = address_tag.text.strip()
extracted_address_url = address_tag['href']
extracted_address_cleaned = extracted_address.replace(",", "").replace(" ", "").lower()
if not extracted_address_cleaned.startswith(self.address_cleaned):
continue
# If the address is a match, we can extract the data
# Extract the expiry date
expiry_date_tag = row.find('td', class_='govuk-table__cell date')
expiry_date = None
if expiry_date_tag is not None:
expiry_date = expiry_date_tag.parent.find('span').text.strip()
extracted_table.append(
{
"extracted_address": extracted_address,
"extracted_address_url": extracted_address_url,
"expiry_date": datetime.strptime(expiry_date, '%d %B %Y'),
}
)
if not extracted_table:
raise ValueError("No EPC found")
if len(extracted_table) > 1:
# We take the one with the most recent expiry date
extracted_table = sorted(extracted_table, key=lambda x: x['expiry_date'], reverse=True)
chosen_epc = self.BASE_ENERGY_URL + extracted_table[0]['extracted_address_url']
epc_certificate = chosen_epc.split('/')[-1]
if self.rrn:
# We build the URL directly
epc_certificate = self.rrn
chosen_epc = f"{self.BASE_ENERGY_URL}/energy-certificate/{epc_certificate}"
else:
chosen_epc, epc_certificate = self._find_epc_page()
address_response = requests.get(chosen_epc, headers=self.HEADERS)
address_res = BeautifulSoup(address_response.text, features="html.parser")
@ -438,11 +398,17 @@ class RetrieveFindMyEpc:
For a post code and address, we pull out all the required data from the find my epc website
"""
if epc_page_source is None:
if epc_page_source is None and rrn is None:
chosen_epc, rrn = self._find_epc_page()
address_response = requests.get(chosen_epc, headers=self.HEADERS)
epc_page_source = address_response.text
address_res = BeautifulSoup(address_response.text, features="html.parser")
elif self.rrn:
epc_certificate = self.rrn
chosen_epc = f"{self.BASE_ENERGY_URL}/energy-certificate/{epc_certificate}"
address_response = requests.get(chosen_epc, headers=self.HEADERS)
epc_page_source = address_response.text
address_res = BeautifulSoup(address_response.text, features="html.parser")
else:
if rrn is None:
raise ValueError("rrn must be provided if epc_page_source is provided")
@ -581,6 +547,19 @@ class RetrieveFindMyEpc:
# 5) Pull out the EPC data
epc_data = self.extract_epc_data(address_res)
# Pull out the address information which can be found in the box with the class "epc-address"
# We split it up on break tags
addr = address_res.find("p", class_="epc-address").get_text(separator="\n").strip()
lines = addr.split("\n")
if len(lines) > 2:
address1 = lines[0]
address2 = lines[1]
postcode = lines[-1]
else:
address1 = lines[0]
address2 = ""
postcode = lines[-1]
resulting_data = {
'epc_certificate': rrn,
'current_epc_rating': current_rating.split(' ')[-6],
@ -594,6 +573,10 @@ class RetrieveFindMyEpc:
**assessment_data,
**low_carbon_energy_sources,
"page_source": epc_page_source,
# Add in address a postcode from the page - covers use cases where we are given RRN
"address1": address1,
"address2": address2,
"postcode": postcode,
}
if return_page: