"""@subtask_handler decorator for Lambdas that operate on existing SubTasks. Translates an AWS Lambda invocation (SQS-shaped or direct) into TaskOrchestrator.run_subtask(...) calls, emitting an INFO log line for each subtask's start and completion and a logged exception on failure. Those lines land in CloudWatch via the Lambda runtime's stdout/stderr capture. Each subtask also records ``cloud_logs_url`` -- a deep link to this invocation's CloudWatch log stream -- so an operator can jump from a SubTask row straight to its logs. It is built from the environment variables the AWS Lambda runtime sets, so it is populated only on real Lambda invocations and left unset under the local RIE (which does not export them). """ import json import logging import os from contextlib import AbstractContextManager from functools import wraps from typing import Any, Callable, Optional, cast from urllib.parse import quote from utilities.aws_lambda.default_orchestrator import default_orchestrator from utilities.aws_lambda.subtask_trigger_body import SubtaskTriggerBody from orchestration.task_orchestrator import TaskOrchestrator logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) OrchestratorCM = Callable[[], AbstractContextManager[TaskOrchestrator]] def subtask_handler( *, orchestrator_cm: Optional[OrchestratorCM] = None, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """Run the wrapped function as the body of an existing SubTask. For each record, validates the body via SubtaskTriggerBody (must contain task_id and sub_task_id), then runs the function inside orchestrator.run_subtask(...). The orchestrator owns the start/complete/ fail lifecycle and cascades status into the parent Task. On failure the underlying exception propagates after the SubTask is marked FAILED. """ factory = orchestrator_cm or default_orchestrator def decorator(func: Callable[..., Any]) -> Callable[..., Any]: @wraps(func) def wrapper(event: dict[str, Any], context: Any) -> None: cloud_logs_url = _cloudwatch_url() with factory() as orchestrator: for record in _records(event): body = _parse_body(record) trigger = SubtaskTriggerBody.model_validate(body) logger.info("Running subtask %s", trigger.sub_task_id) try: orchestrator.run_subtask( trigger.sub_task_id, work=lambda body=body, o=orchestrator: func( body, context, o ), cloud_logs_url=cloud_logs_url, ) except Exception: logger.exception( "Subtask %s failed", trigger.sub_task_id ) raise logger.info("Subtask %s completed", trigger.sub_task_id) return wrapper return decorator def _parse_body(record: dict[str, Any]) -> dict[str, Any]: raw = record.get("body", record) if isinstance(raw, str): try: parsed = json.loads(raw) except json.JSONDecodeError: return {} return cast(dict[str, Any], parsed) if isinstance(parsed, dict) else {} if isinstance(raw, dict): return cast(dict[str, Any], raw) return {} def _records(event: dict[str, Any]) -> list[dict[str, Any]]: raw_records = event.get("Records") if isinstance(raw_records, list): return [r for r in cast(list[Any], raw_records) if isinstance(r, dict)] return [event] def _console_encode(value: str) -> str: """Encode a value for a CloudWatch console deep link. The console expects URL-encoding with the percent signs themselves re-encoded as ``$25`` -- e.g. ``/`` becomes ``%2F`` becomes ``$252F``. """ return quote(value, safe="").replace("%", "$25") def _cloudwatch_url() -> Optional[str]: """Build a CloudWatch console URL for this invocation's log stream. Sourced entirely from the environment variables the AWS Lambda runtime sets -- ``AWS_REGION``, ``AWS_LAMBDA_LOG_GROUP_NAME`` and ``AWS_LAMBDA_LOG_STREAM_NAME``. Returns None when any is absent, which is the case outside a real Lambda (the local RIE does not export them) -- so ``SubTask.cloud_logs_url`` is left unset rather than storing a link that points nowhere. """ region = os.environ.get("AWS_REGION") log_group = os.environ.get("AWS_LAMBDA_LOG_GROUP_NAME") log_stream = os.environ.get("AWS_LAMBDA_LOG_STREAM_NAME") if not (region and log_group and log_stream): return None return ( f"https://{region}.console.aws.amazon.com/cloudwatch/home" f"?region={region}#logsV2:log-groups/log-group/" f"{_console_encode(log_group)}/log-events/{_console_encode(log_stream)}" )