diff --git a/backend/magic_plan/magic_plan_service.py b/backend/magic_plan/magic_plan_service.py index 2be3379d..8a75c716 100644 --- a/backend/magic_plan/magic_plan_service.py +++ b/backend/magic_plan/magic_plan_service.py @@ -1,7 +1,7 @@ import gzip import json from datetime import datetime, timezone -from typing import Optional +from typing import Optional, cast from datatypes.magicplan.api.response import MagicPlanPlan, PlanSummary from datatypes.magicplan.domain.mapper import map_plan @@ -57,7 +57,7 @@ class MagicPlanService: with db_session() as session: session.add(uploaded_file) session.flush() - save_plan(session, plan, uploaded_file.id) + save_plan(session, plan, cast(int, uploaded_file.id)) return plan diff --git a/backend/magic_plan/tests/test_magic_plan_service.py b/backend/magic_plan/tests/test_magic_plan_service.py index 158cf4d6..a2302ab4 100644 --- a/backend/magic_plan/tests/test_magic_plan_service.py +++ b/backend/magic_plan/tests/test_magic_plan_service.py @@ -271,3 +271,38 @@ def test_run_creates_uploaded_file_record( assert uploaded_file.s3_upload_timestamp is not None assert uploaded_file.uprn == 100023336956 assert uploaded_file.hubspot_deal_id == "deal-789" + + +def test_run_passes_flushed_uploaded_file_id_to_save_plan( + mock_client: MagicMock, + plan_summary: PlanSummary, +) -> None: + # Arrange + mock_client.get_plans.return_value = [plan_summary] + service = _make_service(mock_client) + mock_session = MagicMock() + added_objects: list = [] + + mock_session.add.side_effect = added_objects.append + + def simulate_flush() -> None: + for obj in added_objects: + if isinstance(obj, UploadedFile): + obj.id = 42 + + mock_session.flush.side_effect = simulate_flush + + with patch( + "backend.magic_plan.magic_plan_service.find_matching_plan", + return_value=plan_summary, + ), patch("backend.magic_plan.magic_plan_service.save_plan") as mock_save, patch( + "backend.magic_plan.magic_plan_service.db_session" + ) as mock_db, patch( + "backend.magic_plan.magic_plan_service.save_data_to_s3" + ): + mock_db.return_value.__enter__.return_value = mock_session + # Act + service.run(_make_request()) + + # Assert + assert mock_save.call_args[0][2] == 42