diff --git a/backend/app/db/functions/magic_plan_functions.py b/backend/app/db/functions/magic_plan_functions.py index 9400f36f..143e4172 100644 --- a/backend/app/db/functions/magic_plan_functions.py +++ b/backend/app/db/functions/magic_plan_functions.py @@ -14,15 +14,15 @@ from backend.app.db.models.magic_plan import ( ) -def save_plan(session: Session, plan: Plan) -> None: - plan_id: int = _upsert_plan(session, plan) +def save_plan(session: Session, plan: Plan, uploaded_file_id: int) -> None: + plan_id: int = _upsert_plan(session, plan, uploaded_file_id) _delete_children(session, plan_id) floor_ids: list[int] = _insert_floors(session, plan.floors, plan_id) room_ids: list[int] = _insert_rooms(session, plan.floors, floor_ids) _insert_windows_and_doors(session, plan.floors, room_ids) -def _upsert_plan(session: Session, plan: Plan) -> int: +def _upsert_plan(session: Session, plan: Plan, uploaded_file_id: int) -> int: stmt = ( pg_insert(MagicPlanPlanModel) .values( @@ -30,6 +30,7 @@ def _upsert_plan(session: Session, plan: Plan) -> int: name=plan.name, address=plan.address, postcode=plan.postcode, + uploaded_file_id=uploaded_file_id, ) .on_conflict_do_update( index_elements=["magic_plan_uid"], @@ -37,6 +38,7 @@ def _upsert_plan(session: Session, plan: Plan) -> int: "name": plan.name, "address": plan.address, "postcode": plan.postcode, + "uploaded_file_id": uploaded_file_id, }, ) .returning(col(MagicPlanPlanModel.id)) diff --git a/backend/app/db/functions/tests/test_magic_plan_functions.py b/backend/app/db/functions/tests/test_magic_plan_functions.py index e58d0528..0b93685c 100644 --- a/backend/app/db/functions/tests/test_magic_plan_functions.py +++ b/backend/app/db/functions/tests/test_magic_plan_functions.py @@ -36,7 +36,7 @@ def _count(session: Session, model: type[SQLModel]) -> int: def test_plan_row_present_after_save(db_session: Session, domain_plan: Plan) -> None: # Act - save_plan(db_session, domain_plan) + save_plan(db_session, domain_plan, 1) # Assert assert _count(db_session, MagicPlanPlanModel) == 1 @@ -45,7 +45,7 @@ def test_floor_count_matches_domain(db_session: Session, domain_plan: Plan) -> N # Arrange expected = len(domain_plan.floors) # Act - save_plan(db_session, domain_plan) + save_plan(db_session, domain_plan, 1) # Assert assert _count(db_session, MagicPlanFloorModel) == expected @@ -54,7 +54,7 @@ def test_room_count_matches_domain(db_session: Session, domain_plan: Plan) -> No # Arrange expected = sum(len(f.rooms) for f in domain_plan.floors) # Act - save_plan(db_session, domain_plan) + save_plan(db_session, domain_plan, 1) # Assert assert _count(db_session, MagicPlanRoomModel) == expected @@ -63,7 +63,7 @@ def test_window_count_matches_domain(db_session: Session, domain_plan: Plan) -> # Arrange expected = sum(len(r.windows) for f in domain_plan.floors for r in f.rooms) # Act - save_plan(db_session, domain_plan) + save_plan(db_session, domain_plan, 1) # Assert assert _count(db_session, MagicPlanWindowModel) == expected @@ -72,15 +72,15 @@ def test_door_count_matches_domain(db_session: Session, domain_plan: Plan) -> No # Arrange expected = sum(len(r.doors) for f in domain_plan.floors for r in f.rooms) # Act - save_plan(db_session, domain_plan) + save_plan(db_session, domain_plan, 1) # Assert assert _count(db_session, MagicPlanDoorModel) == expected def test_save_plan_idempotent(db_session: Session, domain_plan: Plan) -> None: # Act — call twice within the same session - save_plan(db_session, domain_plan) - save_plan(db_session, domain_plan) + save_plan(db_session, domain_plan, 1) + save_plan(db_session, domain_plan, 1) # Assert — same row counts as a single call assert _count(db_session, MagicPlanPlanModel) == 1 assert _count(db_session, MagicPlanFloorModel) == len(domain_plan.floors) @@ -93,3 +93,23 @@ def test_save_plan_idempotent(db_session: Session, domain_plan: Plan) -> None: assert _count(db_session, MagicPlanDoorModel) == sum( len(r.doors) for f in domain_plan.floors for r in f.rooms ) + + +def test_uploaded_file_id_stored_after_save(db_session: Session, domain_plan: Plan) -> None: + # Act + save_plan(db_session, domain_plan, 1) + # Assert + row = db_session.execute(select(MagicPlanPlanModel)).scalar_one() + assert row.uploaded_file_id == 1 + + +def test_save_plan_updates_uploaded_file_id_on_reingest( + db_session: Session, domain_plan: Plan +) -> None: + # Arrange + save_plan(db_session, domain_plan, 1) + # Act + save_plan(db_session, domain_plan, 2) + # Assert + row = db_session.execute(select(MagicPlanPlanModel)).scalar_one() + assert row.uploaded_file_id == 2 diff --git a/backend/app/db/models/magic_plan.py b/backend/app/db/models/magic_plan.py index 38e9de18..77ca52fd 100644 --- a/backend/app/db/models/magic_plan.py +++ b/backend/app/db/models/magic_plan.py @@ -11,6 +11,7 @@ class MagicPlanPlanModel(SQLModel, table=True): name: Optional[str] = None address: Optional[str] = None postcode: Optional[str] = None + uploaded_file_id: Optional[int] = Field(default=None) class MagicPlanFloorModel(SQLModel, table=True): diff --git a/backend/magic_plan/magic_plan_service.py b/backend/magic_plan/magic_plan_service.py index 22e19ddf..8a75c716 100644 --- a/backend/magic_plan/magic_plan_service.py +++ b/backend/magic_plan/magic_plan_service.py @@ -1,7 +1,7 @@ import gzip import json from datetime import datetime, timezone -from typing import Optional +from typing import Optional, cast from datatypes.magicplan.api.response import MagicPlanPlan, PlanSummary from datatypes.magicplan.domain.mapper import map_plan @@ -55,8 +55,9 @@ class MagicPlanService: ) with db_session() as session: - save_plan(session, plan) session.add(uploaded_file) + session.flush() + save_plan(session, plan, cast(int, uploaded_file.id)) return plan diff --git a/backend/magic_plan/tests/test_magic_plan_service.py b/backend/magic_plan/tests/test_magic_plan_service.py index 158cf4d6..a2302ab4 100644 --- a/backend/magic_plan/tests/test_magic_plan_service.py +++ b/backend/magic_plan/tests/test_magic_plan_service.py @@ -271,3 +271,38 @@ def test_run_creates_uploaded_file_record( assert uploaded_file.s3_upload_timestamp is not None assert uploaded_file.uprn == 100023336956 assert uploaded_file.hubspot_deal_id == "deal-789" + + +def test_run_passes_flushed_uploaded_file_id_to_save_plan( + mock_client: MagicMock, + plan_summary: PlanSummary, +) -> None: + # Arrange + mock_client.get_plans.return_value = [plan_summary] + service = _make_service(mock_client) + mock_session = MagicMock() + added_objects: list = [] + + mock_session.add.side_effect = added_objects.append + + def simulate_flush() -> None: + for obj in added_objects: + if isinstance(obj, UploadedFile): + obj.id = 42 + + mock_session.flush.side_effect = simulate_flush + + with patch( + "backend.magic_plan.magic_plan_service.find_matching_plan", + return_value=plan_summary, + ), patch("backend.magic_plan.magic_plan_service.save_plan") as mock_save, patch( + "backend.magic_plan.magic_plan_service.db_session" + ) as mock_db, patch( + "backend.magic_plan.magic_plan_service.save_data_to_s3" + ): + mock_db.return_value.__enter__.return_value = mock_session + # Act + service.run(_make_request()) + + # Assert + assert mock_save.call_args[0][2] == 42