"""@task_handler decorator for Lambdas that own the entire pipeline. Translates an AWS Lambda invocation (SQS-shaped or direct) into TaskOrchestrator.create_task_with_subtask(...) + run_subtask(...). """ import json from contextlib import AbstractContextManager from functools import wraps from typing import Any, Callable, Optional, cast from utilities.aws_lambda.default_orchestrator import default_orchestrator from domain.tasks.tasks import Source from orchestration.task_orchestrator import TaskOrchestrator OrchestratorCM = Callable[[], AbstractContextManager[TaskOrchestrator]] def task_handler( *, task_source: str, source: Source, orchestrator_cm: Optional[OrchestratorCM] = None, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """Run the wrapped function as the body of a freshly-created Task + SubTask. For each record, creates a new Task + initial SubTask, then runs the wrapped function inside orchestrator.run_subtask(...). `source_id` is read from body[source.value] (silent None if absent — preserved from legacy ADR-0001). Records-style events use SQS partial-batch-failure semantics: individual failures are reported via {"batchItemFailures": [...]} rather than propagating. Direct invocations re-raise. """ factory = orchestrator_cm or default_orchestrator def decorator(func: Callable[..., Any]) -> Callable[..., Any]: @wraps(func) def wrapper(event: dict[str, Any], context: Any) -> Any: with factory() as orchestrator: results: list[Any] = [] failures: list[dict[str, Any]] = [] for record in _records(event): body = _parse_body(record) raw_source_id = body.get(source.value) source_id = ( str(raw_source_id) if raw_source_id is not None else None ) _, subtask = orchestrator.create_task_with_subtask( task_source=task_source, inputs=body, source=source, source_id=source_id, ) try: result = orchestrator.run_subtask( subtask.id, work=lambda body=body: func(body, context), ) results.append(result) except Exception: if "Records" in event: message_id = record.get("messageId", "") failures.append({"itemIdentifier": message_id}) else: raise if "Records" in event: return {"batchItemFailures": failures} return results return wrapper return decorator def _parse_body(record: dict[str, Any]) -> dict[str, Any]: raw = record.get("body", record) if isinstance(raw, str): try: parsed = json.loads(raw) except json.JSONDecodeError: return {} return cast(dict[str, Any], parsed) if isinstance(parsed, dict) else {} if isinstance(raw, dict): return cast(dict[str, Any], raw) return {} def _records(event: dict[str, Any]) -> list[dict[str, Any]]: raw_records = event.get("Records") if isinstance(raw_records, list): return [r for r in cast(list[Any], raw_records) if isinstance(r, dict)] return [event]