From e458f0a2b718987bfc64635b690cb068293463dc Mon Sep 17 00:00:00 2001 From: Jun-te Kim Date: Tue, 12 May 2026 16:24:11 +0000 Subject: [PATCH] task and sub tasks imrpvoed --- backend/app/db/models/tasks.py | 1 + backend/magic_plan/handler.py | 3 +- backend/pashub_fetcher/handler/handler.py | 3 +- backend/utils/subtasks.py | 131 +++++++++++----------- etl/hubspot/scripts/scraper/main.py | 3 +- 5 files changed, 72 insertions(+), 69 deletions(-) diff --git a/backend/app/db/models/tasks.py b/backend/app/db/models/tasks.py index e97a939f..db1b7c04 100644 --- a/backend/app/db/models/tasks.py +++ b/backend/app/db/models/tasks.py @@ -9,6 +9,7 @@ from sqlmodel import SQLModel, Field, Relationship class SourceEnum(enum.Enum): # TODO: move to domain? PORTFOLIO = "portfolio_id" + HUBSPOT_DEAL = "hubspot_deal_id" class Task(SQLModel, table=True): diff --git a/backend/magic_plan/handler.py b/backend/magic_plan/handler.py index 5fd90b7a..e7dc6484 100644 --- a/backend/magic_plan/handler.py +++ b/backend/magic_plan/handler.py @@ -5,13 +5,14 @@ from backend.magic_plan.magic_plan_client import MagicPlanClient from backend.magic_plan.magic_plan_service import MagicPlanService from backend.magic_plan.magic_plan_trigger_request import MagicPlanTriggerRequest from datatypes.magicplan.domain.models import Plan +from backend.app.db.models.tasks import SourceEnum from backend.utils.subtasks import task_handler from utils.logger import setup_logger logger = setup_logger() -@task_handler() +@task_handler(task_source="magic_plan", source=SourceEnum.HUBSPOT_DEAL) def handler(body: dict[str, Any], context: Any) -> str: settings = get_settings() payload = MagicPlanTriggerRequest.model_validate(body) diff --git a/backend/pashub_fetcher/handler/handler.py b/backend/pashub_fetcher/handler/handler.py index 0d12b6bf..cd0c8113 100644 --- a/backend/pashub_fetcher/handler/handler.py +++ b/backend/pashub_fetcher/handler/handler.py @@ -5,6 +5,7 @@ from backend.pashub_fetcher.pashub_client import PashubClient, UnauthorizedError from backend.pashub_fetcher.pashub_service import PashubService from backend.pashub_fetcher.pashub_to_ara_trigger_request import PashubToAraTriggerRequest from backend.pashub_fetcher.token_getter import get_token_from_local_storage +from backend.app.db.models.tasks import SourceEnum from backend.utils.subtasks import task_handler from utils.logger import setup_logger from utils.sharepoint.domna_sharepoint_client import DomnaSharepointClient @@ -21,7 +22,7 @@ def get_pashub_client(email: str, password: str) -> PashubClient: return PashubClient(token=token) -@task_handler() +@task_handler(task_source="pashub_fetcher", source=SourceEnum.HUBSPOT_DEAL) def handler(body: Dict[str, Any], context: Any) -> List[str]: logger.info("Received message") diff --git a/backend/utils/subtasks.py b/backend/utils/subtasks.py index 6be3a742..36e67b78 100644 --- a/backend/utils/subtasks.py +++ b/backend/utils/subtasks.py @@ -1,75 +1,72 @@ -# decorators/subtask_handler.py - -from functools import wraps -from typing import Callable, Any -from uuid import UUID import json +import os +import time +from functools import wraps +from typing import Any, Callable, Optional, cast +from uuid import UUID from backend.app.db.functions.tasks.Tasks import SubTaskInterface, TasksInterface +from backend.app.db.models.tasks import SourceEnum +from backend.app.plan.utils import build_cloudwatch_log_url from utils.logger import setup_logger -def subtask_handler(): - """ - Decorator that wraps your existing handler and automatically: +def _try_build_cloud_logs_url(start_ms: int) -> Optional[str]: + # Returns None outside a Lambda runtime so local/non-Lambda runs don't crash. + required = ("AWS_REGION", "AWS_LAMBDA_LOG_GROUP_NAME", "AWS_LAMBDA_LOG_STREAM_NAME") + if not all(k in os.environ for k in required): + return None + return build_cloudwatch_log_url(start_ms) - - Extracts task_id + sub_task_id from event - - Marks subtask as in progress - - Executes handler logic - - Marks subtask complete on success - - Marks failed on exception + +def subtask_handler() -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """ + Decorator for Lambdas that operate on an already-existing SubTask. Extracts + task_id + sub_task_id from each record, records the CloudWatch logs URL, + marks the SubTask in progress, then complete on success / failed on raise. """ - def decorator(func: Callable[..., Any]): + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: @wraps(func) - def wrapper(event: dict[str, Any], context: Any, *args, **kwargs): + def wrapper(event: dict[str, Any], context: Any, *args: Any, **kwargs: Any) -> None: + start_ms = int(time.time() * 1000) + cloud_logs_url = _try_build_cloud_logs_url(start_ms) records = event.get("Records", [event]) - interface = SubTaskInterface() for record in records: - - # ------------------------------- - # Parse body safely - # ------------------------------- - body = {} - - if isinstance(record.get("body"), str): + raw_body = record.get("body") + body: dict[str, Any] + if isinstance(raw_body, str): try: - body = json.loads(record["body"]) + body = json.loads(raw_body) except Exception: body = {} + elif isinstance(raw_body, dict): + body = cast(dict[str, Any], raw_body) else: - body = record.get("body", {}) or {} + body = {} task_id_raw = body.get("task_id") subtask_id_raw = body.get("sub_task_id") task_id = UUID(task_id_raw) if isinstance(task_id_raw, str) else None - subtask_id = ( - UUID(subtask_id_raw) if isinstance(subtask_id_raw, str) else None - ) + subtask_id = UUID(subtask_id_raw) if isinstance(subtask_id_raw, str) else None if not task_id or not subtask_id: raise RuntimeError("task_id or sub_task_id missing") - # ------------------------------- - # Mark in progress - # ------------------------------- interface.update_subtask_status( subtask_id=subtask_id, status="in progress", + cloud_logs_url=cloud_logs_url, ) try: - # Pass the parsed body into your function result = func(body, context, *args, **kwargs) - # ------------------------------- - # Success → mark complete - # ------------------------------- interface.update_subtask_status( subtask_id=subtask_id, status="complete", @@ -77,75 +74,79 @@ def subtask_handler(): ) except Exception as e: - - # ------------------------------- - # Failure → mark failed - # ------------------------------- interface.update_subtask_status( subtask_id=subtask_id, status="failed", outputs={"error": str(e)}, ) - raise - return None - return wrapper return decorator -def task_handler(): +def task_handler( + task_source: str, + source: SourceEnum, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """ - Decorator that wraps a Lambda handler and automatically: - - - Parses body from the first SQS record (or uses the event dict directly) - - Creates a fresh Task + SubTask in the database - - Marks the subtask as in progress - - Executes the handler, passing the parsed body - - Marks complete on success, failed on exception (and re-raises) + Decorator for Lambdas that are themselves the entry point of a pipeline (no + router in front). For each record the decorator creates a fresh Task + + SubTask with the given task_source and source. source_id is read from + body[source.value] (silent None if absent) — see ADR-0001. Records the + CloudWatch logs URL, marks the SubTask in progress, then complete on + success / failed on raise. """ - def decorator(func: Callable[..., Any]): - - task_source = f"{func.__module__}.{func.__qualname__}" + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: @wraps(func) - def wrapper(event: dict[str, Any], context: Any, *args, **kwargs): + def wrapper(event: dict[str, Any], context: Any, *args: Any, **kwargs: Any) -> Any: logger = setup_logger() + start_ms = int(time.time() * 1000) + cloud_logs_url = _try_build_cloud_logs_url(start_ms) - records = event.get("Records", [event]) # fallback for non-SQS - - results = [] - failures = [] + records = event.get("Records", [event]) + results: list[Any] = [] + failures: list[dict[str, Any]] = [] + interface = SubTaskInterface() for record in records: - # Parse body raw_body = record.get("body", record) - + body: dict[str, Any] if isinstance(raw_body, str): try: body = json.loads(raw_body) except Exception: body = {} + elif isinstance(raw_body, dict): + body = cast(dict[str, Any], raw_body) else: - body = raw_body or {} + body = {} + + raw_source_id = body.get(source.value) + source_id: Optional[str] = ( + str(raw_source_id) if raw_source_id is not None else None + ) - # Create task per message logger.info("Creating task for source: %s", task_source) task_id, subtask_id = TasksInterface.create_task( task_source=task_source, inputs=body, + source=source, + source_id=source_id, ) - logger.info("Created task_id=%s subtask_id=%s", task_id, subtask_id) + if subtask_id is None: + raise RuntimeError("create_task did not return a subtask_id") - interface = SubTaskInterface() + logger.info("Created task_id=%s subtask_id=%s", task_id, subtask_id) interface.update_subtask_status( subtask_id=subtask_id, status="in progress", + cloud_logs_url=cloud_logs_url, ) try: @@ -172,13 +173,11 @@ def task_handler(): if "Records" in event: failures.append({"itemIdentifier": record["messageId"]}) else: - # Handle non-SQS events raise if "Records" in event: return {"batchItemFailures": failures} - # Handle non-SQS events return results return wrapper diff --git a/etl/hubspot/scripts/scraper/main.py b/etl/hubspot/scripts/scraper/main.py index 86844352..a7b640cf 100644 --- a/etl/hubspot/scripts/scraper/main.py +++ b/etl/hubspot/scripts/scraper/main.py @@ -9,6 +9,7 @@ from etl.hubspot.hubspot_deal_differ import HubspotDealDiffer from etl.hubspot.hubspot_trigger_orchestrator_trigger_request import ( HubspotTriggerOrchestratorTriggerRequest, ) +from backend.app.db.models.tasks import SourceEnum from backend.utils.subtasks import task_handler from backend.app.db.models.hubspot_deal_data import HubspotDealData from utils.logger import setup_logger @@ -16,7 +17,7 @@ from utils.logger import setup_logger logger = setup_logger() -@task_handler() +@task_handler(task_source="hubspot_scraper", source=SourceEnum.HUBSPOT_DEAL) def handler(body: dict[str, Any], context: Any) -> None: db_client = HubspotDataToDb() hubspot_client = HubspotClient()