Merge pull request #1358 from Hestia-Homes/feature/batch-save-and-delete-epc

Batch saving and deleting of epc property data, bulk saving of plans
This commit is contained in:
Daniel Roth 2026-06-29 16:46:19 +01:00 committed by GitHub
commit 6c42065304
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 732 additions and 165 deletions

View file

@ -92,7 +92,8 @@ from repositories.comparable_properties.epc_comparable_properties_repository imp
EpcComparablePropertiesRepository,
SkippedCohortCert,
)
from repositories.epc.epc_postgres_repository import EpcPostgresRepository
from repositories.epc.epc_postgres_repository import EpcPostgresRepository, EpcSaveRequest
from repositories.plan.plan_repository import PlanSaveRequest
from repositories.geospatial.geospatial_s3_repository import (
GeospatialS3Repository,
ParquetReader,
@ -178,31 +179,41 @@ class _PropertyWrite:
def _flush_writes(engine: Engine, writes: list[_PropertyWrite]) -> None:
"""Persist a whole batch of modelled Properties in one Unit of Work.
Replays each Property's saves in dependency order (EPC → spatial → solar →
Plan mark-modelled) and commits once. All-or-nothing per batch: a failed
save rolls the whole transaction back and propagates, so the SQS message is
retried every save is an idempotent upsert, so a retry is safe. This mirrors
the PropertyBaselineOrchestrator's existing one-UoW-per-batch contract
(ADR-0012); per-property failures are isolated earlier, in the modelling loop,
EPC writes are batched by source (lodged group first, predicted group second)
so each source emits one DELETE pass + one INSERT pass regardless of batch
size, rather than N×per-property round-trips (ADR-0012). All other writes
(spatial, solar, plan, mark-modelled) remain per-property inside the same
transaction. All-or-nothing per batch: a failed save rolls the whole
transaction back so the SQS message is retried every save is an idempotent
upsert. Per-property failures are isolated earlier, in the modelling loop,
before a write is ever queued."""
lodged_requests = [
EpcSaveRequest(w.lodged_epc, property_id=w.property_id, portfolio_id=w.portfolio_id, source="lodged")
for w in writes
if w.lodged_epc is not None and w.lodged_epc_is_new
]
predicted_requests = [
EpcSaveRequest(w.predicted_epc, property_id=w.property_id, portfolio_id=w.portfolio_id, source="predicted")
for w in writes
if w.predicted_epc is not None and w.predicted_epc_is_new
]
with PostgresUnitOfWork(lambda: Session(engine)) as uow:
if lodged_requests:
uow.epc.save_batch(lodged_requests)
if predicted_requests:
uow.epc.save_batch(predicted_requests)
plan_requests = [
PlanSaveRequest(
w.plan,
property_id=w.property_id,
scenario_id=w.scenario_id,
portfolio_id=w.portfolio_id,
is_default=w.is_default,
)
for w in writes
]
uow.plan.save_batch(plan_requests)
for w in writes:
if w.lodged_epc is not None and w.lodged_epc_is_new:
uow.epc.save(
w.lodged_epc,
property_id=w.property_id,
portfolio_id=w.portfolio_id,
)
elif w.predicted_epc is not None and w.predicted_epc_is_new:
# Persist the synthesised EPC in the predicted slot (ADR-0031), so
# the Baseline stage can re-hydrate it and downstream sees the
# picture the Plan was modelled from.
uow.epc.save(
w.predicted_epc,
property_id=w.property_id,
portfolio_id=w.portfolio_id,
source="predicted",
)
if w.spatial is not None:
uow.spatial.save(w.uprn, w.spatial)
if w.solar is not None:
@ -212,13 +223,6 @@ def _flush_writes(engine: Engine, writes: list[_PropertyWrite]) -> None:
latitude=w.solar.latitude,
insights=w.solar.insights,
)
uow.plan.save(
w.plan,
property_id=w.property_id,
scenario_id=w.scenario_id,
portfolio_id=w.portfolio_id,
is_default=w.is_default,
)
uow.property.mark_modelled(
w.property_id, has_recommendations=w.has_recommendations
)

View file

