mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-30 13:10:47 +00:00
perf(modelling_e2e): batch SubTask bookkeeping to stop per-property writes
Even after batching the data writes, the handler still wrote to the DB per property through the orchestrator's SubTask bookkeeping: create + start + complete each self-committed, and _cascade re-listed every sibling and re-saved the parent on every transition — ~5 writes per property plus an O(N^2) cascade. - TaskOrchestrator.run_subtasks: create all children in one INSERT, run each (failures isolated per child), then persist all terminal states in one bulk save and cascade the parent once. Children go WAITING -> terminal; the transient IN_PROGRESS row is never written. - SubTaskRepository.create_many / save_many (bulk INSERT / bulk fetch + update). - _cascade short-circuits when the Task is already FAILED (terminal) — skips the sibling roll-up entirely. - modelling_e2e handler fans out via run_subtasks instead of per-property create_child_subtask + run_subtask. Per N-property batch the SubTask bookkeeping drops from ~5N writes + an O(N^2) cascade to ~2 writes + 1 cascade. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
de5e9a2362
commit
b1ff711260
7 changed files with 412 additions and 142 deletions
|
|
@ -65,6 +65,7 @@ from domain.modelling.plan import Plan
|
|||
from domain.property.property import Property, PropertyIdentity
|
||||
from domain.property_baseline.calculator_rebaseliner import CalculatorRebaseliner
|
||||
from domain.sap10_calculator.calculator import Sap10Calculator
|
||||
from domain.tasks.subtasks import SubTask
|
||||
from domain.tasks.tasks import Source
|
||||
from harness.console import run_modelling
|
||||
from orchestration.task_orchestrator import TaskOrchestrator
|
||||
|
|
@ -486,142 +487,145 @@ def handler(body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator,
|
|||
# whole batch is flushed in one transaction after the loop.
|
||||
accumulated: list[_PropertyWrite] = []
|
||||
|
||||
for property_id in property_ids:
|
||||
child = orchestrator.create_child_subtask(
|
||||
task_id, inputs={"property_id": property_id}
|
||||
def _work(subtask: SubTask) -> None:
|
||||
inputs = subtask.inputs or {}
|
||||
pid = int(inputs["property_id"])
|
||||
uprn = uprns[pid]
|
||||
postcode = postcodes.get(pid, "")
|
||||
logger.info(f"property={pid} uprn={uprn} postcode={postcode!r}")
|
||||
|
||||
spatial = _spatial_for(geospatial, uprn)
|
||||
restrictions = (
|
||||
spatial.restrictions
|
||||
if spatial is not None
|
||||
else PlanningRestrictions()
|
||||
)
|
||||
coordinates: Optional[Coordinates] = (
|
||||
spatial.coordinates if spatial is not None else None
|
||||
)
|
||||
|
||||
def _work(pid: int = property_id) -> None:
|
||||
uprn = uprns[pid]
|
||||
postcode = postcodes.get(pid, "")
|
||||
logger.info(f"property={pid} uprn={uprn} postcode={postcode!r}")
|
||||
epc: Optional[EpcPropertyData] = epc_client.get_by_uprn(uprn)
|
||||
overrides = overlays_from(overrides_reader.overrides_for(pid))
|
||||
predicted_epc: Optional[EpcPropertyData] = None
|
||||
|
||||
spatial = _spatial_for(geospatial, uprn)
|
||||
restrictions = (
|
||||
spatial.restrictions
|
||||
if spatial is not None
|
||||
else PlanningRestrictions()
|
||||
)
|
||||
coordinates: Optional[Coordinates] = (
|
||||
spatial.coordinates if spatial is not None else None
|
||||
)
|
||||
|
||||
epc: Optional[EpcPropertyData] = epc_client.get_by_uprn(uprn)
|
||||
overrides = overlays_from(overrides_reader.overrides_for(pid))
|
||||
predicted_epc: Optional[EpcPropertyData] = None
|
||||
|
||||
if epc is not None:
|
||||
logger.info(f"property={pid} lodged EPC found")
|
||||
effective_epc = Property(
|
||||
identity=PropertyIdentity(
|
||||
portfolio_id=portfolio_id,
|
||||
postcode=postcode,
|
||||
address="",
|
||||
uprn=uprn,
|
||||
),
|
||||
epc=epc,
|
||||
landlord_overrides=overrides,
|
||||
).effective_epc
|
||||
else:
|
||||
logger.info(
|
||||
f"property={pid} no lodged EPC — attempting prediction"
|
||||
)
|
||||
predicted_epc = _predict_epc(
|
||||
property_id=pid,
|
||||
uprn=uprn,
|
||||
postcode=postcode,
|
||||
if epc is not None:
|
||||
logger.info(f"property={pid} lodged EPC found")
|
||||
effective_epc = Property(
|
||||
identity=PropertyIdentity(
|
||||
portfolio_id=portfolio_id,
|
||||
attributes_reader=prediction_attrs_reader,
|
||||
coordinates=coordinates,
|
||||
cohort_for=_get_cohort,
|
||||
broaden=_broaden,
|
||||
predictor=predictor,
|
||||
)
|
||||
effective_epc = Property(
|
||||
identity=PropertyIdentity(
|
||||
portfolio_id=portfolio_id,
|
||||
postcode=postcode,
|
||||
address="",
|
||||
uprn=uprn,
|
||||
),
|
||||
epc=None,
|
||||
predicted_epc=predicted_epc,
|
||||
landlord_overrides=overrides,
|
||||
).effective_epc
|
||||
postcode=postcode,
|
||||
address="",
|
||||
uprn=uprn,
|
||||
),
|
||||
epc=epc,
|
||||
landlord_overrides=overrides,
|
||||
).effective_epc
|
||||
else:
|
||||
logger.info(
|
||||
f"property={pid} no lodged EPC — attempting prediction"
|
||||
)
|
||||
predicted_epc = _predict_epc(
|
||||
property_id=pid,
|
||||
uprn=uprn,
|
||||
postcode=postcode,
|
||||
portfolio_id=portfolio_id,
|
||||
attributes_reader=prediction_attrs_reader,
|
||||
coordinates=coordinates,
|
||||
cohort_for=_get_cohort,
|
||||
broaden=_broaden,
|
||||
predictor=predictor,
|
||||
)
|
||||
effective_epc = Property(
|
||||
identity=PropertyIdentity(
|
||||
portfolio_id=portfolio_id,
|
||||
postcode=postcode,
|
||||
address="",
|
||||
uprn=uprn,
|
||||
),
|
||||
epc=None,
|
||||
predicted_epc=predicted_epc,
|
||||
landlord_overrides=overrides,
|
||||
).effective_epc
|
||||
|
||||
solar_insights: Optional[dict[str, Any]]
|
||||
solar_was_fetched = False
|
||||
if no_solar:
|
||||
solar_insights = None
|
||||
else:
|
||||
solar_insights = stored_solar.get(uprn)
|
||||
if solar_insights is None:
|
||||
solar_insights = _solar_insights_for(solar_client, spatial)
|
||||
solar_was_fetched = solar_insights is not None
|
||||
solar_insights: Optional[dict[str, Any]]
|
||||
solar_was_fetched = False
|
||||
if no_solar:
|
||||
solar_insights = None
|
||||
else:
|
||||
solar_insights = stored_solar.get(uprn)
|
||||
if solar_insights is None:
|
||||
solar_insights = _solar_insights_for(solar_client, spatial)
|
||||
solar_was_fetched = solar_insights is not None
|
||||
|
||||
plan = run_modelling(
|
||||
effective_epc,
|
||||
planning_restrictions=restrictions,
|
||||
solar_insights=solar_insights,
|
||||
considered_measures=None,
|
||||
products=products,
|
||||
scenario=scenario,
|
||||
print_table=False,
|
||||
plan = run_modelling(
|
||||
effective_epc,
|
||||
planning_restrictions=restrictions,
|
||||
solar_insights=solar_insights,
|
||||
considered_measures=None,
|
||||
products=products,
|
||||
scenario=scenario,
|
||||
print_table=False,
|
||||
)
|
||||
logger.info(
|
||||
f"property={pid} modelling complete "
|
||||
f"measures={len(plan.measures)}"
|
||||
)
|
||||
|
||||
if dry_run:
|
||||
measure_types = (
|
||||
", ".join(m.measure_type for m in plan.measures) or "none"
|
||||
)
|
||||
logger.info(
|
||||
f"property={pid} modelling complete "
|
||||
f"measures={len(plan.measures)}"
|
||||
f"[dry_run] property={pid} "
|
||||
f"measures=[{measure_types}] — skipping DB write"
|
||||
)
|
||||
return
|
||||
|
||||
solar_write: Optional[_SolarWrite] = None
|
||||
if (
|
||||
solar_was_fetched
|
||||
and solar_insights is not None
|
||||
and spatial is not None
|
||||
and spatial.coordinates is not None
|
||||
):
|
||||
solar_write = _SolarWrite(
|
||||
uprn=uprn,
|
||||
longitude=spatial.coordinates.longitude,
|
||||
latitude=spatial.coordinates.latitude,
|
||||
insights=solar_insights,
|
||||
)
|
||||
|
||||
if dry_run:
|
||||
measure_types = (
|
||||
", ".join(m.measure_type for m in plan.measures) or "none"
|
||||
)
|
||||
logger.info(
|
||||
f"[dry_run] property={pid} "
|
||||
f"measures=[{measure_types}] — skipping DB write"
|
||||
)
|
||||
return
|
||||
|
||||
solar_write: Optional[_SolarWrite] = None
|
||||
if (
|
||||
solar_was_fetched
|
||||
and solar_insights is not None
|
||||
and spatial is not None
|
||||
and spatial.coordinates is not None
|
||||
):
|
||||
solar_write = _SolarWrite(
|
||||
uprn=uprn,
|
||||
longitude=spatial.coordinates.longitude,
|
||||
latitude=spatial.coordinates.latitude,
|
||||
insights=solar_insights,
|
||||
)
|
||||
|
||||
# Queue this Property's writes rather than committing now — the
|
||||
# whole batch is persisted in one Unit of Work after the loop
|
||||
# (see _flush_writes). The EPC is saved in its lodged or predicted
|
||||
# slot (ADR-0031) at flush time depending on which is set here.
|
||||
accumulated.append(
|
||||
_PropertyWrite(
|
||||
property_id=pid,
|
||||
uprn=uprn,
|
||||
portfolio_id=portfolio_id,
|
||||
scenario_id=scenario_id,
|
||||
is_default=scenario.is_default,
|
||||
lodged_epc=epc,
|
||||
predicted_epc=predicted_epc,
|
||||
spatial=spatial,
|
||||
solar=solar_write,
|
||||
plan=plan,
|
||||
has_recommendations=bool(plan.measures),
|
||||
)
|
||||
# Queue this Property's writes rather than committing now — the
|
||||
# whole batch is persisted in one Unit of Work after the loop
|
||||
# (see _flush_writes). The EPC is saved in its lodged or predicted
|
||||
# slot (ADR-0031) at flush time depending on which is set here.
|
||||
accumulated.append(
|
||||
_PropertyWrite(
|
||||
property_id=pid,
|
||||
uprn=uprn,
|
||||
portfolio_id=portfolio_id,
|
||||
scenario_id=scenario_id,
|
||||
is_default=scenario.is_default,
|
||||
lodged_epc=epc,
|
||||
predicted_epc=predicted_epc,
|
||||
spatial=spatial,
|
||||
solar=solar_write,
|
||||
plan=plan,
|
||||
has_recommendations=bool(plan.measures),
|
||||
)
|
||||
logger.info(f"property={pid} queued for write")
|
||||
)
|
||||
logger.info(f"property={pid} queued for write")
|
||||
|
||||
try:
|
||||
orchestrator.run_subtask(child.id, work=_work)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
# Fan the batch out into one child SubTask per property and run them in
|
||||
# a single batched pass: create all children, model each (failures
|
||||
# isolated per child), then persist all their statuses in two writes +
|
||||
# one cascade — not ~5 writes and a full parent re-roll-up per property
|
||||
# (see TaskOrchestrator.run_subtasks).
|
||||
orchestrator.run_subtasks(
|
||||
task_id,
|
||||
[{"property_id": pid} for pid in property_ids],
|
||||
work=_work,
|
||||
)
|
||||
|
||||
# Persist the whole batch in one transaction, then re-establish every
|
||||
# written Property's Baseline (the orchestrator batches its own UoW). The
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import Any, Callable, Optional
|
|||
from uuid import UUID
|
||||
|
||||
from domain.tasks.subtasks import SubTask
|
||||
from domain.tasks.tasks import Source, Task
|
||||
from domain.tasks.tasks import Source, Task, TaskStatus
|
||||
from repositories.tasks.subtask_repository import SubTaskRepository
|
||||
from repositories.tasks.task_repository import TaskRepository
|
||||
from utilities.private import private
|
||||
|
|
@ -98,9 +98,59 @@ class TaskOrchestrator:
|
|||
self.complete_subtask(subtask_id, result)
|
||||
return result
|
||||
|
||||
def run_subtasks(
|
||||
self,
|
||||
parent_task_id: UUID,
|
||||
inputs_per_subtask: list[dict[str, Any]],
|
||||
work: Callable[[SubTask], Any],
|
||||
cloud_logs_url: Optional[str] = None,
|
||||
) -> list[Any]:
|
||||
"""Fan a parent Task out into one child SubTask per item, run ``work`` for
|
||||
each (failures isolated per child — a raising item is marked failed and
|
||||
its siblings still run), and persist the whole batch in **two** writes
|
||||
plus **one** cascade.
|
||||
|
||||
This is the batched form of ``run_subtask``: instead of ~5 writes and a
|
||||
full parent re-roll-up *per child* (``create`` + ``start`` + ``complete``
|
||||
each cascading — an O(N²) cost on the parent's children), it does one bulk
|
||||
``create_many``, runs every item recording its terminal state in memory,
|
||||
then one bulk ``save_many`` and a single ``_cascade``. Children move
|
||||
straight from WAITING to their terminal state — the transient IN_PROGRESS
|
||||
row is never written, since for a fast batch it only adds DB churn.
|
||||
|
||||
Returns one entry per item in order: the work's result, or ``None`` for an
|
||||
item whose work raised.
|
||||
"""
|
||||
subtasks = [
|
||||
SubTask.create(task_id=parent_task_id, inputs=inputs)
|
||||
for inputs in inputs_per_subtask
|
||||
]
|
||||
self._subtasks.create_many(subtasks)
|
||||
|
||||
results: list[Any] = []
|
||||
for subtask in subtasks:
|
||||
subtask.start(cloud_logs_url)
|
||||
try:
|
||||
result = work(subtask)
|
||||
except Exception as e: # noqa: BLE001 — isolate per child; siblings continue
|
||||
subtask.fail(e)
|
||||
results.append(None)
|
||||
else:
|
||||
subtask.complete(result)
|
||||
results.append(result)
|
||||
|
||||
self._subtasks.save_many(subtasks)
|
||||
self._cascade(parent_task_id)
|
||||
return results
|
||||
|
||||
@private
|
||||
def _cascade(self, task_id: UUID) -> None:
|
||||
statuses = [s.status for s in self._subtasks.list_by_task(task_id)]
|
||||
task = self._tasks.get(task_id)
|
||||
# FAILED is terminal: once any SubTask has failed the Task is failed and
|
||||
# stays failed, so skip the (potentially large) sibling roll-up entirely —
|
||||
# no need to list and re-check the SubTasks.
|
||||
if task.status is TaskStatus.FAILED:
|
||||
return
|
||||
statuses = [s.status for s in self._subtasks.list_by_task(task_id)]
|
||||
task.recalculate_from_subtasks(statuses)
|
||||
self._tasks.save(task)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from datetime import datetime, timezone
|
|||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Session, select
|
||||
from sqlmodel import Session, col, select
|
||||
|
||||
from domain.tasks.subtasks import SubTask, SubTaskStatus
|
||||
from infrastructure.postgres.subtask_table import SubTaskRow
|
||||
|
|
@ -22,6 +22,12 @@ class SubTaskPostgresRepository(SubTaskRepository):
|
|||
self._session.refresh(row)
|
||||
return self._to_domain(row)
|
||||
|
||||
def create_many(self, subtasks: list[SubTask]) -> None:
|
||||
if not subtasks:
|
||||
return
|
||||
self._session.add_all([self._to_row(s) for s in subtasks])
|
||||
self._session.commit()
|
||||
|
||||
def get(self, subtask_id: UUID) -> SubTask:
|
||||
row = self._session.get(SubTaskRow, subtask_id)
|
||||
if row is None:
|
||||
|
|
@ -46,6 +52,36 @@ class SubTaskPostgresRepository(SubTaskRepository):
|
|||
self._session.add(row)
|
||||
self._session.commit()
|
||||
|
||||
def save_many(self, subtasks: list[SubTask]) -> None:
|
||||
if not subtasks:
|
||||
return
|
||||
by_id = {s.id: s for s in subtasks}
|
||||
rows = self._session.exec(
|
||||
select(SubTaskRow).where(col(SubTaskRow.id).in_(list(by_id)))
|
||||
).all()
|
||||
found = {row.id for row in rows}
|
||||
missing = set(by_id) - found
|
||||
if missing:
|
||||
raise ValueError(f"SubTask(s) not found: {sorted(str(m) for m in missing)}")
|
||||
now = datetime.now(timezone.utc)
|
||||
for row in rows:
|
||||
subtask = by_id[row.id]
|
||||
row.status = subtask.status.value
|
||||
row.job_started = subtask.job_started
|
||||
row.job_completed = subtask.job_completed
|
||||
row.inputs = (
|
||||
json.dumps(subtask.inputs) if subtask.inputs is not None else None
|
||||
)
|
||||
row.outputs = (
|
||||
json.dumps(subtask.outputs)
|
||||
if subtask.outputs is not None
|
||||
else None
|
||||
)
|
||||
row.cloud_logs_url = subtask.cloud_logs_url
|
||||
row.updated_at = now
|
||||
self._session.add(row)
|
||||
self._session.commit()
|
||||
|
||||
def list_by_task(self, task_id: UUID) -> list[SubTask]:
|
||||
rows = self._session.exec(
|
||||
select(SubTaskRow).where(SubTaskRow.task_id == task_id)
|
||||
|
|
|
|||
|
|
@ -8,11 +8,24 @@ class SubTaskRepository(ABC):
|
|||
@abstractmethod
|
||||
def create(self, subtask: SubTask) -> SubTask: ...
|
||||
|
||||
@abstractmethod
|
||||
def create_many(self, subtasks: list[SubTask]) -> None:
|
||||
"""Persist many SubTasks in one round-trip (one INSERT + one commit) —
|
||||
the batch form of ``create`` for callers that fan a Task out into many
|
||||
children at once."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get(self, subtask_id: UUID) -> SubTask: ...
|
||||
|
||||
@abstractmethod
|
||||
def save(self, subtask: SubTask) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def save_many(self, subtasks: list[SubTask]) -> None:
|
||||
"""Persist updates to many SubTasks in one round-trip (one bulk fetch +
|
||||
one commit) — the batch form of ``save``."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def list_by_task(self, task_id: UUID) -> list[SubTask]: ...
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from pydantic import ValidationError
|
|||
from applications.modelling_e2e.modelling_e2e_trigger_body import (
|
||||
ModellingE2ETriggerBody,
|
||||
)
|
||||
from domain.tasks.subtasks import SubTask
|
||||
|
||||
PROPERTY_ID = 12345
|
||||
UPRN = 987654321
|
||||
|
|
@ -41,11 +42,26 @@ _BODY = {
|
|||
|
||||
|
||||
def _mock_orchestrator() -> MagicMock:
|
||||
"""Mock TaskOrchestrator whose run_subtasks creates one SubTask per input and
|
||||
runs work for each, isolating per-item failures (mirroring the real method)."""
|
||||
mock = MagicMock()
|
||||
mock.run_subtask.side_effect = lambda subtask_id, work, **kwargs: work()
|
||||
child = MagicMock()
|
||||
child.id = uuid4()
|
||||
mock.create_child_subtask.return_value = child
|
||||
|
||||
def _run_subtasks(
|
||||
parent_task_id: UUID,
|
||||
inputs_per_subtask: list[dict[str, Any]],
|
||||
work: Any,
|
||||
**kwargs: Any,
|
||||
) -> list[Any]:
|
||||
results: list[Any] = []
|
||||
for inputs in inputs_per_subtask:
|
||||
subtask = SubTask.create(task_id=parent_task_id, inputs=inputs)
|
||||
try:
|
||||
results.append(work(subtask))
|
||||
except Exception: # noqa: BLE001 — siblings continue, as in real impl
|
||||
results.append(None)
|
||||
return results
|
||||
|
||||
mock.run_subtasks.side_effect = _run_subtasks
|
||||
return mock
|
||||
|
||||
|
||||
|
|
@ -194,13 +210,13 @@ def test_handler_creates_one_child_subtask_per_property_id() -> None:
|
|||
None, mock_orch, task_id,
|
||||
)
|
||||
|
||||
# Assert — one child SubTask per property, inputs record the property_id
|
||||
assert mock_orch.create_child_subtask.call_count == 3
|
||||
calls = mock_orch.create_child_subtask.call_args_list
|
||||
recorded_ids = [c.kwargs["inputs"]["property_id"] for c in calls]
|
||||
assert recorded_ids == [pid1, pid2, pid3]
|
||||
# All three calls used the same task_id
|
||||
assert all(c.args[0] == task_id for c in calls)
|
||||
# Assert — the batch is fanned out in ONE run_subtasks call, with one input
|
||||
# per property (each recording its property_id) under the same parent task.
|
||||
mock_orch.run_subtasks.assert_called_once()
|
||||
args = mock_orch.run_subtasks.call_args.args
|
||||
assert args[0] == task_id
|
||||
inputs_per_subtask = args[1]
|
||||
assert [i["property_id"] for i in inputs_per_subtask] == [pid1, pid2, pid3]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -786,8 +802,8 @@ def test_empty_own_postcode_broadens_to_nearby_and_predicts() -> None:
|
|||
|
||||
def test_per_property_failure_fails_child_subtask_and_siblings_continue() -> None:
|
||||
"""Two properties: property 1 succeeds, property 2 fails during modelling.
|
||||
The handler does not raise; property 1's UoW was committed; run_subtask was
|
||||
called for both properties."""
|
||||
The handler does not raise; both properties are run in the one batched
|
||||
run_subtasks pass, and property 1's write is still committed."""
|
||||
# Arrange
|
||||
pid1, pid2 = 111, 222
|
||||
mock_engine = _engine_mock([pid1, pid2], [1001, 1002], [POSTCODE, POSTCODE])
|
||||
|
|
@ -847,8 +863,10 @@ def test_per_property_failure_fails_child_subtask_and_siblings_continue() -> Non
|
|||
orchestrator=mock_orch,
|
||||
)
|
||||
|
||||
# run_subtask called for both properties; pid1 committed
|
||||
assert mock_orch.run_subtask.call_count == 2
|
||||
# Both properties run in one batched run_subtasks pass; pid1 still committed
|
||||
# (the failed pid2 is isolated and simply never queued for write).
|
||||
mock_orch.run_subtasks.assert_called_once()
|
||||
assert len(mock_orch.run_subtasks.call_args.args[1]) == 2
|
||||
mock_uow.commit.assert_called_once()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -195,3 +195,87 @@ def test_run_subtask_failing_work_marks_failed_and_reraises(
|
|||
|
||||
assert harness.subtasks.get(subtask.id).status is SubTaskStatus.FAILED
|
||||
assert harness.tasks.get(task.id).status is TaskStatus.FAILED
|
||||
|
||||
|
||||
def test_run_subtasks_creates_runs_and_completes_a_whole_batch(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
"""run_subtasks fans the parent into one child per item, runs each, and
|
||||
leaves every child COMPLETE with the parent cascaded to COMPLETE."""
|
||||
# arrange — a parent task whose coordinator subtask is done
|
||||
task, coordinator = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
harness.orchestrator.complete_subtask(coordinator.id)
|
||||
|
||||
# act
|
||||
results = harness.orchestrator.run_subtasks(
|
||||
task.id,
|
||||
[{"property_id": 1}, {"property_id": 2}, {"property_id": 3}],
|
||||
work=lambda st: (st.inputs or {})["property_id"] * 10,
|
||||
)
|
||||
|
||||
# assert — results in order, all children COMPLETE, parent COMPLETE
|
||||
assert results == [10, 20, 30]
|
||||
children = [s for s in harness.subtasks.list_by_task(task.id) if s.id != coordinator.id]
|
||||
assert len(children) == 3
|
||||
assert all(c.status is SubTaskStatus.COMPLETE for c in children)
|
||||
assert {(c.inputs or {})["property_id"] for c in children} == {1, 2, 3}
|
||||
assert harness.tasks.get(task.id).status is TaskStatus.COMPLETE
|
||||
|
||||
|
||||
def test_run_subtasks_isolates_a_failing_item_and_continues(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
"""A raising item is marked FAILED with its error recorded, its siblings still
|
||||
complete (result None for the failure), and the parent cascades to FAILED."""
|
||||
# arrange
|
||||
task, coordinator = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
harness.orchestrator.complete_subtask(coordinator.id)
|
||||
|
||||
def _work(st: SubTask) -> int:
|
||||
pid = (st.inputs or {})["property_id"]
|
||||
if pid == 2:
|
||||
raise RuntimeError("property 2 exploded")
|
||||
return pid
|
||||
|
||||
# act — does NOT raise; failure is isolated
|
||||
results = harness.orchestrator.run_subtasks(
|
||||
task.id,
|
||||
[{"property_id": 1}, {"property_id": 2}, {"property_id": 3}],
|
||||
work=_work,
|
||||
)
|
||||
|
||||
# assert
|
||||
assert results == [1, None, 3]
|
||||
children = {
|
||||
(c.inputs or {})["property_id"]: c
|
||||
for c in harness.subtasks.list_by_task(task.id)
|
||||
if c.id != coordinator.id
|
||||
}
|
||||
assert children[1].status is SubTaskStatus.COMPLETE
|
||||
assert children[3].status is SubTaskStatus.COMPLETE
|
||||
assert children[2].status is SubTaskStatus.FAILED
|
||||
assert children[2].outputs == {"error": "property 2 exploded"}
|
||||
# any FAILED child → parent FAILED
|
||||
assert harness.tasks.get(task.id).status is TaskStatus.FAILED
|
||||
|
||||
|
||||
def test_cascade_short_circuits_once_task_already_failed(harness: Harness) -> None:
|
||||
"""Once the Task is FAILED, completing another SubTask leaves it FAILED — the
|
||||
terminal state is not recomputed away."""
|
||||
# arrange — two children; fail the first so the task is FAILED
|
||||
task, coordinator = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
second = harness.orchestrator.create_child_subtask(task.id)
|
||||
harness.orchestrator.fail_subtask(coordinator.id, RuntimeError("boom"))
|
||||
assert harness.tasks.get(task.id).status is TaskStatus.FAILED
|
||||
|
||||
# act — complete the other child
|
||||
harness.orchestrator.complete_subtask(second.id)
|
||||
|
||||
# assert — task stays FAILED (cascade short-circuited on the terminal state)
|
||||
assert harness.tasks.get(task.id).status is TaskStatus.FAILED
|
||||
|
|
|
|||
|
|
@ -98,3 +98,68 @@ def test_get_missing_raises(session: Session) -> None:
|
|||
# act / assert
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
repo.get(uuid4())
|
||||
|
||||
|
||||
def test_create_many_persists_every_subtask(session: Session) -> None:
|
||||
# arrange
|
||||
repo = SubTaskPostgresRepository(session=session)
|
||||
task_id = _persisted_task_id(session)
|
||||
subtasks = [
|
||||
SubTask.create(task_id=task_id, inputs={"property_id": pid})
|
||||
for pid in (1, 2, 3)
|
||||
]
|
||||
|
||||
# act
|
||||
repo.create_many(subtasks)
|
||||
|
||||
# assert — all three persisted, in WAITING, with their inputs
|
||||
persisted = {
|
||||
(s.inputs or {})["property_id"]: s for s in repo.list_by_task(task_id)
|
||||
}
|
||||
assert set(persisted) == {1, 2, 3}
|
||||
assert all(s.status is SubTaskStatus.WAITING for s in persisted.values())
|
||||
|
||||
|
||||
def test_create_many_empty_is_a_noop(session: Session) -> None:
|
||||
# arrange / act / assert — does not raise, writes nothing
|
||||
SubTaskPostgresRepository(session=session).create_many([])
|
||||
|
||||
|
||||
def test_save_many_persists_each_subtasks_terminal_state(session: Session) -> None:
|
||||
# arrange — three created subtasks
|
||||
repo = SubTaskPostgresRepository(session=session)
|
||||
task_id = _persisted_task_id(session)
|
||||
subtasks = [
|
||||
SubTask.create(task_id=task_id, inputs={"property_id": pid})
|
||||
for pid in (1, 2, 3)
|
||||
]
|
||||
repo.create_many(subtasks)
|
||||
|
||||
# act — move each to a terminal state in memory, then one bulk save
|
||||
subtasks[0].complete({"ok": True})
|
||||
subtasks[1].fail(RuntimeError("nope"))
|
||||
subtasks[2].complete()
|
||||
repo.save_many(subtasks)
|
||||
|
||||
# assert
|
||||
persisted = {
|
||||
(s.inputs or {})["property_id"]: s for s in repo.list_by_task(task_id)
|
||||
}
|
||||
assert persisted[1].status is SubTaskStatus.COMPLETE
|
||||
assert persisted[1].outputs == {"result": {"ok": True}}
|
||||
assert persisted[2].status is SubTaskStatus.FAILED
|
||||
assert persisted[2].outputs == {"error": "nope"}
|
||||
assert persisted[3].status is SubTaskStatus.COMPLETE
|
||||
|
||||
|
||||
def test_save_many_raises_when_a_subtask_does_not_exist(session: Session) -> None:
|
||||
# arrange — one persisted, one never created
|
||||
repo = SubTaskPostgresRepository(session=session)
|
||||
task_id = _persisted_task_id(session)
|
||||
persisted = SubTask.create(task_id=task_id)
|
||||
repo.create_many([persisted])
|
||||
ghost = SubTask.create(task_id=task_id)
|
||||
|
||||
# act / assert
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
repo.save_many([persisted, ghost])
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue