utilities/aws_lambda: @subtask_handler injects TaskOrchestrator as third positional arg

The wrapped function now receives the decorator-owned TaskOrchestrator as
a third positional argument so handlers can compose their own use-case
orchestrator that shares the session, instead of opening a second Postgres
connection per invocation.

Both existing callers (backend/ordnanceSurvey/main.py and
backend/bulk_address2uprn_combiner/main.py) have their signatures extended
to accept the new positional argument (typed Optional[TaskOrchestrator] so
the legacy backend.utils.subtasks.subtask_handler — which only passes two
args — keeps working until the migration to the new decorator lands).

@task_handler is intentionally unchanged in this slice; symmetry is
deferred per issue #1103.
This commit is contained in:
Jun-te Kim 2026-05-19 17:31:27 +00:00
parent d7f14033ba
commit d70e8a9e53
6 changed files with 168 additions and 4 deletions

View file

@ -2,7 +2,7 @@ import os
import boto3
import pandas as pd
from io import BytesIO
from typing import Any
from typing import Any, Optional
from uuid import UUID
from datetime import datetime, timezone
@ -12,6 +12,7 @@ from backend.app.db.functions.bulk_address_uploads_functions import (
set_combined_output_s3_uri,
set_combining_status,
)
from orchestration.task_orchestrator import TaskOrchestrator
logger = setup_logger()
@ -35,7 +36,16 @@ def download_csv(s3_client, bucket: str, key: str) -> pd.DataFrame:
@subtask_handler()
def handler(body: dict[str, Any], context: Any) -> str:
def handler(
body: dict[str, Any],
context: Any,
orchestrator: Optional[TaskOrchestrator] = None,
) -> str:
# `orchestrator` is injected by the new utilities.aws_lambda.subtask_handler
# decorator; unused here but accepted so the contract is uniform across
# callers (see issue #1103).
del orchestrator
task_id_str: str = body.get("task_id", "")
if not task_id_str:

View file

@ -16,6 +16,7 @@ from backend.ordnanceSurvey.helpers import (
os_places_results_to_dataframe,
)
from backend.app.config import get_settings
from orchestration.task_orchestrator import TaskOrchestrator
from sqlalchemy import select
from datetime import datetime
import uuid
@ -105,7 +106,16 @@ def save_results_to_s3(
@subtask_handler() # This assumes task_id and subtask_id is defined in event.Records.body
def handler(body: dict[str, Any], context: Any, local: bool = False) -> None:
def handler(
body: dict[str, Any],
context: Any,
orchestrator: Optional[TaskOrchestrator] = None,
local: bool = False,
) -> None:
# `orchestrator` is injected by the new utilities.aws_lambda.subtask_handler
# decorator; unused here but accepted so the contract is uniform across
# callers (see issue #1103).
del orchestrator
# delete this line after test
# local = True

View file

View file

View file

@ -0,0 +1,144 @@
"""Tests for the @subtask_handler decorator.
Covers the contract that the decorator owns the parent SubTask lifecycle and
injects the decorator-owned TaskOrchestrator as a third positional argument
to the wrapped function so the handler can compose its own use-case
orchestrator that shares the session.
"""
from collections.abc import Generator, Iterator
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any
from uuid import UUID
import pytest
from sqlmodel import Session, SQLModel, create_engine
from domain.tasks.subtasks import SubTaskStatus
from domain.tasks.tasks import TaskStatus
from orchestration.task_orchestrator import TaskOrchestrator
from repositories.tasks.subtask_postgres_repository import SubTaskPostgresRepository
from repositories.tasks.task_postgres_repository import TaskPostgresRepository
from utilities.aws_lambda.subtask_handler import subtask_handler
@dataclass
class Harness:
orchestrator: TaskOrchestrator
tasks: TaskPostgresRepository
subtasks: SubTaskPostgresRepository
@contextmanager
def factory(self) -> Generator[TaskOrchestrator, None, None]:
yield self.orchestrator
@pytest.fixture
def harness() -> Iterator[Harness]:
engine = create_engine("sqlite://")
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
tasks = TaskPostgresRepository(session=session)
subtasks = SubTaskPostgresRepository(session=session)
yield Harness(
orchestrator=TaskOrchestrator(task_repo=tasks, subtask_repo=subtasks),
tasks=tasks,
subtasks=subtasks,
)
def _direct_event(task_id: UUID, subtask_id: UUID) -> dict[str, Any]:
return {"task_id": str(task_id), "sub_task_id": str(subtask_id)}
def test_subtask_handler_injects_orchestrator_as_third_positional_argument(
harness: Harness,
) -> None:
"""The wrapped function receives the decorator-owned TaskOrchestrator
so it can share the session with its own use-case orchestrator."""
_, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
received: dict[str, Any] = {}
@subtask_handler(orchestrator_cm=harness.factory)
def handler(
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
) -> None:
received["body"] = body
received["context"] = context
received["orchestrator"] = orchestrator
handler(_direct_event(subtask.task_id, subtask.id), context="ctx-sentinel")
assert received["orchestrator"] is harness.orchestrator
assert received["context"] == "ctx-sentinel"
assert received["body"]["sub_task_id"] == str(subtask.id)
def test_subtask_handler_completes_parent_subtask_on_success(
harness: Harness,
) -> None:
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
@subtask_handler(orchestrator_cm=harness.factory)
def handler(
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
) -> None:
return None
handler(_direct_event(task.id, subtask.id), context=None)
assert harness.subtasks.get(subtask.id).status is SubTaskStatus.COMPLETE
assert harness.tasks.get(task.id).status is TaskStatus.COMPLETE
def test_subtask_handler_marks_parent_failed_and_reraises_on_error(
harness: Harness,
) -> None:
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
@subtask_handler(orchestrator_cm=harness.factory)
def handler(
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
) -> None:
raise RuntimeError("boom")
with pytest.raises(RuntimeError, match="boom"):
handler(_direct_event(task.id, subtask.id), context=None)
assert harness.subtasks.get(subtask.id).status is SubTaskStatus.FAILED
assert harness.tasks.get(task.id).status is TaskStatus.FAILED
def test_subtask_handler_injected_orchestrator_can_create_child_subtask(
harness: Harness,
) -> None:
"""Smoke check the share-the-session promise: the injected orchestrator
is the same one the decorator owns, so a handler can use it to create
child SubTasks under the same session."""
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
child_ids: list[UUID] = []
@subtask_handler(orchestrator_cm=harness.factory)
def handler(
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
) -> None:
child = orchestrator.create_child_subtask(task.id, inputs={"split": 1})
child_ids.append(child.id)
handler(_direct_event(task.id, subtask.id), context=None)
assert len(child_ids) == 1
persisted_child = harness.subtasks.get(child_ids[0])
assert persisted_child.task_id == task.id
assert persisted_child.status is SubTaskStatus.WAITING

View file

@ -39,7 +39,7 @@ def subtask_handler(
trigger = SubtaskTriggerBody.model_validate(body)
orchestrator.run_subtask(
trigger.sub_task_id,
work=lambda body=body: func(body, context),
work=lambda body=body, o=orchestrator: func(body, context, o),
)
return wrapper