mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
task and sub tasks imrpvoed
This commit is contained in:
parent
272bfbde13
commit
e458f0a2b7
5 changed files with 72 additions and 69 deletions
|
|
@ -9,6 +9,7 @@ from sqlmodel import SQLModel, Field, Relationship
|
|||
|
||||
class SourceEnum(enum.Enum): # TODO: move to domain?
|
||||
PORTFOLIO = "portfolio_id"
|
||||
HUBSPOT_DEAL = "hubspot_deal_id"
|
||||
|
||||
|
||||
class Task(SQLModel, table=True):
|
||||
|
|
|
|||
|
|
@ -5,13 +5,14 @@ from backend.magic_plan.magic_plan_client import MagicPlanClient
|
|||
from backend.magic_plan.magic_plan_service import MagicPlanService
|
||||
from backend.magic_plan.magic_plan_trigger_request import MagicPlanTriggerRequest
|
||||
from datatypes.magicplan.domain.models import Plan
|
||||
from backend.app.db.models.tasks import SourceEnum
|
||||
from backend.utils.subtasks import task_handler
|
||||
from utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@task_handler()
|
||||
@task_handler(task_source="magic_plan", source=SourceEnum.HUBSPOT_DEAL)
|
||||
def handler(body: dict[str, Any], context: Any) -> str:
|
||||
settings = get_settings()
|
||||
payload = MagicPlanTriggerRequest.model_validate(body)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from backend.pashub_fetcher.pashub_client import PashubClient, UnauthorizedError
|
|||
from backend.pashub_fetcher.pashub_service import PashubService
|
||||
from backend.pashub_fetcher.pashub_to_ara_trigger_request import PashubToAraTriggerRequest
|
||||
from backend.pashub_fetcher.token_getter import get_token_from_local_storage
|
||||
from backend.app.db.models.tasks import SourceEnum
|
||||
from backend.utils.subtasks import task_handler
|
||||
from utils.logger import setup_logger
|
||||
from utils.sharepoint.domna_sharepoint_client import DomnaSharepointClient
|
||||
|
|
@ -21,7 +22,7 @@ def get_pashub_client(email: str, password: str) -> PashubClient:
|
|||
return PashubClient(token=token)
|
||||
|
||||
|
||||
@task_handler()
|
||||
@task_handler(task_source="pashub_fetcher", source=SourceEnum.HUBSPOT_DEAL)
|
||||
def handler(body: Dict[str, Any], context: Any) -> List[str]:
|
||||
logger.info("Received message")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,75 +1,72 @@
|
|||
# decorators/subtask_handler.py
|
||||
|
||||
from functools import wraps
|
||||
from typing import Callable, Any
|
||||
from uuid import UUID
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Optional, cast
|
||||
from uuid import UUID
|
||||
|
||||
from backend.app.db.functions.tasks.Tasks import SubTaskInterface, TasksInterface
|
||||
from backend.app.db.models.tasks import SourceEnum
|
||||
from backend.app.plan.utils import build_cloudwatch_log_url
|
||||
from utils.logger import setup_logger
|
||||
|
||||
|
||||
def subtask_handler():
|
||||
"""
|
||||
Decorator that wraps your existing handler and automatically:
|
||||
def _try_build_cloud_logs_url(start_ms: int) -> Optional[str]:
|
||||
# Returns None outside a Lambda runtime so local/non-Lambda runs don't crash.
|
||||
required = ("AWS_REGION", "AWS_LAMBDA_LOG_GROUP_NAME", "AWS_LAMBDA_LOG_STREAM_NAME")
|
||||
if not all(k in os.environ for k in required):
|
||||
return None
|
||||
return build_cloudwatch_log_url(start_ms)
|
||||
|
||||
- Extracts task_id + sub_task_id from event
|
||||
- Marks subtask as in progress
|
||||
- Executes handler logic
|
||||
- Marks subtask complete on success
|
||||
- Marks failed on exception
|
||||
|
||||
def subtask_handler() -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
"""
|
||||
Decorator for Lambdas that operate on an already-existing SubTask. Extracts
|
||||
task_id + sub_task_id from each record, records the CloudWatch logs URL,
|
||||
marks the SubTask in progress, then complete on success / failed on raise.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., Any]):
|
||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(event: dict[str, Any], context: Any, *args, **kwargs):
|
||||
def wrapper(event: dict[str, Any], context: Any, *args: Any, **kwargs: Any) -> None:
|
||||
start_ms = int(time.time() * 1000)
|
||||
cloud_logs_url = _try_build_cloud_logs_url(start_ms)
|
||||
|
||||
records = event.get("Records", [event])
|
||||
|
||||
interface = SubTaskInterface()
|
||||
|
||||
for record in records:
|
||||
|
||||
# -------------------------------
|
||||
# Parse body safely
|
||||
# -------------------------------
|
||||
body = {}
|
||||
|
||||
if isinstance(record.get("body"), str):
|
||||
raw_body = record.get("body")
|
||||
body: dict[str, Any]
|
||||
if isinstance(raw_body, str):
|
||||
try:
|
||||
body = json.loads(record["body"])
|
||||
body = json.loads(raw_body)
|
||||
except Exception:
|
||||
body = {}
|
||||
elif isinstance(raw_body, dict):
|
||||
body = cast(dict[str, Any], raw_body)
|
||||
else:
|
||||
body = record.get("body", {}) or {}
|
||||
body = {}
|
||||
|
||||
task_id_raw = body.get("task_id")
|
||||
subtask_id_raw = body.get("sub_task_id")
|
||||
|
||||
task_id = UUID(task_id_raw) if isinstance(task_id_raw, str) else None
|
||||
subtask_id = (
|
||||
UUID(subtask_id_raw) if isinstance(subtask_id_raw, str) else None
|
||||
)
|
||||
subtask_id = UUID(subtask_id_raw) if isinstance(subtask_id_raw, str) else None
|
||||
|
||||
if not task_id or not subtask_id:
|
||||
raise RuntimeError("task_id or sub_task_id missing")
|
||||
|
||||
# -------------------------------
|
||||
# Mark in progress
|
||||
# -------------------------------
|
||||
interface.update_subtask_status(
|
||||
subtask_id=subtask_id,
|
||||
status="in progress",
|
||||
cloud_logs_url=cloud_logs_url,
|
||||
)
|
||||
|
||||
try:
|
||||
# Pass the parsed body into your function
|
||||
result = func(body, context, *args, **kwargs)
|
||||
|
||||
# -------------------------------
|
||||
# Success → mark complete
|
||||
# -------------------------------
|
||||
interface.update_subtask_status(
|
||||
subtask_id=subtask_id,
|
||||
status="complete",
|
||||
|
|
@ -77,75 +74,79 @@ def subtask_handler():
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
# -------------------------------
|
||||
# Failure → mark failed
|
||||
# -------------------------------
|
||||
interface.update_subtask_status(
|
||||
subtask_id=subtask_id,
|
||||
status="failed",
|
||||
outputs={"error": str(e)},
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def task_handler():
|
||||
def task_handler(
|
||||
task_source: str,
|
||||
source: SourceEnum,
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
"""
|
||||
Decorator that wraps a Lambda handler and automatically:
|
||||
|
||||
- Parses body from the first SQS record (or uses the event dict directly)
|
||||
- Creates a fresh Task + SubTask in the database
|
||||
- Marks the subtask as in progress
|
||||
- Executes the handler, passing the parsed body
|
||||
- Marks complete on success, failed on exception (and re-raises)
|
||||
Decorator for Lambdas that are themselves the entry point of a pipeline (no
|
||||
router in front). For each record the decorator creates a fresh Task +
|
||||
SubTask with the given task_source and source. source_id is read from
|
||||
body[source.value] (silent None if absent) — see ADR-0001. Records the
|
||||
CloudWatch logs URL, marks the SubTask in progress, then complete on
|
||||
success / failed on raise.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., Any]):
|
||||
|
||||
task_source = f"{func.__module__}.{func.__qualname__}"
|
||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(event: dict[str, Any], context: Any, *args, **kwargs):
|
||||
def wrapper(event: dict[str, Any], context: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
logger = setup_logger()
|
||||
start_ms = int(time.time() * 1000)
|
||||
cloud_logs_url = _try_build_cloud_logs_url(start_ms)
|
||||
|
||||
records = event.get("Records", [event]) # fallback for non-SQS
|
||||
|
||||
results = []
|
||||
failures = []
|
||||
records = event.get("Records", [event])
|
||||
results: list[Any] = []
|
||||
failures: list[dict[str, Any]] = []
|
||||
interface = SubTaskInterface()
|
||||
|
||||
for record in records:
|
||||
# Parse body
|
||||
raw_body = record.get("body", record)
|
||||
|
||||
body: dict[str, Any]
|
||||
if isinstance(raw_body, str):
|
||||
try:
|
||||
body = json.loads(raw_body)
|
||||
except Exception:
|
||||
body = {}
|
||||
elif isinstance(raw_body, dict):
|
||||
body = cast(dict[str, Any], raw_body)
|
||||
else:
|
||||
body = raw_body or {}
|
||||
body = {}
|
||||
|
||||
raw_source_id = body.get(source.value)
|
||||
source_id: Optional[str] = (
|
||||
str(raw_source_id) if raw_source_id is not None else None
|
||||
)
|
||||
|
||||
# Create task per message
|
||||
logger.info("Creating task for source: %s", task_source)
|
||||
task_id, subtask_id = TasksInterface.create_task(
|
||||
task_source=task_source,
|
||||
inputs=body,
|
||||
source=source,
|
||||
source_id=source_id,
|
||||
)
|
||||
|
||||
logger.info("Created task_id=%s subtask_id=%s", task_id, subtask_id)
|
||||
if subtask_id is None:
|
||||
raise RuntimeError("create_task did not return a subtask_id")
|
||||
|
||||
interface = SubTaskInterface()
|
||||
logger.info("Created task_id=%s subtask_id=%s", task_id, subtask_id)
|
||||
|
||||
interface.update_subtask_status(
|
||||
subtask_id=subtask_id,
|
||||
status="in progress",
|
||||
cloud_logs_url=cloud_logs_url,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -172,13 +173,11 @@ def task_handler():
|
|||
if "Records" in event:
|
||||
failures.append({"itemIdentifier": record["messageId"]})
|
||||
else:
|
||||
# Handle non-SQS events
|
||||
raise
|
||||
|
||||
if "Records" in event:
|
||||
return {"batchItemFailures": failures}
|
||||
|
||||
# Handle non-SQS events
|
||||
return results
|
||||
|
||||
return wrapper
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from etl.hubspot.hubspot_deal_differ import HubspotDealDiffer
|
|||
from etl.hubspot.hubspot_trigger_orchestrator_trigger_request import (
|
||||
HubspotTriggerOrchestratorTriggerRequest,
|
||||
)
|
||||
from backend.app.db.models.tasks import SourceEnum
|
||||
from backend.utils.subtasks import task_handler
|
||||
from backend.app.db.models.hubspot_deal_data import HubspotDealData
|
||||
from utils.logger import setup_logger
|
||||
|
|
@ -16,7 +17,7 @@ from utils.logger import setup_logger
|
|||
logger = setup_logger()
|
||||
|
||||
|
||||
@task_handler()
|
||||
@task_handler(task_source="hubspot_scraper", source=SourceEnum.HUBSPOT_DEAL)
|
||||
def handler(body: dict[str, Any], context: Any) -> None:
|
||||
db_client = HubspotDataToDb()
|
||||
hubspot_client = HubspotClient()
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue