task and sub tasks imrpvoed

This commit is contained in:
Jun-te Kim 2026-05-12 16:24:11 +00:00
parent 272bfbde13
commit e458f0a2b7
5 changed files with 72 additions and 69 deletions

View file

@ -9,6 +9,7 @@ from sqlmodel import SQLModel, Field, Relationship
class SourceEnum(enum.Enum): # TODO: move to domain? class SourceEnum(enum.Enum): # TODO: move to domain?
PORTFOLIO = "portfolio_id" PORTFOLIO = "portfolio_id"
HUBSPOT_DEAL = "hubspot_deal_id"
class Task(SQLModel, table=True): class Task(SQLModel, table=True):

View file

@ -5,13 +5,14 @@ from backend.magic_plan.magic_plan_client import MagicPlanClient
from backend.magic_plan.magic_plan_service import MagicPlanService from backend.magic_plan.magic_plan_service import MagicPlanService
from backend.magic_plan.magic_plan_trigger_request import MagicPlanTriggerRequest from backend.magic_plan.magic_plan_trigger_request import MagicPlanTriggerRequest
from datatypes.magicplan.domain.models import Plan from datatypes.magicplan.domain.models import Plan
from backend.app.db.models.tasks import SourceEnum
from backend.utils.subtasks import task_handler from backend.utils.subtasks import task_handler
from utils.logger import setup_logger from utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
@task_handler() @task_handler(task_source="magic_plan", source=SourceEnum.HUBSPOT_DEAL)
def handler(body: dict[str, Any], context: Any) -> str: def handler(body: dict[str, Any], context: Any) -> str:
settings = get_settings() settings = get_settings()
payload = MagicPlanTriggerRequest.model_validate(body) payload = MagicPlanTriggerRequest.model_validate(body)

View file

@ -5,6 +5,7 @@ from backend.pashub_fetcher.pashub_client import PashubClient, UnauthorizedError
from backend.pashub_fetcher.pashub_service import PashubService from backend.pashub_fetcher.pashub_service import PashubService
from backend.pashub_fetcher.pashub_to_ara_trigger_request import PashubToAraTriggerRequest from backend.pashub_fetcher.pashub_to_ara_trigger_request import PashubToAraTriggerRequest
from backend.pashub_fetcher.token_getter import get_token_from_local_storage from backend.pashub_fetcher.token_getter import get_token_from_local_storage
from backend.app.db.models.tasks import SourceEnum
from backend.utils.subtasks import task_handler from backend.utils.subtasks import task_handler
from utils.logger import setup_logger from utils.logger import setup_logger
from utils.sharepoint.domna_sharepoint_client import DomnaSharepointClient from utils.sharepoint.domna_sharepoint_client import DomnaSharepointClient
@ -21,7 +22,7 @@ def get_pashub_client(email: str, password: str) -> PashubClient:
return PashubClient(token=token) return PashubClient(token=token)
@task_handler() @task_handler(task_source="pashub_fetcher", source=SourceEnum.HUBSPOT_DEAL)
def handler(body: Dict[str, Any], context: Any) -> List[str]: def handler(body: Dict[str, Any], context: Any) -> List[str]:
logger.info("Received message") logger.info("Received message")

View file

