From b1ff71126003cd0792abfa050308ef1c512224c9 Mon Sep 17 00:00:00 2001 From: Jun-te Kim Date: Wed, 24 Jun 2026 19:26:42 +0000 Subject: [PATCH] perf(modelling_e2e): batch SubTask bookkeeping to stop per-property writes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Even after batching the data writes, the handler still wrote to the DB per property through the orchestrator's SubTask bookkeeping: create + start + complete each self-committed, and _cascade re-listed every sibling and re-saved the parent on every transition — ~5 writes per property plus an O(N^2) cascade. - TaskOrchestrator.run_subtasks: create all children in one INSERT, run each (failures isolated per child), then persist all terminal states in one bulk save and cascade the parent once. Children go WAITING -> terminal; the transient IN_PROGRESS row is never written. - SubTaskRepository.create_many / save_many (bulk INSERT / bulk fetch + update). - _cascade short-circuits when the Task is already FAILED (terminal) — skips the sibling roll-up entirely. - modelling_e2e handler fans out via run_subtasks instead of per-property create_child_subtask + run_subtask. Per N-property batch the SubTask bookkeeping drops from ~5N writes + an O(N^2) cascade to ~2 writes + 1 cascade. Co-Authored-By: Claude Opus 4.8 (1M context) --- applications/modelling_e2e/handler.py | 252 +++++++++--------- orchestration/task_orchestrator.py | 54 +++- .../tasks/subtask_postgres_repository.py | 38 ++- repositories/tasks/subtask_repository.py | 13 + .../modelling_e2e/test_handler.py | 48 ++-- tests/orchestration/test_task_orchestrator.py | 84 ++++++ .../test_subtask_postgres_repository.py | 65 +++++ 7 files changed, 412 insertions(+), 142 deletions(-) diff --git a/applications/modelling_e2e/handler.py b/applications/modelling_e2e/handler.py index 9454e6d4..cf1d1564 100644 --- a/applications/modelling_e2e/handler.py +++ b/applications/modelling_e2e/handler.py @@ -65,6 +65,7 @@ from domain.modelling.plan import Plan from domain.property.property import Property, PropertyIdentity from domain.property_baseline.calculator_rebaseliner import CalculatorRebaseliner from domain.sap10_calculator.calculator import Sap10Calculator +from domain.tasks.subtasks import SubTask from domain.tasks.tasks import Source from harness.console import run_modelling from orchestration.task_orchestrator import TaskOrchestrator @@ -486,142 +487,145 @@ def handler(body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator, # whole batch is flushed in one transaction after the loop. accumulated: list[_PropertyWrite] = [] - for property_id in property_ids: - child = orchestrator.create_child_subtask( - task_id, inputs={"property_id": property_id} + def _work(subtask: SubTask) -> None: + inputs = subtask.inputs or {} + pid = int(inputs["property_id"]) + uprn = uprns[pid] + postcode = postcodes.get(pid, "") + logger.info(f"property={pid} uprn={uprn} postcode={postcode!r}") + + spatial = _spatial_for(geospatial, uprn) + restrictions = ( + spatial.restrictions + if spatial is not None + else PlanningRestrictions() + ) + coordinates: Optional[Coordinates] = ( + spatial.coordinates if spatial is not None else None ) - def _work(pid: int = property_id) -> None: - uprn = uprns[pid] - postcode = postcodes.get(pid, "") - logger.info(f"property={pid} uprn={uprn} postcode={postcode!r}") + epc: Optional[EpcPropertyData] = epc_client.get_by_uprn(uprn) + overrides = overlays_from(overrides_reader.overrides_for(pid)) + predicted_epc: Optional[EpcPropertyData] = None - spatial = _spatial_for(geospatial, uprn) - restrictions = ( - spatial.restrictions - if spatial is not None - else PlanningRestrictions() - ) - coordinates: Optional[Coordinates] = ( - spatial.coordinates if spatial is not None else None - ) - - epc: Optional[EpcPropertyData] = epc_client.get_by_uprn(uprn) - overrides = overlays_from(overrides_reader.overrides_for(pid)) - predicted_epc: Optional[EpcPropertyData] = None - - if epc is not None: - logger.info(f"property={pid} lodged EPC found") - effective_epc = Property( - identity=PropertyIdentity( - portfolio_id=portfolio_id, - postcode=postcode, - address="", - uprn=uprn, - ), - epc=epc, - landlord_overrides=overrides, - ).effective_epc - else: - logger.info( - f"property={pid} no lodged EPC — attempting prediction" - ) - predicted_epc = _predict_epc( - property_id=pid, - uprn=uprn, - postcode=postcode, + if epc is not None: + logger.info(f"property={pid} lodged EPC found") + effective_epc = Property( + identity=PropertyIdentity( portfolio_id=portfolio_id, - attributes_reader=prediction_attrs_reader, - coordinates=coordinates, - cohort_for=_get_cohort, - broaden=_broaden, - predictor=predictor, - ) - effective_epc = Property( - identity=PropertyIdentity( - portfolio_id=portfolio_id, - postcode=postcode, - address="", - uprn=uprn, - ), - epc=None, - predicted_epc=predicted_epc, - landlord_overrides=overrides, - ).effective_epc + postcode=postcode, + address="", + uprn=uprn, + ), + epc=epc, + landlord_overrides=overrides, + ).effective_epc + else: + logger.info( + f"property={pid} no lodged EPC — attempting prediction" + ) + predicted_epc = _predict_epc( + property_id=pid, + uprn=uprn, + postcode=postcode, + portfolio_id=portfolio_id, + attributes_reader=prediction_attrs_reader, + coordinates=coordinates, + cohort_for=_get_cohort, + broaden=_broaden, + predictor=predictor, + ) + effective_epc = Property( + identity=PropertyIdentity( + portfolio_id=portfolio_id, + postcode=postcode, + address="", + uprn=uprn, + ), + epc=None, + predicted_epc=predicted_epc, + landlord_overrides=overrides, + ).effective_epc - solar_insights: Optional[dict[str, Any]] - solar_was_fetched = False - if no_solar: - solar_insights = None - else: - solar_insights = stored_solar.get(uprn) - if solar_insights is None: - solar_insights = _solar_insights_for(solar_client, spatial) - solar_was_fetched = solar_insights is not None + solar_insights: Optional[dict[str, Any]] + solar_was_fetched = False + if no_solar: + solar_insights = None + else: + solar_insights = stored_solar.get(uprn) + if solar_insights is None: + solar_insights = _solar_insights_for(solar_client, spatial) + solar_was_fetched = solar_insights is not None - plan = run_modelling( - effective_epc, - planning_restrictions=restrictions, - solar_insights=solar_insights, - considered_measures=None, - products=products, - scenario=scenario, - print_table=False, + plan = run_modelling( + effective_epc, + planning_restrictions=restrictions, + solar_insights=solar_insights, + considered_measures=None, + products=products, + scenario=scenario, + print_table=False, + ) + logger.info( + f"property={pid} modelling complete " + f"measures={len(plan.measures)}" + ) + + if dry_run: + measure_types = ( + ", ".join(m.measure_type for m in plan.measures) or "none" ) logger.info( - f"property={pid} modelling complete " - f"measures={len(plan.measures)}" + f"[dry_run] property={pid} " + f"measures=[{measure_types}] — skipping DB write" + ) + return + + solar_write: Optional[_SolarWrite] = None + if ( + solar_was_fetched + and solar_insights is not None + and spatial is not None + and spatial.coordinates is not None + ): + solar_write = _SolarWrite( + uprn=uprn, + longitude=spatial.coordinates.longitude, + latitude=spatial.coordinates.latitude, + insights=solar_insights, ) - if dry_run: - measure_types = ( - ", ".join(m.measure_type for m in plan.measures) or "none" - ) - logger.info( - f"[dry_run] property={pid} " - f"measures=[{measure_types}] — skipping DB write" - ) - return - - solar_write: Optional[_SolarWrite] = None - if ( - solar_was_fetched - and solar_insights is not None - and spatial is not None - and spatial.coordinates is not None - ): - solar_write = _SolarWrite( - uprn=uprn, - longitude=spatial.coordinates.longitude, - latitude=spatial.coordinates.latitude, - insights=solar_insights, - ) - - # Queue this Property's writes rather than committing now — the - # whole batch is persisted in one Unit of Work after the loop - # (see _flush_writes). The EPC is saved in its lodged or predicted - # slot (ADR-0031) at flush time depending on which is set here. - accumulated.append( - _PropertyWrite( - property_id=pid, - uprn=uprn, - portfolio_id=portfolio_id, - scenario_id=scenario_id, - is_default=scenario.is_default, - lodged_epc=epc, - predicted_epc=predicted_epc, - spatial=spatial, - solar=solar_write, - plan=plan, - has_recommendations=bool(plan.measures), - ) + # Queue this Property's writes rather than committing now — the + # whole batch is persisted in one Unit of Work after the loop + # (see _flush_writes). The EPC is saved in its lodged or predicted + # slot (ADR-0031) at flush time depending on which is set here. + accumulated.append( + _PropertyWrite( + property_id=pid, + uprn=uprn, + portfolio_id=portfolio_id, + scenario_id=scenario_id, + is_default=scenario.is_default, + lodged_epc=epc, + predicted_epc=predicted_epc, + spatial=spatial, + solar=solar_write, + plan=plan, + has_recommendations=bool(plan.measures), ) - logger.info(f"property={pid} queued for write") + ) + logger.info(f"property={pid} queued for write") - try: - orchestrator.run_subtask(child.id, work=_work) - except Exception: # noqa: BLE001 - pass + # Fan the batch out into one child SubTask per property and run them in + # a single batched pass: create all children, model each (failures + # isolated per child), then persist all their statuses in two writes + + # one cascade — not ~5 writes and a full parent re-roll-up per property + # (see TaskOrchestrator.run_subtasks). + orchestrator.run_subtasks( + task_id, + [{"property_id": pid} for pid in property_ids], + work=_work, + ) # Persist the whole batch in one transaction, then re-establish every # written Property's Baseline (the orchestrator batches its own UoW). The diff --git a/orchestration/task_orchestrator.py b/orchestration/task_orchestrator.py index ebb71a32..15915d9f 100644 --- a/orchestration/task_orchestrator.py +++ b/orchestration/task_orchestrator.py @@ -2,7 +2,7 @@ from typing import Any, Callable, Optional from uuid import UUID from domain.tasks.subtasks import SubTask -from domain.tasks.tasks import Source, Task +from domain.tasks.tasks import Source, Task, TaskStatus from repositories.tasks.subtask_repository import SubTaskRepository from repositories.tasks.task_repository import TaskRepository from utilities.private import private @@ -98,9 +98,59 @@ class TaskOrchestrator: self.complete_subtask(subtask_id, result) return result + def run_subtasks( + self, + parent_task_id: UUID, + inputs_per_subtask: list[dict[str, Any]], + work: Callable[[SubTask], Any], + cloud_logs_url: Optional[str] = None, + ) -> list[Any]: + """Fan a parent Task out into one child SubTask per item, run ``work`` for + each (failures isolated per child — a raising item is marked failed and + its siblings still run), and persist the whole batch in **two** writes + plus **one** cascade. + + This is the batched form of ``run_subtask``: instead of ~5 writes and a + full parent re-roll-up *per child* (``create`` + ``start`` + ``complete`` + each cascading — an O(N²) cost on the parent's children), it does one bulk + ``create_many``, runs every item recording its terminal state in memory, + then one bulk ``save_many`` and a single ``_cascade``. Children move + straight from WAITING to their terminal state — the transient IN_PROGRESS + row is never written, since for a fast batch it only adds DB churn. + + Returns one entry per item in order: the work's result, or ``None`` for an + item whose work raised. + """ + subtasks = [ + SubTask.create(task_id=parent_task_id, inputs=inputs) + for inputs in inputs_per_subtask + ] + self._subtasks.create_many(subtasks) + + results: list[Any] = [] + for subtask in subtasks: + subtask.start(cloud_logs_url) + try: + result = work(subtask) + except Exception as e: # noqa: BLE001 — isolate per child; siblings continue + subtask.fail(e) + results.append(None) + else: + subtask.complete(result) + results.append(result) + + self._subtasks.save_many(subtasks) + self._cascade(parent_task_id) + return results + @private def _cascade(self, task_id: UUID) -> None: - statuses = [s.status for s in self._subtasks.list_by_task(task_id)] task = self._tasks.get(task_id) + # FAILED is terminal: once any SubTask has failed the Task is failed and + # stays failed, so skip the (potentially large) sibling roll-up entirely — + # no need to list and re-check the SubTasks. + if task.status is TaskStatus.FAILED: + return + statuses = [s.status for s in self._subtasks.list_by_task(task_id)] task.recalculate_from_subtasks(statuses) self._tasks.save(task) diff --git a/repositories/tasks/subtask_postgres_repository.py b/repositories/tasks/subtask_postgres_repository.py index affc280e..06929c7d 100644 --- a/repositories/tasks/subtask_postgres_repository.py +++ b/repositories/tasks/subtask_postgres_repository.py @@ -3,7 +3,7 @@ from datetime import datetime, timezone from typing import Any, Optional from uuid import UUID -from sqlmodel import Session, select +from sqlmodel import Session, col, select from domain.tasks.subtasks import SubTask, SubTaskStatus from infrastructure.postgres.subtask_table import SubTaskRow @@ -22,6 +22,12 @@ class SubTaskPostgresRepository(SubTaskRepository): self._session.refresh(row) return self._to_domain(row) + def create_many(self, subtasks: list[SubTask]) -> None: + if not subtasks: + return + self._session.add_all([self._to_row(s) for s in subtasks]) + self._session.commit() + def get(self, subtask_id: UUID) -> SubTask: row = self._session.get(SubTaskRow, subtask_id) if row is None: @@ -46,6 +52,36 @@ class SubTaskPostgresRepository(SubTaskRepository): self._session.add(row) self._session.commit() + def save_many(self, subtasks: list[SubTask]) -> None: + if not subtasks: + return + by_id = {s.id: s for s in subtasks} + rows = self._session.exec( + select(SubTaskRow).where(col(SubTaskRow.id).in_(list(by_id))) + ).all() + found = {row.id for row in rows} + missing = set(by_id) - found + if missing: + raise ValueError(f"SubTask(s) not found: {sorted(str(m) for m in missing)}") + now = datetime.now(timezone.utc) + for row in rows: + subtask = by_id[row.id] + row.status = subtask.status.value + row.job_started = subtask.job_started + row.job_completed = subtask.job_completed + row.inputs = ( + json.dumps(subtask.inputs) if subtask.inputs is not None else None + ) + row.outputs = ( + json.dumps(subtask.outputs) + if subtask.outputs is not None + else None + ) + row.cloud_logs_url = subtask.cloud_logs_url + row.updated_at = now + self._session.add(row) + self._session.commit() + def list_by_task(self, task_id: UUID) -> list[SubTask]: rows = self._session.exec( select(SubTaskRow).where(SubTaskRow.task_id == task_id) diff --git a/repositories/tasks/subtask_repository.py b/repositories/tasks/subtask_repository.py index adb36f99..e1e3b5c6 100644 --- a/repositories/tasks/subtask_repository.py +++ b/repositories/tasks/subtask_repository.py @@ -8,11 +8,24 @@ class SubTaskRepository(ABC): @abstractmethod def create(self, subtask: SubTask) -> SubTask: ... + @abstractmethod + def create_many(self, subtasks: list[SubTask]) -> None: + """Persist many SubTasks in one round-trip (one INSERT + one commit) — + the batch form of ``create`` for callers that fan a Task out into many + children at once.""" + ... + @abstractmethod def get(self, subtask_id: UUID) -> SubTask: ... @abstractmethod def save(self, subtask: SubTask) -> None: ... + @abstractmethod + def save_many(self, subtasks: list[SubTask]) -> None: + """Persist updates to many SubTasks in one round-trip (one bulk fetch + + one commit) — the batch form of ``save``.""" + ... + @abstractmethod def list_by_task(self, task_id: UUID) -> list[SubTask]: ... diff --git a/tests/applications/modelling_e2e/test_handler.py b/tests/applications/modelling_e2e/test_handler.py index d55fd272..860e5313 100644 --- a/tests/applications/modelling_e2e/test_handler.py +++ b/tests/applications/modelling_e2e/test_handler.py @@ -18,6 +18,7 @@ from pydantic import ValidationError from applications.modelling_e2e.modelling_e2e_trigger_body import ( ModellingE2ETriggerBody, ) +from domain.tasks.subtasks import SubTask PROPERTY_ID = 12345 UPRN = 987654321 @@ -41,11 +42,26 @@ _BODY = { def _mock_orchestrator() -> MagicMock: + """Mock TaskOrchestrator whose run_subtasks creates one SubTask per input and + runs work for each, isolating per-item failures (mirroring the real method).""" mock = MagicMock() - mock.run_subtask.side_effect = lambda subtask_id, work, **kwargs: work() - child = MagicMock() - child.id = uuid4() - mock.create_child_subtask.return_value = child + + def _run_subtasks( + parent_task_id: UUID, + inputs_per_subtask: list[dict[str, Any]], + work: Any, + **kwargs: Any, + ) -> list[Any]: + results: list[Any] = [] + for inputs in inputs_per_subtask: + subtask = SubTask.create(task_id=parent_task_id, inputs=inputs) + try: + results.append(work(subtask)) + except Exception: # noqa: BLE001 — siblings continue, as in real impl + results.append(None) + return results + + mock.run_subtasks.side_effect = _run_subtasks return mock @@ -194,13 +210,13 @@ def test_handler_creates_one_child_subtask_per_property_id() -> None: None, mock_orch, task_id, ) - # Assert — one child SubTask per property, inputs record the property_id - assert mock_orch.create_child_subtask.call_count == 3 - calls = mock_orch.create_child_subtask.call_args_list - recorded_ids = [c.kwargs["inputs"]["property_id"] for c in calls] - assert recorded_ids == [pid1, pid2, pid3] - # All three calls used the same task_id - assert all(c.args[0] == task_id for c in calls) + # Assert — the batch is fanned out in ONE run_subtasks call, with one input + # per property (each recording its property_id) under the same parent task. + mock_orch.run_subtasks.assert_called_once() + args = mock_orch.run_subtasks.call_args.args + assert args[0] == task_id + inputs_per_subtask = args[1] + assert [i["property_id"] for i in inputs_per_subtask] == [pid1, pid2, pid3] # --------------------------------------------------------------------------- @@ -786,8 +802,8 @@ def test_empty_own_postcode_broadens_to_nearby_and_predicts() -> None: def test_per_property_failure_fails_child_subtask_and_siblings_continue() -> None: """Two properties: property 1 succeeds, property 2 fails during modelling. - The handler does not raise; property 1's UoW was committed; run_subtask was - called for both properties.""" + The handler does not raise; both properties are run in the one batched + run_subtasks pass, and property 1's write is still committed.""" # Arrange pid1, pid2 = 111, 222 mock_engine = _engine_mock([pid1, pid2], [1001, 1002], [POSTCODE, POSTCODE]) @@ -847,8 +863,10 @@ def test_per_property_failure_fails_child_subtask_and_siblings_continue() -> Non orchestrator=mock_orch, ) - # run_subtask called for both properties; pid1 committed - assert mock_orch.run_subtask.call_count == 2 + # Both properties run in one batched run_subtasks pass; pid1 still committed + # (the failed pid2 is isolated and simply never queued for write). + mock_orch.run_subtasks.assert_called_once() + assert len(mock_orch.run_subtasks.call_args.args[1]) == 2 mock_uow.commit.assert_called_once() diff --git a/tests/orchestration/test_task_orchestrator.py b/tests/orchestration/test_task_orchestrator.py index ae89991d..8767c8c1 100644 --- a/tests/orchestration/test_task_orchestrator.py +++ b/tests/orchestration/test_task_orchestrator.py @@ -195,3 +195,87 @@ def test_run_subtask_failing_work_marks_failed_and_reraises( assert harness.subtasks.get(subtask.id).status is SubTaskStatus.FAILED assert harness.tasks.get(task.id).status is TaskStatus.FAILED + + +def test_run_subtasks_creates_runs_and_completes_a_whole_batch( + harness: Harness, +) -> None: + """run_subtasks fans the parent into one child per item, runs each, and + leaves every child COMPLETE with the parent cascaded to COMPLETE.""" + # arrange — a parent task whose coordinator subtask is done + task, coordinator = harness.orchestrator.create_task_with_subtask( + task_source="manual:test" + ) + harness.orchestrator.complete_subtask(coordinator.id) + + # act + results = harness.orchestrator.run_subtasks( + task.id, + [{"property_id": 1}, {"property_id": 2}, {"property_id": 3}], + work=lambda st: (st.inputs or {})["property_id"] * 10, + ) + + # assert — results in order, all children COMPLETE, parent COMPLETE + assert results == [10, 20, 30] + children = [s for s in harness.subtasks.list_by_task(task.id) if s.id != coordinator.id] + assert len(children) == 3 + assert all(c.status is SubTaskStatus.COMPLETE for c in children) + assert {(c.inputs or {})["property_id"] for c in children} == {1, 2, 3} + assert harness.tasks.get(task.id).status is TaskStatus.COMPLETE + + +def test_run_subtasks_isolates_a_failing_item_and_continues( + harness: Harness, +) -> None: + """A raising item is marked FAILED with its error recorded, its siblings still + complete (result None for the failure), and the parent cascades to FAILED.""" + # arrange + task, coordinator = harness.orchestrator.create_task_with_subtask( + task_source="manual:test" + ) + harness.orchestrator.complete_subtask(coordinator.id) + + def _work(st: SubTask) -> int: + pid = (st.inputs or {})["property_id"] + if pid == 2: + raise RuntimeError("property 2 exploded") + return pid + + # act — does NOT raise; failure is isolated + results = harness.orchestrator.run_subtasks( + task.id, + [{"property_id": 1}, {"property_id": 2}, {"property_id": 3}], + work=_work, + ) + + # assert + assert results == [1, None, 3] + children = { + (c.inputs or {})["property_id"]: c + for c in harness.subtasks.list_by_task(task.id) + if c.id != coordinator.id + } + assert children[1].status is SubTaskStatus.COMPLETE + assert children[3].status is SubTaskStatus.COMPLETE + assert children[2].status is SubTaskStatus.FAILED + assert children[2].outputs == {"error": "property 2 exploded"} + # any FAILED child → parent FAILED + assert harness.tasks.get(task.id).status is TaskStatus.FAILED + + +def test_cascade_short_circuits_once_task_already_failed(harness: Harness) -> None: + """Once the Task is FAILED, completing another SubTask leaves it FAILED — the + terminal state is not recomputed away.""" + # arrange — two children; fail the first so the task is FAILED + task, coordinator = harness.orchestrator.create_task_with_subtask( + task_source="manual:test" + ) + second = harness.orchestrator.create_child_subtask(task.id) + harness.orchestrator.fail_subtask(coordinator.id, RuntimeError("boom")) + assert harness.tasks.get(task.id).status is TaskStatus.FAILED + + # act — complete the other child + harness.orchestrator.complete_subtask(second.id) + + # assert — task stays FAILED (cascade short-circuited on the terminal state) + assert harness.tasks.get(task.id).status is TaskStatus.FAILED diff --git a/tests/repositories/tasks/postgres/test_subtask_postgres_repository.py b/tests/repositories/tasks/postgres/test_subtask_postgres_repository.py index 9cec52ea..81cdbad1 100644 --- a/tests/repositories/tasks/postgres/test_subtask_postgres_repository.py +++ b/tests/repositories/tasks/postgres/test_subtask_postgres_repository.py @@ -98,3 +98,68 @@ def test_get_missing_raises(session: Session) -> None: # act / assert with pytest.raises(ValueError, match="not found"): repo.get(uuid4()) + + +def test_create_many_persists_every_subtask(session: Session) -> None: + # arrange + repo = SubTaskPostgresRepository(session=session) + task_id = _persisted_task_id(session) + subtasks = [ + SubTask.create(task_id=task_id, inputs={"property_id": pid}) + for pid in (1, 2, 3) + ] + + # act + repo.create_many(subtasks) + + # assert — all three persisted, in WAITING, with their inputs + persisted = { + (s.inputs or {})["property_id"]: s for s in repo.list_by_task(task_id) + } + assert set(persisted) == {1, 2, 3} + assert all(s.status is SubTaskStatus.WAITING for s in persisted.values()) + + +def test_create_many_empty_is_a_noop(session: Session) -> None: + # arrange / act / assert — does not raise, writes nothing + SubTaskPostgresRepository(session=session).create_many([]) + + +def test_save_many_persists_each_subtasks_terminal_state(session: Session) -> None: + # arrange — three created subtasks + repo = SubTaskPostgresRepository(session=session) + task_id = _persisted_task_id(session) + subtasks = [ + SubTask.create(task_id=task_id, inputs={"property_id": pid}) + for pid in (1, 2, 3) + ] + repo.create_many(subtasks) + + # act — move each to a terminal state in memory, then one bulk save + subtasks[0].complete({"ok": True}) + subtasks[1].fail(RuntimeError("nope")) + subtasks[2].complete() + repo.save_many(subtasks) + + # assert + persisted = { + (s.inputs or {})["property_id"]: s for s in repo.list_by_task(task_id) + } + assert persisted[1].status is SubTaskStatus.COMPLETE + assert persisted[1].outputs == {"result": {"ok": True}} + assert persisted[2].status is SubTaskStatus.FAILED + assert persisted[2].outputs == {"error": "nope"} + assert persisted[3].status is SubTaskStatus.COMPLETE + + +def test_save_many_raises_when_a_subtask_does_not_exist(session: Session) -> None: + # arrange — one persisted, one never created + repo = SubTaskPostgresRepository(session=session) + task_id = _persisted_task_id(session) + persisted = SubTask.create(task_id=task_id) + repo.create_many([persisted]) + ghost = SubTask.create(task_id=task_id) + + # act / assert + with pytest.raises(ValueError, match="not found"): + repo.save_many([persisted, ghost])