Model/utilities/aws_lambda/subtask_handler.py
Jun-te Kim d70e8a9e53 utilities/aws_lambda: @subtask_handler injects TaskOrchestrator as third positional arg
The wrapped function now receives the decorator-owned TaskOrchestrator as
a third positional argument so handlers can compose their own use-case
orchestrator that shares the session, instead of opening a second Postgres
connection per invocation.

Both existing callers (backend/ordnanceSurvey/main.py and
backend/bulk_address2uprn_combiner/main.py) have their signatures extended
to accept the new positional argument (typed Optional[TaskOrchestrator] so
the legacy backend.utils.subtasks.subtask_handler — which only passes two
args — keeps working until the migration to the new decorator lands).

@task_handler is intentionally unchanged in this slice; symmetry is
deferred per issue #1103.
2026-05-19 17:31:27 +00:00

67 lines
2.4 KiB
Python

"""@subtask_handler decorator for Lambdas that operate on existing SubTasks.
Translates an AWS Lambda invocation (SQS-shaped or direct) into
TaskOrchestrator.run_subtask(...) calls.
"""
import json
from contextlib import AbstractContextManager
from functools import wraps
from typing import Any, Callable, Optional, cast
from utilities.aws_lambda.default_orchestrator import default_orchestrator
from utilities.aws_lambda.subtask_trigger_body import SubtaskTriggerBody
from orchestration.task_orchestrator import TaskOrchestrator
OrchestratorCM = Callable[[], AbstractContextManager[TaskOrchestrator]]
def subtask_handler(
*,
orchestrator_cm: Optional[OrchestratorCM] = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Run the wrapped function as the body of an existing SubTask.
For each record, validates the body via SubtaskTriggerBody (must contain
task_id and sub_task_id), then runs the function inside
orchestrator.run_subtask(...). The orchestrator owns the start/complete/
fail lifecycle and cascades status into the parent Task. On failure the
underlying exception propagates after the SubTask is marked FAILED.
"""
factory = orchestrator_cm or default_orchestrator
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def wrapper(event: dict[str, Any], context: Any) -> None:
with factory() as orchestrator:
for record in _records(event):
body = _parse_body(record)
trigger = SubtaskTriggerBody.model_validate(body)
orchestrator.run_subtask(
trigger.sub_task_id,
work=lambda body=body, o=orchestrator: func(body, context, o),
)
return wrapper
return decorator
def _parse_body(record: dict[str, Any]) -> dict[str, Any]:
raw = record.get("body", record)
if isinstance(raw, str):
try:
parsed = json.loads(raw)
except json.JSONDecodeError:
return {}
return cast(dict[str, Any], parsed) if isinstance(parsed, dict) else {}
if isinstance(raw, dict):
return cast(dict[str, Any], raw)
return {}
def _records(event: dict[str, Any]) -> list[dict[str, Any]]:
raw_records = event.get("Records")
if isinstance(raw_records, list):
return [r for r in cast(list[Any], raw_records) if isinstance(r, dict)]
return [event]