Model/utilities/aws_lambda/subtask_handler.py
2026-05-20 11:07:40 +00:00

124 lines
4.9 KiB
Python

"""@subtask_handler decorator for Lambdas that operate on existing SubTasks.
Translates an AWS Lambda invocation (SQS-shaped or direct) into
TaskOrchestrator.run_subtask(...) calls, emitting an INFO log line for each
subtask's start and completion and a logged exception on failure. Those lines
land in CloudWatch via the Lambda runtime's stdout/stderr capture.
Each subtask also records ``cloud_logs_url`` -- a deep link to this
invocation's CloudWatch log stream -- so an operator can jump from a SubTask
row straight to its logs. It is built from the environment variables the AWS
Lambda runtime sets, so it is populated only on real Lambda invocations and
left unset under the local RIE (which does not export them).
"""
import json
import logging
import os
from contextlib import AbstractContextManager
from functools import wraps
from typing import Any, Callable, Optional, cast
from urllib.parse import quote
from utilities.aws_lambda.default_orchestrator import default_orchestrator
from utilities.aws_lambda.subtask_trigger_body import SubtaskTriggerBody
from orchestration.task_orchestrator import TaskOrchestrator
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
OrchestratorCM = Callable[[], AbstractContextManager[TaskOrchestrator]]
def subtask_handler(
*,
orchestrator_cm: Optional[OrchestratorCM] = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Run the wrapped function as the body of an existing SubTask.
For each record, validates the body via SubtaskTriggerBody (must contain
task_id and sub_task_id), then runs the function inside
orchestrator.run_subtask(...). The orchestrator owns the start/complete/
fail lifecycle and cascades status into the parent Task. On failure the
underlying exception propagates after the SubTask is marked FAILED.
"""
factory = orchestrator_cm or default_orchestrator
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def wrapper(event: dict[str, Any], context: Any) -> None:
cloud_logs_url = _cloudwatch_url()
with factory() as orchestrator:
for record in _records(event):
body = _parse_body(record)
trigger = SubtaskTriggerBody.model_validate(body)
logger.info("Running subtask %s", trigger.sub_task_id)
try:
orchestrator.run_subtask(
trigger.sub_task_id,
work=lambda body=body, o=orchestrator: func(
body, context, o
),
cloud_logs_url=cloud_logs_url,
)
except Exception:
logger.exception(
"Subtask %s failed", trigger.sub_task_id
)
raise
logger.info("Subtask %s completed", trigger.sub_task_id)
return wrapper
return decorator
def _parse_body(record: dict[str, Any]) -> dict[str, Any]:
raw = record.get("body", record)
if isinstance(raw, str):
try:
parsed = json.loads(raw)
except json.JSONDecodeError:
return {}
return cast(dict[str, Any], parsed) if isinstance(parsed, dict) else {}
if isinstance(raw, dict):
return cast(dict[str, Any], raw)
return {}
def _records(event: dict[str, Any]) -> list[dict[str, Any]]:
raw_records = event.get("Records")
if isinstance(raw_records, list):
return [r for r in cast(list[Any], raw_records) if isinstance(r, dict)]
return [event]
def _console_encode(value: str) -> str:
"""Encode a value for a CloudWatch console deep link.
The console expects URL-encoding with the percent signs themselves
re-encoded as ``$25`` -- e.g. ``/`` becomes ``%2F`` becomes ``$252F``.
"""
return quote(value, safe="").replace("%", "$25")
def _cloudwatch_url() -> Optional[str]:
"""Build a CloudWatch console URL for this invocation's log stream.
Sourced entirely from the environment variables the AWS Lambda runtime
sets -- ``AWS_REGION``, ``AWS_LAMBDA_LOG_GROUP_NAME`` and
``AWS_LAMBDA_LOG_STREAM_NAME``. Returns None when any is absent, which is
the case outside a real Lambda (the local RIE does not export them) -- so
``SubTask.cloud_logs_url`` is left unset rather than storing a link that
points nowhere.
"""
region = os.environ.get("AWS_REGION")
log_group = os.environ.get("AWS_LAMBDA_LOG_GROUP_NAME")
log_stream = os.environ.get("AWS_LAMBDA_LOG_STREAM_NAME")
if not (region and log_group and log_stream):
return None
return (
f"https://{region}.console.aws.amazon.com/cloudwatch/home"
f"?region={region}#logsV2:log-groups/log-group/"
f"{_console_encode(log_group)}/log-events/{_console_encode(log_stream)}"
)