diff --git a/applications/modelling_e2e/handler.py b/applications/modelling_e2e/handler.py index 7d011fef..9454e6d4 100644 --- a/applications/modelling_e2e/handler.py +++ b/applications/modelling_e2e/handler.py @@ -4,9 +4,9 @@ One SQS message = one batch of properties sharing a portfolio, scenario, and (by caller convention) postcode. The handler reads ``property_ids``, ``portfolio_id``, ``scenario_id``, ``no_solar``, and ``dry_run`` from the message body, fetches or predicts each property's EPC, runs the full modelling -pipeline (SAP10 → optimiser) via ``harness.console.run_modelling``, and -persists the resulting Plan via ``PostgresUnitOfWork`` in one atomic transaction -per property. +pipeline (SAP10 → optimiser) via ``harness.console.run_modelling``, buffers each +resulting Plan in memory, and persists the whole batch via ``PostgresUnitOfWork`` +in one atomic transaction at the end. When no lodged EPC is found, EPC Prediction (Path 3, ADR-0031) synthesises one from the postcode cohort. ``_cohort_cache`` is module-level so warm Lambda @@ -21,16 +21,18 @@ crashing. The DB engine is module-scoped (ADR-0012). Architecturally each invocation uses one DB connection at a time: the handler reads everything up front — overrides, Scenario, a catalogue snapshot, and stored Solar — through one short-lived read -Session, closes it, then writes each Property in a sequential Unit of Work whose -overrides resolve on its own session, so no two Sessions ever overlap. The engine -uses ``NullPool`` rather than a fixed pool so that target is a graceful ceiling, -not a hard one: a fresh connection is opened per checkout and closed on return, -so there is no pool slot to exhaust — any future accidental overlap opens a -transient second connection instead of dead-locking the Lambda. +Session, closes it, models the batch (buffering each Plan in memory), then +persists the whole batch in one end-of-batch Unit of Work whose overrides resolve +on its own session, so no two Sessions ever overlap. The engine uses ``NullPool`` +rather than a fixed pool so that target is a graceful ceiling, not a hard one: a +fresh connection is opened per checkout and closed on return, so there is no pool +slot to exhaust — any future accidental overlap opens a transient second +connection instead of dead-locking the Lambda. """ from __future__ import annotations +import dataclasses import io import os from collections.abc import Callable, Generator @@ -59,6 +61,7 @@ from domain.epc_prediction.prediction_target import ( from domain.geospatial.coordinates import Coordinates from domain.geospatial.planning_restrictions import PlanningRestrictions from domain.geospatial.spatial_reference import SpatialReference +from domain.modelling.plan import Plan from domain.property.property import Property, PropertyIdentity from domain.property_baseline.calculator_rebaseliner import CalculatorRebaseliner from domain.sap10_calculator.calculator import Sap10Calculator @@ -134,6 +137,90 @@ _nearby_cohort_cache: dict[tuple[str, str], list[ComparableProperty]] = {} logger = setup_logger() +@dataclasses.dataclass(frozen=True) +class _SolarWrite: + """A freshly-fetched Solar insight queued for persistence. Only set when the + insight was fetched this run — stored insights are never re-written.""" + + uprn: int + longitude: float + latitude: float + insights: dict[str, Any] + + +@dataclasses.dataclass(frozen=True) +class _PropertyWrite: + """One modelled Property's full persistence intent, accumulated in memory + during the compute loop and replayed in a single end-of-batch Unit of Work. + + Buffering the writes (rather than committing per property) keeps the single + pooled connection idle through the CPU-bound modelling loop, then collapses + the whole batch into one transaction — far fewer statements for RDS to parse, + plan, and commit, which is the RDS-CPU bottleneck this targets (ADR-0012).""" + + property_id: int + uprn: int + portfolio_id: int + scenario_id: int + is_default: bool + lodged_epc: Optional[EpcPropertyData] + predicted_epc: Optional[EpcPropertyData] + spatial: Optional[SpatialReference] + solar: Optional[_SolarWrite] + plan: Plan + has_recommendations: bool + + +def _flush_writes(engine: Engine, writes: list[_PropertyWrite]) -> None: + """Persist a whole batch of modelled Properties in one Unit of Work. + + Replays each Property's saves in dependency order (EPC → spatial → solar → + Plan → mark-modelled) and commits once. All-or-nothing per batch: a failed + save rolls the whole transaction back and propagates, so the SQS message is + retried — every save is an idempotent upsert, so a retry is safe. This mirrors + the PropertyBaselineOrchestrator's existing one-UoW-per-batch contract + (ADR-0012); per-property failures are isolated earlier, in the modelling loop, + before a write is ever queued.""" + with PostgresUnitOfWork(lambda: Session(engine)) as uow: + for w in writes: + if w.lodged_epc is not None: + uow.epc.save( + w.lodged_epc, + property_id=w.property_id, + portfolio_id=w.portfolio_id, + ) + elif w.predicted_epc is not None: + # Persist the synthesised EPC in the predicted slot (ADR-0031), so + # the Baseline stage can re-hydrate it and downstream sees the + # picture the Plan was modelled from. + uow.epc.save( + w.predicted_epc, + property_id=w.property_id, + portfolio_id=w.portfolio_id, + source="predicted", + ) + if w.spatial is not None: + uow.spatial.save(w.uprn, w.spatial) + if w.solar is not None: + uow.solar.save( + w.solar.uprn, + longitude=w.solar.longitude, + latitude=w.solar.latitude, + insights=w.solar.insights, + ) + uow.plan.save( + w.plan, + property_id=w.property_id, + scenario_id=w.scenario_id, + portfolio_id=w.portfolio_id, + is_default=w.is_default, + ) + uow.property.mark_modelled( + w.property_id, has_recommendations=w.has_recommendations + ) + uow.commit() + + def _get_engine() -> Engine: global _engine if _engine is None: @@ -317,25 +404,22 @@ def handler(body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator, solar_client = GoogleSolarApiClient(os.environ["GOOGLE_SOLAR_API_KEY"]) with engine.connect() as conn: - uprn_rows = conn.execute( - text("SELECT id, uprn FROM property WHERE id = ANY(:ids)"), - {"ids": property_ids}, - ).fetchall() - postcode_rows = conn.execute( - text("SELECT id, postcode FROM property WHERE id = ANY(:ids)"), + property_rows = conn.execute( + text("SELECT id, uprn, postcode FROM property WHERE id = ANY(:ids)"), {"ids": property_ids}, ).fetchall() - uprns: dict[int, int] = {int(row[0]): int(row[1]) for row in uprn_rows} - postcodes: dict[int, str] = {int(row[0]): (row[1] or "") for row in postcode_rows} + uprns: dict[int, int] = {int(row[0]): int(row[1]) for row in property_rows} + postcodes: dict[int, str] = {int(row[0]): (row[2] or "") for row in property_rows} - # Pre-fetch every Property's overrides up front (each call opens and closes - # its own short read Session) and serve them from memory through the loop, so - # no override read Session is held open alongside a write Unit of Work. + # Pre-fetch every Property's overrides up front in one query (one short read + # Session, opened and closed before the write loop) and serve them from memory + # through the loop, so no override read Session is held open alongside a write + # Unit of Work. overrides_postgres_reader = PropertyOverridesPostgresReader(lambda: Session(engine)) - overrides_by_pid: dict[int, ResolvedPropertyOverrides] = { - pid: overrides_postgres_reader.overrides_for(pid) for pid in property_ids - } + overrides_by_pid: dict[int, ResolvedPropertyOverrides] = ( + overrides_postgres_reader.overrides_for_many(property_ids) + ) overrides_reader = InMemoryPropertyOverridesReader(overrides_by_pid) prediction_attrs_reader = OverrideBackedPredictionAttributesReader(overrides_reader) comparables_repo = EpcComparablePropertiesRepository( @@ -368,10 +452,10 @@ def handler(body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator, ) return _nearby_cohort_cache[key] - # Re-establishes each lodged Property's Baseline Performance from the just- - # persisted EPC (one UoW per property, committed after the Plan's). Predicted - # Properties have no lodged figures, so they get no baseline (mirrors the e2e - # runner and the ara_first_run Baseline stage). + # Re-establishes every written Property's Baseline Performance from the just- + # persisted EPCs. Run once for the whole batch after the write flush — the + # orchestrator already does the batch in one UoW (ADR-0012) — rather than once + # per property, so the batch costs one baseline transaction, not N. baseline_orchestrator = PropertyBaselineOrchestrator( unit_of_work=lambda: PostgresUnitOfWork(lambda: Session(engine)), rebaseliner=CalculatorRebaseliner(Sap10Calculator()), @@ -380,24 +464,28 @@ def handler(body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator, read_session = Session(engine) try: - # Read everything the write loop needs up front: the Scenario, an in-memory - # snapshot of the catalogue (priced after the Session closes), and each - # UPRN's stored Solar insights. Then close the read Session immediately so - # its pooled connection is free before the loop — each Property's write - # Unit of Work reuses that single connection rather than opening a second - # alongside a held-open read Session. (The ``finally`` is the safety net.) + # Read everything the modelling loop needs up front: the Scenario, an + # in-memory snapshot of the catalogue (priced after the Session closes), + # and each UPRN's stored Solar insights. Then close the read Session + # immediately so its pooled connection is free for the single end-of-batch + # write Unit of Work — no write ever opens a second connection alongside a + # held-open read Session. (The ``finally`` is the safety net.) scenario = ScenarioPostgresRepository(read_session).get_many([scenario_id])[0] products = catalogue_snapshot_with_off_catalogue_overrides(read_session) stored_solar: dict[int, Optional[dict[str, Any]]] = ( {} if no_solar - else { - uprn: SolarPostgresRepository(read_session).get(uprn) - for uprn in set(uprns.values()) - } + else SolarPostgresRepository(read_session).get_many( + list(set(uprns.values())) + ) ) read_session.close() + # Each Property models in its own child SubTask (failures isolated here), + # appending its persistence intent to this buffer instead of writing — the + # whole batch is flushed in one transaction after the loop. + accumulated: list[_PropertyWrite] = [] + for property_id in property_ids: child = orchestrator.create_child_subtask( task_id, inputs={"property_id": property_id} @@ -495,56 +583,60 @@ def handler(body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator, ) return - with PostgresUnitOfWork(lambda: Session(engine)) as uow: - if epc is not None: - uow.epc.save( - epc, property_id=pid, portfolio_id=portfolio_id - ) - elif predicted_epc is not None: - # Persist the synthesised EPC in the predicted slot (ADR-0031), - # so the Baseline stage can re-hydrate it and downstream sees - # the picture the Plan was modelled from. - uow.epc.save( - predicted_epc, - property_id=pid, - portfolio_id=portfolio_id, - source="predicted", - ) - if spatial is not None: - uow.spatial.save(uprn, spatial) - if ( - solar_was_fetched - and solar_insights is not None - and spatial is not None - and spatial.coordinates is not None - ): - uow.solar.save( - uprn, - longitude=spatial.coordinates.longitude, - latitude=spatial.coordinates.latitude, - insights=solar_insights, - ) - uow.plan.save( - plan, - property_id=pid, - scenario_id=scenario_id, - portfolio_id=portfolio_id, - is_default=scenario.is_default, + solar_write: Optional[_SolarWrite] = None + if ( + solar_was_fetched + and solar_insights is not None + and spatial is not None + and spatial.coordinates is not None + ): + solar_write = _SolarWrite( + uprn=uprn, + longitude=spatial.coordinates.longitude, + latitude=spatial.coordinates.latitude, + insights=solar_insights, ) - uow.property.mark_modelled( - pid, has_recommendations=bool(plan.measures) - ) - uow.commit() - logger.info(f"property={pid} plan saved") - baseline_orchestrator.run([pid]) - logger.info(f"property={pid} baseline saved") + # Queue this Property's writes rather than committing now — the + # whole batch is persisted in one Unit of Work after the loop + # (see _flush_writes). The EPC is saved in its lodged or predicted + # slot (ADR-0031) at flush time depending on which is set here. + accumulated.append( + _PropertyWrite( + property_id=pid, + uprn=uprn, + portfolio_id=portfolio_id, + scenario_id=scenario_id, + is_default=scenario.is_default, + lodged_epc=epc, + predicted_epc=predicted_epc, + spatial=spatial, + solar=solar_write, + plan=plan, + has_recommendations=bool(plan.measures), + ) + ) + logger.info(f"property={pid} queued for write") try: orchestrator.run_subtask(child.id, work=_work) except Exception: # noqa: BLE001 pass + # Persist the whole batch in one transaction, then re-establish every + # written Property's Baseline (the orchestrator batches its own UoW). The + # N per-property write transactions plus N baseline transactions collapse + # to two — the RDS-CPU win. Skipped entirely on a dry run or an all-failed + # batch, where nothing was queued. + if accumulated: + _flush_writes(engine, accumulated) + baseline_orchestrator.run([w.property_id for w in accumulated]) + logger.info( + f"persisted {len(accumulated)} " + f"{'property' if len(accumulated) == 1 else 'properties'} " + f"and baselines" + ) + # Cohort certs the mapper could not consume were skipped (not aborted on) # so prediction could proceed; surface them — with cert numbers — in the # subtask outputs so the mapper gaps can be closed later. diff --git a/repositories/property/property_overrides_postgres_reader.py b/repositories/property/property_overrides_postgres_reader.py index c30993b6..c3aa2b1c 100644 --- a/repositories/property/property_overrides_postgres_reader.py +++ b/repositories/property/property_overrides_postgres_reader.py @@ -46,6 +46,33 @@ def _resolve_overrides(session: Session, property_id: int) -> ResolvedPropertyOv ) +def _resolve_overrides_many( + session: Session, property_ids: list[int] +) -> dict[int, ResolvedPropertyOverrides]: + """Resolve overrides for many Properties in one query. Returns an entry for + every requested id; a Property with no rows resolves to an empty snapshot.""" + rows = session.exec( + select(PropertyOverrideRow).where( + col(PropertyOverrideRow.property_id).in_(property_ids) + ) + ).all() + grouped: dict[int, list[ResolvedPropertyOverride]] = { + property_id: [] for property_id in property_ids + } + for row in rows: + grouped.setdefault(row.property_id, []).append( + ResolvedPropertyOverride( + override_component=row.override_component, + building_part=row.building_part, + override_value=row.override_value, + ) + ) + return { + property_id: ResolvedPropertyOverrides(rows=tuple(overrides)) + for property_id, overrides in grouped.items() + } + + class PropertyOverridesPostgresReader(PropertyOverridesReader): """Opens its own short read session per call — for standalone use outside a Unit of Work, where there is no shared session/connection to reuse.""" @@ -57,6 +84,16 @@ class PropertyOverridesPostgresReader(PropertyOverridesReader): with self._session_factory() as session: return _resolve_overrides(session, property_id) + def overrides_for_many( + self, property_ids: list[int] + ) -> dict[int, ResolvedPropertyOverrides]: + """Every requested Property's resolved overrides in one query — the batch + form of ``overrides_for``. The returned dict has an entry for every + requested id; a Property with no rows resolves to an empty snapshot + (exactly as ``overrides_for`` returns for a Property with no rows).""" + with self._session_factory() as session: + return _resolve_overrides_many(session, property_ids) + class OpenSessionPropertyOverridesReader(PropertyOverridesReader): """Reads on a caller-owned, already-open session without closing it — for use diff --git a/repositories/solar/solar_postgres_repository.py b/repositories/solar/solar_postgres_repository.py index c6911f6b..e797df80 100644 --- a/repositories/solar/solar_postgres_repository.py +++ b/repositories/solar/solar_postgres_repository.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import Any, Optional -from sqlmodel import Session, select +from sqlmodel import Session, col, select from infrastructure.postgres.solar_table import SolarRow from repositories.solar.solar_repository import SolarRepository @@ -38,3 +38,15 @@ class SolarPostgresRepository(SolarRepository): select(SolarRow).where(SolarRow.uprn == uprn) ).first() return row.google_api_response if row is not None else None + + def get_many(self, uprns: list[int]) -> dict[int, Optional[dict[str, Any]]]: + """Stored insights for many UPRNs in one query — the batch form of + ``get``. The returned dict has an entry for every requested UPRN; a UPRN + with no stored row maps to None (exactly as ``get`` returns).""" + rows = self._session.exec( + select(SolarRow).where(col(SolarRow.uprn).in_(uprns)) + ).all() + stored: dict[int, Optional[dict[str, Any]]] = { + row.uprn: row.google_api_response for row in rows + } + return {uprn: stored.get(uprn) for uprn in uprns} diff --git a/tests/applications/modelling_e2e/test_handler.py b/tests/applications/modelling_e2e/test_handler.py index 8cd41c67..d55fd272 100644 --- a/tests/applications/modelling_e2e/test_handler.py +++ b/tests/applications/modelling_e2e/test_handler.py @@ -63,17 +63,17 @@ def _engine_mock( uprns: list[int], postcodes: list[str], ) -> MagicMock: - """Mock engine whose connect() returns UPRN then postcode rows.""" + """Mock engine whose connect() returns one (id, uprn, postcode) row set — + the handler reads all three columns in a single query.""" mock_engine = MagicMock() mock_conn = mock_engine.connect.return_value.__enter__.return_value - uprn_result = MagicMock() - uprn_result.fetchall.return_value = list(zip(property_ids, uprns)) + property_result = MagicMock() + property_result.fetchall.return_value = list( + zip(property_ids, uprns, postcodes) + ) - postcode_result = MagicMock() - postcode_result.fetchall.return_value = list(zip(property_ids, postcodes)) - - mock_conn.execute.side_effect = [uprn_result, postcode_result] + mock_conn.execute.return_value = property_result return mock_engine @@ -170,7 +170,9 @@ def test_handler_creates_one_child_subtask_per_property_id() -> None: stack.enter_context( patch("applications.modelling_e2e.handler.overlays_from", return_value=[]) ) - stack.enter_context(patch("applications.modelling_e2e.handler.PropertyOverridesPostgresReader")) + stack.enter_context( + patch("applications.modelling_e2e.handler.PropertyOverridesPostgresReader") + ).return_value.overrides_for_many.return_value = {} stack.enter_context( patch("applications.modelling_e2e.handler.ScenarioPostgresRepository") ).return_value.get_many.return_value = [MagicMock()] @@ -255,7 +257,7 @@ def test_lodged_epc_path_saves_epc_plan_and_marks_modelled( ) stack.enter_context( patch("applications.modelling_e2e.handler.PropertyOverridesPostgresReader") - ) + ).return_value.overrides_for_many.return_value = {} stack.enter_context( patch("applications.modelling_e2e.handler.ScenarioPostgresRepository") ).return_value.get_many.return_value = [MagicMock()] @@ -343,7 +345,7 @@ def test_skipped_cohort_certs_do_not_prevent_plan_being_saved() -> None: ) stack.enter_context( patch("applications.modelling_e2e.handler.PropertyOverridesPostgresReader") - ) + ).return_value.overrides_for_many.return_value = {} stack.enter_context( patch("applications.modelling_e2e.handler.ScenarioPostgresRepository") ).return_value.get_many.return_value = [MagicMock()] @@ -413,7 +415,9 @@ def test_skipped_cohort_certs_are_logged_and_handler_does_not_raise() -> None: stack.enter_context( patch("applications.modelling_e2e.handler.overlays_from", return_value=[]) ) - stack.enter_context(patch("applications.modelling_e2e.handler.PropertyOverridesPostgresReader")) + stack.enter_context( + patch("applications.modelling_e2e.handler.PropertyOverridesPostgresReader") + ).return_value.overrides_for_many.return_value = {} stack.enter_context( patch("applications.modelling_e2e.handler.ScenarioPostgresRepository") ).return_value.get_many.return_value = [MagicMock()] @@ -508,7 +512,7 @@ def test_prediction_path_saves_predicted_epc_plan_and_baseline( ) stack.enter_context( patch("applications.modelling_e2e.handler.PropertyOverridesPostgresReader") - ) + ).return_value.overrides_for_many.return_value = {} # Prediction infrastructure from domain.epc_prediction.prediction_target import PredictionTargetAttributes @@ -612,7 +616,7 @@ def test_empty_cohort_gates_property_out_without_saving() -> None: ) stack.enter_context( patch("applications.modelling_e2e.handler.PropertyOverridesPostgresReader") - ) + ).return_value.overrides_for_many.return_value = {} from domain.epc_prediction.prediction_target import PredictionTargetAttributes stack.enter_context( @@ -712,7 +716,7 @@ def test_empty_own_postcode_broadens_to_nearby_and_predicts() -> None: ) stack.enter_context( patch("applications.modelling_e2e.handler.PropertyOverridesPostgresReader") - ) + ).return_value.overrides_for_many.return_value = {} from domain.epc_prediction.prediction_target import PredictionTargetAttributes stack.enter_context( @@ -818,7 +822,9 @@ def test_per_property_failure_fails_child_subtask_and_siblings_continue() -> Non stack.enter_context( patch("applications.modelling_e2e.handler.overlays_from", return_value=[]) ) - stack.enter_context(patch("applications.modelling_e2e.handler.PropertyOverridesPostgresReader")) + stack.enter_context( + patch("applications.modelling_e2e.handler.PropertyOverridesPostgresReader") + ).return_value.overrides_for_many.return_value = {} stack.enter_context( patch("applications.modelling_e2e.handler.ScenarioPostgresRepository") ).return_value.get_many.return_value = [MagicMock()] @@ -846,6 +852,74 @@ def test_per_property_failure_fails_child_subtask_and_siblings_continue() -> Non mock_uow.commit.assert_called_once() +# --------------------------------------------------------------------------- +# End-of-batch write — one transaction for the whole batch +# --------------------------------------------------------------------------- + + +def test_batch_persists_in_one_transaction_and_one_baseline_run( + _baseline_orchestrator: MagicMock, +) -> None: + """Three properties model independently but persist together: every Plan is + saved, yet the write Unit of Work commits exactly once and the Baseline + orchestrator runs once for the whole batch (the RDS-CPU win).""" + # Arrange + pid1, pid2, pid3 = 111, 222, 333 + mock_engine = _engine_mock( + [pid1, pid2, pid3], [1001, 1002, 1003], [POSTCODE, POSTCODE, POSTCODE] + ) + mock_uow = MagicMock() + + with ExitStack() as stack: + stack.enter_context(patch("applications.modelling_e2e.handler.os.environ", _ENV)) + stack.enter_context( + patch("applications.modelling_e2e.handler._get_engine", return_value=mock_engine) + ) + stack.enter_context( + patch("applications.modelling_e2e.handler.EpcClientService") + ).return_value.get_by_uprn.return_value = MagicMock() + stack.enter_context(patch("applications.modelling_e2e.handler.GeospatialS3Repository")) + stack.enter_context(patch("applications.modelling_e2e.handler.GoogleSolarApiClient")) + stack.enter_context( + patch("applications.modelling_e2e.handler._spatial_for", return_value=None) + ) + stack.enter_context( + patch("applications.modelling_e2e.handler._solar_insights_for", return_value=None) + ) + stack.enter_context( + patch("applications.modelling_e2e.handler.overlays_from", return_value=[]) + ) + stack.enter_context( + patch("applications.modelling_e2e.handler.PropertyOverridesPostgresReader") + ).return_value.overrides_for_many.return_value = {} + stack.enter_context( + patch("applications.modelling_e2e.handler.ScenarioPostgresRepository") + ).return_value.get_many.return_value = [MagicMock()] + stack.enter_context(patch("applications.modelling_e2e.handler.catalogue_snapshot_with_off_catalogue_overrides")) + stack.enter_context(patch("applications.modelling_e2e.handler.Session")) + stack.enter_context( + patch("applications.modelling_e2e.handler.run_modelling", return_value=_plan_mock()) + ) + MockUoW = stack.enter_context(patch("applications.modelling_e2e.handler.PostgresUnitOfWork")) + MockUoW.return_value.__enter__.return_value = mock_uow + MockUoW.return_value.__exit__.return_value = False + + # Act + _call_handler( + {"property_ids": [pid1, pid2, pid3], "portfolio_id": PORTFOLIO_ID, + "scenario_id": SCENARIO_ID, "no_solar": True, "dry_run": False} + ) + + # Assert — all three Plans saved, but a single shared transaction: + assert mock_uow.plan.save.call_count == 3 + assert mock_uow.property.mark_modelled.call_count == 3 + mock_uow.commit.assert_called_once() + # One write Unit of Work opened for the whole batch, not one per property. + MockUoW.return_value.__enter__.assert_called_once() + # Baseline re-established once, for every written property together. + _baseline_orchestrator.return_value.run.assert_called_once_with([pid1, pid2, pid3]) + + # --------------------------------------------------------------------------- # Cohort cache hit # --------------------------------------------------------------------------- @@ -903,7 +977,7 @@ def test_cohort_cache_prevents_duplicate_candidates_for_calls() -> None: ) stack.enter_context( patch("applications.modelling_e2e.handler.PropertyOverridesPostgresReader") - ) + ).return_value.overrides_for_many.return_value = {} from domain.epc_prediction.prediction_target import PredictionTargetAttributes stack.enter_context( @@ -1007,7 +1081,7 @@ def test_dry_run_skips_all_db_writes() -> None: ) stack.enter_context( patch("applications.modelling_e2e.handler.PropertyOverridesPostgresReader") - ) + ).return_value.overrides_for_many.return_value = {} stack.enter_context( patch("applications.modelling_e2e.handler.ScenarioPostgresRepository") ).return_value.get_many.return_value = [MagicMock()] diff --git a/tests/repositories/property/test_property_overrides_postgres_reader.py b/tests/repositories/property/test_property_overrides_postgres_reader.py index 053af9a2..0aec3afc 100644 --- a/tests/repositories/property/test_property_overrides_postgres_reader.py +++ b/tests/repositories/property/test_property_overrides_postgres_reader.py @@ -115,3 +115,45 @@ def test_property_without_overrides_reads_empty(db_engine: Engine) -> None: # Assert assert resolved.rows == () assert resolved.value("property_type", 0) is None + + +def test_overrides_for_many_reads_every_property_in_one_pass( + db_engine: Engine, +) -> None: + """The batch read returns an entry for every requested Property — each + Property's own rows, faithfully, and an empty snapshot for one with none.""" + # Arrange — Property 7 has two overrides, Property 8 has one, Property 9 none. + with Session(db_engine) as session: + _seed( + session, + property_id=7, + override_component="property_type", + override_value="House", + ) + _seed( + session, + property_id=7, + building_part=1, + override_component="wall_type", + override_value="Solid brick, with internal insulation", + ) + _seed( + session, + property_id=8, + override_component="property_type", + override_value="Flat", + ) + session.commit() + + reader = PropertyOverridesPostgresReader(lambda: Session(db_engine)) + + # Act + resolved = reader.overrides_for_many([7, 8, 9]) + + # Assert — every requested id present; rows attributed to the right Property. + assert set(resolved) == {7, 8, 9} + assert len(resolved[7].rows) == 2 + assert resolved[7].value("property_type", 0) == "House" + assert resolved[7].value("wall_type", 1) == "Solid brick, with internal insulation" + assert resolved[8].value("property_type", 0) == "Flat" + assert resolved[9].rows == () diff --git a/tests/repositories/solar/test_solar_repository.py b/tests/repositories/solar/test_solar_repository.py index 3cbfd394..c18a5cbe 100644 --- a/tests/repositories/solar/test_solar_repository.py +++ b/tests/repositories/solar/test_solar_repository.py @@ -41,3 +41,23 @@ def test_get_returns_none_when_no_insights_stored(db_engine: Engine) -> None: # Assert assert reloaded is None + + +def test_get_many_returns_entry_for_every_requested_uprn(db_engine: Engine) -> None: + """get_many reads many UPRNs in one query: stored UPRNs return their + insights, unstored ones map to None, and every requested UPRN is present.""" + # Arrange + insights_10: dict[str, Any] = {"name": "buildings/A"} + insights_20: dict[str, Any] = {"name": "buildings/B"} + with Session(db_engine) as session: + repo = SolarPostgresRepository(session) + repo.save(10, longitude=-0.1, latitude=51.5, insights=insights_10) + repo.save(20, longitude=-0.2, latitude=52.0, insights=insights_20) + session.commit() + + # Act — 30 has no stored row + with Session(db_engine) as session: + loaded = SolarPostgresRepository(session).get_many([10, 20, 30]) + + # Assert + assert loaded == {10: insights_10, 20: insights_20, 30: None}