mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
185 lines
6.6 KiB
Python
185 lines
6.6 KiB
Python
import json
|
|
import os
|
|
import time
|
|
from functools import wraps
|
|
from typing import Any, Callable, Optional, cast
|
|
from uuid import UUID
|
|
|
|
from backend.app.db.functions.tasks.Tasks import SubTaskInterface, TasksInterface
|
|
from backend.app.db.models.tasks import SourceEnum
|
|
from backend.utils.cloudwatch import build_cloudwatch_log_url
|
|
from utils.logger import setup_logger
|
|
|
|
|
|
def _try_build_cloud_logs_url(start_ms: int) -> Optional[str]:
|
|
# Returns None outside a Lambda runtime so local/non-Lambda runs don't crash.
|
|
required = ("AWS_REGION", "AWS_LAMBDA_LOG_GROUP_NAME", "AWS_LAMBDA_LOG_STREAM_NAME")
|
|
if not all(k in os.environ for k in required):
|
|
return None
|
|
return build_cloudwatch_log_url(start_ms)
|
|
|
|
|
|
def subtask_handler() -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
"""
|
|
Decorator for Lambdas that operate on an already-existing SubTask. Extracts
|
|
task_id + sub_task_id from each record, records the CloudWatch logs URL,
|
|
marks the SubTask in progress, then complete on success / failed on raise.
|
|
"""
|
|
|
|
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
|
|
@wraps(func)
|
|
def wrapper(event: dict[str, Any], context: Any, *args: Any, **kwargs: Any) -> None:
|
|
start_ms = int(time.time() * 1000)
|
|
cloud_logs_url = _try_build_cloud_logs_url(start_ms)
|
|
|
|
records = event.get("Records", [event])
|
|
interface = SubTaskInterface()
|
|
|
|
for record in records:
|
|
raw_body = record.get("body")
|
|
body: dict[str, Any]
|
|
if isinstance(raw_body, str):
|
|
try:
|
|
body = json.loads(raw_body)
|
|
except Exception:
|
|
body = {}
|
|
elif isinstance(raw_body, dict):
|
|
body = cast(dict[str, Any], raw_body)
|
|
else:
|
|
body = {}
|
|
|
|
task_id_raw = body.get("task_id")
|
|
subtask_id_raw = body.get("sub_task_id")
|
|
|
|
task_id = UUID(task_id_raw) if isinstance(task_id_raw, str) else None
|
|
subtask_id = UUID(subtask_id_raw) if isinstance(subtask_id_raw, str) else None
|
|
|
|
if not task_id or not subtask_id:
|
|
raise RuntimeError("task_id or sub_task_id missing")
|
|
|
|
interface.update_subtask_status(
|
|
subtask_id=subtask_id,
|
|
status="in progress",
|
|
cloud_logs_url=cloud_logs_url,
|
|
)
|
|
|
|
try:
|
|
result = func(body, context, *args, **kwargs)
|
|
|
|
interface.update_subtask_status(
|
|
subtask_id=subtask_id,
|
|
status="complete",
|
|
outputs={"result": result} if result else None,
|
|
)
|
|
|
|
except Exception as e:
|
|
interface.update_subtask_status(
|
|
subtask_id=subtask_id,
|
|
status="failed",
|
|
outputs={"error": str(e)},
|
|
)
|
|
raise
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
def task_handler(
|
|
task_source: str,
|
|
source: SourceEnum,
|
|
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
"""
|
|
Decorator for Lambdas that are themselves the entry point of a pipeline (no
|
|
router in front). For each record the decorator creates a fresh Task +
|
|
SubTask with the given task_source and source. source_id is read from
|
|
body[source.value] (silent None if absent) — see ADR-0001. Records the
|
|
CloudWatch logs URL, marks the SubTask in progress, then complete on
|
|
success / failed on raise.
|
|
"""
|
|
|
|
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
|
|
@wraps(func)
|
|
def wrapper(event: dict[str, Any], context: Any, *args: Any, **kwargs: Any) -> Any:
|
|
logger = setup_logger()
|
|
start_ms = int(time.time() * 1000)
|
|
cloud_logs_url = _try_build_cloud_logs_url(start_ms)
|
|
|
|
records = event.get("Records", [event])
|
|
results: list[Any] = []
|
|
failures: list[dict[str, Any]] = []
|
|
interface = SubTaskInterface()
|
|
|
|
for record in records:
|
|
raw_body = record.get("body", record)
|
|
body: dict[str, Any]
|
|
if isinstance(raw_body, str):
|
|
try:
|
|
body = json.loads(raw_body)
|
|
except Exception:
|
|
body = {}
|
|
elif isinstance(raw_body, dict):
|
|
body = cast(dict[str, Any], raw_body)
|
|
else:
|
|
body = {}
|
|
|
|
raw_source_id = body.get(source.value)
|
|
source_id: Optional[str] = (
|
|
str(raw_source_id) if raw_source_id is not None else None
|
|
)
|
|
|
|
logger.info("Creating task for source: %s", task_source)
|
|
task_id, subtask_id = TasksInterface.create_task(
|
|
task_source=task_source,
|
|
inputs=body,
|
|
source=source,
|
|
source_id=source_id,
|
|
)
|
|
|
|
if subtask_id is None:
|
|
raise RuntimeError("create_task did not return a subtask_id")
|
|
|
|
logger.info("Created task_id=%s subtask_id=%s", task_id, subtask_id)
|
|
|
|
interface.update_subtask_status(
|
|
subtask_id=subtask_id,
|
|
status="in progress",
|
|
cloud_logs_url=cloud_logs_url,
|
|
)
|
|
|
|
try:
|
|
result = func(body, context, *args, **kwargs)
|
|
|
|
interface.update_subtask_status(
|
|
subtask_id=subtask_id,
|
|
status="complete",
|
|
outputs={"result": result} if result else None,
|
|
)
|
|
|
|
logger.info("Task %s completed successfully", task_id)
|
|
results.append(result)
|
|
|
|
except Exception as e:
|
|
logger.exception("Task %s failed: %s", task_id, e)
|
|
|
|
interface.update_subtask_status(
|
|
subtask_id=subtask_id,
|
|
status="failed",
|
|
outputs={"error": str(e)},
|
|
)
|
|
|
|
if "Records" in event:
|
|
failures.append({"itemIdentifier": record["messageId"]})
|
|
else:
|
|
raise
|
|
|
|
if "Records" in event:
|
|
return {"batchItemFailures": failures}
|
|
|
|
return results
|
|
|
|
return wrapper
|
|
|
|
return decorator
|