perf(modelling_e2e): batch SubTask bookkeeping to stop per-property writes

Even after batching the data writes, the handler still wrote to the DB per
property through the orchestrator's SubTask bookkeeping: create + start +
complete each self-committed, and _cascade re-listed every sibling and re-saved
the parent on every transition — ~5 writes per property plus an O(N^2) cascade.

- TaskOrchestrator.run_subtasks: create all children in one INSERT, run each
  (failures isolated per child), then persist all terminal states in one bulk
  save and cascade the parent once. Children go WAITING -> terminal; the
  transient IN_PROGRESS row is never written.
- SubTaskRepository.create_many / save_many (bulk INSERT / bulk fetch + update).
- _cascade short-circuits when the Task is already FAILED (terminal) — skips the
  sibling roll-up entirely.
- modelling_e2e handler fans out via run_subtasks instead of per-property
  create_child_subtask + run_subtask.

Per N-property batch the SubTask bookkeeping drops from ~5N writes + an O(N^2)
cascade to ~2 writes + 1 cascade.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
Jun-te Kim 2026-06-24 19:26:42 +00:00
parent de5e9a2362
commit b1ff711260
7 changed files with 412 additions and 142 deletions

View file

@ -65,6 +65,7 @@ from domain.modelling.plan import Plan
from domain.property.property import Property, PropertyIdentity
from domain.property_baseline.calculator_rebaseliner import CalculatorRebaseliner
from domain.sap10_calculator.calculator import Sap10Calculator
from domain.tasks.subtasks import SubTask
from domain.tasks.tasks import Source
from harness.console import run_modelling
from orchestration.task_orchestrator import TaskOrchestrator
@ -486,142 +487,145 @@ def handler(body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator,
# whole batch is flushed in one transaction after the loop.
accumulated: list[_PropertyWrite] = []
for property_id in property_ids:
child = orchestrator.create_child_subtask(
task_id, inputs={"property_id": property_id}
def _work(subtask: SubTask) -> None:
inputs = subtask.inputs or {}
pid = int(inputs["property_id"])
uprn = uprns[pid]
postcode = postcodes.get(pid, "")
logger.info(f"property={pid} uprn={uprn} postcode={postcode!r}")
spatial = _spatial_for(geospatial, uprn)
restrictions = (
spatial.restrictions
if spatial is not None
else PlanningRestrictions()
)
coordinates: Optional[Coordinates] = (
spatial.coordinates if spatial is not None else None
)
def _work(pid: int = property_id) -> None:
uprn = uprns[pid]
postcode = postcodes.get(pid, "")
logger.info(f"property={pid} uprn={uprn} postcode={postcode!r}")
epc: Optional[EpcPropertyData] = epc_client.get_by_uprn(uprn)
overrides = overlays_from(overrides_reader.overrides_for(pid))
predicted_epc: Optional[EpcPropertyData] = None
spatial = _spatial_for(geospatial, uprn)
restrictions = (
spatial.restrictions
if spatial is not None
else PlanningRestrictions()
)
coordinates: Optional[Coordinates] = (
spatial.coordinates if spatial is not None else None
)
epc: Optional[EpcPropertyData] = epc_client.get_by_uprn(uprn)
overrides = overlays_from(overrides_reader.overrides_for(pid))
predicted_epc: Optional[EpcPropertyData] = None
if epc is not None:
logger.info(f"property={pid} lodged EPC found")
effective_epc = Property(
identity=PropertyIdentity(
portfolio_id=portfolio_id,
postcode=postcode,
address="",
uprn=uprn,
),
epc=epc,
landlord_overrides=overrides,
).effective_epc
else:
logger.info(
f"property={pid} no lodged EPC — attempting prediction"
)
predicted_epc = _predict_epc(
property_id=pid,
uprn=uprn,
postcode=postcode,
if epc is not None:
logger.info(f"property={pid} lodged EPC found")
effective_epc = Property(
identity=PropertyIdentity(
portfolio_id=portfolio_id,
attributes_reader=prediction_attrs_reader,
coordinates=coordinates,
cohort_for=_get_cohort,
broaden=_broaden,
predictor=predictor,
)
effective_epc = Property(
identity=PropertyIdentity(
portfolio_id=portfolio_id,
postcode=postcode,
address="",
uprn=uprn,
),
epc=None,
predicted_epc=predicted_epc,
landlord_overrides=overrides,
).effective_epc
postcode=postcode,
address="",
uprn=uprn,
),
epc=epc,
landlord_overrides=overrides,
).effective_epc
else:
logger.info(
f"property={pid} no lodged EPC — attempting prediction"
)
predicted_epc = _predict_epc(
property_id=pid,
uprn=uprn,
postcode=postcode,
portfolio_id=portfolio_id,
attributes_reader=prediction_attrs_reader,
coordinates=coordinates,
cohort_for=_get_cohort,
broaden=_broaden,
predictor=predictor,
)
effective_epc = Property(
identity=PropertyIdentity(
portfolio_id=portfolio_id,
postcode=postcode,
address="",
uprn=uprn,
),
epc=None,
predicted_epc=predicted_epc,
landlord_overrides=overrides,
).effective_epc
solar_insights: Optional[dict[str, Any]]
solar_was_fetched = False
if no_solar:
solar_insights = None
else:
solar_insights = stored_solar.get(uprn)
if solar_insights is None:
solar_insights = _solar_insights_for(solar_client, spatial)
solar_was_fetched = solar_insights is not None
solar_insights: Optional[dict[str, Any]]
solar_was_fetched = False
if no_solar:
solar_insights = None
else:
solar_insights = stored_solar.get(uprn)
if solar_insights is None:
solar_insights = _solar_insights_for(solar_client, spatial)
solar_was_fetched = solar_insights is not None
plan = run_modelling(
effective_epc,
planning_restrictions=restrictions,
solar_insights=solar_insights,
considered_measures=None,
products=products,
scenario=scenario,
print_table=False,
plan = run_modelling(
effective_epc,
planning_restrictions=restrictions,
solar_insights=solar_insights,
considered_measures=None,
products=products,
scenario=scenario,
print_table=False,
)
logger.info(
f"property={pid} modelling complete "
f"measures={len(plan.measures)}"
)
if dry_run:
measure_types = (
", ".join(m.measure_type for m in plan.measures) or "none"
)
logger.info(
f"property={pid} modelling complete "
f"measures={len(plan.measures)}"
f"[dry_run] property={pid} "
f"measures=[{measure_types}] — skipping DB write"
)
return
solar_write: Optional[_SolarWrite] = None
if (
solar_was_fetched
and solar_insights is not None
and spatial is not None
and spatial.coordinates is not None
):
solar_write = _SolarWrite(
uprn=uprn,
longitude=spatial.coordinates.longitude,
latitude=spatial.coordinates.latitude,
insights=solar_insights,
)
if dry_run:
measure_types = (
", ".join(m.measure_type for m in plan.measures) or "none"
)
logger.info(
f"[dry_run] property={pid} "
f"measures=[{measure_types}] — skipping DB write"
)
return
solar_write: Optional[_SolarWrite] = None
if (
solar_was_fetched
and solar_insights is not None
and spatial is not None
and spatial.coordinates is not None
):
solar_write = _SolarWrite(
uprn=uprn,
longitude=spatial.coordinates.longitude,
latitude=spatial.coordinates.latitude,
insights=solar_insights,
)
# Queue this Property's writes rather than committing now — the
# whole batch is persisted in one Unit of Work after the loop
# (see _flush_writes). The EPC is saved in its lodged or predicted
# slot (ADR-0031) at flush time depending on which is set here.
accumulated.append(
_PropertyWrite(
property_id=pid,
uprn=uprn,
portfolio_id=portfolio_id,
scenario_id=scenario_id,
is_default=scenario.is_default,
lodged_epc=epc,
predicted_epc=predicted_epc,
spatial=spatial,
solar=solar_write,
plan=plan,
has_recommendations=bool(plan.measures),
)
# Queue this Property's writes rather than committing now — the
# whole batch is persisted in one Unit of Work after the loop
# (see _flush_writes). The EPC is saved in its lodged or predicted
# slot (ADR-0031) at flush time depending on which is set here.
accumulated.append(
_PropertyWrite(
property_id=pid,
uprn=uprn,
portfolio_id=portfolio_id,
scenario_id=scenario_id,
is_default=scenario.is_default,
lodged_epc=epc,
predicted_epc=predicted_epc,
spatial=spatial,
solar=solar_write,
plan=plan,
has_recommendations=bool(plan.measures),
)
logger.info(f"property={pid} queued for write")
)
logger.info(f"property={pid} queued for write")
try:
orchestrator.run_subtask(child.id, work=_work)
except Exception: # noqa: BLE001
pass
# Fan the batch out into one child SubTask per property and run them in
# a single batched pass: create all children, model each (failures
# isolated per child), then persist all their statuses in two writes +
# one cascade — not ~5 writes and a full parent re-roll-up per property
# (see TaskOrchestrator.run_subtasks).
orchestrator.run_subtasks(
task_id,
[{"property_id": pid} for pid in property_ids],
work=_work,
)
# Persist the whole batch in one transaction, then re-establish every
# written Property's Baseline (the orchestrator batches its own UoW). The

View file

@ -2,7 +2,7 @@ from typing import Any, Callable, Optional
from uuid import UUID
from domain.tasks.subtasks import SubTask
from domain.tasks.tasks import Source, Task
from domain.tasks.tasks import Source, Task, TaskStatus
from repositories.tasks.subtask_repository import SubTaskRepository
from repositories.tasks.task_repository import TaskRepository
from utilities.private import private
@ -98,9 +98,59 @@ class TaskOrchestrator:
self.complete_subtask(subtask_id, result)
return result
def run_subtasks(
self,
parent_task_id: UUID,
inputs_per_subtask: list[dict[str, Any]],
work: Callable[[SubTask], Any],
cloud_logs_url: Optional[str] = None,
) -> list[Any]:
"""Fan a parent Task out into one child SubTask per item, run ``work`` for
each (failures isolated per child a raising item is marked failed and
its siblings still run), and persist the whole batch in **two** writes
plus **one** cascade.
This is the batched form of ``run_subtask``: instead of ~5 writes and a
full parent re-roll-up *per child* (``create`` + ``start`` + ``complete``
each cascading an O() cost on the parent's children), it does one bulk
``create_many``, runs every item recording its terminal state in memory,
then one bulk ``save_many`` and a single ``_cascade``. Children move
straight from WAITING to their terminal state the transient IN_PROGRESS
row is never written, since for a fast batch it only adds DB churn.
Returns one entry per item in order: the work's result, or ``None`` for an
item whose work raised.
"""
subtasks = [
SubTask.create(task_id=parent_task_id, inputs=inputs)
for inputs in inputs_per_subtask
]
self._subtasks.create_many(subtasks)
results: list[Any] = []
for subtask in subtasks:
subtask.start(cloud_logs_url)
try:
result = work(subtask)
except Exception as e: # noqa: BLE001 — isolate per child; siblings continue
subtask.fail(e)
results.append(None)
else:
subtask.complete(result)
results.append(result)
self._subtasks.save_many(subtasks)
self._cascade(parent_task_id)
return results
@private
def _cascade(self, task_id: UUID) -> None:
statuses = [s.status for s in self._subtasks.list_by_task(task_id)]
task = self._tasks.get(task_id)
# FAILED is terminal: once any SubTask has failed the Task is failed and
# stays failed, so skip the (potentially large) sibling roll-up entirely —
# no need to list and re-check the SubTasks.
if task.status is TaskStatus.FAILED:
return
statuses = [s.status for s in self._subtasks.list_by_task(task_id)]
task.recalculate_from_subtasks(statuses)
self._tasks.save(task)

View file

@ -3,7 +3,7 @@ from datetime import datetime, timezone
from typing import Any, Optional
from uuid import UUID
from sqlmodel import Session, select
from sqlmodel import Session, col, select
from domain.tasks.subtasks import SubTask, SubTaskStatus
from infrastructure.postgres.subtask_table import SubTaskRow
@ -22,6 +22,12 @@ class SubTaskPostgresRepository(SubTaskRepository):
self._session.refresh(row)
return self._to_domain(row)
def create_many(self, subtasks: list[SubTask]) -> None:
if not subtasks:
return
self._session.add_all([self._to_row(s) for s in subtasks])
self._session.commit()
def get(self, subtask_id: UUID) -> SubTask:
row = self._session.get(SubTaskRow, subtask_id)
if row is None:
@ -46,6 +52,36 @@ class SubTaskPostgresRepository(SubTaskRepository):
self._session.add(row)
self._session.commit()
def save_many(self, subtasks: list[SubTask]) -> None:
if not subtasks:
return
by_id = {s.id: s for s in subtasks}
rows = self._session.exec(
select(SubTaskRow).where(col(SubTaskRow.id).in_(list(by_id)))
).all()
found = {row.id for row in rows}
missing = set(by_id) - found
if missing:
raise ValueError(f"SubTask(s) not found: {sorted(str(m) for m in missing)}")
now = datetime.now(timezone.utc)
for row in rows:
subtask = by_id[row.id]
row.status = subtask.status.value
row.job_started = subtask.job_started
row.job_completed = subtask.job_completed
row.inputs = (
json.dumps(subtask.inputs) if subtask.inputs is not None else None
)
row.outputs = (
json.dumps(subtask.outputs)
if subtask.outputs is not None
else None
)
row.cloud_logs_url = subtask.cloud_logs_url
row.updated_at = now
self._session.add(row)
self._session.commit()
def list_by_task(self, task_id: UUID) -> list[SubTask]:
rows = self._session.exec(
select(SubTaskRow).where(SubTaskRow.task_id == task_id)

View file

@ -8,11 +8,24 @@ class SubTaskRepository(ABC):
@abstractmethod
def create(self, subtask: SubTask) -> SubTask: ...
@abstractmethod
def create_many(self, subtasks: list[SubTask]) -> None:
"""Persist many SubTasks in one round-trip (one INSERT + one commit) —
the batch form of ``create`` for callers that fan a Task out into many
children at once."""
...
@abstractmethod
def get(self, subtask_id: UUID) -> SubTask: ...
@abstractmethod
def save(self, subtask: SubTask) -> None: ...
@abstractmethod
def save_many(self, subtasks: list[SubTask]) -> None:
"""Persist updates to many SubTasks in one round-trip (one bulk fetch +
one commit) the batch form of ``save``."""
...
@abstractmethod
def list_by_task(self, task_id: UUID) -> list[SubTask]: ...

View file

@ -18,6 +18,7 @@ from pydantic import ValidationError
from applications.modelling_e2e.modelling_e2e_trigger_body import (
ModellingE2ETriggerBody,
)
from domain.tasks.subtasks import SubTask
PROPERTY_ID = 12345
UPRN = 987654321
@ -41,11 +42,26 @@ _BODY = {
def _mock_orchestrator() -> MagicMock:
"""Mock TaskOrchestrator whose run_subtasks creates one SubTask per input and
runs work for each, isolating per-item failures (mirroring the real method)."""
mock = MagicMock()
mock.run_subtask.side_effect = lambda subtask_id, work, **kwargs: work()
child = MagicMock()
child.id = uuid4()
mock.create_child_subtask.return_value = child
def _run_subtasks(
parent_task_id: UUID,
inputs_per_subtask: list[dict[str, Any]],
work: Any,
**kwargs: Any,
) -> list[Any]:
results: list[Any] = []
for inputs in inputs_per_subtask:
subtask = SubTask.create(task_id=parent_task_id, inputs=inputs)
try:
results.append(work(subtask))
except Exception: # noqa: BLE001 — siblings continue, as in real impl
results.append(None)
return results
mock.run_subtasks.side_effect = _run_subtasks
return mock
@ -194,13 +210,13 @@ def test_handler_creates_one_child_subtask_per_property_id() -> None:
None, mock_orch, task_id,
)
# Assert — one child SubTask per property, inputs record the property_id
assert mock_orch.create_child_subtask.call_count == 3
calls = mock_orch.create_child_subtask.call_args_list
recorded_ids = [c.kwargs["inputs"]["property_id"] for c in calls]
assert recorded_ids == [pid1, pid2, pid3]
# All three calls used the same task_id
assert all(c.args[0] == task_id for c in calls)
# Assert — the batch is fanned out in ONE run_subtasks call, with one input
# per property (each recording its property_id) under the same parent task.
mock_orch.run_subtasks.assert_called_once()
args = mock_orch.run_subtasks.call_args.args
assert args[0] == task_id
inputs_per_subtask = args[1]
assert [i["property_id"] for i in inputs_per_subtask] == [pid1, pid2, pid3]
# ---------------------------------------------------------------------------
@ -786,8 +802,8 @@ def test_empty_own_postcode_broadens_to_nearby_and_predicts() -> None:
def test_per_property_failure_fails_child_subtask_and_siblings_continue() -> None:
"""Two properties: property 1 succeeds, property 2 fails during modelling.
The handler does not raise; property 1's UoW was committed; run_subtask was
called for both properties."""
The handler does not raise; both properties are run in the one batched
run_subtasks pass, and property 1's write is still committed."""
# Arrange
pid1, pid2 = 111, 222
mock_engine = _engine_mock([pid1, pid2], [1001, 1002], [POSTCODE, POSTCODE])
@ -847,8 +863,10 @@ def test_per_property_failure_fails_child_subtask_and_siblings_continue() -> Non
orchestrator=mock_orch,
)
# run_subtask called for both properties; pid1 committed
assert mock_orch.run_subtask.call_count == 2
# Both properties run in one batched run_subtasks pass; pid1 still committed
# (the failed pid2 is isolated and simply never queued for write).
mock_orch.run_subtasks.assert_called_once()
assert len(mock_orch.run_subtasks.call_args.args[1]) == 2
mock_uow.commit.assert_called_once()

View file

@ -195,3 +195,87 @@ def test_run_subtask_failing_work_marks_failed_and_reraises(
assert harness.subtasks.get(subtask.id).status is SubTaskStatus.FAILED
assert harness.tasks.get(task.id).status is TaskStatus.FAILED
def test_run_subtasks_creates_runs_and_completes_a_whole_batch(
harness: Harness,
) -> None:
"""run_subtasks fans the parent into one child per item, runs each, and
leaves every child COMPLETE with the parent cascaded to COMPLETE."""
# arrange — a parent task whose coordinator subtask is done
task, coordinator = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
harness.orchestrator.complete_subtask(coordinator.id)
# act
results = harness.orchestrator.run_subtasks(
task.id,
[{"property_id": 1}, {"property_id": 2}, {"property_id": 3}],
work=lambda st: (st.inputs or {})["property_id"] * 10,
)
# assert — results in order, all children COMPLETE, parent COMPLETE
assert results == [10, 20, 30]
children = [s for s in harness.subtasks.list_by_task(task.id) if s.id != coordinator.id]
assert len(children) == 3
assert all(c.status is SubTaskStatus.COMPLETE for c in children)
assert {(c.inputs or {})["property_id"] for c in children} == {1, 2, 3}
assert harness.tasks.get(task.id).status is TaskStatus.COMPLETE
def test_run_subtasks_isolates_a_failing_item_and_continues(
harness: Harness,
) -> None:
"""A raising item is marked FAILED with its error recorded, its siblings still
complete (result None for the failure), and the parent cascades to FAILED."""
# arrange
task, coordinator = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
harness.orchestrator.complete_subtask(coordinator.id)
def _work(st: SubTask) -> int:
pid = (st.inputs or {})["property_id"]
if pid == 2:
raise RuntimeError("property 2 exploded")
return pid
# act — does NOT raise; failure is isolated
results = harness.orchestrator.run_subtasks(
task.id,
[{"property_id": 1}, {"property_id": 2}, {"property_id": 3}],
work=_work,
)
# assert
assert results == [1, None, 3]
children = {
(c.inputs or {})["property_id"]: c
for c in harness.subtasks.list_by_task(task.id)
if c.id != coordinator.id
}
assert children[1].status is SubTaskStatus.COMPLETE
assert children[3].status is SubTaskStatus.COMPLETE
assert children[2].status is SubTaskStatus.FAILED
assert children[2].outputs == {"error": "property 2 exploded"}
# any FAILED child → parent FAILED
assert harness.tasks.get(task.id).status is TaskStatus.FAILED
def test_cascade_short_circuits_once_task_already_failed(harness: Harness) -> None:
"""Once the Task is FAILED, completing another SubTask leaves it FAILED — the
terminal state is not recomputed away."""
# arrange — two children; fail the first so the task is FAILED
task, coordinator = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
second = harness.orchestrator.create_child_subtask(task.id)
harness.orchestrator.fail_subtask(coordinator.id, RuntimeError("boom"))
assert harness.tasks.get(task.id).status is TaskStatus.FAILED
# act — complete the other child
harness.orchestrator.complete_subtask(second.id)
# assert — task stays FAILED (cascade short-circuited on the terminal state)
assert harness.tasks.get(task.id).status is TaskStatus.FAILED

View file

@ -98,3 +98,68 @@ def test_get_missing_raises(session: Session) -> None:
# act / assert
with pytest.raises(ValueError, match="not found"):
repo.get(uuid4())
def test_create_many_persists_every_subtask(session: Session) -> None:
# arrange
repo = SubTaskPostgresRepository(session=session)
task_id = _persisted_task_id(session)
subtasks = [
SubTask.create(task_id=task_id, inputs={"property_id": pid})
for pid in (1, 2, 3)
]
# act
repo.create_many(subtasks)
# assert — all three persisted, in WAITING, with their inputs
persisted = {
(s.inputs or {})["property_id"]: s for s in repo.list_by_task(task_id)
}
assert set(persisted) == {1, 2, 3}
assert all(s.status is SubTaskStatus.WAITING for s in persisted.values())
def test_create_many_empty_is_a_noop(session: Session) -> None:
# arrange / act / assert — does not raise, writes nothing
SubTaskPostgresRepository(session=session).create_many([])
def test_save_many_persists_each_subtasks_terminal_state(session: Session) -> None:
# arrange — three created subtasks
repo = SubTaskPostgresRepository(session=session)
task_id = _persisted_task_id(session)
subtasks = [
SubTask.create(task_id=task_id, inputs={"property_id": pid})
for pid in (1, 2, 3)
]
repo.create_many(subtasks)
# act — move each to a terminal state in memory, then one bulk save
subtasks[0].complete({"ok": True})
subtasks[1].fail(RuntimeError("nope"))
subtasks[2].complete()
repo.save_many(subtasks)
# assert
persisted = {
(s.inputs or {})["property_id"]: s for s in repo.list_by_task(task_id)
}
assert persisted[1].status is SubTaskStatus.COMPLETE
assert persisted[1].outputs == {"result": {"ok": True}}
assert persisted[2].status is SubTaskStatus.FAILED
assert persisted[2].outputs == {"error": "nope"}
assert persisted[3].status is SubTaskStatus.COMPLETE
def test_save_many_raises_when_a_subtask_does_not_exist(session: Session) -> None:
# arrange — one persisted, one never created
repo = SubTaskPostgresRepository(session=session)
task_id = _persisted_task_id(session)
persisted = SubTask.create(task_id=task_id)
repo.create_many([persisted])
ghost = SubTask.create(task_id=task_id)
# act / assert
with pytest.raises(ValueError, match="not found"):
repo.save_many([persisted, ghost])