@ -1,10 +1,12 @@
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass, field
from datetime import date, datetime
from typing import Optional, Protocol, TypeVar
from typing import Any, Optional, Protocol, TypeVar
from sqlmodel import Session, col, delete, select
from sqlalchemy import insert as _sa_insert
from sqlmodel import Session, SQLModel, col, delete, select
from datatypes.epc.domain.epc import Epc
from datatypes.epc.domain.epc_property_data import (
@ -54,6 +56,23 @@ from utilities.private import private
_T = TypeVar("_T")
@dataclass(frozen=True)
class EpcSaveRequest:
data: EpcPropertyData
property_id: Optional[int] = None
portfolio_id: Optional[int] = None
source: EpcSource = field(default="lodged")
def _col_values(model: SQLModel, exclude: frozenset[str] = frozenset()) -> dict[str, Any]:
"""Extract column-keyed values from a SQLModel instance for Core INSERT."""
return {
c.name: getattr(model, c.name) # type: ignore[union-attr]
for c in model.__table__.c # type: ignore[attr-defined]
if c.name not in exclude # type: ignore[union-attr]
}
def _require(value: Optional[_T], field: str) -> _T:
if value is None:
raise ValueError(f"epc_property row is missing required field {field!r}")
@ -111,75 +130,195 @@ class EpcPostgresRepository(EpcRepository):
portfolio_id: Optional[int] = None,
source: EpcSource = "lodged",
) -> int:
# Idempotent on (property_id, source): a re-run replaces the property's
# EPC graph for THAT source rather than duplicating it (ADR-0012), and a
# predicted save leaves the lodged one intact, and vice versa (ADR-0031).
# Anonymous saves (no property_id) always insert.
if property_id is not None:
self._delete_for_property(property_id, source)
parent = EpcPropertyModel.from_epc_property_data(
data, property_id=property_id, portfolio_id=portfolio_id, source=source
)
self._session.add(parent)
self._session.flush()
epc_property_id = _require(parent.id, "id")
return self.save_batch([EpcSaveRequest(data, property_id, portfolio_id, source)])[0]
self._session.add(
EpcPropertyEnergyPerformanceModel.from_epc_property_data(
data, epc_property_id=epc_property_id
)
)
for detail in data.sap_heating.main_heating_details:
self._session.add(
EpcMainHeatingDetailModel.from_domain(detail, epc_property_id)
)
for part in data.sap_building_parts:
bp = EpcBuildingPartModel.from_domain(part, epc_property_id)
self._session.add(bp)
self._session.flush()
bp_id = _require(bp.id, "epc_building_part.id")
for dim in part.sap_floor_dimensions:
self._session.add(EpcFloorDimensionModel.from_domain(dim, bp_id))
for window in data.sap_windows:
self._session.add(EpcWindowModel.from_domain(window, epc_property_id))
for index, array in enumerate(data.sap_energy_source.photovoltaic_arrays or []):
self._session.add(
EpcPhotovoltaicArrayModel.from_domain(array, index, epc_property_id)
)
def save_batch(self, requests: list[EpcSaveRequest]) -> list[int]:
"""Insert all EPCs in `requests` in one pass per table, returning one
epc_property_id per request in the same order as the input.
for element_type, elements in (
("roof", data.roofs),
("wall", data.walls),
("floor", data.floors),
("main_heating", data.main_heating),
Deletes are batched first (one IN-query per child table per source),
then all parent rows are inserted with a single RETURNING statement so
positional ordering maps each returned id to its request. Building-part
ids are captured the same way so floor-dimension FKs are resolved without
any per-property flush round-trips (ADR-0012).
"""
if not requests:
return []
# Batch-delete existing rows grouped by source so the lodged and predicted
# slots remain independent (ADR-0031).
pids_by_source: dict[EpcSource, list[int]] = {}
for r in requests:
if r.property_id is not None:
pids_by_source.setdefault(r.source, []).append(r.property_id)
for src, pids in pids_by_source.items():
self._delete_for_properties(pids, src)
# Insert all parent (epc_property) rows; capture returned ids positionally.
parent_rows = [
_col_values(
EpcPropertyModel.from_epc_property_data(
r.data, property_id=r.property_id, portfolio_id=r.portfolio_id, source=r.source
),
exclude=frozenset({"id"}),
)
for r in requests
]
returned_parents = self._session.execute( # type: ignore[deprecated]
_sa_insert(EpcPropertyModel).returning(EpcPropertyModel.__table__.c["id"]), # type: ignore[attr-defined]
parent_rows,
).all()
epc_property_ids = [row[0] for row in returned_parents]
# Collect child rows, accumulating building parts in an ordered list so
# the positional RETURNING trick can map part objects to their new ids.
perf_rows: list[dict[str, Any]] = []
heating_rows: list[dict[str, Any]] = []
parts_ordered: list[tuple[Any, int]] = [] # (SapBuildingPart, epc_property_id)
window_rows: list[dict[str, Any]] = []
pv_rows: list[dict[str, Any]] = []
element_rows: list[dict[str, Any]] = []
flat_rows: list[dict[str, Any]] = []
rhi_rows: list[dict[str, Any]] = []
for r, epc_pid in zip(requests, epc_property_ids):
d = r.data
perf_rows.append(
_col_values(
EpcPropertyEnergyPerformanceModel.from_epc_property_data(d, epc_pid),
exclude=frozenset({"id"}),
)
)
for detail in d.sap_heating.main_heating_details:
heating_rows.append(
_col_values(EpcMainHeatingDetailModel.from_domain(detail, epc_pid), frozenset({"id"}))
)
for part in d.sap_building_parts:
parts_ordered.append((part, epc_pid))
for window in d.sap_windows:
window_rows.append(
_col_values(EpcWindowModel.from_domain(window, epc_pid), frozenset({"id"}))
)
for idx, array in enumerate(d.sap_energy_source.photovoltaic_arrays or []):
pv_rows.append(
_col_values(EpcPhotovoltaicArrayModel.from_domain(array, idx, epc_pid), frozenset({"id"}))
)
for etype, els in (
("roof", d.roofs),
("wall", d.walls),
("floor", d.floors),
("main_heating", d.main_heating),
):
for el in els:
element_rows.append(
_col_values(EpcEnergyElementModel.from_domain(el, etype, epc_pid), frozenset({"id"}))
)
for el, etype in (
(d.window, "window"),
(d.lighting, "lighting"),
(d.hot_water, "hot_water"),
(d.secondary_heating, "secondary_heating"),
(d.main_heating_controls, "main_heating_controls"),
):
if el is not None:
element_rows.append(
_col_values(EpcEnergyElementModel.from_domain(el, etype, epc_pid), frozenset({"id"}))
)
if d.sap_flat_details is not None:
flat_rows.append(
_col_values(EpcFlatDetailsModel.from_domain(d.sap_flat_details, epc_pid), frozenset({"id"}))
)
if d.renewable_heat_incentive is not None:
rhi_rows.append(
_col_values(EpcRenewableHeatIncentiveModel.from_domain(d.renewable_heat_incentive, epc_pid), frozenset({"id"}))
)
# Bulk-insert all simple child tables (no downstream FK dependency).
if perf_rows:
self._session.execute(_sa_insert(EpcPropertyEnergyPerformanceModel), perf_rows) # type: ignore[deprecated]
if heating_rows:
self._session.execute(_sa_insert(EpcMainHeatingDetailModel), heating_rows) # type: ignore[deprecated]
if window_rows:
self._session.execute(_sa_insert(EpcWindowModel), window_rows) # type: ignore[deprecated]
if pv_rows:
self._session.execute(_sa_insert(EpcPhotovoltaicArrayModel), pv_rows) # type: ignore[deprecated]
if element_rows:
self._session.execute(_sa_insert(EpcEnergyElementModel), element_rows) # type: ignore[deprecated]
if flat_rows:
self._session.execute(_sa_insert(EpcFlatDetailsModel), flat_rows) # type: ignore[deprecated]
if rhi_rows:
self._session.execute(_sa_insert(EpcRenewableHeatIncentiveModel), rhi_rows) # type: ignore[deprecated]
# Building parts: insert with RETURNING and zip positionally to resolve
# floor-dimension FKs. Do NOT key by id(part) — the same EpcPropertyData
# object can appear in multiple requests (same epc, different property_ids),
# giving identical object ids that collapse the dict and mis-wire FKs.
# Positional zip is safe because PostgreSQL preserves VALUES order in RETURNING.
if parts_ordered:
bp_rows = [
_col_values(EpcBuildingPartModel.from_domain(part, epc_pid), frozenset({"id"}))
for part, epc_pid in parts_ordered
]
returned_bps = self._session.execute( # type: ignore[deprecated]
_sa_insert(EpcBuildingPartModel).returning(EpcBuildingPartModel.__table__.c["id"]), # type: ignore[attr-defined]
bp_rows,
).all()
floor_rows: list[dict[str, Any]] = [
_col_values(EpcFloorDimensionModel.from_domain(dim, bp_row[0]), frozenset({"id"}))
for (part, _), bp_row in zip(parts_ordered, returned_bps)
for dim in part.sap_floor_dimensions
]
if floor_rows:
self._session.execute(_sa_insert(EpcFloorDimensionModel), floor_rows) # type: ignore[deprecated]
return epc_property_ids
def _delete_for_properties(self, property_ids: list[int], source: EpcSource) -> None:
"""Batch-delete every EPC graph for the given property_ids and source in
one pass per child table (IN queries), replacing the per-property flush
loop that drove RDS CPU to saturation during bulk modelling runs."""
epc_ids = [
i
for i in self._session.exec(
select(EpcPropertyModel.id)
.where(col(EpcPropertyModel.property_id).in_(property_ids))
.where(EpcPropertyModel.source == source)
).all()
if i is not None
]
if not epc_ids:
return
part_ids = [
i
for i in self._session.exec(
select(EpcBuildingPartModel.id).where(
col(EpcBuildingPartModel.epc_property_id).in_(epc_ids)
)
).all()
if i is not None
]
if part_ids:
self._session.exec( # type: ignore[call-overload]
delete(EpcFloorDimensionModel).where(
col(EpcFloorDimensionModel.epc_building_part_id).in_(part_ids)
)
)
for child in (
EpcPropertyEnergyPerformanceModel,
EpcEnergyElementModel,
EpcMainHeatingDetailModel,
EpcBuildingPartModel,
EpcWindowModel,
EpcPhotovoltaicArrayModel,
EpcFlatDetailsModel,
EpcRenewableHeatIncentiveModel,
):
for el in elements:
self._session.add(
EpcEnergyElementModel.from_domain(el, element_type, epc_property_id)
)
for el, element_type in (
(data.window, "window"),
(data.lighting, "lighting"),
(data.hot_water, "hot_water"),
(data.secondary_heating, "secondary_heating"),
(data.main_heating_controls, "main_heating_controls"),
):
if el is not None:
self._session.add(
EpcEnergyElementModel.from_domain(el, element_type, epc_property_id)
)
if data.sap_flat_details is not None:
self._session.add(
EpcFlatDetailsModel.from_domain(data.sap_flat_details, epc_property_id)
self._session.exec( # type: ignore[call-overload]
delete(child).where(col(child.epc_property_id).in_(epc_ids))
)
if data.renewable_heat_incentive is not None:
self._session.add(
EpcRenewableHeatIncentiveModel.from_domain(
data.renewable_heat_incentive, epc_property_id
)
)
return epc_property_id
self._session.exec( # type: ignore[call-overload]
delete(EpcPropertyModel).where(col(EpcPropertyModel.id).in_(epc_ids))
)
def _delete_for_property(self, property_id: int, source: EpcSource) -> None:
"""Remove the property's existing EPC graph for `source` (parent + child

View file

@ -1,10 +1,13 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Literal, Optional
from typing import TYPE_CHECKING, Literal, Optional
from datatypes.epc.domain.epc_property_data import EpcPropertyData
if TYPE_CHECKING:
from repositories.epc.epc_postgres_repository import EpcSaveRequest
# Provenance of a persisted EPC picture (ADR-0031): a real "lodged" EPC, or a
# "predicted" one synthesised by EPC Prediction. A property can hold one of each.
EpcSource = Literal["lodged", "predicted"]
@ -29,6 +32,9 @@ class EpcRepository(ABC):
source: EpcSource = "lodged",
) -> int: ...
@abstractmethod
def save_batch(self, requests: "list[EpcSaveRequest]") -> list[int]: ...
@abstractmethod
def get(self, epc_property_id: int) -> EpcPropertyData: ...

View file

@ -1,10 +1,22 @@
from __future__ import annotations
from typing import Any
from sqlalchemy import insert as _sa_insert
from sqlmodel import Session, col, update
from domain.modelling.plan import Plan
from infrastructure.postgres.modelling import PlanModel, RecommendationModel
from repositories.plan.plan_repository import PlanRepository
from repositories.plan.plan_repository import PlanRepository, PlanSaveRequest
def _col_values(model: Any, exclude: frozenset[str] = frozenset()) -> dict[str, Any]:
"""Extract column-keyed values from a SQLModel instance for Core INSERT."""
return {
c.name: getattr(model, c.name)
for c in model.__table__.c
if c.name not in exclude
}
class PlanPostgresRepository(PlanRepository):
@ -29,37 +41,70 @@ class PlanPostgresRepository(PlanRepository):
portfolio_id: int,
is_default: bool,
) -> int:
# Soft-replace (ADR-0012): keep prior Plans as history rather than DELETEing
# them — the cascade delete of recommendation rows was the slow part. When
# this Plan is the default, demote every prior Plan for the same
# (property_id, scenario_id) to is_default=False, so exactly one Plan for
# the pair stays default (the one just inserted).
if is_default:
return self.save_batch(
[PlanSaveRequest(plan, property_id=property_id, scenario_id=scenario_id, portfolio_id=portfolio_id, is_default=is_default)]
)[0]
def save_batch(self, requests: list[PlanSaveRequest]) -> list[int]:
"""Persist all Plans in three statements regardless of batch size.
1. One demote UPDATE (only when any request has ``is_default=True``).
2. One bulk plan INSERT with RETURNING to capture ids positionally.
3. One bulk recommendation INSERT (skipped when no measures exist).
"""
if not requests:
return []
# Demote prior default Plans for every property in the batch that is
# receiving a new default Plan — one UPDATE for the whole batch.
default_pids = [r.property_id for r in requests if r.is_default]
if default_pids:
# scenario_id is uniform per batch (one scenario per SQS message).
scenario_id = requests[0].scenario_id
self._session.exec( # type: ignore[call-overload]
update(PlanModel)
.where(
col(PlanModel.property_id) == property_id,
col(PlanModel.property_id).in_(default_pids),
col(PlanModel.scenario_id) == scenario_id,
)
.values(is_default=False)
)
plan_row = PlanModel.from_domain(
plan,
property_id=property_id,
scenario_id=scenario_id,
portfolio_id=portfolio_id,
is_default=is_default,
)
self._session.add(plan_row)
self._session.flush()
if plan_row.id is None:
raise ValueError("plan row did not receive an id")
for measure in plan.measures:
self._session.add(
RecommendationModel.from_domain(
measure, property_id=property_id, plan_id=plan_row.id
)
# Bulk INSERT all plan rows; capture returned ids positionally.
plan_rows = [
_col_values(
PlanModel.from_domain(
r.plan,
property_id=r.property_id,
scenario_id=r.scenario_id,
portfolio_id=r.portfolio_id,
is_default=r.is_default,
),
exclude=frozenset({"id"}),
)
return plan_row.id
for r in requests
]
returned = self._session.execute( # type: ignore[deprecated]
_sa_insert(PlanModel).returning(PlanModel.__table__.c["id"]), # type: ignore[attr-defined]
plan_rows,
).all()
plan_ids = [row[0] for row in returned]
# Accumulate recommendation rows across all requests; properties with
# zero measures contribute nothing (no special-casing needed).
rec_rows = [
_col_values(
RecommendationModel.from_domain(
measure, property_id=r.property_id, plan_id=plan_id
),
exclude=frozenset({"id"}),
)
for r, plan_id in zip(requests, plan_ids)
for measure in r.plan.measures
]
if rec_rows:
self._session.execute( # type: ignore[deprecated]
_sa_insert(RecommendationModel), rec_rows
)
return plan_ids

View file

@ -1,10 +1,25 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from domain.modelling.plan import Plan
@dataclass(frozen=True)
class PlanSaveRequest:
"""Bundles the five fields the plan repository needs to persist one Plan.
Mirrors ``EpcSaveRequest`` in shape used by ``PlanRepository.save_batch()``
to accumulate write intent before the batch is flushed in one Unit of Work."""
plan: Plan
property_id: int
scenario_id: int
portfolio_id: int
is_default: bool
class PlanRepository(ABC):
"""Persists a Plan (and its Plan Measures) for a Property + Scenario.
@ -30,3 +45,12 @@ class PlanRepository(ABC):
``(property_id, scenario_id)`` as history; when ``is_default`` is True,
demotes those prior Plans to ``is_default=False``."""
...
@abstractmethod
def save_batch(self, requests: list[PlanSaveRequest]) -> list[int]:
"""Persist a batch of Plans in three statements regardless of batch size.
Returns one plan id per request in input order. Fires a single demote
UPDATE only when at least one request has ``is_default=True``. Keeps
prior Plans as history (ADR-0017)."""
...

View file

@ -19,6 +19,7 @@ from applications.modelling_e2e.modelling_e2e_trigger_body import (
ModellingE2ETriggerBody,
)
from domain.tasks.subtasks import SubTask
from repositories.epc.epc_postgres_repository import EpcSaveRequest
PROPERTY_ID = 12345
UPRN = 987654321
@ -355,10 +356,10 @@ def test_lodged_epc_path_saves_epc_plan_and_marks_modelled(
_call_handler(_BODY)
# Assert — EPC saved (lodged path), plan saved, property marked modelled
mock_uow.epc.save.assert_called_once_with(
mock_epc, property_id=PROPERTY_ID, portfolio_id=PORTFOLIO_ID
mock_uow.epc.save_batch.assert_called_once_with(
[EpcSaveRequest(mock_epc, property_id=PROPERTY_ID, portfolio_id=PORTFOLIO_ID, source="lodged")]
)
mock_uow.plan.save.assert_called_once()
mock_uow.plan.save_batch.assert_called_once()
mock_uow.property.mark_modelled.assert_called_once_with(
PROPERTY_ID, has_recommendations=False
)
@ -447,7 +448,7 @@ def test_skipped_cohort_certs_do_not_prevent_plan_being_saved() -> None:
_call_handler(_BODY)
# Assert — plan committed despite the skipped cert
mock_uow.plan.save.assert_called_once()
mock_uow.plan.save_batch.assert_called_once()
mock_uow.commit.assert_called_once()
@ -512,7 +513,7 @@ def test_skipped_cohort_certs_are_logged_and_handler_does_not_raise() -> None:
_call_handler(_BODY)
# Assert — plan committed; skipped cert number surfaced in a log call
mock_uow.plan.save.assert_called_once()
mock_uow.plan.save_batch.assert_called_once()
mock_uow.commit.assert_called_once()
logged_messages = " ".join(
str(c.args) + str(c.kwargs) for c in mock_logger.info.call_args_list
@ -631,13 +632,10 @@ def test_prediction_path_saves_predicted_epc_plan_and_baseline(
_call_handler(_BODY)
# Assert — predicted EPC persisted in the predicted slot, plan saved, baseline run
mock_uow.epc.save.assert_called_once_with(
mock_predicted_epc,
property_id=PROPERTY_ID,
portfolio_id=PORTFOLIO_ID,
source="predicted",
mock_uow.epc.save_batch.assert_called_once_with(
[EpcSaveRequest(mock_predicted_epc, property_id=PROPERTY_ID, portfolio_id=PORTFOLIO_ID, source="predicted")]
)
mock_uow.plan.save.assert_called_once()
mock_uow.plan.save_batch.assert_called_once()
mock_uow.commit.assert_called_once()
_baseline_orchestrator.return_value.run.assert_called_once_with([PROPERTY_ID])
@ -841,13 +839,10 @@ def test_empty_own_postcode_broadens_to_nearby_and_predicts() -> None:
# Assert — broadening fired, and the broadened cohort produced a saved plan
# with its predicted EPC persisted in the predicted slot.
MockRepo.return_value.candidates_near.assert_called_once()
mock_uow.epc.save.assert_called_once_with(
mock_predicted_epc,
property_id=PROPERTY_ID,
portfolio_id=PORTFOLIO_ID,
source="predicted",
mock_uow.epc.save_batch.assert_called_once_with(
[EpcSaveRequest(mock_predicted_epc, property_id=PROPERTY_ID, portfolio_id=PORTFOLIO_ID, source="predicted")]
)
mock_uow.plan.save.assert_called_once()
mock_uow.plan.save_batch.assert_called_once()
mock_uow.commit.assert_called_once()
@ -984,8 +979,9 @@ def test_batch_persists_in_one_transaction_and_one_baseline_run(
"scenario_id": SCENARIO_ID, "refetch_solar": False, "dry_run": False}
)
# Assert — all three Plans saved, but a single shared transaction:
assert mock_uow.plan.save.call_count == 3
# Assert — all three Plans saved in one batch call, but a single shared transaction:
mock_uow.plan.save_batch.assert_called_once()
assert len(mock_uow.plan.save_batch.call_args[0][0]) == 3
assert mock_uow.property.mark_modelled.call_count == 3
mock_uow.commit.assert_called_once()
# One write Unit of Work opened for the whole batch, not one per property.
@ -1170,10 +1166,8 @@ def test_refetch_epc_false_with_stored_epc_skips_api_call() -> None:
# Assert — API not called; stored EPC flows into run_modelling
mock_epc_client.get_by_uprn.assert_not_called()
mock_run_modelling.assert_called_once()
# Stored lodged EPC is persisted in the lodged slot
mock_uow.epc.save.assert_called_once_with(
stored_epc, property_id=PROPERTY_ID, portfolio_id=PORTFOLIO_ID
)
# Stored EPC is NOT re-saved — it was read from DB unchanged (PR #1353)
mock_uow.epc.save_batch.assert_not_called()
def test_refetch_epc_false_without_stored_epc_skips_api_and_goes_to_prediction() -> None:
@ -1258,11 +1252,8 @@ def test_refetch_epc_false_without_stored_epc_skips_api_and_goes_to_prediction()
# Assert — API was NOT called; prediction ran and its output was persisted
mock_epc_client.get_by_uprn.assert_not_called()
mock_uow.epc.save.assert_called_once_with(
mock_predicted_epc,
property_id=PROPERTY_ID,
portfolio_id=PORTFOLIO_ID,
source="predicted",
mock_uow.epc.save_batch.assert_called_once_with(
[EpcSaveRequest(mock_predicted_epc, property_id=PROPERTY_ID, portfolio_id=PORTFOLIO_ID, source="predicted")]
)
@ -1396,14 +1387,9 @@ def test_repredict_epc_false_with_stored_predicted_epc_skips_prediction() -> Non
# Act
_call_handler({**_BODY, "repredict_epc": False})
# Assert — EpcPrediction.predict never called; stored EPC persisted in predicted slot
# Assert — EpcPrediction.predict never called; stored predicted EPC NOT re-saved (PR #1353)
mock_predictor.predict.assert_not_called()
mock_uow.epc.save.assert_called_once_with(
stored_predicted,
property_id=PROPERTY_ID,
portfolio_id=PORTFOLIO_ID,
source="predicted",
)
mock_uow.epc.save_batch.assert_not_called()
def test_repredict_epc_false_without_stored_predicted_epc_falls_back_to_live_prediction() -> None:
@ -1488,11 +1474,8 @@ def test_repredict_epc_false_without_stored_predicted_epc_falls_back_to_live_pre
# Assert — live prediction was used as fallback
mock_predictor.predict.assert_called_once()
mock_uow.epc.save.assert_called_once_with(
mock_predicted_epc,
property_id=PROPERTY_ID,
portfolio_id=PORTFOLIO_ID,
source="predicted",
mock_uow.epc.save_batch.assert_called_once_with(
[EpcSaveRequest(mock_predicted_epc, property_id=PROPERTY_ID, portfolio_id=PORTFOLIO_ID, source="predicted")]
)

View file

@ -17,7 +17,8 @@ from domain.modelling.scenario import Scenario
from domain.property_baseline.property_baseline_performance import PropertyBaselinePerformance
from domain.property.properties import Properties
from domain.property.property import Property
from repositories.plan.plan_repository import PlanRepository
from repositories.epc.epc_postgres_repository import EpcSaveRequest
from repositories.plan.plan_repository import PlanRepository, PlanSaveRequest
from repositories.product.product_repository import ProductRepository
from repositories.property_baseline.property_baseline_repository import PropertyBaselineRepository
from repositories.epc.epc_repository import EpcRepository, EpcSource
@ -130,6 +131,9 @@ class FakeEpcRepo(EpcRepository):
if property_id in self._predicted_by_property
}
def save_batch(self, requests: list[EpcSaveRequest]) -> list[int]:
return [self.save(r.data, r.property_id, r.portfolio_id, r.source) for r in requests]
class FakeSolarRepo(SolarRepository):
"""In-memory Google Solar insights store keyed by UPRN. Seed `by_uprn` to
@ -218,6 +222,18 @@ class FakePlanRepository(PlanRepository):
self._next_id += 1
return plan_id
def save_batch(self, requests: list[PlanSaveRequest]) -> list[int]:
return [
self.save(
r.plan,
property_id=r.property_id,
scenario_id=r.scenario_id,
portfolio_id=r.portfolio_id,
is_default=r.is_default,
)
for r in requests
]
class _UnsetProductRepo(ProductRepository):
"""Default for a `FakeUnitOfWork` built without a catalogue — raises if a

View file

@ -0,0 +1,167 @@
"""Batch EPC write path — save_batch() correctness and safety tests.
Guards the four user stories from #1348:
1. FK mis-wiring regression: building-part IDs must not be crossed between
properties in the same save_batch() call.
2. save()/save_batch() parity: the single-property delegation path is loss-free.
3. Batch idempotency: a second save_batch() with the same requests replaces,
not duplicates.
4. Source isolation: lodged and predicted slots coexist after separate
save_batch() calls on the same property IDs.
"""
from __future__ import annotations
import json
from dataclasses import replace
from pathlib import Path
from typing import Any
from sqlalchemy import Engine
from sqlmodel import Session
from datatypes.epc.domain.epc_property_data import EpcPropertyData
from datatypes.epc.domain.mapper import EpcPropertyDataMapper
from repositories.epc.epc_postgres_repository import EpcPostgresRepository, EpcSaveRequest
_JSON_SAMPLES = Path(__file__).resolve().parents[3] / "backend/epc_api/json_samples"
def _load_epc(schema_dir: str = "RdSAP-Schema-21.0.0") -> EpcPropertyData:
raw: dict[str, Any] = json.loads(
(_JSON_SAMPLES / schema_dir / "epc.json").read_text()
)
return EpcPropertyDataMapper.from_api_response(raw)
def _with_floor_areas(epc: EpcPropertyData, areas_m2: list[float]) -> EpcPropertyData:
"""Replace the building parts with variants that have a single floor dimension
carrying the given total_floor_area_m2 making them easy to distinguish after
a round-trip without changing anything else about the EPC."""
template_bp = epc.sap_building_parts[0]
template_dim = template_bp.sap_floor_dimensions[0]
new_parts = [
replace(template_bp, sap_floor_dimensions=[replace(template_dim, total_floor_area_m2=a)])
for a in areas_m2
]
return replace(epc, sap_building_parts=new_parts)
# ---------------------------------------------------------------------------
# Tracer bullet: single-request save_batch() is loss-free vs save()
# ---------------------------------------------------------------------------
def test_single_request_save_batch_matches_save(db_engine: Engine) -> None:
# Arrange
epc = _load_epc()
with Session(db_engine) as session:
repo = EpcPostgresRepository(session)
epc_id_via_save = repo.save(epc, property_id=1001)
epc_id_via_batch = repo.save_batch([EpcSaveRequest(epc, property_id=1002)])[0]
session.commit()
# Act
with Session(db_engine) as session:
repo = EpcPostgresRepository(session)
via_save = repo.get(epc_id_via_save)
via_batch = repo.get(epc_id_via_batch)
# Assert — both paths reconstruct the original exactly.
assert via_save == epc
assert via_batch == epc
# ---------------------------------------------------------------------------
# FK mis-wiring regression: building-part IDs must not be crossed
# ---------------------------------------------------------------------------
def test_multi_property_building_part_ids_are_not_crossed(db_engine: Engine) -> None:
# Arrange — property A has 2 parts with distinctive areas; B has 1 with a
# third distinctive area. If part IDs are mis-wired the floor-dimension FK
# rows end up under the wrong property.
base = _load_epc()
epc_a = _with_floor_areas(base, [10.0, 20.0])
epc_b = _with_floor_areas(base, [99.0])
with Session(db_engine) as session:
repo = EpcPostgresRepository(session)
repo.save_batch([
EpcSaveRequest(epc_a, property_id=2001),
EpcSaveRequest(epc_b, property_id=2002),
])
session.commit()
# Act
with Session(db_engine) as session:
repo = EpcPostgresRepository(session)
reloaded_a = repo.get_for_property(2001)
reloaded_b = repo.get_for_property(2002)
# Assert — each property's building parts carry its own floor areas.
assert reloaded_a is not None
assert reloaded_b is not None
areas_a = sorted(
dim.total_floor_area_m2
for part in reloaded_a.sap_building_parts
for dim in part.sap_floor_dimensions
)
areas_b = sorted(
dim.total_floor_area_m2
for part in reloaded_b.sap_building_parts
for dim in part.sap_floor_dimensions
)
assert areas_a == [10.0, 20.0]
assert areas_b == [99.0]
# ---------------------------------------------------------------------------
# Idempotency: second save_batch() replaces, not duplicates
# ---------------------------------------------------------------------------
def test_save_batch_is_idempotent(db_engine: Engine) -> None:
# Arrange
epc = _load_epc()
requests = [EpcSaveRequest(epc, property_id=3001)]
with Session(db_engine) as session:
EpcPostgresRepository(session).save_batch(requests)
session.commit()
# Act — re-save the same batch.
with Session(db_engine) as session:
EpcPostgresRepository(session).save_batch(requests)
session.commit()
# Assert — exactly one EPC survives (no duplicate rows).
with Session(db_engine) as session:
result = EpcPostgresRepository(session).get_for_property(3001)
assert result == epc
# ---------------------------------------------------------------------------
# Source isolation: lodged and predicted slots survive separate batch saves
# ---------------------------------------------------------------------------
def test_lodged_and_predicted_batch_slots_are_independent(db_engine: Engine) -> None:
# Arrange — two properties each get a lodged EPC and then a predicted EPC
# via separate save_batch() calls.
epc = _load_epc()
property_ids = [4001, 4002]
with Session(db_engine) as session:
repo = EpcPostgresRepository(session)
repo.save_batch([EpcSaveRequest(epc, property_id=pid, source="lodged") for pid in property_ids])
repo.save_batch([EpcSaveRequest(epc, property_id=pid, source="predicted") for pid in property_ids])
session.commit()
# Act
with Session(db_engine) as session:
repo = EpcPostgresRepository(session)
lodged = repo.get_for_properties(property_ids)
predicted = repo.get_predicted_for_properties(property_ids)
# Assert — both slots are populated for both properties.
assert lodged == {4001: epc, 4002: epc}
assert predicted == {4001: epc, 4002: epc}

View file

@ -0,0 +1,183 @@
"""Batch plan write path — save_batch() correctness and safety tests.
Guards the four user stories from #1355:
1. save()/save_batch() parity: a single-element save_batch() produces
identical DB state (plan row + recommendation rows) as the equivalent
save() call.
2. Recommendation FK isolation: two properties in the same save_batch() each
get their own recommendation rows; no FK cross-wiring between properties.
3. Demote correctness: a second save_batch() for the same properties demotes
the prior default Plans and inserts fresh ones (history preserved).
4. Non-default batch: a save_batch() where all writes have is_default=False
leaves any pre-existing default Plan untouched.
"""
from __future__ import annotations
from sqlalchemy import Engine
from sqlmodel import Session, col, select
from domain.modelling.measure_type import MeasureType
from domain.modelling.plan import Plan, PlanMeasure
from domain.modelling.recommendation import Cost
from domain.modelling.scoring.package_scorer import Score
from domain.modelling.scoring.scoring import MeasureImpact
from infrastructure.postgres.modelling import PlanModel, RecommendationModel
from repositories.plan.plan_postgres_repository import PlanPostgresRepository
from repositories.plan.plan_repository import PlanSaveRequest
def _plan(*, sap: float = 70.0, measures: int = 1) -> Plan:
ms: tuple[PlanMeasure, ...] = tuple(
PlanMeasure(
measure_type=MeasureType.CAVITY_WALL_INSULATION,
description="Cavity wall insulation",
cost=Cost(total=1000.0, contingency_rate=0.10),
impact=MeasureImpact(
sap_points=8.0,
co2_savings_kg_per_yr=500.0,
energy_savings_kwh_per_yr=2000.0,
),
kwh_savings=1500.0,
energy_cost_savings=300.0,
)
for _ in range(measures)
)
return Plan(
measures=ms,
baseline=Score(sap_continuous=40.0, co2_kg_per_yr=4000.0, primary_energy_kwh_per_yr=20000.0),
post_retrofit=Score(sap_continuous=sap, co2_kg_per_yr=3500.0, primary_energy_kwh_per_yr=18000.0),
)
# ---------------------------------------------------------------------------
# Tracer bullet: single-element save_batch() is loss-free vs save()
# ---------------------------------------------------------------------------
def test_single_request_save_batch_matches_save(db_engine: Engine) -> None:
# Arrange
plan = _plan()
scenario_id = 7
with Session(db_engine) as session:
repo = PlanPostgresRepository(session)
save_id = repo.save(plan, property_id=5001, scenario_id=scenario_id, portfolio_id=1, is_default=True)
batch_id = repo.save_batch([PlanSaveRequest(plan, property_id=5002, scenario_id=scenario_id, portfolio_id=1, is_default=True)])[0]
session.commit()
# Act
with Session(db_engine) as session:
via_save = session.get(PlanModel, save_id)
via_batch = session.get(PlanModel, batch_id)
recs_save = session.exec(select(RecommendationModel).where(col(RecommendationModel.plan_id) == save_id)).all()
recs_batch = session.exec(select(RecommendationModel).where(col(RecommendationModel.plan_id) == batch_id)).all()
# Assert — both paths produce one plan row + one recommendation row with the
# same field values (modulo property_id which differs by design).
assert via_save is not None
assert via_batch is not None
assert via_save.is_default is True
assert via_batch.is_default is True
assert via_save.post_sap_points == via_batch.post_sap_points
assert via_save.post_co2_emissions == via_batch.post_co2_emissions
assert via_save.co2_savings == via_batch.co2_savings
assert len(recs_save) == 1
assert len(recs_batch) == 1
assert recs_save[0].type == recs_batch[0].type
assert recs_save[0].estimated_cost == recs_batch[0].estimated_cost
assert recs_save[0].sap_points == recs_batch[0].sap_points
assert recs_batch[0].plan_id == batch_id
# ---------------------------------------------------------------------------
# FK isolation: recommendation rows must not be crossed between properties
# ---------------------------------------------------------------------------
def test_multi_property_recommendation_fks_are_not_crossed(db_engine: Engine) -> None:
# Arrange — property A gets 2 measures, property B gets 1 measure.
plan_a = _plan(measures=2)
plan_b = _plan(measures=1)
with Session(db_engine) as session:
[id_a, id_b] = PlanPostgresRepository(session).save_batch([
PlanSaveRequest(plan_a, property_id=6001, scenario_id=7, portfolio_id=1, is_default=True),
PlanSaveRequest(plan_b, property_id=6002, scenario_id=7, portfolio_id=1, is_default=True),
])
session.commit()
# Act
with Session(db_engine) as session:
recs_a = session.exec(select(RecommendationModel).where(col(RecommendationModel.plan_id) == id_a)).all()
recs_b = session.exec(select(RecommendationModel).where(col(RecommendationModel.plan_id) == id_b)).all()
# Assert — A has 2 rows, B has 1; none cross-wired.
assert len(recs_a) == 2
assert len(recs_b) == 1
assert all(r.plan_id == id_a and r.property_id == 6001 for r in recs_a)
assert all(r.plan_id == id_b and r.property_id == 6002 for r in recs_b)
# ---------------------------------------------------------------------------
# Demote correctness: second save_batch() demotes prior defaults
# ---------------------------------------------------------------------------
def test_second_save_batch_demotes_prior_default_plans(db_engine: Engine) -> None:
# Arrange — first batch creates default Plans for two properties.
plan = _plan()
requests = [
PlanSaveRequest(plan, property_id=7001, scenario_id=7, portfolio_id=1, is_default=True),
PlanSaveRequest(plan, property_id=7002, scenario_id=7, portfolio_id=1, is_default=True),
]
with Session(db_engine) as session:
first_ids = PlanPostgresRepository(session).save_batch(requests)
session.commit()
# Act — re-run the same batch; new Plans should become default, old ones demoted.
with Session(db_engine) as session:
second_ids = PlanPostgresRepository(session).save_batch(requests)
session.commit()
# Assert — history is preserved (4 plan rows total); exactly one default per property.
with Session(db_engine) as session:
rows_7001 = session.exec(select(PlanModel).where(col(PlanModel.property_id) == 7001)).all()
rows_7002 = session.exec(select(PlanModel).where(col(PlanModel.property_id) == 7002)).all()
by_id_7001 = {p.id: p for p in rows_7001}
by_id_7002 = {p.id: p for p in rows_7002}
assert len(rows_7001) == 2
assert len(rows_7002) == 2
assert by_id_7001[first_ids[0]].is_default is False
assert by_id_7001[second_ids[0]].is_default is True
assert by_id_7002[first_ids[1]].is_default is False
assert by_id_7002[second_ids[1]].is_default is True
# ---------------------------------------------------------------------------
# Non-default batch: existing default Plan is untouched
# ---------------------------------------------------------------------------
def test_non_default_save_batch_does_not_demote_existing_default(db_engine: Engine) -> None:
# Arrange — a default Plan already exists for the property.
plan = _plan()
with Session(db_engine) as session:
default_id = PlanPostgresRepository(session).save(
plan, property_id=8001, scenario_id=7, portfolio_id=1, is_default=True
)
session.commit()
# Act — save a non-default Plan via save_batch(); no demote UPDATE should fire.
with Session(db_engine) as session:
PlanPostgresRepository(session).save_batch([
PlanSaveRequest(plan, property_id=8001, scenario_id=7, portfolio_id=1, is_default=False),
])
session.commit()
# Assert — the original default Plan is still the default.
with Session(db_engine) as session:
rows = session.exec(select(PlanModel).where(col(PlanModel.property_id) == 8001)).all()
by_id = {p.id: p for p in rows}
assert len(rows) == 2
assert by_id[default_id].is_default is True
assert sum(1 for p in rows if p.is_default) == 1