mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
98 lines
3.5 KiB
Python
98 lines
3.5 KiB
Python
"""@task_handler decorator for Lambdas that own the entire pipeline.
|
|
|
|
Translates an AWS Lambda invocation (SQS-shaped or direct) into
|
|
TaskOrchestrator.create_task_with_subtask(...) + run_subtask(...).
|
|
"""
|
|
|
|
import json
|
|
from contextlib import AbstractContextManager
|
|
from functools import wraps
|
|
from typing import Any, Callable, Optional, cast
|
|
|
|
from utilities.aws_lambda.default_orchestrator import default_orchestrator
|
|
from domain.tasks.tasks import Source
|
|
from orchestration.task_orchestrator import TaskOrchestrator
|
|
|
|
OrchestratorCM = Callable[[], AbstractContextManager[TaskOrchestrator]]
|
|
|
|
|
|
def task_handler(
|
|
*,
|
|
task_source: str,
|
|
source: Source,
|
|
orchestrator_cm: Optional[OrchestratorCM] = None,
|
|
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
"""Run the wrapped function as the body of a freshly-created Task + SubTask.
|
|
|
|
For each record, creates a new Task + initial SubTask, then runs the
|
|
wrapped function inside orchestrator.run_subtask(...). `source_id` is
|
|
read from body[source.value] (silent None if absent — preserved from
|
|
legacy ADR-0001).
|
|
|
|
Records-style events use SQS partial-batch-failure semantics: individual
|
|
failures are reported via {"batchItemFailures": [...]} rather than
|
|
propagating. Direct invocations re-raise.
|
|
"""
|
|
factory = orchestrator_cm or default_orchestrator
|
|
|
|
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
@wraps(func)
|
|
def wrapper(event: dict[str, Any], context: Any) -> Any:
|
|
with factory() as orchestrator:
|
|
results: list[Any] = []
|
|
failures: list[dict[str, Any]] = []
|
|
|
|
for record in _records(event):
|
|
body = _parse_body(record)
|
|
raw_source_id = body.get(source.value)
|
|
source_id = (
|
|
str(raw_source_id) if raw_source_id is not None else None
|
|
)
|
|
|
|
_, subtask = orchestrator.create_task_with_subtask(
|
|
task_source=task_source,
|
|
inputs=body,
|
|
source=source,
|
|
source_id=source_id,
|
|
)
|
|
|
|
try:
|
|
result = orchestrator.run_subtask(
|
|
subtask.id,
|
|
work=lambda body=body: func(body, context),
|
|
)
|
|
results.append(result)
|
|
except Exception:
|
|
if "Records" in event:
|
|
message_id = record.get("messageId", "")
|
|
failures.append({"itemIdentifier": message_id})
|
|
else:
|
|
raise
|
|
|
|
if "Records" in event:
|
|
return {"batchItemFailures": failures}
|
|
return results
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
def _parse_body(record: dict[str, Any]) -> dict[str, Any]:
|
|
raw = record.get("body", record)
|
|
if isinstance(raw, str):
|
|
try:
|
|
parsed = json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
return {}
|
|
return cast(dict[str, Any], parsed) if isinstance(parsed, dict) else {}
|
|
if isinstance(raw, dict):
|
|
return cast(dict[str, Any], raw)
|
|
return {}
|
|
|
|
|
|
def _records(event: dict[str, Any]) -> list[dict[str, Any]]:
|
|
raw_records = event.get("Records")
|
|
if isinstance(raw_records, list):
|
|
return [r for r in cast(list[Any], raw_records) if isinstance(r, dict)]
|
|
return [event]
|