diff --git a/backend/bulk_address2uprn_combiner/main.py b/backend/bulk_address2uprn_combiner/main.py index 44f0b3f9..37136e52 100644 --- a/backend/bulk_address2uprn_combiner/main.py +++ b/backend/bulk_address2uprn_combiner/main.py @@ -2,7 +2,7 @@ import os import boto3 import pandas as pd from io import BytesIO -from typing import Any +from typing import Any, Optional from uuid import UUID from datetime import datetime, timezone @@ -12,6 +12,7 @@ from backend.app.db.functions.bulk_address_uploads_functions import ( set_combined_output_s3_uri, set_combining_status, ) +from orchestration.task_orchestrator import TaskOrchestrator logger = setup_logger() @@ -35,7 +36,16 @@ def download_csv(s3_client, bucket: str, key: str) -> pd.DataFrame: @subtask_handler() -def handler(body: dict[str, Any], context: Any) -> str: +def handler( + body: dict[str, Any], + context: Any, + orchestrator: Optional[TaskOrchestrator] = None, +) -> str: + # `orchestrator` is injected by the new utilities.aws_lambda.subtask_handler + # decorator; unused here but accepted so the contract is uniform across + # callers (see issue #1103). + del orchestrator + task_id_str: str = body.get("task_id", "") if not task_id_str: diff --git a/backend/ordnanceSurvey/main.py b/backend/ordnanceSurvey/main.py index 6e82b468..18c4e2f2 100644 --- a/backend/ordnanceSurvey/main.py +++ b/backend/ordnanceSurvey/main.py @@ -16,6 +16,7 @@ from backend.ordnanceSurvey.helpers import ( os_places_results_to_dataframe, ) from backend.app.config import get_settings +from orchestration.task_orchestrator import TaskOrchestrator from sqlalchemy import select from datetime import datetime import uuid @@ -105,7 +106,16 @@ def save_results_to_s3( @subtask_handler() # This assumes task_id and subtask_id is defined in event.Records.body -def handler(body: dict[str, Any], context: Any, local: bool = False) -> None: +def handler( + body: dict[str, Any], + context: Any, + orchestrator: Optional[TaskOrchestrator] = None, + local: bool = False, +) -> None: + # `orchestrator` is injected by the new utilities.aws_lambda.subtask_handler + # decorator; unused here but accepted so the contract is uniform across + # callers (see issue #1103). + del orchestrator # delete this line after test # local = True diff --git a/tests/utilities/__init__.py b/tests/utilities/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/utilities/aws_lambda/__init__.py b/tests/utilities/aws_lambda/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/utilities/aws_lambda/test_subtask_handler.py b/tests/utilities/aws_lambda/test_subtask_handler.py new file mode 100644 index 00000000..426b250f --- /dev/null +++ b/tests/utilities/aws_lambda/test_subtask_handler.py @@ -0,0 +1,144 @@ +"""Tests for the @subtask_handler decorator. + +Covers the contract that the decorator owns the parent SubTask lifecycle and +injects the decorator-owned TaskOrchestrator as a third positional argument +to the wrapped function — so the handler can compose its own use-case +orchestrator that shares the session. +""" + +from collections.abc import Generator, Iterator +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any +from uuid import UUID + +import pytest +from sqlmodel import Session, SQLModel, create_engine + +from domain.tasks.subtasks import SubTaskStatus +from domain.tasks.tasks import TaskStatus +from orchestration.task_orchestrator import TaskOrchestrator +from repositories.tasks.subtask_postgres_repository import SubTaskPostgresRepository +from repositories.tasks.task_postgres_repository import TaskPostgresRepository +from utilities.aws_lambda.subtask_handler import subtask_handler + + +@dataclass +class Harness: + orchestrator: TaskOrchestrator + tasks: TaskPostgresRepository + subtasks: SubTaskPostgresRepository + + @contextmanager + def factory(self) -> Generator[TaskOrchestrator, None, None]: + yield self.orchestrator + + +@pytest.fixture +def harness() -> Iterator[Harness]: + engine = create_engine("sqlite://") + SQLModel.metadata.create_all(engine) + with Session(engine) as session: + tasks = TaskPostgresRepository(session=session) + subtasks = SubTaskPostgresRepository(session=session) + yield Harness( + orchestrator=TaskOrchestrator(task_repo=tasks, subtask_repo=subtasks), + tasks=tasks, + subtasks=subtasks, + ) + + +def _direct_event(task_id: UUID, subtask_id: UUID) -> dict[str, Any]: + return {"task_id": str(task_id), "sub_task_id": str(subtask_id)} + + +def test_subtask_handler_injects_orchestrator_as_third_positional_argument( + harness: Harness, +) -> None: + """The wrapped function receives the decorator-owned TaskOrchestrator + so it can share the session with its own use-case orchestrator.""" + _, subtask = harness.orchestrator.create_task_with_subtask( + task_source="manual:test" + ) + + received: dict[str, Any] = {} + + @subtask_handler(orchestrator_cm=harness.factory) + def handler( + body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator + ) -> None: + received["body"] = body + received["context"] = context + received["orchestrator"] = orchestrator + + handler(_direct_event(subtask.task_id, subtask.id), context="ctx-sentinel") + + assert received["orchestrator"] is harness.orchestrator + assert received["context"] == "ctx-sentinel" + assert received["body"]["sub_task_id"] == str(subtask.id) + + +def test_subtask_handler_completes_parent_subtask_on_success( + harness: Harness, +) -> None: + task, subtask = harness.orchestrator.create_task_with_subtask( + task_source="manual:test" + ) + + @subtask_handler(orchestrator_cm=harness.factory) + def handler( + body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator + ) -> None: + return None + + handler(_direct_event(task.id, subtask.id), context=None) + + assert harness.subtasks.get(subtask.id).status is SubTaskStatus.COMPLETE + assert harness.tasks.get(task.id).status is TaskStatus.COMPLETE + + +def test_subtask_handler_marks_parent_failed_and_reraises_on_error( + harness: Harness, +) -> None: + task, subtask = harness.orchestrator.create_task_with_subtask( + task_source="manual:test" + ) + + @subtask_handler(orchestrator_cm=harness.factory) + def handler( + body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator + ) -> None: + raise RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + handler(_direct_event(task.id, subtask.id), context=None) + + assert harness.subtasks.get(subtask.id).status is SubTaskStatus.FAILED + assert harness.tasks.get(task.id).status is TaskStatus.FAILED + + +def test_subtask_handler_injected_orchestrator_can_create_child_subtask( + harness: Harness, +) -> None: + """Smoke check the share-the-session promise: the injected orchestrator + is the same one the decorator owns, so a handler can use it to create + child SubTasks under the same session.""" + task, subtask = harness.orchestrator.create_task_with_subtask( + task_source="manual:test" + ) + + child_ids: list[UUID] = [] + + @subtask_handler(orchestrator_cm=harness.factory) + def handler( + body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator + ) -> None: + child = orchestrator.create_child_subtask(task.id, inputs={"split": 1}) + child_ids.append(child.id) + + handler(_direct_event(task.id, subtask.id), context=None) + + assert len(child_ids) == 1 + persisted_child = harness.subtasks.get(child_ids[0]) + assert persisted_child.task_id == task.id + assert persisted_child.status is SubTaskStatus.WAITING diff --git a/utilities/aws_lambda/subtask_handler.py b/utilities/aws_lambda/subtask_handler.py index 64c1daa6..5ad5f6e1 100644 --- a/utilities/aws_lambda/subtask_handler.py +++ b/utilities/aws_lambda/subtask_handler.py @@ -39,7 +39,7 @@ def subtask_handler( trigger = SubtaskTriggerBody.model_validate(body) orchestrator.run_subtask( trigger.sub_task_id, - work=lambda body=body: func(body, context), + work=lambda body=body, o=orchestrator: func(body, context, o), ) return wrapper