From fe69bccf2216638370dfd77d0009e0921a83b964 Mon Sep 17 00:00:00 2001 From: Daniel Roth Date: Mon, 29 Jun 2026 12:19:49 +0000 Subject: [PATCH 1/8] =?UTF-8?q?Batch=20EPC=20writes=20via=20save=5Fbatch()?= =?UTF-8?q?=20on=20EpcPostgresRepository=20=F0=9F=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Sonnet 4.6 --- repositories/epc/epc_postgres_repository.py | 29 ++- repositories/epc/epc_repository.py | 8 +- tests/orchestration/fakes.py | 4 + tests/repositories/epc/test_epc_batch_save.py | 168 ++++++++++++++++++ 4 files changed, 206 insertions(+), 3 deletions(-) create mode 100644 tests/repositories/epc/test_epc_batch_save.py diff --git a/repositories/epc/epc_postgres_repository.py b/repositories/epc/epc_postgres_repository.py index 60383fd1..4e15afb4 100644 --- a/repositories/epc/epc_postgres_repository.py +++ b/repositories/epc/epc_postgres_repository.py @@ -1,10 +1,12 @@ from __future__ import annotations from collections.abc import Sequence +from dataclasses import dataclass, field from datetime import date, datetime -from typing import Optional, Protocol, TypeVar +from typing import Any, Optional, Protocol, TypeVar -from sqlmodel import Session, col, delete, select +from sqlalchemy import insert as _sa_insert +from sqlmodel import Session, SQLModel, col, delete, select from datatypes.epc.domain.epc import Epc from datatypes.epc.domain.epc_property_data import ( @@ -54,6 +56,23 @@ from utilities.private import private _T = TypeVar("_T") +@dataclass(frozen=True) +class EpcSaveRequest: + data: EpcPropertyData + property_id: Optional[int] = None + portfolio_id: Optional[int] = None + source: EpcSource = field(default="lodged") + + +def _col_values(model: SQLModel, exclude: frozenset[str] = frozenset()) -> dict[str, Any]: + """Extract column-keyed values from a SQLModel instance for Core INSERT.""" + return { + c.name: getattr(model, c.name) + for c in model.__table__.c # type: ignore[attr-defined] + if c.name not in exclude + } + + def _require(value: Optional[_T], field: str) -> _T: if value is None: raise ValueError(f"epc_property row is missing required field {field!r}") @@ -181,6 +200,12 @@ class EpcPostgresRepository(EpcRepository): ) return epc_property_id + def save_batch(self, requests: list[EpcSaveRequest]) -> list[int]: + raise NotImplementedError + + def _delete_for_properties(self, property_ids: list[int], source: EpcSource) -> None: + raise NotImplementedError + def _delete_for_property(self, property_id: int, source: EpcSource) -> None: """Remove the property's existing EPC graph for `source` (parent + child tables) so a re-save replaces rather than duplicates (ADR-0012), without diff --git a/repositories/epc/epc_repository.py b/repositories/epc/epc_repository.py index e0af06c0..a9fb6316 100644 --- a/repositories/epc/epc_repository.py +++ b/repositories/epc/epc_repository.py @@ -1,10 +1,13 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Literal, Optional +from typing import TYPE_CHECKING, Literal, Optional from datatypes.epc.domain.epc_property_data import EpcPropertyData +if TYPE_CHECKING: + from repositories.epc.epc_postgres_repository import EpcSaveRequest + # Provenance of a persisted EPC picture (ADR-0031): a real "lodged" EPC, or a # "predicted" one synthesised by EPC Prediction. A property can hold one of each. EpcSource = Literal["lodged", "predicted"] @@ -29,6 +32,9 @@ class EpcRepository(ABC): source: EpcSource = "lodged", ) -> int: ... + @abstractmethod + def save_batch(self, requests: "list[EpcSaveRequest]") -> list[int]: ... + @abstractmethod def get(self, epc_property_id: int) -> EpcPropertyData: ... diff --git a/tests/orchestration/fakes.py b/tests/orchestration/fakes.py index 7b180ca3..b068874e 100644 --- a/tests/orchestration/fakes.py +++ b/tests/orchestration/fakes.py @@ -17,6 +17,7 @@ from domain.modelling.scenario import Scenario from domain.property_baseline.property_baseline_performance import PropertyBaselinePerformance from domain.property.properties import Properties from domain.property.property import Property +from repositories.epc.epc_postgres_repository import EpcSaveRequest from repositories.plan.plan_repository import PlanRepository from repositories.product.product_repository import ProductRepository from repositories.property_baseline.property_baseline_repository import PropertyBaselineRepository @@ -130,6 +131,9 @@ class FakeEpcRepo(EpcRepository): if property_id in self._predicted_by_property } + def save_batch(self, requests: list[EpcSaveRequest]) -> list[int]: + return [self.save(r.data, r.property_id, r.portfolio_id, r.source) for r in requests] + class FakeSolarRepo(SolarRepository): """In-memory Google Solar insights store keyed by UPRN. Seed `by_uprn` to diff --git a/tests/repositories/epc/test_epc_batch_save.py b/tests/repositories/epc/test_epc_batch_save.py new file mode 100644 index 00000000..03489884 --- /dev/null +++ b/tests/repositories/epc/test_epc_batch_save.py @@ -0,0 +1,168 @@ +"""Batch EPC write path — save_batch() correctness and safety tests. + +Guards the four user stories from #1348: + 1. FK mis-wiring regression: building-part IDs must not be crossed between + properties in the same save_batch() call. + 2. save()/save_batch() parity: the single-property delegation path is loss-free. + 3. Batch idempotency: a second save_batch() with the same requests replaces, + not duplicates. + 4. Source isolation: lodged and predicted slots coexist after separate + save_batch() calls on the same property IDs. +""" + +from __future__ import annotations + +import json +from dataclasses import replace +from pathlib import Path +from typing import Any + +import pytest +from sqlalchemy import Engine +from sqlmodel import Session + +from datatypes.epc.domain.epc_property_data import EpcPropertyData, SapFloorDimension +from datatypes.epc.domain.mapper import EpcPropertyDataMapper +from repositories.epc.epc_postgres_repository import EpcPostgresRepository, EpcSaveRequest + +_JSON_SAMPLES = Path(__file__).resolve().parents[3] / "backend/epc_api/json_samples" + + +def _load_epc(schema_dir: str = "RdSAP-Schema-21.0.0") -> EpcPropertyData: + raw: dict[str, Any] = json.loads( + (_JSON_SAMPLES / schema_dir / "epc.json").read_text() + ) + return EpcPropertyDataMapper.from_api_response(raw) + + +def _with_floor_areas(epc: EpcPropertyData, areas_m2: list[float]) -> EpcPropertyData: + """Replace the building parts with variants that have a single floor dimension + carrying the given total_floor_area_m2 — making them easy to distinguish after + a round-trip without changing anything else about the EPC.""" + template_bp = epc.sap_building_parts[0] + template_dim = template_bp.sap_floor_dimensions[0] + new_parts = [ + replace(template_bp, sap_floor_dimensions=[replace(template_dim, total_floor_area_m2=a)]) + for a in areas_m2 + ] + return replace(epc, sap_building_parts=new_parts) + + +# --------------------------------------------------------------------------- +# Tracer bullet: single-request save_batch() is loss-free vs save() +# --------------------------------------------------------------------------- + +def test_single_request_save_batch_matches_save(db_engine: Engine) -> None: + # Arrange + epc = _load_epc() + + with Session(db_engine) as session: + repo = EpcPostgresRepository(session) + epc_id_via_save = repo.save(epc, property_id=1001) + epc_id_via_batch = repo.save_batch([EpcSaveRequest(epc, property_id=1002)])[0] + session.commit() + + # Act + with Session(db_engine) as session: + repo = EpcPostgresRepository(session) + via_save = repo.get(epc_id_via_save) + via_batch = repo.get(epc_id_via_batch) + + # Assert — both paths reconstruct the original exactly. + assert via_save == epc + assert via_batch == epc + + +# --------------------------------------------------------------------------- +# FK mis-wiring regression: building-part IDs must not be crossed +# --------------------------------------------------------------------------- + +def test_multi_property_building_part_ids_are_not_crossed(db_engine: Engine) -> None: + # Arrange — property A has 2 parts with distinctive areas; B has 1 with a + # third distinctive area. If part IDs are mis-wired the floor-dimension FK + # rows end up under the wrong property. + base = _load_epc() + epc_a = _with_floor_areas(base, [10.0, 20.0]) + epc_b = _with_floor_areas(base, [99.0]) + + with Session(db_engine) as session: + repo = EpcPostgresRepository(session) + repo.save_batch([ + EpcSaveRequest(epc_a, property_id=2001), + EpcSaveRequest(epc_b, property_id=2002), + ]) + session.commit() + + # Act + with Session(db_engine) as session: + repo = EpcPostgresRepository(session) + reloaded_a = repo.get_for_property(2001) + reloaded_b = repo.get_for_property(2002) + + # Assert — each property's building parts carry its own floor areas. + assert reloaded_a is not None + assert reloaded_b is not None + areas_a = sorted( + dim.total_floor_area_m2 + for part in reloaded_a.sap_building_parts + for dim in part.sap_floor_dimensions + ) + areas_b = sorted( + dim.total_floor_area_m2 + for part in reloaded_b.sap_building_parts + for dim in part.sap_floor_dimensions + ) + assert areas_a == [10.0, 20.0] + assert areas_b == [99.0] + + +# --------------------------------------------------------------------------- +# Idempotency: second save_batch() replaces, not duplicates +# --------------------------------------------------------------------------- + +def test_save_batch_is_idempotent(db_engine: Engine) -> None: + # Arrange + epc = _load_epc() + requests = [EpcSaveRequest(epc, property_id=3001)] + + with Session(db_engine) as session: + EpcPostgresRepository(session).save_batch(requests) + session.commit() + + # Act — re-save the same batch. + with Session(db_engine) as session: + EpcPostgresRepository(session).save_batch(requests) + session.commit() + + # Assert — exactly one EPC survives (no duplicate rows). + with Session(db_engine) as session: + result = EpcPostgresRepository(session).get_for_property(3001) + + assert result == epc + + +# --------------------------------------------------------------------------- +# Source isolation: lodged and predicted slots survive separate batch saves +# --------------------------------------------------------------------------- + +def test_lodged_and_predicted_batch_slots_are_independent(db_engine: Engine) -> None: + # Arrange — two properties each get a lodged EPC and then a predicted EPC + # via separate save_batch() calls. + epc = _load_epc() + property_ids = [4001, 4002] + + with Session(db_engine) as session: + repo = EpcPostgresRepository(session) + repo.save_batch([EpcSaveRequest(epc, property_id=pid, source="lodged") for pid in property_ids]) + repo.save_batch([EpcSaveRequest(epc, property_id=pid, source="predicted") for pid in property_ids]) + session.commit() + + # Act + with Session(db_engine) as session: + repo = EpcPostgresRepository(session) + lodged = repo.get_for_properties(property_ids) + predicted = repo.get_predicted_for_properties(property_ids) + + # Assert — both slots are populated for both properties. + assert lodged == {4001: epc, 4002: epc} + assert predicted == {4001: epc, 4002: epc} From 587465bff77b47bc562cd586411d281b4f1897c1 Mon Sep 17 00:00:00 2001 From: Daniel Roth Date: Mon, 29 Jun 2026 13:04:30 +0000 Subject: [PATCH 2/8] =?UTF-8?q?Batch=20EPC=20writes=20via=20save=5Fbatch()?= =?UTF-8?q?=20on=20EpcPostgresRepository=20=F0=9F=9F=A9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Sonnet 4.6 --- repositories/epc/epc_postgres_repository.py | 256 ++++++++++++++------ 1 file changed, 185 insertions(+), 71 deletions(-) diff --git a/repositories/epc/epc_postgres_repository.py b/repositories/epc/epc_postgres_repository.py index 4e15afb4..95f09fdf 100644 --- a/repositories/epc/epc_postgres_repository.py +++ b/repositories/epc/epc_postgres_repository.py @@ -130,81 +130,195 @@ class EpcPostgresRepository(EpcRepository): portfolio_id: Optional[int] = None, source: EpcSource = "lodged", ) -> int: - # Idempotent on (property_id, source): a re-run replaces the property's - # EPC graph for THAT source rather than duplicating it (ADR-0012), and a - # predicted save leaves the lodged one intact, and vice versa (ADR-0031). - # Anonymous saves (no property_id) always insert. - if property_id is not None: - self._delete_for_property(property_id, source) - parent = EpcPropertyModel.from_epc_property_data( - data, property_id=property_id, portfolio_id=portfolio_id, source=source - ) - self._session.add(parent) - self._session.flush() - epc_property_id = _require(parent.id, "id") - - self._session.add( - EpcPropertyEnergyPerformanceModel.from_epc_property_data( - data, epc_property_id=epc_property_id - ) - ) - for detail in data.sap_heating.main_heating_details: - self._session.add( - EpcMainHeatingDetailModel.from_domain(detail, epc_property_id) - ) - for part in data.sap_building_parts: - bp = EpcBuildingPartModel.from_domain(part, epc_property_id) - self._session.add(bp) - self._session.flush() - bp_id = _require(bp.id, "epc_building_part.id") - for dim in part.sap_floor_dimensions: - self._session.add(EpcFloorDimensionModel.from_domain(dim, bp_id)) - for window in data.sap_windows: - self._session.add(EpcWindowModel.from_domain(window, epc_property_id)) - for index, array in enumerate(data.sap_energy_source.photovoltaic_arrays or []): - self._session.add( - EpcPhotovoltaicArrayModel.from_domain(array, index, epc_property_id) - ) - - for element_type, elements in ( - ("roof", data.roofs), - ("wall", data.walls), - ("floor", data.floors), - ("main_heating", data.main_heating), - ): - for el in elements: - self._session.add( - EpcEnergyElementModel.from_domain(el, element_type, epc_property_id) - ) - for el, element_type in ( - (data.window, "window"), - (data.lighting, "lighting"), - (data.hot_water, "hot_water"), - (data.secondary_heating, "secondary_heating"), - (data.main_heating_controls, "main_heating_controls"), - ): - if el is not None: - self._session.add( - EpcEnergyElementModel.from_domain(el, element_type, epc_property_id) - ) - - if data.sap_flat_details is not None: - self._session.add( - EpcFlatDetailsModel.from_domain(data.sap_flat_details, epc_property_id) - ) - if data.renewable_heat_incentive is not None: - self._session.add( - EpcRenewableHeatIncentiveModel.from_domain( - data.renewable_heat_incentive, epc_property_id - ) - ) - return epc_property_id + return self.save_batch([EpcSaveRequest(data, property_id, portfolio_id, source)])[0] def save_batch(self, requests: list[EpcSaveRequest]) -> list[int]: - raise NotImplementedError + """Insert all EPCs in `requests` in one pass per table, returning one + epc_property_id per request in the same order as the input. + + Deletes are batched first (one IN-query per child table per source), + then all parent rows are inserted with a single RETURNING statement so + positional ordering maps each returned id to its request. Building-part + ids are captured the same way so floor-dimension FKs are resolved without + any per-property flush round-trips (ADR-0012). + """ + if not requests: + return [] + + # Batch-delete existing rows grouped by source so the lodged and predicted + # slots remain independent (ADR-0031). + pids_by_source: dict[EpcSource, list[int]] = {} + for r in requests: + if r.property_id is not None: + pids_by_source.setdefault(r.source, []).append(r.property_id) + for src, pids in pids_by_source.items(): + self._delete_for_properties(pids, src) + + # Insert all parent (epc_property) rows; capture returned ids positionally. + parent_rows = [ + _col_values( + EpcPropertyModel.from_epc_property_data( + r.data, property_id=r.property_id, portfolio_id=r.portfolio_id, source=r.source + ), + exclude=frozenset({"id"}), + ) + for r in requests + ] + returned_parents = self._session.execute( + _sa_insert(EpcPropertyModel).returning(EpcPropertyModel.id), + parent_rows, + ).all() + epc_property_ids = [row[0] for row in returned_parents] + + # Collect child rows, accumulating building parts in an ordered list so + # the positional RETURNING trick can map part objects to their new ids. + perf_rows: list[dict[str, Any]] = [] + heating_rows: list[dict[str, Any]] = [] + parts_ordered: list[tuple[Any, int]] = [] # (SapBuildingPart, epc_property_id) + window_rows: list[dict[str, Any]] = [] + pv_rows: list[dict[str, Any]] = [] + element_rows: list[dict[str, Any]] = [] + flat_rows: list[dict[str, Any]] = [] + rhi_rows: list[dict[str, Any]] = [] + + for r, epc_pid in zip(requests, epc_property_ids): + d = r.data + perf_rows.append( + _col_values( + EpcPropertyEnergyPerformanceModel.from_epc_property_data(d, epc_pid), + exclude=frozenset({"id"}), + ) + ) + for detail in d.sap_heating.main_heating_details: + heating_rows.append( + _col_values(EpcMainHeatingDetailModel.from_domain(detail, epc_pid), frozenset({"id"})) + ) + for part in d.sap_building_parts: + parts_ordered.append((part, epc_pid)) + for window in d.sap_windows: + window_rows.append( + _col_values(EpcWindowModel.from_domain(window, epc_pid), frozenset({"id"})) + ) + for idx, array in enumerate(d.sap_energy_source.photovoltaic_arrays or []): + pv_rows.append( + _col_values(EpcPhotovoltaicArrayModel.from_domain(array, idx, epc_pid), frozenset({"id"})) + ) + for etype, els in ( + ("roof", d.roofs), + ("wall", d.walls), + ("floor", d.floors), + ("main_heating", d.main_heating), + ): + for el in els: + element_rows.append( + _col_values(EpcEnergyElementModel.from_domain(el, etype, epc_pid), frozenset({"id"})) + ) + for el, etype in ( + (d.window, "window"), + (d.lighting, "lighting"), + (d.hot_water, "hot_water"), + (d.secondary_heating, "secondary_heating"), + (d.main_heating_controls, "main_heating_controls"), + ): + if el is not None: + element_rows.append( + _col_values(EpcEnergyElementModel.from_domain(el, etype, epc_pid), frozenset({"id"})) + ) + if d.sap_flat_details is not None: + flat_rows.append( + _col_values(EpcFlatDetailsModel.from_domain(d.sap_flat_details, epc_pid), frozenset({"id"})) + ) + if d.renewable_heat_incentive is not None: + rhi_rows.append( + _col_values(EpcRenewableHeatIncentiveModel.from_domain(d.renewable_heat_incentive, epc_pid), frozenset({"id"})) + ) + + # Bulk-insert all simple child tables (no downstream FK dependency). + if perf_rows: + self._session.execute(_sa_insert(EpcPropertyEnergyPerformanceModel), perf_rows) + if heating_rows: + self._session.execute(_sa_insert(EpcMainHeatingDetailModel), heating_rows) + if window_rows: + self._session.execute(_sa_insert(EpcWindowModel), window_rows) + if pv_rows: + self._session.execute(_sa_insert(EpcPhotovoltaicArrayModel), pv_rows) + if element_rows: + self._session.execute(_sa_insert(EpcEnergyElementModel), element_rows) + if flat_rows: + self._session.execute(_sa_insert(EpcFlatDetailsModel), flat_rows) + if rhi_rows: + self._session.execute(_sa_insert(EpcRenewableHeatIncentiveModel), rhi_rows) + + # Building parts: insert with RETURNING and zip positionally to resolve + # floor-dimension FKs. Do NOT key by id(part) — the same EpcPropertyData + # object can appear in multiple requests (same epc, different property_ids), + # giving identical object ids that collapse the dict and mis-wire FKs. + # Positional zip is safe because PostgreSQL preserves VALUES order in RETURNING. + if parts_ordered: + bp_rows = [ + _col_values(EpcBuildingPartModel.from_domain(part, epc_pid), frozenset({"id"})) + for part, epc_pid in parts_ordered + ] + returned_bps = self._session.execute( + _sa_insert(EpcBuildingPartModel).returning(EpcBuildingPartModel.id), + bp_rows, + ).all() + floor_rows: list[dict[str, Any]] = [ + _col_values(EpcFloorDimensionModel.from_domain(dim, bp_row[0]), frozenset({"id"})) + for (part, _), bp_row in zip(parts_ordered, returned_bps) + for dim in part.sap_floor_dimensions + ] + if floor_rows: + self._session.execute(_sa_insert(EpcFloorDimensionModel), floor_rows) + + return epc_property_ids def _delete_for_properties(self, property_ids: list[int], source: EpcSource) -> None: - raise NotImplementedError + """Batch-delete every EPC graph for the given property_ids and source in + one pass per child table (IN queries), replacing the per-property flush + loop that drove RDS CPU to saturation during bulk modelling runs.""" + epc_ids = [ + i + for i in self._session.exec( + select(EpcPropertyModel.id) + .where(col(EpcPropertyModel.property_id).in_(property_ids)) + .where(EpcPropertyModel.source == source) + ).all() + if i is not None + ] + if not epc_ids: + return + part_ids = [ + i + for i in self._session.exec( + select(EpcBuildingPartModel.id).where( + col(EpcBuildingPartModel.epc_property_id).in_(epc_ids) + ) + ).all() + if i is not None + ] + if part_ids: + self._session.exec( # type: ignore[call-overload] + delete(EpcFloorDimensionModel).where( + col(EpcFloorDimensionModel.epc_building_part_id).in_(part_ids) + ) + ) + for child in ( + EpcPropertyEnergyPerformanceModel, + EpcEnergyElementModel, + EpcMainHeatingDetailModel, + EpcBuildingPartModel, + EpcWindowModel, + EpcPhotovoltaicArrayModel, + EpcFlatDetailsModel, + EpcRenewableHeatIncentiveModel, + ): + self._session.exec( # type: ignore[call-overload] + delete(child).where(col(child.epc_property_id).in_(epc_ids)) + ) + self._session.exec( # type: ignore[call-overload] + delete(EpcPropertyModel).where(col(EpcPropertyModel.id).in_(epc_ids)) + ) def _delete_for_property(self, property_id: int, source: EpcSource) -> None: """Remove the property's existing EPC graph for `source` (parent + child From 0fa1b9001c84bfea0d0681853ead0b777473e99a Mon Sep 17 00:00:00 2001 From: Daniel Roth Date: Mon, 29 Jun 2026 13:09:56 +0000 Subject: [PATCH 3/8] =?UTF-8?q?Batch=20EPC=20writes=20in=20=5Fflush=5Fwrit?= =?UTF-8?q?es:=20two=20save=5Fbatch()=20calls=20instead=20of=20N=20save()?= =?UTF-8?q?=20calls=20=F0=9F=9F=A9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Sonnet 4.6 --- applications/modelling_e2e/handler.py | 45 +++++++++---------- .../modelling_e2e/test_handler.py | 44 +++++++----------- 2 files changed, 37 insertions(+), 52 deletions(-) diff --git a/applications/modelling_e2e/handler.py b/applications/modelling_e2e/handler.py index 532d5fb4..caa8a7d7 100644 --- a/applications/modelling_e2e/handler.py +++ b/applications/modelling_e2e/handler.py @@ -92,7 +92,7 @@ from repositories.comparable_properties.epc_comparable_properties_repository imp EpcComparablePropertiesRepository, SkippedCohortCert, ) -from repositories.epc.epc_postgres_repository import EpcPostgresRepository +from repositories.epc.epc_postgres_repository import EpcPostgresRepository, EpcSaveRequest from repositories.geospatial.geospatial_s3_repository import ( GeospatialS3Repository, ParquetReader, @@ -176,31 +176,30 @@ class _PropertyWrite: def _flush_writes(engine: Engine, writes: list[_PropertyWrite]) -> None: """Persist a whole batch of modelled Properties in one Unit of Work. - Replays each Property's saves in dependency order (EPC → spatial → solar → - Plan → mark-modelled) and commits once. All-or-nothing per batch: a failed - save rolls the whole transaction back and propagates, so the SQS message is - retried — every save is an idempotent upsert, so a retry is safe. This mirrors - the PropertyBaselineOrchestrator's existing one-UoW-per-batch contract - (ADR-0012); per-property failures are isolated earlier, in the modelling loop, + EPC writes are batched by source (lodged group first, predicted group second) + so each source emits one DELETE pass + one INSERT pass regardless of batch + size, rather than N×per-property round-trips (ADR-0012). All other writes + (spatial, solar, plan, mark-modelled) remain per-property inside the same + transaction. All-or-nothing per batch: a failed save rolls the whole + transaction back so the SQS message is retried — every save is an idempotent + upsert. Per-property failures are isolated earlier, in the modelling loop, before a write is ever queued.""" + lodged_requests = [ + EpcSaveRequest(w.lodged_epc, property_id=w.property_id, portfolio_id=w.portfolio_id, source="lodged") + for w in writes + if w.lodged_epc is not None + ] + predicted_requests = [ + EpcSaveRequest(w.predicted_epc, property_id=w.property_id, portfolio_id=w.portfolio_id, source="predicted") + for w in writes + if w.predicted_epc is not None + ] with PostgresUnitOfWork(lambda: Session(engine)) as uow: + if lodged_requests: + uow.epc.save_batch(lodged_requests) + if predicted_requests: + uow.epc.save_batch(predicted_requests) for w in writes: - if w.lodged_epc is not None: - uow.epc.save( - w.lodged_epc, - property_id=w.property_id, - portfolio_id=w.portfolio_id, - ) - elif w.predicted_epc is not None: - # Persist the synthesised EPC in the predicted slot (ADR-0031), so - # the Baseline stage can re-hydrate it and downstream sees the - # picture the Plan was modelled from. - uow.epc.save( - w.predicted_epc, - property_id=w.property_id, - portfolio_id=w.portfolio_id, - source="predicted", - ) if w.spatial is not None: uow.spatial.save(w.uprn, w.spatial) if w.solar is not None: diff --git a/tests/applications/modelling_e2e/test_handler.py b/tests/applications/modelling_e2e/test_handler.py index 90b3e97c..880ac9f4 100644 --- a/tests/applications/modelling_e2e/test_handler.py +++ b/tests/applications/modelling_e2e/test_handler.py @@ -19,6 +19,7 @@ from applications.modelling_e2e.modelling_e2e_trigger_body import ( ModellingE2ETriggerBody, ) from domain.tasks.subtasks import SubTask +from repositories.epc.epc_postgres_repository import EpcSaveRequest PROPERTY_ID = 12345 UPRN = 987654321 @@ -355,8 +356,8 @@ def test_lodged_epc_path_saves_epc_plan_and_marks_modelled( _call_handler(_BODY) # Assert — EPC saved (lodged path), plan saved, property marked modelled - mock_uow.epc.save.assert_called_once_with( - mock_epc, property_id=PROPERTY_ID, portfolio_id=PORTFOLIO_ID + mock_uow.epc.save_batch.assert_called_once_with( + [EpcSaveRequest(mock_epc, property_id=PROPERTY_ID, portfolio_id=PORTFOLIO_ID, source="lodged")] ) mock_uow.plan.save.assert_called_once() mock_uow.property.mark_modelled.assert_called_once_with( @@ -631,11 +632,8 @@ def test_prediction_path_saves_predicted_epc_plan_and_baseline( _call_handler(_BODY) # Assert — predicted EPC persisted in the predicted slot, plan saved, baseline run - mock_uow.epc.save.assert_called_once_with( - mock_predicted_epc, - property_id=PROPERTY_ID, - portfolio_id=PORTFOLIO_ID, - source="predicted", + mock_uow.epc.save_batch.assert_called_once_with( + [EpcSaveRequest(mock_predicted_epc, property_id=PROPERTY_ID, portfolio_id=PORTFOLIO_ID, source="predicted")] ) mock_uow.plan.save.assert_called_once() mock_uow.commit.assert_called_once() @@ -841,11 +839,8 @@ def test_empty_own_postcode_broadens_to_nearby_and_predicts() -> None: # Assert — broadening fired, and the broadened cohort produced a saved plan # with its predicted EPC persisted in the predicted slot. MockRepo.return_value.candidates_near.assert_called_once() - mock_uow.epc.save.assert_called_once_with( - mock_predicted_epc, - property_id=PROPERTY_ID, - portfolio_id=PORTFOLIO_ID, - source="predicted", + mock_uow.epc.save_batch.assert_called_once_with( + [EpcSaveRequest(mock_predicted_epc, property_id=PROPERTY_ID, portfolio_id=PORTFOLIO_ID, source="predicted")] ) mock_uow.plan.save.assert_called_once() mock_uow.commit.assert_called_once() @@ -1171,8 +1166,8 @@ def test_refetch_epc_false_with_stored_epc_skips_api_call() -> None: mock_epc_client.get_by_uprn.assert_not_called() mock_run_modelling.assert_called_once() # Stored lodged EPC is persisted in the lodged slot - mock_uow.epc.save.assert_called_once_with( - stored_epc, property_id=PROPERTY_ID, portfolio_id=PORTFOLIO_ID + mock_uow.epc.save_batch.assert_called_once_with( + [EpcSaveRequest(stored_epc, property_id=PROPERTY_ID, portfolio_id=PORTFOLIO_ID, source="lodged")] ) @@ -1258,11 +1253,8 @@ def test_refetch_epc_false_without_stored_epc_skips_api_and_goes_to_prediction() # Assert — API was NOT called; prediction ran and its output was persisted mock_epc_client.get_by_uprn.assert_not_called() - mock_uow.epc.save.assert_called_once_with( - mock_predicted_epc, - property_id=PROPERTY_ID, - portfolio_id=PORTFOLIO_ID, - source="predicted", + mock_uow.epc.save_batch.assert_called_once_with( + [EpcSaveRequest(mock_predicted_epc, property_id=PROPERTY_ID, portfolio_id=PORTFOLIO_ID, source="predicted")] ) @@ -1398,11 +1390,8 @@ def test_repredict_epc_false_with_stored_predicted_epc_skips_prediction() -> Non # Assert — EpcPrediction.predict never called; stored EPC persisted in predicted slot mock_predictor.predict.assert_not_called() - mock_uow.epc.save.assert_called_once_with( - stored_predicted, - property_id=PROPERTY_ID, - portfolio_id=PORTFOLIO_ID, - source="predicted", + mock_uow.epc.save_batch.assert_called_once_with( + [EpcSaveRequest(stored_predicted, property_id=PROPERTY_ID, portfolio_id=PORTFOLIO_ID, source="predicted")] ) @@ -1488,11 +1477,8 @@ def test_repredict_epc_false_without_stored_predicted_epc_falls_back_to_live_pre # Assert — live prediction was used as fallback mock_predictor.predict.assert_called_once() - mock_uow.epc.save.assert_called_once_with( - mock_predicted_epc, - property_id=PROPERTY_ID, - portfolio_id=PORTFOLIO_ID, - source="predicted", + mock_uow.epc.save_batch.assert_called_once_with( + [EpcSaveRequest(mock_predicted_epc, property_id=PROPERTY_ID, portfolio_id=PORTFOLIO_ID, source="predicted")] ) From f27d1e21bb804881435b52b9e25ddbb0aac75d0a Mon Sep 17 00:00:00 2001 From: Daniel Roth Date: Mon, 29 Jun 2026 13:17:29 +0000 Subject: [PATCH 4/8] =?UTF-8?q?Batch=20EPC=20writes=20in=20EpcPostgresRepo?= =?UTF-8?q?sitory=20pass=20pyright=20strict=20=F0=9F=9F=AA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Sonnet 4.6 --- repositories/epc/epc_postgres_repository.py | 28 +++++++++---------- tests/repositories/epc/test_epc_batch_save.py | 3 +- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/repositories/epc/epc_postgres_repository.py b/repositories/epc/epc_postgres_repository.py index 95f09fdf..1c92e1bf 100644 --- a/repositories/epc/epc_postgres_repository.py +++ b/repositories/epc/epc_postgres_repository.py @@ -67,9 +67,9 @@ class EpcSaveRequest: def _col_values(model: SQLModel, exclude: frozenset[str] = frozenset()) -> dict[str, Any]: """Extract column-keyed values from a SQLModel instance for Core INSERT.""" return { - c.name: getattr(model, c.name) + c.name: getattr(model, c.name) # type: ignore[union-attr] for c in model.__table__.c # type: ignore[attr-defined] - if c.name not in exclude + if c.name not in exclude # type: ignore[union-attr] } @@ -164,8 +164,8 @@ class EpcPostgresRepository(EpcRepository): ) for r in requests ] - returned_parents = self._session.execute( - _sa_insert(EpcPropertyModel).returning(EpcPropertyModel.id), + returned_parents = self._session.execute( # type: ignore[deprecated] + _sa_insert(EpcPropertyModel).returning(EpcPropertyModel.__table__.c["id"]), # type: ignore[attr-defined] parent_rows, ).all() epc_property_ids = [row[0] for row in returned_parents] @@ -235,19 +235,19 @@ class EpcPostgresRepository(EpcRepository): # Bulk-insert all simple child tables (no downstream FK dependency). if perf_rows: - self._session.execute(_sa_insert(EpcPropertyEnergyPerformanceModel), perf_rows) + self._session.execute(_sa_insert(EpcPropertyEnergyPerformanceModel), perf_rows) # type: ignore[deprecated] if heating_rows: - self._session.execute(_sa_insert(EpcMainHeatingDetailModel), heating_rows) + self._session.execute(_sa_insert(EpcMainHeatingDetailModel), heating_rows) # type: ignore[deprecated] if window_rows: - self._session.execute(_sa_insert(EpcWindowModel), window_rows) + self._session.execute(_sa_insert(EpcWindowModel), window_rows) # type: ignore[deprecated] if pv_rows: - self._session.execute(_sa_insert(EpcPhotovoltaicArrayModel), pv_rows) + self._session.execute(_sa_insert(EpcPhotovoltaicArrayModel), pv_rows) # type: ignore[deprecated] if element_rows: - self._session.execute(_sa_insert(EpcEnergyElementModel), element_rows) + self._session.execute(_sa_insert(EpcEnergyElementModel), element_rows) # type: ignore[deprecated] if flat_rows: - self._session.execute(_sa_insert(EpcFlatDetailsModel), flat_rows) + self._session.execute(_sa_insert(EpcFlatDetailsModel), flat_rows) # type: ignore[deprecated] if rhi_rows: - self._session.execute(_sa_insert(EpcRenewableHeatIncentiveModel), rhi_rows) + self._session.execute(_sa_insert(EpcRenewableHeatIncentiveModel), rhi_rows) # type: ignore[deprecated] # Building parts: insert with RETURNING and zip positionally to resolve # floor-dimension FKs. Do NOT key by id(part) — the same EpcPropertyData @@ -259,8 +259,8 @@ class EpcPostgresRepository(EpcRepository): _col_values(EpcBuildingPartModel.from_domain(part, epc_pid), frozenset({"id"})) for part, epc_pid in parts_ordered ] - returned_bps = self._session.execute( - _sa_insert(EpcBuildingPartModel).returning(EpcBuildingPartModel.id), + returned_bps = self._session.execute( # type: ignore[deprecated] + _sa_insert(EpcBuildingPartModel).returning(EpcBuildingPartModel.__table__.c["id"]), # type: ignore[attr-defined] bp_rows, ).all() floor_rows: list[dict[str, Any]] = [ @@ -269,7 +269,7 @@ class EpcPostgresRepository(EpcRepository): for dim in part.sap_floor_dimensions ] if floor_rows: - self._session.execute(_sa_insert(EpcFloorDimensionModel), floor_rows) + self._session.execute(_sa_insert(EpcFloorDimensionModel), floor_rows) # type: ignore[deprecated] return epc_property_ids diff --git a/tests/repositories/epc/test_epc_batch_save.py b/tests/repositories/epc/test_epc_batch_save.py index 03489884..c31e22b2 100644 --- a/tests/repositories/epc/test_epc_batch_save.py +++ b/tests/repositories/epc/test_epc_batch_save.py @@ -17,11 +17,10 @@ from dataclasses import replace from pathlib import Path from typing import Any -import pytest from sqlalchemy import Engine from sqlmodel import Session -from datatypes.epc.domain.epc_property_data import EpcPropertyData, SapFloorDimension +from datatypes.epc.domain.epc_property_data import EpcPropertyData from datatypes.epc.domain.mapper import EpcPropertyDataMapper from repositories.epc.epc_postgres_repository import EpcPostgresRepository, EpcSaveRequest From 9c6b47702515b1b329e16d1107f79dd365777d36 Mon Sep 17 00:00:00 2001 From: Daniel Roth Date: Mon, 29 Jun 2026 15:04:54 +0000 Subject: [PATCH 5/8] =?UTF-8?q?Batch=20plan=20saves=20reduce=20RDS=20CPU?= =?UTF-8?q?=20during=20bulk=20modelling=20runs=20=F0=9F=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Sonnet 4.6 --- repositories/plan/plan_postgres_repository.py | 5 +- repositories/plan/plan_repository.py | 24 +++ .../repositories/plan/test_plan_batch_save.py | 183 ++++++++++++++++++ 3 files changed, 211 insertions(+), 1 deletion(-) create mode 100644 tests/repositories/plan/test_plan_batch_save.py diff --git a/repositories/plan/plan_postgres_repository.py b/repositories/plan/plan_postgres_repository.py index 7be21bac..36253764 100644 --- a/repositories/plan/plan_postgres_repository.py +++ b/repositories/plan/plan_postgres_repository.py @@ -4,7 +4,7 @@ from sqlmodel import Session, col, update from domain.modelling.plan import Plan from infrastructure.postgres.modelling import PlanModel, RecommendationModel -from repositories.plan.plan_repository import PlanRepository +from repositories.plan.plan_repository import PlanRepository, PlanSaveRequest class PlanPostgresRepository(PlanRepository): @@ -63,3 +63,6 @@ class PlanPostgresRepository(PlanRepository): ) ) return plan_row.id + + def save_batch(self, requests: list[PlanSaveRequest]) -> list[int]: + raise NotImplementedError diff --git a/repositories/plan/plan_repository.py b/repositories/plan/plan_repository.py index b534e8ea..6fdc16d6 100644 --- a/repositories/plan/plan_repository.py +++ b/repositories/plan/plan_repository.py @@ -1,10 +1,25 @@ from __future__ import annotations from abc import ABC, abstractmethod +from dataclasses import dataclass from domain.modelling.plan import Plan +@dataclass(frozen=True) +class PlanSaveRequest: + """Bundles the five fields the plan repository needs to persist one Plan. + + Mirrors ``EpcSaveRequest`` in shape — used by ``PlanRepository.save_batch()`` + to accumulate write intent before the batch is flushed in one Unit of Work.""" + + plan: Plan + property_id: int + scenario_id: int + portfolio_id: int + is_default: bool + + class PlanRepository(ABC): """Persists a Plan (and its Plan Measures) for a Property + Scenario. @@ -30,3 +45,12 @@ class PlanRepository(ABC): ``(property_id, scenario_id)`` as history; when ``is_default`` is True, demotes those prior Plans to ``is_default=False``.""" ... + + @abstractmethod + def save_batch(self, requests: list[PlanSaveRequest]) -> list[int]: + """Persist a batch of Plans in three statements regardless of batch size. + + Returns one plan id per request in input order. Fires a single demote + UPDATE only when at least one request has ``is_default=True``. Keeps + prior Plans as history (ADR-0017).""" + ... diff --git a/tests/repositories/plan/test_plan_batch_save.py b/tests/repositories/plan/test_plan_batch_save.py new file mode 100644 index 00000000..e82d0f1e --- /dev/null +++ b/tests/repositories/plan/test_plan_batch_save.py @@ -0,0 +1,183 @@ +"""Batch plan write path — save_batch() correctness and safety tests. + +Guards the four user stories from #1355: + 1. save()/save_batch() parity: a single-element save_batch() produces + identical DB state (plan row + recommendation rows) as the equivalent + save() call. + 2. Recommendation FK isolation: two properties in the same save_batch() each + get their own recommendation rows; no FK cross-wiring between properties. + 3. Demote correctness: a second save_batch() for the same properties demotes + the prior default Plans and inserts fresh ones (history preserved). + 4. Non-default batch: a save_batch() where all writes have is_default=False + leaves any pre-existing default Plan untouched. +""" + +from __future__ import annotations + +from sqlalchemy import Engine +from sqlmodel import Session, col, select + +from domain.modelling.measure_type import MeasureType +from domain.modelling.plan import Plan, PlanMeasure +from domain.modelling.recommendation import Cost +from domain.modelling.scoring.package_scorer import Score +from domain.modelling.scoring.scoring import MeasureImpact +from infrastructure.postgres.modelling import PlanModel, RecommendationModel +from repositories.plan.plan_postgres_repository import PlanPostgresRepository +from repositories.plan.plan_repository import PlanSaveRequest + + +def _plan(*, sap: float = 70.0, measures: int = 1) -> Plan: + ms: tuple[PlanMeasure, ...] = tuple( + PlanMeasure( + measure_type=MeasureType.CAVITY_WALL_INSULATION, + description="Cavity wall insulation", + cost=Cost(total=1000.0, contingency_rate=0.10), + impact=MeasureImpact( + sap_points=8.0, + co2_savings_kg_per_yr=500.0, + energy_savings_kwh_per_yr=2000.0, + ), + kwh_savings=1500.0, + energy_cost_savings=300.0, + ) + for _ in range(measures) + ) + return Plan( + measures=ms, + baseline=Score(sap_continuous=40.0, co2_kg_per_yr=4000.0, primary_energy_kwh_per_yr=20000.0), + post_retrofit=Score(sap_continuous=sap, co2_kg_per_yr=3500.0, primary_energy_kwh_per_yr=18000.0), + ) + + +# --------------------------------------------------------------------------- +# Tracer bullet: single-element save_batch() is loss-free vs save() +# --------------------------------------------------------------------------- + +def test_single_request_save_batch_matches_save(db_engine: Engine) -> None: + # Arrange + plan = _plan() + scenario_id = 7 + + with Session(db_engine) as session: + repo = PlanPostgresRepository(session) + save_id = repo.save(plan, property_id=5001, scenario_id=scenario_id, portfolio_id=1, is_default=True) + batch_id = repo.save_batch([PlanSaveRequest(plan, property_id=5002, scenario_id=scenario_id, portfolio_id=1, is_default=True)])[0] + session.commit() + + # Act + with Session(db_engine) as session: + via_save = session.get(PlanModel, save_id) + via_batch = session.get(PlanModel, batch_id) + recs_save = session.exec(select(RecommendationModel).where(col(RecommendationModel.plan_id) == save_id)).all() + recs_batch = session.exec(select(RecommendationModel).where(col(RecommendationModel.plan_id) == batch_id)).all() + + # Assert — both paths produce one plan row + one recommendation row with the + # same field values (modulo property_id which differs by design). + assert via_save is not None + assert via_batch is not None + assert via_save.is_default is True + assert via_batch.is_default is True + assert via_save.post_sap_points == via_batch.post_sap_points + assert via_save.post_co2_emissions == via_batch.post_co2_emissions + assert via_save.co2_savings == via_batch.co2_savings + assert len(recs_save) == 1 + assert len(recs_batch) == 1 + assert recs_save[0].type == recs_batch[0].type + assert recs_save[0].estimated_cost == recs_batch[0].estimated_cost + assert recs_save[0].sap_points == recs_batch[0].sap_points + assert recs_save[0].plan_id == batch_id + + +# --------------------------------------------------------------------------- +# FK isolation: recommendation rows must not be crossed between properties +# --------------------------------------------------------------------------- + +def test_multi_property_recommendation_fks_are_not_crossed(db_engine: Engine) -> None: + # Arrange — property A gets 2 measures, property B gets 1 measure. + plan_a = _plan(measures=2) + plan_b = _plan(measures=1) + + with Session(db_engine) as session: + [id_a, id_b] = PlanPostgresRepository(session).save_batch([ + PlanSaveRequest(plan_a, property_id=6001, scenario_id=7, portfolio_id=1, is_default=True), + PlanSaveRequest(plan_b, property_id=6002, scenario_id=7, portfolio_id=1, is_default=True), + ]) + session.commit() + + # Act + with Session(db_engine) as session: + recs_a = session.exec(select(RecommendationModel).where(col(RecommendationModel.plan_id) == id_a)).all() + recs_b = session.exec(select(RecommendationModel).where(col(RecommendationModel.plan_id) == id_b)).all() + + # Assert — A has 2 rows, B has 1; none cross-wired. + assert len(recs_a) == 2 + assert len(recs_b) == 1 + assert all(r.plan_id == id_a and r.property_id == 6001 for r in recs_a) + assert all(r.plan_id == id_b and r.property_id == 6002 for r in recs_b) + + +# --------------------------------------------------------------------------- +# Demote correctness: second save_batch() demotes prior defaults +# --------------------------------------------------------------------------- + +def test_second_save_batch_demotes_prior_default_plans(db_engine: Engine) -> None: + # Arrange — first batch creates default Plans for two properties. + plan = _plan() + requests = [ + PlanSaveRequest(plan, property_id=7001, scenario_id=7, portfolio_id=1, is_default=True), + PlanSaveRequest(plan, property_id=7002, scenario_id=7, portfolio_id=1, is_default=True), + ] + with Session(db_engine) as session: + first_ids = PlanPostgresRepository(session).save_batch(requests) + session.commit() + + # Act — re-run the same batch; new Plans should become default, old ones demoted. + with Session(db_engine) as session: + second_ids = PlanPostgresRepository(session).save_batch(requests) + session.commit() + + # Assert — history is preserved (4 plan rows total); exactly one default per property. + with Session(db_engine) as session: + rows_7001 = session.exec(select(PlanModel).where(col(PlanModel.property_id) == 7001)).all() + rows_7002 = session.exec(select(PlanModel).where(col(PlanModel.property_id) == 7002)).all() + + by_id_7001 = {p.id: p for p in rows_7001} + by_id_7002 = {p.id: p for p in rows_7002} + + assert len(rows_7001) == 2 + assert len(rows_7002) == 2 + assert by_id_7001[first_ids[0]].is_default is False + assert by_id_7001[second_ids[0]].is_default is True + assert by_id_7002[first_ids[1]].is_default is False + assert by_id_7002[second_ids[1]].is_default is True + + +# --------------------------------------------------------------------------- +# Non-default batch: existing default Plan is untouched +# --------------------------------------------------------------------------- + +def test_non_default_save_batch_does_not_demote_existing_default(db_engine: Engine) -> None: + # Arrange — a default Plan already exists for the property. + plan = _plan() + with Session(db_engine) as session: + default_id = PlanPostgresRepository(session).save( + plan, property_id=8001, scenario_id=7, portfolio_id=1, is_default=True + ) + session.commit() + + # Act — save a non-default Plan via save_batch(); no demote UPDATE should fire. + with Session(db_engine) as session: + PlanPostgresRepository(session).save_batch([ + PlanSaveRequest(plan, property_id=8001, scenario_id=7, portfolio_id=1, is_default=False), + ]) + session.commit() + + # Assert — the original default Plan is still the default. + with Session(db_engine) as session: + rows = session.exec(select(PlanModel).where(col(PlanModel.property_id) == 8001)).all() + by_id = {p.id: p for p in rows} + + assert len(rows) == 2 + assert by_id[default_id].is_default is True + assert sum(1 for p in rows if p.is_default) == 1 From 46ca714ef992726a95e6b2a6e6aed85726ab399e Mon Sep 17 00:00:00 2001 From: Daniel Roth Date: Mon, 29 Jun 2026 15:06:47 +0000 Subject: [PATCH 6/8] =?UTF-8?q?Batch=20plan=20saves=20reduce=20RDS=20CPU?= =?UTF-8?q?=20during=20bulk=20modelling=20runs=20=F0=9F=9F=A9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Sonnet 4.6 --- repositories/plan/plan_postgres_repository.py | 75 ++++++++++++++++++- .../repositories/plan/test_plan_batch_save.py | 2 +- 2 files changed, 75 insertions(+), 2 deletions(-) diff --git a/repositories/plan/plan_postgres_repository.py b/repositories/plan/plan_postgres_repository.py index 36253764..9bf22aea 100644 --- a/repositories/plan/plan_postgres_repository.py +++ b/repositories/plan/plan_postgres_repository.py @@ -1,5 +1,8 @@ from __future__ import annotations +from typing import Any + +from sqlalchemy import insert as _sa_insert from sqlmodel import Session, col, update from domain.modelling.plan import Plan @@ -7,6 +10,15 @@ from infrastructure.postgres.modelling import PlanModel, RecommendationModel from repositories.plan.plan_repository import PlanRepository, PlanSaveRequest +def _col_values(model: Any, exclude: frozenset[str] = frozenset()) -> dict[str, Any]: + """Extract column-keyed values from a SQLModel instance for Core INSERT.""" + return { + c.name: getattr(model, c.name) + for c in model.__table__.c + if c.name not in exclude + } + + class PlanPostgresRepository(PlanRepository): """Maps a Plan and its Plan Measures onto the live ``plan`` / ``recommendation`` tables (ADR-0017). Does not commit — the Unit of Work @@ -65,4 +77,65 @@ class PlanPostgresRepository(PlanRepository): return plan_row.id def save_batch(self, requests: list[PlanSaveRequest]) -> list[int]: - raise NotImplementedError + """Persist all Plans in three statements regardless of batch size. + + 1. One demote UPDATE (only when any request has ``is_default=True``). + 2. One bulk plan INSERT with RETURNING to capture ids positionally. + 3. One bulk recommendation INSERT (skipped when no measures exist). + """ + if not requests: + return [] + + # Demote prior default Plans for every property in the batch that is + # receiving a new default Plan — one UPDATE for the whole batch. + default_pids = [r.property_id for r in requests if r.is_default] + if default_pids: + # scenario_id is uniform per batch (one scenario per SQS message). + scenario_id = requests[0].scenario_id + self._session.exec( # type: ignore[call-overload] + update(PlanModel) + .where( + col(PlanModel.property_id).in_(default_pids), + col(PlanModel.scenario_id) == scenario_id, + ) + .values(is_default=False) + ) + + # Bulk INSERT all plan rows; capture returned ids positionally. + plan_rows = [ + _col_values( + PlanModel.from_domain( + r.plan, + property_id=r.property_id, + scenario_id=r.scenario_id, + portfolio_id=r.portfolio_id, + is_default=r.is_default, + ), + exclude=frozenset({"id"}), + ) + for r in requests + ] + returned = self._session.execute( # type: ignore[deprecated] + _sa_insert(PlanModel).returning(PlanModel.__table__.c["id"]), # type: ignore[attr-defined] + plan_rows, + ).all() + plan_ids = [row[0] for row in returned] + + # Accumulate recommendation rows across all requests; properties with + # zero measures contribute nothing (no special-casing needed). + rec_rows = [ + _col_values( + RecommendationModel.from_domain( + measure, property_id=r.property_id, plan_id=plan_id + ), + exclude=frozenset({"id"}), + ) + for r, plan_id in zip(requests, plan_ids) + for measure in r.plan.measures + ] + if rec_rows: + self._session.execute( # type: ignore[deprecated] + _sa_insert(RecommendationModel), rec_rows + ) + + return plan_ids diff --git a/tests/repositories/plan/test_plan_batch_save.py b/tests/repositories/plan/test_plan_batch_save.py index e82d0f1e..400b2cdf 100644 --- a/tests/repositories/plan/test_plan_batch_save.py +++ b/tests/repositories/plan/test_plan_batch_save.py @@ -86,7 +86,7 @@ def test_single_request_save_batch_matches_save(db_engine: Engine) -> None: assert recs_save[0].type == recs_batch[0].type assert recs_save[0].estimated_cost == recs_batch[0].estimated_cost assert recs_save[0].sap_points == recs_batch[0].sap_points - assert recs_save[0].plan_id == batch_id + assert recs_batch[0].plan_id == batch_id # --------------------------------------------------------------------------- From 4764bc7c155904a0ba2f6cb0ef9546bd7494a930 Mon Sep 17 00:00:00 2001 From: Daniel Roth Date: Mon, 29 Jun 2026 15:08:47 +0000 Subject: [PATCH 7/8] =?UTF-8?q?Batch=20plan=20saves=20reduce=20RDS=20CPU?= =?UTF-8?q?=20during=20bulk=20modelling=20runs=20=F0=9F=9F=AA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Sonnet 4.6 --- applications/modelling_e2e/handler.py | 19 ++++++---- repositories/plan/plan_postgres_repository.py | 37 ++----------------- 2 files changed, 15 insertions(+), 41 deletions(-) diff --git a/applications/modelling_e2e/handler.py b/applications/modelling_e2e/handler.py index 88ab14a0..916c2c24 100644 --- a/applications/modelling_e2e/handler.py +++ b/applications/modelling_e2e/handler.py @@ -93,6 +93,7 @@ from repositories.comparable_properties.epc_comparable_properties_repository imp SkippedCohortCert, ) from repositories.epc.epc_postgres_repository import EpcPostgresRepository, EpcSaveRequest +from repositories.plan.plan_repository import PlanSaveRequest from repositories.geospatial.geospatial_s3_repository import ( GeospatialS3Repository, ParquetReader, @@ -201,6 +202,17 @@ def _flush_writes(engine: Engine, writes: list[_PropertyWrite]) -> None: uow.epc.save_batch(lodged_requests) if predicted_requests: uow.epc.save_batch(predicted_requests) + plan_requests = [ + PlanSaveRequest( + w.plan, + property_id=w.property_id, + scenario_id=w.scenario_id, + portfolio_id=w.portfolio_id, + is_default=w.is_default, + ) + for w in writes + ] + uow.plan.save_batch(plan_requests) for w in writes: if w.spatial is not None: uow.spatial.save(w.uprn, w.spatial) @@ -211,13 +223,6 @@ def _flush_writes(engine: Engine, writes: list[_PropertyWrite]) -> None: latitude=w.solar.latitude, insights=w.solar.insights, ) - uow.plan.save( - w.plan, - property_id=w.property_id, - scenario_id=w.scenario_id, - portfolio_id=w.portfolio_id, - is_default=w.is_default, - ) uow.property.mark_modelled( w.property_id, has_recommendations=w.has_recommendations ) diff --git a/repositories/plan/plan_postgres_repository.py b/repositories/plan/plan_postgres_repository.py index 9bf22aea..e81e9084 100644 --- a/repositories/plan/plan_postgres_repository.py +++ b/repositories/plan/plan_postgres_repository.py @@ -41,40 +41,9 @@ class PlanPostgresRepository(PlanRepository): portfolio_id: int, is_default: bool, ) -> int: - # Soft-replace (ADR-0012): keep prior Plans as history rather than DELETEing - # them — the cascade delete of recommendation rows was the slow part. When - # this Plan is the default, demote every prior Plan for the same - # (property_id, scenario_id) to is_default=False, so exactly one Plan for - # the pair stays default (the one just inserted). - if is_default: - self._session.exec( # type: ignore[call-overload] - update(PlanModel) - .where( - col(PlanModel.property_id) == property_id, - col(PlanModel.scenario_id) == scenario_id, - ) - .values(is_default=False) - ) - - plan_row = PlanModel.from_domain( - plan, - property_id=property_id, - scenario_id=scenario_id, - portfolio_id=portfolio_id, - is_default=is_default, - ) - self._session.add(plan_row) - self._session.flush() - if plan_row.id is None: - raise ValueError("plan row did not receive an id") - - for measure in plan.measures: - self._session.add( - RecommendationModel.from_domain( - measure, property_id=property_id, plan_id=plan_row.id - ) - ) - return plan_row.id + return self.save_batch( + [PlanSaveRequest(plan, property_id=property_id, scenario_id=scenario_id, portfolio_id=portfolio_id, is_default=is_default)] + )[0] def save_batch(self, requests: list[PlanSaveRequest]) -> list[int]: """Persist all Plans in three statements regardless of batch size. From c202f15d8cdf321da5355cc4c1e75407b0b70084 Mon Sep 17 00:00:00 2001 From: Daniel Roth Date: Mon, 29 Jun 2026 15:40:22 +0000 Subject: [PATCH 8/8] fix broken unit tests --- .../modelling_e2e/test_handler.py | 27 +++++++++---------- tests/orchestration/fakes.py | 14 +++++++++- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/tests/applications/modelling_e2e/test_handler.py b/tests/applications/modelling_e2e/test_handler.py index 880ac9f4..a09fe266 100644 --- a/tests/applications/modelling_e2e/test_handler.py +++ b/tests/applications/modelling_e2e/test_handler.py @@ -359,7 +359,7 @@ def test_lodged_epc_path_saves_epc_plan_and_marks_modelled( mock_uow.epc.save_batch.assert_called_once_with( [EpcSaveRequest(mock_epc, property_id=PROPERTY_ID, portfolio_id=PORTFOLIO_ID, source="lodged")] ) - mock_uow.plan.save.assert_called_once() + mock_uow.plan.save_batch.assert_called_once() mock_uow.property.mark_modelled.assert_called_once_with( PROPERTY_ID, has_recommendations=False ) @@ -448,7 +448,7 @@ def test_skipped_cohort_certs_do_not_prevent_plan_being_saved() -> None: _call_handler(_BODY) # Assert — plan committed despite the skipped cert - mock_uow.plan.save.assert_called_once() + mock_uow.plan.save_batch.assert_called_once() mock_uow.commit.assert_called_once() @@ -513,7 +513,7 @@ def test_skipped_cohort_certs_are_logged_and_handler_does_not_raise() -> None: _call_handler(_BODY) # Assert — plan committed; skipped cert number surfaced in a log call - mock_uow.plan.save.assert_called_once() + mock_uow.plan.save_batch.assert_called_once() mock_uow.commit.assert_called_once() logged_messages = " ".join( str(c.args) + str(c.kwargs) for c in mock_logger.info.call_args_list @@ -635,7 +635,7 @@ def test_prediction_path_saves_predicted_epc_plan_and_baseline( mock_uow.epc.save_batch.assert_called_once_with( [EpcSaveRequest(mock_predicted_epc, property_id=PROPERTY_ID, portfolio_id=PORTFOLIO_ID, source="predicted")] ) - mock_uow.plan.save.assert_called_once() + mock_uow.plan.save_batch.assert_called_once() mock_uow.commit.assert_called_once() _baseline_orchestrator.return_value.run.assert_called_once_with([PROPERTY_ID]) @@ -842,7 +842,7 @@ def test_empty_own_postcode_broadens_to_nearby_and_predicts() -> None: mock_uow.epc.save_batch.assert_called_once_with( [EpcSaveRequest(mock_predicted_epc, property_id=PROPERTY_ID, portfolio_id=PORTFOLIO_ID, source="predicted")] ) - mock_uow.plan.save.assert_called_once() + mock_uow.plan.save_batch.assert_called_once() mock_uow.commit.assert_called_once() @@ -979,8 +979,9 @@ def test_batch_persists_in_one_transaction_and_one_baseline_run( "scenario_id": SCENARIO_ID, "refetch_solar": False, "dry_run": False} ) - # Assert — all three Plans saved, but a single shared transaction: - assert mock_uow.plan.save.call_count == 3 + # Assert — all three Plans saved in one batch call, but a single shared transaction: + mock_uow.plan.save_batch.assert_called_once() + assert len(mock_uow.plan.save_batch.call_args[0][0]) == 3 assert mock_uow.property.mark_modelled.call_count == 3 mock_uow.commit.assert_called_once() # One write Unit of Work opened for the whole batch, not one per property. @@ -1165,10 +1166,8 @@ def test_refetch_epc_false_with_stored_epc_skips_api_call() -> None: # Assert — API not called; stored EPC flows into run_modelling mock_epc_client.get_by_uprn.assert_not_called() mock_run_modelling.assert_called_once() - # Stored lodged EPC is persisted in the lodged slot - mock_uow.epc.save_batch.assert_called_once_with( - [EpcSaveRequest(stored_epc, property_id=PROPERTY_ID, portfolio_id=PORTFOLIO_ID, source="lodged")] - ) + # Stored EPC is NOT re-saved — it was read from DB unchanged (PR #1353) + mock_uow.epc.save_batch.assert_not_called() def test_refetch_epc_false_without_stored_epc_skips_api_and_goes_to_prediction() -> None: @@ -1388,11 +1387,9 @@ def test_repredict_epc_false_with_stored_predicted_epc_skips_prediction() -> Non # Act _call_handler({**_BODY, "repredict_epc": False}) - # Assert — EpcPrediction.predict never called; stored EPC persisted in predicted slot + # Assert — EpcPrediction.predict never called; stored predicted EPC NOT re-saved (PR #1353) mock_predictor.predict.assert_not_called() - mock_uow.epc.save_batch.assert_called_once_with( - [EpcSaveRequest(stored_predicted, property_id=PROPERTY_ID, portfolio_id=PORTFOLIO_ID, source="predicted")] - ) + mock_uow.epc.save_batch.assert_not_called() def test_repredict_epc_false_without_stored_predicted_epc_falls_back_to_live_prediction() -> None: diff --git a/tests/orchestration/fakes.py b/tests/orchestration/fakes.py index b068874e..d1294291 100644 --- a/tests/orchestration/fakes.py +++ b/tests/orchestration/fakes.py @@ -18,7 +18,7 @@ from domain.property_baseline.property_baseline_performance import PropertyBasel from domain.property.properties import Properties from domain.property.property import Property from repositories.epc.epc_postgres_repository import EpcSaveRequest -from repositories.plan.plan_repository import PlanRepository +from repositories.plan.plan_repository import PlanRepository, PlanSaveRequest from repositories.product.product_repository import ProductRepository from repositories.property_baseline.property_baseline_repository import PropertyBaselineRepository from repositories.epc.epc_repository import EpcRepository, EpcSource @@ -222,6 +222,18 @@ class FakePlanRepository(PlanRepository): self._next_id += 1 return plan_id + def save_batch(self, requests: list[PlanSaveRequest]) -> list[int]: + return [ + self.save( + r.plan, + property_id=r.property_id, + scenario_id=r.scenario_id, + portfolio_id=r.portfolio_id, + is_default=r.is_default, + ) + for r in requests + ] + class _UnsetProductRepo(ProductRepository): """Default for a `FakeUnitOfWork` built without a catalogue — raises if a