"""@subtask_handler decorator for Lambdas that operate on existing SubTasks. Translates an AWS Lambda invocation (SQS-shaped or direct) into TaskOrchestrator.run_subtask(...) calls. """ import json import logging import os from contextlib import AbstractContextManager from functools import wraps from typing import Any, Callable, Optional, cast from urllib.parse import quote from utilities.aws_lambda.default_orchestrator import default_orchestrator from utilities.aws_lambda.subtask_trigger_body import SubtaskTriggerBody from orchestration.task_orchestrator import TaskOrchestrator logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) OrchestratorCM = Callable[[], AbstractContextManager[TaskOrchestrator]] def subtask_handler( *, orchestrator_cm: Optional[OrchestratorCM] = None, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """Run the wrapped function as the body of an existing SubTask. For each record, validates the body via SubtaskTriggerBody (must contain task_id and sub_task_id), then runs the function inside orchestrator.run_subtask(...). The orchestrator owns the start/complete/ fail lifecycle and cascades status into the parent Task. On failure the underlying exception propagates after the SubTask is marked FAILED. """ factory = orchestrator_cm or default_orchestrator def decorator(func: Callable[..., Any]) -> Callable[..., Any]: @wraps(func) def wrapper(event: dict[str, Any], context: Any) -> None: cloud_logs_url = _cloudwatch_url() with factory() as orchestrator: for record in _records(event): body = _parse_body(record) trigger = SubtaskTriggerBody.model_validate(body) logger.info("Running subtask %s", trigger.sub_task_id) try: orchestrator.run_subtask( trigger.sub_task_id, work=lambda body=body, o=orchestrator: func( body, context, o ), cloud_logs_url=cloud_logs_url, ) except Exception: logger.exception( "Subtask %s failed", trigger.sub_task_id ) raise logger.info("Subtask %s completed", trigger.sub_task_id) return wrapper return decorator def _parse_body(record: dict[str, Any]) -> dict[str, Any]: raw = record.get("body", record) if isinstance(raw, str): try: parsed = json.loads(raw) except json.JSONDecodeError: return {} return cast(dict[str, Any], parsed) if isinstance(parsed, dict) else {} if isinstance(raw, dict): return cast(dict[str, Any], raw) return {} def _records(event: dict[str, Any]) -> list[dict[str, Any]]: raw_records = event.get("Records") if isinstance(raw_records, list): return [r for r in cast(list[Any], raw_records) if isinstance(r, dict)] return [event] def _console_encode(value: str) -> str: return quote(value, safe="").replace("%", "$25") def _cloudwatch_url() -> Optional[str]: region = os.environ.get("AWS_REGION") log_group = os.environ.get("AWS_LAMBDA_LOG_GROUP_NAME") log_stream = os.environ.get("AWS_LAMBDA_LOG_STREAM_NAME") if not (region and log_group and log_stream): return None return ( f"https://{region}.console.aws.amazon.com/cloudwatch/home" f"?region={region}#logsV2:log-groups/log-group/" f"{_console_encode(log_group)}/log-events/{_console_encode(log_stream)}" )