diff --git a/tests/orchestration/magic_plan/test_magic_plan_orchestrator.py b/tests/orchestration/magic_plan/test_magic_plan_orchestrator.py index 15d1a723..4eadebc9 100644 --- a/tests/orchestration/magic_plan/test_magic_plan_orchestrator.py +++ b/tests/orchestration/magic_plan/test_magic_plan_orchestrator.py @@ -1,5 +1,6 @@ import json from pathlib import Path +from typing import Optional from unittest.mock import ANY, MagicMock, patch import pytest @@ -14,6 +15,7 @@ from backend.app.db.models.uploaded_file import ( UploadedFile, ) from infrastructure.magic_plan.magic_plan_client import MagicPlanClient +from infrastructure.s3.s3_client import S3Client from orchestration.magic_plan_orchestrator import MagicPlanOrchestrator from applications.magic_plan.magic_plan_trigger_request import MagicPlanTriggerRequest @@ -49,20 +51,38 @@ def mock_client() -> MagicMock: return client -def _make_service(mock_client: MagicMock) -> MagicPlanOrchestrator: - return MagicPlanOrchestrator(magic_plan_api_client=mock_client, s3_bucket=S3_BUCKET) +def _make_s3_client(bucket: str = S3_BUCKET) -> MagicMock: + s3 = MagicMock(spec=S3Client) + s3.bucket = bucket + return s3 + + +def _make_service( + mock_client: MagicMock, mock_s3: Optional[MagicMock] = None +) -> MagicPlanOrchestrator: + if mock_s3 is None: + mock_s3 = _make_s3_client() + return MagicPlanOrchestrator(magic_plan_api_client=mock_client, s3_client=mock_s3) def _make_request( address: str = "2 Laburnum Way Bromley BR2 8BZ", hubspot_deal_id: str = "deal-123", - uprn: str | None = None, + uprn: Optional[str] = None, ) -> MagicPlanTriggerRequest: return MagicPlanTriggerRequest( address=address, hubspot_deal_id=hubspot_deal_id, uprn=uprn ) +def _patch_db() -> tuple[patch, patch, patch]: # type: ignore[type-arg] + return ( + patch("orchestration.magic_plan_orchestrator.PostgresConfig"), + patch("orchestration.magic_plan_orchestrator.make_engine"), + patch("orchestration.magic_plan_orchestrator.make_session"), + ) + + # --- no match --- @@ -88,16 +108,13 @@ def test_run_fetches_plan_with_matched_id( mock_client.get_plans.return_value = [plan_summary] mock_client.get_plan.return_value = api_magic_plan service = _make_service(mock_client) + p_config, p_engine, p_session = _patch_db() with patch( "orchestration.magic_plan_orchestrator.find_matching_plan", return_value=plan_summary, ), patch( "orchestration.magic_plan_orchestrator.MagicPlanPostgresRepository" - ), patch( - "orchestration.magic_plan_orchestrator.db_session" - ), patch( - "orchestration.magic_plan_orchestrator.save_data_to_s3" - ): + ), p_config, p_engine, p_session: service.run(_make_request()) # Assert mock_client.get_plan_raw.assert_called_once_with(plan_summary.id) @@ -113,16 +130,13 @@ def test_run_returns_mapped_plan( mock_client.get_plans.return_value = [plan_summary] mock_client.get_plan.return_value = api_magic_plan service = _make_service(mock_client) + p_config, p_engine, p_session = _patch_db() with patch( "orchestration.magic_plan_orchestrator.find_matching_plan", return_value=plan_summary, ), patch( "orchestration.magic_plan_orchestrator.MagicPlanPostgresRepository" - ), patch( - "orchestration.magic_plan_orchestrator.db_session" - ), patch( - "orchestration.magic_plan_orchestrator.save_data_to_s3" - ): + ), p_config, p_engine, p_session: result = service.run(_make_request()) # Assert assert isinstance(result, Plan) @@ -139,17 +153,14 @@ def test_run_calls_save_with_mapped_plan( mock_client.get_plan.return_value = api_magic_plan service = _make_service(mock_client) mock_repo = MagicMock() + p_config, p_engine, p_session = _patch_db() with patch( "orchestration.magic_plan_orchestrator.find_matching_plan", return_value=plan_summary, ), patch( "orchestration.magic_plan_orchestrator.MagicPlanPostgresRepository", return_value=mock_repo, - ), patch( - "orchestration.magic_plan_orchestrator.db_session" - ), patch( - "orchestration.magic_plan_orchestrator.save_data_to_s3" - ): + ), p_config, p_engine, p_session: service.run(_make_request()) # Assert — save called with a Plan whose uid matches saved_plan: Plan = mock_repo.save.call_args[0][0] @@ -165,16 +176,13 @@ def test_run_accepts_uprn_without_error( mock_client.get_plans.return_value = [plan_summary] mock_client.get_plan.return_value = api_magic_plan service = _make_service(mock_client) + p_config, p_engine, p_session = _patch_db() with patch( "orchestration.magic_plan_orchestrator.find_matching_plan", return_value=plan_summary, ), patch( "orchestration.magic_plan_orchestrator.MagicPlanPostgresRepository" - ), patch( - "orchestration.magic_plan_orchestrator.db_session" - ), patch( - "orchestration.magic_plan_orchestrator.save_data_to_s3" - ): + ), p_config, p_engine, p_session: service.run(_make_request(uprn="100023336956")) @@ -183,64 +191,51 @@ def test_run_accepts_uprn_without_error( def test_run_uploads_to_s3_with_uprn_key( mock_client: MagicMock, - api_magic_plan: MagicPlanPlan, plan_summary: PlanSummary, ) -> None: # Arrange mock_client.get_plans.return_value = [plan_summary] + mock_s3 = _make_s3_client() request = _make_request(uprn="100023336956") - service = MagicPlanOrchestrator( - magic_plan_api_client=mock_client, s3_bucket=S3_BUCKET - ) + service = _make_service(mock_client, mock_s3) + p_config, p_engine, p_session = _patch_db() with patch( "orchestration.magic_plan_orchestrator.find_matching_plan", return_value=plan_summary, ), patch( "orchestration.magic_plan_orchestrator.MagicPlanPostgresRepository" - ), patch( - "orchestration.magic_plan_orchestrator.db_session" - ), patch( - "orchestration.magic_plan_orchestrator.save_data_to_s3" - ) as mock_s3: + ), p_config, p_engine, p_session: # Act service.run(request) # Assert - mock_s3.assert_called_once_with( - ANY, - S3_BUCKET, + mock_s3.put_object.assert_called_once_with( f"documents/uprn/100023336956/magic_plan_{plan_summary.id}.json.gz", + ANY, ) def test_run_uploads_to_s3_with_deal_id_key_when_uprn_absent( mock_client: MagicMock, - api_magic_plan: MagicPlanPlan, plan_summary: PlanSummary, ) -> None: # Arrange mock_client.get_plans.return_value = [plan_summary] - mock_client.get_plan.return_value = api_magic_plan + mock_s3 = _make_s3_client() request = _make_request(hubspot_deal_id="deal-456", uprn=None) - service = MagicPlanOrchestrator( - magic_plan_api_client=mock_client, s3_bucket=S3_BUCKET - ) + service = _make_service(mock_client, mock_s3) + p_config, p_engine, p_session = _patch_db() with patch( "orchestration.magic_plan_orchestrator.find_matching_plan", return_value=plan_summary, ), patch( "orchestration.magic_plan_orchestrator.MagicPlanPostgresRepository" - ), patch( - "orchestration.magic_plan_orchestrator.db_session" - ), patch( - "orchestration.magic_plan_orchestrator.save_data_to_s3" - ) as mock_s3: + ), p_config, p_engine, p_session: # Act service.run(request) # Assert - mock_s3.assert_called_once_with( - ANY, - S3_BUCKET, + mock_s3.put_object.assert_called_once_with( f"documents/hubspot_deal_id/deal-456/magic_plan_{plan_summary.id}.json.gz", + ANY, ) @@ -249,28 +244,22 @@ def test_run_uploads_to_s3_with_deal_id_key_when_uprn_absent( def test_run_creates_uploaded_file_record( mock_client: MagicMock, - api_magic_plan: MagicPlanPlan, plan_summary: PlanSummary, ) -> None: # Arrange mock_client.get_plans.return_value = [plan_summary] - mock_client.get_plan.return_value = api_magic_plan + mock_s3 = _make_s3_client() request = _make_request(hubspot_deal_id="deal-789", uprn="100023336956") - service = MagicPlanOrchestrator( - magic_plan_api_client=mock_client, s3_bucket=S3_BUCKET - ) + service = _make_service(mock_client, mock_s3) mock_session = MagicMock() + p_config, p_engine, p_session = _patch_db() with patch( "orchestration.magic_plan_orchestrator.find_matching_plan", return_value=plan_summary, ), patch( "orchestration.magic_plan_orchestrator.MagicPlanPostgresRepository" - ), patch( - "orchestration.magic_plan_orchestrator.db_session" - ) as mock_db, patch( - "orchestration.magic_plan_orchestrator.save_data_to_s3" - ): - mock_db.return_value.__enter__.return_value = mock_session + ), p_config, p_engine, p_session as mock_make_session: + mock_make_session.return_value = mock_session # Act service.run(request) # Assert @@ -299,7 +288,7 @@ def test_run_passes_flushed_uploaded_file_id_to_save( mock_client.get_plans.return_value = [plan_summary] service = _make_service(mock_client) mock_session = MagicMock() - added_objects: list = [] + added_objects: list[object] = [] mock_session.add.side_effect = added_objects.append @@ -310,6 +299,7 @@ def test_run_passes_flushed_uploaded_file_id_to_save( mock_session.flush.side_effect = simulate_flush mock_repo = MagicMock() + p_config, p_engine, p_session = _patch_db() with patch( "orchestration.magic_plan_orchestrator.find_matching_plan", @@ -317,12 +307,8 @@ def test_run_passes_flushed_uploaded_file_id_to_save( ), patch( "orchestration.magic_plan_orchestrator.MagicPlanPostgresRepository", return_value=mock_repo, - ), patch( - "orchestration.magic_plan_orchestrator.db_session" - ) as mock_db, patch( - "orchestration.magic_plan_orchestrator.save_data_to_s3" - ): - mock_db.return_value.__enter__.return_value = mock_session + ), p_config, p_engine, p_session as mock_make_session: + mock_make_session.return_value = mock_session # Act service.run(_make_request())