task and sub tasks imrpvoed

This commit is contained in:
Jun-te Kim 2026-05-12 16:24:11 +00:00
parent 272bfbde13
commit e458f0a2b7
5 changed files with 72 additions and 69 deletions

View file

@ -9,6 +9,7 @@ from sqlmodel import SQLModel, Field, Relationship
class SourceEnum(enum.Enum): # TODO: move to domain?
PORTFOLIO = "portfolio_id"
HUBSPOT_DEAL = "hubspot_deal_id"
class Task(SQLModel, table=True):

View file

@ -5,13 +5,14 @@ from backend.magic_plan.magic_plan_client import MagicPlanClient
from backend.magic_plan.magic_plan_service import MagicPlanService
from backend.magic_plan.magic_plan_trigger_request import MagicPlanTriggerRequest
from datatypes.magicplan.domain.models import Plan
from backend.app.db.models.tasks import SourceEnum
from backend.utils.subtasks import task_handler
from utils.logger import setup_logger
logger = setup_logger()
@task_handler()
@task_handler(task_source="magic_plan", source=SourceEnum.HUBSPOT_DEAL)
def handler(body: dict[str, Any], context: Any) -> str:
settings = get_settings()
payload = MagicPlanTriggerRequest.model_validate(body)

View file

@ -5,6 +5,7 @@ from backend.pashub_fetcher.pashub_client import PashubClient, UnauthorizedError
from backend.pashub_fetcher.pashub_service import PashubService
from backend.pashub_fetcher.pashub_to_ara_trigger_request import PashubToAraTriggerRequest
from backend.pashub_fetcher.token_getter import get_token_from_local_storage
from backend.app.db.models.tasks import SourceEnum
from backend.utils.subtasks import task_handler
from utils.logger import setup_logger
from utils.sharepoint.domna_sharepoint_client import DomnaSharepointClient
@ -21,7 +22,7 @@ def get_pashub_client(email: str, password: str) -> PashubClient:
return PashubClient(token=token)
@task_handler()
@task_handler(task_source="pashub_fetcher", source=SourceEnum.HUBSPOT_DEAL)
def handler(body: Dict[str, Any], context: Any) -> List[str]:
logger.info("Received message")

View file