@ -1,75 +1,72 @@
# decorators/subtask_handler.py
from functools import wraps
from typing import Callable, Any
from uuid import UUID
import json import json
import os
import time
from functools import wraps
from typing import Any, Callable, Optional, cast
from uuid import UUID
from backend.app.db.functions.tasks.Tasks import SubTaskInterface, TasksInterface from backend.app.db.functions.tasks.Tasks import SubTaskInterface, TasksInterface
from backend.app.db.models.tasks import SourceEnum
from backend.app.plan.utils import build_cloudwatch_log_url
from utils.logger import setup_logger from utils.logger import setup_logger
def subtask_handler(): def _try_build_cloud_logs_url(start_ms: int) -> Optional[str]:
""" # Returns None outside a Lambda runtime so local/non-Lambda runs don't crash.
Decorator that wraps your existing handler and automatically: required = ("AWS_REGION", "AWS_LAMBDA_LOG_GROUP_NAME", "AWS_LAMBDA_LOG_STREAM_NAME")
if not all(k in os.environ for k in required):
return None
return build_cloudwatch_log_url(start_ms)
- Extracts task_id + sub_task_id from event
- Marks subtask as in progress def subtask_handler() -> Callable[[Callable[..., Any]], Callable[..., Any]]:
- Executes handler logic """
- Marks subtask complete on success Decorator for Lambdas that operate on an already-existing SubTask. Extracts
- Marks failed on exception task_id + sub_task_id from each record, records the CloudWatch logs URL,
marks the SubTask in progress, then complete on success / failed on raise.
""" """
def decorator(func: Callable[..., Any]): def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func) @wraps(func)
def wrapper(event: dict[str, Any], context: Any, *args, **kwargs): def wrapper(event: dict[str, Any], context: Any, *args: Any, **kwargs: Any) -> None:
start_ms = int(time.time() * 1000)
cloud_logs_url = _try_build_cloud_logs_url(start_ms)
records = event.get("Records", [event]) records = event.get("Records", [event])
interface = SubTaskInterface() interface = SubTaskInterface()
for record in records: for record in records:
raw_body = record.get("body")
# ------------------------------- body: dict[str, Any]
# Parse body safely if isinstance(raw_body, str):
# -------------------------------
body = {}
if isinstance(record.get("body"), str):
try: try:
body = json.loads(record["body"]) body = json.loads(raw_body)
except Exception: except Exception:
body = {} body = {}
elif isinstance(raw_body, dict):
body = cast(dict[str, Any], raw_body)
else: else:
body = record.get("body", {}) or {} body = {}
task_id_raw = body.get("task_id") task_id_raw = body.get("task_id")
subtask_id_raw = body.get("sub_task_id") subtask_id_raw = body.get("sub_task_id")
task_id = UUID(task_id_raw) if isinstance(task_id_raw, str) else None task_id = UUID(task_id_raw) if isinstance(task_id_raw, str) else None
subtask_id = ( subtask_id = UUID(subtask_id_raw) if isinstance(subtask_id_raw, str) else None
UUID(subtask_id_raw) if isinstance(subtask_id_raw, str) else None
)
if not task_id or not subtask_id: if not task_id or not subtask_id:
raise RuntimeError("task_id or sub_task_id missing") raise RuntimeError("task_id or sub_task_id missing")
# -------------------------------
# Mark in progress
# -------------------------------
interface.update_subtask_status( interface.update_subtask_status(
subtask_id=subtask_id, subtask_id=subtask_id,
status="in progress", status="in progress",
cloud_logs_url=cloud_logs_url,
) )
try: try:
# Pass the parsed body into your function
result = func(body, context, *args, **kwargs) result = func(body, context, *args, **kwargs)
# -------------------------------
# Success → mark complete
# -------------------------------
interface.update_subtask_status( interface.update_subtask_status(
subtask_id=subtask_id, subtask_id=subtask_id,
status="complete", status="complete",
@ -77,75 +74,79 @@ def subtask_handler():
) )
except Exception as e: except Exception as e:
# -------------------------------
# Failure → mark failed
# -------------------------------
interface.update_subtask_status( interface.update_subtask_status(
subtask_id=subtask_id, subtask_id=subtask_id,
status="failed", status="failed",
outputs={"error": str(e)}, outputs={"error": str(e)},
) )
raise raise
return None
return wrapper return wrapper
return decorator return decorator
def task_handler(): def task_handler(
task_source: str,
source: SourceEnum,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
""" """
Decorator that wraps a Lambda handler and automatically: Decorator for Lambdas that are themselves the entry point of a pipeline (no
router in front). For each record the decorator creates a fresh Task +
- Parses body from the first SQS record (or uses the event dict directly) SubTask with the given task_source and source. source_id is read from
- Creates a fresh Task + SubTask in the database body[source.value] (silent None if absent) see ADR-0001. Records the
- Marks the subtask as in progress CloudWatch logs URL, marks the SubTask in progress, then complete on
- Executes the handler, passing the parsed body success / failed on raise.
- Marks complete on success, failed on exception (and re-raises)
""" """
def decorator(func: Callable[..., Any]): def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
task_source = f"{func.__module__}.{func.__qualname__}"
@wraps(func) @wraps(func)
def wrapper(event: dict[str, Any], context: Any, *args, **kwargs): def wrapper(event: dict[str, Any], context: Any, *args: Any, **kwargs: Any) -> Any:
logger = setup_logger() logger = setup_logger()
start_ms = int(time.time() * 1000)
cloud_logs_url = _try_build_cloud_logs_url(start_ms)
records = event.get("Records", [event]) # fallback for non-SQS records = event.get("Records", [event])
results: list[Any] = []
results = [] failures: list[dict[str, Any]] = []
failures = [] interface = SubTaskInterface()
for record in records: for record in records:
# Parse body
raw_body = record.get("body", record) raw_body = record.get("body", record)
body: dict[str, Any]
if isinstance(raw_body, str): if isinstance(raw_body, str):
try: try:
body = json.loads(raw_body) body = json.loads(raw_body)
except Exception: except Exception:
body = {} body = {}
elif isinstance(raw_body, dict):
body = cast(dict[str, Any], raw_body)
else: else:
body = raw_body or {} body = {}
raw_source_id = body.get(source.value)
source_id: Optional[str] = (
str(raw_source_id) if raw_source_id is not None else None
)
# Create task per message
logger.info("Creating task for source: %s", task_source) logger.info("Creating task for source: %s", task_source)
task_id, subtask_id = TasksInterface.create_task( task_id, subtask_id = TasksInterface.create_task(
task_source=task_source, task_source=task_source,
inputs=body, inputs=body,
source=source,
source_id=source_id,
) )
logger.info("Created task_id=%s subtask_id=%s", task_id, subtask_id) if subtask_id is None:
raise RuntimeError("create_task did not return a subtask_id")
interface = SubTaskInterface() logger.info("Created task_id=%s subtask_id=%s", task_id, subtask_id)
interface.update_subtask_status( interface.update_subtask_status(
subtask_id=subtask_id, subtask_id=subtask_id,
status="in progress", status="in progress",
cloud_logs_url=cloud_logs_url,
) )
try: try:
@ -172,13 +173,11 @@ def task_handler():
if "Records" in event: if "Records" in event:
failures.append({"itemIdentifier": record["messageId"]}) failures.append({"itemIdentifier": record["messageId"]})
else: else:
# Handle non-SQS events
raise raise
if "Records" in event: if "Records" in event:
return {"batchItemFailures": failures} return {"batchItemFailures": failures}
# Handle non-SQS events
return results return results
return wrapper return wrapper

View file

@ -9,6 +9,7 @@ from etl.hubspot.hubspot_deal_differ import HubspotDealDiffer
from etl.hubspot.hubspot_trigger_orchestrator_trigger_request import ( from etl.hubspot.hubspot_trigger_orchestrator_trigger_request import (
HubspotTriggerOrchestratorTriggerRequest, HubspotTriggerOrchestratorTriggerRequest,
) )
from backend.app.db.models.tasks import SourceEnum
from backend.utils.subtasks import task_handler from backend.utils.subtasks import task_handler
from backend.app.db.models.hubspot_deal_data import HubspotDealData from backend.app.db.models.hubspot_deal_data import HubspotDealData
from utils.logger import setup_logger from utils.logger import setup_logger
@ -16,7 +17,7 @@ from utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
@task_handler() @task_handler(task_source="hubspot_scraper", source=SourceEnum.HUBSPOT_DEAL)
def handler(body: dict[str, Any], context: Any) -> None: def handler(body: dict[str, Any], context: Any) -> None:
db_client = HubspotDataToDb() db_client = HubspotDataToDb()
hubspot_client = HubspotClient() hubspot_client = HubspotClient()