mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
102 lines
3.7 KiB
Python
102 lines
3.7 KiB
Python
"""@subtask_handler decorator for Lambdas that operate on existing SubTasks.
|
|
|
|
Translates an AWS Lambda invocation (SQS-shaped or direct) into
|
|
TaskOrchestrator.run_subtask(...) calls.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
from contextlib import AbstractContextManager
|
|
from functools import wraps
|
|
from typing import Any, Callable, Optional, cast
|
|
from urllib.parse import quote
|
|
|
|
from utilities.aws_lambda.default_orchestrator import default_orchestrator
|
|
from utilities.aws_lambda.subtask_trigger_body import SubtaskTriggerBody
|
|
from orchestration.task_orchestrator import TaskOrchestrator
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.INFO)
|
|
|
|
OrchestratorCM = Callable[[], AbstractContextManager[TaskOrchestrator]]
|
|
|
|
|
|
def subtask_handler(
|
|
*,
|
|
orchestrator_cm: Optional[OrchestratorCM] = None,
|
|
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
"""Run the wrapped function as the body of an existing SubTask.
|
|
|
|
For each record, validates the body via SubtaskTriggerBody (must contain
|
|
task_id and sub_task_id), then runs the function inside
|
|
orchestrator.run_subtask(...). The orchestrator owns the start/complete/
|
|
fail lifecycle and cascades status into the parent Task. On failure the
|
|
underlying exception propagates after the SubTask is marked FAILED.
|
|
"""
|
|
factory = orchestrator_cm or default_orchestrator
|
|
|
|
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
@wraps(func)
|
|
def wrapper(event: dict[str, Any], context: Any) -> None:
|
|
cloud_logs_url = _cloudwatch_url()
|
|
with factory() as orchestrator:
|
|
for record in _records(event):
|
|
body = _parse_body(record)
|
|
trigger = SubtaskTriggerBody.model_validate(body)
|
|
logger.info("Running subtask %s", trigger.sub_task_id)
|
|
try:
|
|
orchestrator.run_subtask(
|
|
trigger.sub_task_id,
|
|
work=lambda body=body, o=orchestrator: func(
|
|
body, context, o
|
|
),
|
|
cloud_logs_url=cloud_logs_url,
|
|
)
|
|
except Exception:
|
|
logger.exception(
|
|
"Subtask %s failed", trigger.sub_task_id
|
|
)
|
|
raise
|
|
logger.info("Subtask %s completed", trigger.sub_task_id)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
def _parse_body(record: dict[str, Any]) -> dict[str, Any]:
|
|
raw = record.get("body", record)
|
|
if isinstance(raw, str):
|
|
try:
|
|
parsed = json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
return {}
|
|
return cast(dict[str, Any], parsed) if isinstance(parsed, dict) else {}
|
|
if isinstance(raw, dict):
|
|
return cast(dict[str, Any], raw)
|
|
return {}
|
|
|
|
|
|
def _records(event: dict[str, Any]) -> list[dict[str, Any]]:
|
|
raw_records = event.get("Records")
|
|
if isinstance(raw_records, list):
|
|
return [r for r in cast(list[Any], raw_records) if isinstance(r, dict)]
|
|
return [event]
|
|
|
|
|
|
def _console_encode(value: str) -> str:
|
|
return quote(value, safe="").replace("%", "$25")
|
|
|
|
|
|
def _cloudwatch_url() -> Optional[str]:
|
|
region = os.environ.get("AWS_REGION")
|
|
log_group = os.environ.get("AWS_LAMBDA_LOG_GROUP_NAME")
|
|
log_stream = os.environ.get("AWS_LAMBDA_LOG_STREAM_NAME")
|
|
if not (region and log_group and log_stream):
|
|
return None
|
|
return (
|
|
f"https://{region}.console.aws.amazon.com/cloudwatch/home"
|
|
f"?region={region}#logsV2:log-groups/log-group/"
|
|
f"{_console_encode(log_group)}/log-events/{_console_encode(log_stream)}"
|
|
)
|