@ -1,75 +1,72 @@
# decorators/subtask_handler.py
from functools import wraps
from typing import Callable, Any
from uuid import UUID
import json
import os
import time
from functools import wraps
from typing import Any, Callable, Optional, cast
from uuid import UUID
from backend.app.db.functions.tasks.Tasks import SubTaskInterface, TasksInterface
from backend.app.db.models.tasks import SourceEnum
from backend.app.plan.utils import build_cloudwatch_log_url
from utils.logger import setup_logger
def subtask_handler():
"""
Decorator that wraps your existing handler and automatically:
def _try_build_cloud_logs_url(start_ms: int) -> Optional[str]:
# Returns None outside a Lambda runtime so local/non-Lambda runs don't crash.
required = ("AWS_REGION", "AWS_LAMBDA_LOG_GROUP_NAME", "AWS_LAMBDA_LOG_STREAM_NAME")
if not all(k in os.environ for k in required):
return None
return build_cloudwatch_log_url(start_ms)
- Extracts task_id + sub_task_id from event
- Marks subtask as in progress
- Executes handler logic
- Marks subtask complete on success
- Marks failed on exception
def subtask_handler() -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""
Decorator for Lambdas that operate on an already-existing SubTask. Extracts
task_id + sub_task_id from each record, records the CloudWatch logs URL,
marks the SubTask in progress, then complete on success / failed on raise.
"""
def decorator(func: Callable[..., Any]):
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def wrapper(event: dict[str, Any], context: Any, *args, **kwargs):
def wrapper(event: dict[str, Any], context: Any, *args: Any, **kwargs: Any) -> None:
start_ms = int(time.time() * 1000)
cloud_logs_url = _try_build_cloud_logs_url(start_ms)
records = event.get("Records", [event])
interface = SubTaskInterface()
for record in records:
# -------------------------------
# Parse body safely
# -------------------------------
body = {}
if isinstance(record.get("body"), str):
raw_body = record.get("body")
body: dict[str, Any]
if isinstance(raw_body, str):
try:
body = json.loads(record["body"])
body = json.loads(raw_body)
except Exception:
body = {}
elif isinstance(raw_body, dict):
body = cast(dict[str, Any], raw_body)
else:
body = record.get("body", {}) or {}
body = {}
task_id_raw = body.get("task_id")
subtask_id_raw = body.get("sub_task_id")
task_id = UUID(task_id_raw) if isinstance(task_id_raw, str) else None
subtask_id = (
UUID(subtask_id_raw) if isinstance(subtask_id_raw, str) else None
)
subtask_id = UUID(subtask_id_raw) if isinstance(subtask_id_raw, str) else None
if not task_id or not subtask_id:
raise RuntimeError("task_id or sub_task_id missing")
# -------------------------------
# Mark in progress
# -------------------------------
interface.update_subtask_status(
subtask_id=subtask_id,
status="in progress",
cloud_logs_url=cloud_logs_url,
)
try:
# Pass the parsed body into your function
result = func(body, context, *args, **kwargs)
# -------------------------------
# Success → mark complete
# -------------------------------
interface.update_subtask_status(
subtask_id=subtask_id,
status="complete",
@ -77,75 +74,79 @@ def subtask_handler():
)
except Exception as e:
# -------------------------------
# Failure → mark failed
# -------------------------------
interface.update_subtask_status(
subtask_id=subtask_id,
status="failed",
outputs={"error": str(e)},
)
raise
return None
return wrapper
return decorator
def task_handler():
def task_handler(
task_source: str,
source: SourceEnum,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""
Decorator that wraps a Lambda handler and automatically:
- Parses body from the first SQS record (or uses the event dict directly)
- Creates a fresh Task + SubTask in the database
- Marks the subtask as in progress
- Executes the handler, passing the parsed body
- Marks complete on success, failed on exception (and re-raises)
Decorator for Lambdas that are themselves the entry point of a pipeline (no
router in front). For each record the decorator creates a fresh Task +
SubTask with the given task_source and source. source_id is read from
body[source.value] (silent None if absent) see ADR-0001. Records the
CloudWatch logs URL, marks the SubTask in progress, then complete on
success / failed on raise.
"""
def decorator(func: Callable[..., Any]):
task_source = f"{func.__module__}.{func.__qualname__}"
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def wrapper(event: dict[str, Any], context: Any, *args, **kwargs):
def wrapper(event: dict[str, Any], context: Any, *args: Any, **kwargs: Any) -> Any:
logger = setup_logger()
start_ms = int(time.time() * 1000)
cloud_logs_url = _try_build_cloud_logs_url(start_ms)
records = event.get("Records", [event]) # fallback for non-SQS
results = []
failures = []
records = event.get("Records", [event])
results: list[Any] = []
failures: list[dict[str, Any]] = []
interface = SubTaskInterface()
for record in records:
# Parse body
raw_body = record.get("body", record)
body: dict[str, Any]
if isinstance(raw_body, str):
try:
body = json.loads(raw_body)
except Exception:
body = {}
elif isinstance(raw_body, dict):
body = cast(dict[str, Any], raw_body)
else:
body = raw_body or {}
body = {}
raw_source_id = body.get(source.value)
source_id: Optional[str] = (
str(raw_source_id) if raw_source_id is not None else None
)
# Create task per message
logger.info("Creating task for source: %s", task_source)
task_id, subtask_id = TasksInterface.create_task(
task_source=task_source,
inputs=body,
source=source,
source_id=source_id,
)
logger.info("Created task_id=%s subtask_id=%s", task_id, subtask_id)
if subtask_id is None:
raise RuntimeError("create_task did not return a subtask_id")
interface = SubTaskInterface()
logger.info("Created task_id=%s subtask_id=%s", task_id, subtask_id)
interface.update_subtask_status(
subtask_id=subtask_id,
status="in progress",
cloud_logs_url=cloud_logs_url,
)
try:
@ -172,13 +173,11 @@ def task_handler():
if "Records" in event:
failures.append({"itemIdentifier": record["messageId"]})
else:
# Handle non-SQS events
raise
if "Records" in event:
return {"batchItemFailures": failures}
# Handle non-SQS events
return results
return wrapper

View file

@ -9,6 +9,7 @@ from etl.hubspot.hubspot_deal_differ import HubspotDealDiffer
from etl.hubspot.hubspot_trigger_orchestrator_trigger_request import (
HubspotTriggerOrchestratorTriggerRequest,
)
from backend.app.db.models.tasks import SourceEnum
from backend.utils.subtasks import task_handler
from backend.app.db.models.hubspot_deal_data import HubspotDealData
from utils.logger import setup_logger
@ -16,7 +17,7 @@ from utils.logger import setup_logger
logger = setup_logger()
@task_handler()
@task_handler(task_source="hubspot_scraper", source=SourceEnum.HUBSPOT_DEAL)
def handler(body: dict[str, Any], context: Any) -> None:
db_client = HubspotDataToDb()
hubspot_client = HubspotClient()