Model/backend/utils/subtasks.py
2026-05-12 16:24:11 +00:00

185 lines
6.6 KiB
Python

import json
import os
import time
from functools import wraps
from typing import Any, Callable, Optional, cast
from uuid import UUID
from backend.app.db.functions.tasks.Tasks import SubTaskInterface, TasksInterface
from backend.app.db.models.tasks import SourceEnum
from backend.app.plan.utils import build_cloudwatch_log_url
from utils.logger import setup_logger
def _try_build_cloud_logs_url(start_ms: int) -> Optional[str]:
# Returns None outside a Lambda runtime so local/non-Lambda runs don't crash.
required = ("AWS_REGION", "AWS_LAMBDA_LOG_GROUP_NAME", "AWS_LAMBDA_LOG_STREAM_NAME")
if not all(k in os.environ for k in required):
return None
return build_cloudwatch_log_url(start_ms)
def subtask_handler() -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""
Decorator for Lambdas that operate on an already-existing SubTask. Extracts
task_id + sub_task_id from each record, records the CloudWatch logs URL,
marks the SubTask in progress, then complete on success / failed on raise.
"""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def wrapper(event: dict[str, Any], context: Any, *args: Any, **kwargs: Any) -> None:
start_ms = int(time.time() * 1000)
cloud_logs_url = _try_build_cloud_logs_url(start_ms)
records = event.get("Records", [event])
interface = SubTaskInterface()
for record in records:
raw_body = record.get("body")
body: dict[str, Any]
if isinstance(raw_body, str):
try:
body = json.loads(raw_body)
except Exception:
body = {}
elif isinstance(raw_body, dict):
body = cast(dict[str, Any], raw_body)
else:
body = {}
task_id_raw = body.get("task_id")
subtask_id_raw = body.get("sub_task_id")
task_id = UUID(task_id_raw) if isinstance(task_id_raw, str) else None
subtask_id = UUID(subtask_id_raw) if isinstance(subtask_id_raw, str) else None
if not task_id or not subtask_id:
raise RuntimeError("task_id or sub_task_id missing")
interface.update_subtask_status(
subtask_id=subtask_id,
status="in progress",
cloud_logs_url=cloud_logs_url,
)
try:
result = func(body, context, *args, **kwargs)
interface.update_subtask_status(
subtask_id=subtask_id,
status="complete",
outputs={"result": result} if result else None,
)
except Exception as e:
interface.update_subtask_status(
subtask_id=subtask_id,
status="failed",
outputs={"error": str(e)},
)
raise
return wrapper
return decorator
def task_handler(
task_source: str,
source: SourceEnum,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""
Decorator for Lambdas that are themselves the entry point of a pipeline (no
router in front). For each record the decorator creates a fresh Task +
SubTask with the given task_source and source. source_id is read from
body[source.value] (silent None if absent) — see ADR-0001. Records the
CloudWatch logs URL, marks the SubTask in progress, then complete on
success / failed on raise.
"""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def wrapper(event: dict[str, Any], context: Any, *args: Any, **kwargs: Any) -> Any:
logger = setup_logger()
start_ms = int(time.time() * 1000)
cloud_logs_url = _try_build_cloud_logs_url(start_ms)
records = event.get("Records", [event])
results: list[Any] = []
failures: list[dict[str, Any]] = []
interface = SubTaskInterface()
for record in records:
raw_body = record.get("body", record)
body: dict[str, Any]
if isinstance(raw_body, str):
try:
body = json.loads(raw_body)
except Exception:
body = {}
elif isinstance(raw_body, dict):
body = cast(dict[str, Any], raw_body)
else:
body = {}
raw_source_id = body.get(source.value)
source_id: Optional[str] = (
str(raw_source_id) if raw_source_id is not None else None
)
logger.info("Creating task for source: %s", task_source)
task_id, subtask_id = TasksInterface.create_task(
task_source=task_source,
inputs=body,
source=source,
source_id=source_id,
)
if subtask_id is None:
raise RuntimeError("create_task did not return a subtask_id")
logger.info("Created task_id=%s subtask_id=%s", task_id, subtask_id)
interface.update_subtask_status(
subtask_id=subtask_id,
status="in progress",
cloud_logs_url=cloud_logs_url,
)
try:
result = func(body, context, *args, **kwargs)
interface.update_subtask_status(
subtask_id=subtask_id,
status="complete",
outputs={"result": result} if result else None,
)
logger.info("Task %s completed successfully", task_id)
results.append(result)
except Exception as e:
logger.exception("Task %s failed: %s", task_id, e)
interface.update_subtask_status(
subtask_id=subtask_id,
status="failed",
outputs={"error": str(e)},
)
if "Records" in event:
failures.append({"itemIdentifier": record["messageId"]})
else:
raise
if "Records" in event:
return {"batchItemFailures": failures}
return results
return wrapper
return decorator