Merge branch 'main' into feature/integrate_new_epc_with_historical_epc

This commit is contained in:
Jun-te Kim 2026-05-13 08:38:50 +00:00
commit 3bcb94f9e5
20 changed files with 239 additions and 148 deletions

View file

@ -9,6 +9,7 @@ from sqlmodel import SQLModel, Field, Relationship
class SourceEnum(enum.Enum): # TODO: move to domain?
PORTFOLIO = "portfolio_id"
HUBSPOT_DEAL = "hubspot_deal_id"
class Task(SQLModel, table=True):

View file

@ -1,5 +1,4 @@
import ast
import os
from typing import Optional
import msgpack
from uuid import UUID
@ -8,6 +7,7 @@ from backend.addresses.Address import Address
from backend.app.config import get_settings
from backend.app.plan.data_classes import PropertyRequestData
from backend.app.db.functions.tasks.Tasks import SubTaskInterface
from backend.utils.cloudwatch import build_cloudwatch_log_url
from starlette.responses import Response
from utils.logger import setup_logger
@ -241,33 +241,6 @@ def parse_eco_packages(
return measures, mapped["target_sap"], mapped["plan_type"], already_installed
def build_cloudwatch_log_url(start_ms: Optional[int]) -> str:
"""
Build a CloudWatch Logs URL for the current Lambda invocation,
including timestamp window from start_ms to end_ms (epoch ms).
"""
logger.info("Building cloudwatch logs URL")
region = os.environ["AWS_REGION"]
logger.info("Building cloudwatch logs URL: Got AWS region")
log_group = os.environ["AWS_LAMBDA_LOG_GROUP_NAME"]
logger.info("Building cloudwatch logs URL: Got lambda log group name")
log_stream = os.environ["AWS_LAMBDA_LOG_STREAM_NAME"]
logger.info("Building cloudwatch logs URL: Got lambda log stream name")
# CloudWatch console requires / encoded as $252F
encoded_group = log_group.replace("/", "$252F")
encoded_stream = log_stream.replace("/", "$252F")
# Return the full URL with time range
return (
f"https://console.aws.amazon.com/cloudwatch/home?"
f"region={region}"
f"#logsV2:log-groups/log-group/{encoded_group}"
f"/log-events/{encoded_stream}"
f"$3Fstart={start_ms}"
)
def handle_error(
msg: str,
exception: Exception,

View file

@ -3,7 +3,7 @@ import time
from typing import Any, Mapping
from backend.app.db.functions.tasks.Tasks import SubTaskInterface
from backend.app.plan.utils import build_cloudwatch_log_url
from backend.utils.cloudwatch import build_cloudwatch_log_url
from backend.categorisation.categorisation_trigger_request import (
CategorisationTriggerRequest,
)

View file

@ -15,7 +15,8 @@ from backend.app.db.functions.tasks.Tasks import SubTaskInterface
from backend.app.db.models.recommendations import PlanModel, ScenarioModel
from backend.app.domain.classes.plan import Plan
from backend.app.domain.classes.scenario import Scenario
from backend.app.plan.utils import build_cloudwatch_log_url, handle_error
from backend.app.plan.utils import handle_error
from backend.utils.cloudwatch import build_cloudwatch_log_url
from backend.categorisation.categorisation_trigger_request import (
CategorisationTriggerRequest,
)

View file

@ -23,8 +23,9 @@ from backend.app.db.functions.tasks.Tasks import SubTaskInterface
from backend.app.plan.schemas import PlanTriggerRequest
from backend.app.plan.utils import (
get_cleaned, patch_epc, extract_property_request_data, handle_error, build_cloudwatch_log_url
get_cleaned, patch_epc, extract_property_request_data, handle_error
)
from backend.utils.cloudwatch import build_cloudwatch_log_url
from backend.app.utils import sap_to_epc
import backend.app.assumptions as assumptions

View file

@ -5,13 +5,14 @@ from backend.magic_plan.magic_plan_client import MagicPlanClient
from backend.magic_plan.magic_plan_service import MagicPlanService
from backend.magic_plan.magic_plan_trigger_request import MagicPlanTriggerRequest
from datatypes.magicplan.domain.models import Plan
from backend.app.db.models.tasks import SourceEnum
from backend.utils.subtasks import task_handler
from utils.logger import setup_logger
logger = setup_logger()
@task_handler()
@task_handler(task_source="magic_plan", source=SourceEnum.HUBSPOT_DEAL)
def handler(body: dict[str, Any], context: Any) -> str:
settings = get_settings()
payload = MagicPlanTriggerRequest.model_validate(body)
@ -20,7 +21,9 @@ def handler(body: dict[str, Any], context: Any) -> str:
api_key=settings.MAGICPLAN_API_KEY,
)
# TODO: read s3_bucket from env var so staging/prod use the correct bucket
plan: Plan = MagicPlanService(client, s3_bucket="retrofit-energy-assessments-dev").run(payload)
plan: Plan = MagicPlanService(
client, s3_bucket="retrofit-energy-assessments-dev"
).run(payload)
logger.info("Saved MagicPlan plan uid=%s", plan.uid)
return plan.uid
@ -30,7 +33,6 @@ if __name__ == "__main__":
"Records": [
{
"body": '{"address": "2 Laburnum Way Bromley BR2 8BZ", "hubspot_deal_id": "local-test-deal"}',
"messageId": "local-test",
}
]
}

View file

@ -5,3 +5,7 @@ sqlmodel
psycopg2-binary==2.9.10
pydantic-settings==2.6.0
boto3==1.35.44
pytz==2024.2
pandas==2.2.2
numpy==2.1.2

View file

@ -0,0 +1,11 @@
version: "3.9"
services:
ecmk-fetcher-lambda:
build:
context: ../../../
dockerfile: backend/magic_plan/handler/Dockerfile
ports:
- "9000:8080"
env_file:
- ../../../.env

View file

@ -0,0 +1,29 @@
#!/usr/bin/env python3
import json
import requests
HOST = "localhost"
PORT = "9000"
LAMBDA_URL = f"http://{HOST}:{PORT}/2015-03-31/functions/function/invocations"
payload = {
"Records": [
{
"messageId": "test-message-id",
"body": json.dumps(
# {
# "address": "2 Laburnum Way, Rombley, BR2 8BZ | Retrofit Assessment",
# "hubspot_deal_id": "500262906061",
# }
{"address": "33 Wallaby Way, Sydney", "hubspot_deal_id": "123456789"}
),
}
]
}
response = requests.post(LAMBDA_URL, json=payload)
print("Status code:", response.status_code)
print("Response:")
print(response.text)

View file

@ -1,20 +1,27 @@
import requests
from datatypes.magicplan.api.response import MagicPlanPlan, PlansListResponse
from datatypes.magicplan.api.response import MagicPlanPlan, PlanSummary, PlansListResponse
_BASE_URL = "https://cloud.magicplan.app/api/v2"
class MagicPlanClient:
def __init__(self, customer_id: str, api_key: str) -> None:
self._api_key = api_key
self._session = requests.Session()
self._session.headers.update({"customer": customer_id})
self._session.headers.update({"customer": customer_id, "key": api_key})
def get_plans(self) -> PlansListResponse:
r = self._session.get(f"{_BASE_URL}/plans", params={"key": self._api_key})
r.raise_for_status()
return PlansListResponse.model_validate(r.json()["data"])
def get_plans(self) -> list[PlanSummary]:
all_plans: list[PlanSummary] = []
page = 1
while True:
r = self._session.get(f"{_BASE_URL}/workgroups/plans", params={"page": page})
r.raise_for_status()
response = PlansListResponse.model_validate(r.json()["data"])
all_plans.extend(response.plans)
if not response.paging.next_page:
break
page += 1
return all_plans
def get_plan(self, plan_id: str) -> MagicPlanPlan:
return MagicPlanPlan.model_validate(self._fetch_plan(plan_id).json()["data"])
@ -23,8 +30,6 @@ class MagicPlanClient:
return self._fetch_plan(plan_id).content
def _fetch_plan(self, plan_id: str) -> requests.Response:
r = self._session.get(
f"{_BASE_URL}/plans/{plan_id}", params={"key": self._api_key}
)
r = self._session.get(f"{_BASE_URL}/plans/get/{plan_id}")
r.raise_for_status()
return r

View file

@ -3,11 +3,7 @@ import json
from datetime import datetime, timezone
from typing import Optional
from datatypes.magicplan.api.response import (
MagicPlanPlan,
PlanSummary,
PlansListResponse,
)
from datatypes.magicplan.api.response import MagicPlanPlan, PlanSummary
from datatypes.magicplan.domain.mapper import map_plan
from datatypes.magicplan.domain.models import Plan
@ -39,10 +35,8 @@ class MagicPlanService:
if uprn is not None:
logger.info("MagicPlanService.run uprn=%s", uprn)
plans_response: PlansListResponse = self._client.get_plans()
matched: Optional[PlanSummary] = find_matching_plan(
plans_response.plans, address
)
plans: list[PlanSummary] = self._client.get_plans()
matched: Optional[PlanSummary] = find_matching_plan(plans, address)
if matched is None:
raise ValueError(f"No MagicPlan found for address: {address!r}")

View file

@ -54,7 +54,7 @@ def test_handler_raises_on_missing_address(mock_plan: MagicMock) -> None:
def test_handler_constructs_client_from_settings(mock_service: MagicMock) -> None:
# Arrange
body = {"address": ADDRESS}
body = {"address": ADDRESS, "hubspot_deal_id": "deal-123"}
with patch("backend.magic_plan.handler.get_settings", return_value=_make_settings(customer_id="cust-xyz", api_key="key-xyz")), \
patch("backend.magic_plan.handler.MagicPlanClient") as MockClient, \
patch("backend.magic_plan.handler.MagicPlanService", return_value=mock_service):
@ -69,31 +69,37 @@ def test_handler_constructs_client_from_settings(mock_service: MagicMock) -> Non
def test_handler_calls_service_run_with_address(mock_service: MagicMock) -> None:
# Arrange
body = {"address": ADDRESS}
body = {"address": ADDRESS, "hubspot_deal_id": "deal-123"}
with patch("backend.magic_plan.handler.get_settings", return_value=_make_settings()), \
patch("backend.magic_plan.handler.MagicPlanClient"), \
patch("backend.magic_plan.handler.MagicPlanService", return_value=mock_service):
# Act
_call_handler(body)
# Assert
mock_service.run.assert_called_once_with(ADDRESS, None)
mock_service.run.assert_called_once()
request = mock_service.run.call_args.args[0]
assert request.address == ADDRESS
assert request.uprn is None
def test_handler_passes_uprn_to_service(mock_service: MagicMock) -> None:
# Arrange
body = {"address": ADDRESS, "uprn": "100023336956"}
body = {"address": ADDRESS, "uprn": "100023336956", "hubspot_deal_id": "deal-123"}
with patch("backend.magic_plan.handler.get_settings", return_value=_make_settings()), \
patch("backend.magic_plan.handler.MagicPlanClient"), \
patch("backend.magic_plan.handler.MagicPlanService", return_value=mock_service):
# Act
_call_handler(body)
# Assert
mock_service.run.assert_called_once_with(ADDRESS, "100023336956")
mock_service.run.assert_called_once()
request = mock_service.run.call_args.args[0]
assert request.address == ADDRESS
assert request.uprn == "100023336956"
def test_handler_returns_plan_uid(mock_service: MagicMock) -> None:
# Arrange
body = {"address": ADDRESS}
body = {"address": ADDRESS, "hubspot_deal_id": "deal-123"}
with patch("backend.magic_plan.handler.get_settings", return_value=_make_settings()), \
patch("backend.magic_plan.handler.MagicPlanClient"), \
patch("backend.magic_plan.handler.MagicPlanService", return_value=mock_service):

View file

@ -7,7 +7,7 @@ import pytest
import requests
from backend.magic_plan.magic_plan_client import MagicPlanClient
from datatypes.magicplan.api.response import MagicPlanPlan, PlansListResponse
from datatypes.magicplan.api.response import MagicPlanPlan, PlanSummary
FIXTURE_DIR = Path(__file__).parents[2] / "magic_plan"
BASE_URL = "https://cloud.magicplan.app/api/v2"
@ -20,6 +20,7 @@ def _load_fixture(name: str) -> dict[str, Any]:
def _make_client(mock_session: MagicMock) -> MagicPlanClient:
mock_session.headers = {}
with patch(
"backend.magic_plan.magic_plan_client.requests.Session",
return_value=mock_session,
@ -44,7 +45,14 @@ def test_customer_header_set_on_session(mock_session: MagicMock) -> None:
# Act
_make_client(mock_session)
# Assert
mock_session.headers.update.assert_called_once_with({"customer": CUSTOMER_ID})
assert mock_session.headers["customer"] == CUSTOMER_ID
def test_api_key_header_set_on_session(mock_session: MagicMock) -> None:
# Act
_make_client(mock_session)
# Assert
assert mock_session.headers["key"] == API_KEY
# --- get_plans ---
@ -63,7 +71,7 @@ def test_get_plans_calls_correct_url(
client.get_plans()
# Assert
mock_session.get.assert_called_once_with(
f"{BASE_URL}/plans", params={"key": API_KEY}
f"{BASE_URL}/workgroups/plans", params={"page": 1}
)
@ -82,7 +90,7 @@ def test_get_plans_calls_raise_for_status(
mock_session.get.return_value.raise_for_status.assert_called_once()
def test_get_plans_returns_plans_list_response(
def test_get_plans_returns_list_of_plan_summaries(
client: MagicPlanClient, mock_session: MagicMock
) -> None:
# Arrange
@ -94,8 +102,9 @@ def test_get_plans_returns_plans_list_response(
# Act
result = client.get_plans()
# Assert
assert isinstance(result, PlansListResponse)
assert len(result.plans) == 1
assert isinstance(result, list)
assert len(result) == 1
assert isinstance(result[0], PlanSummary)
def test_get_plans_propagates_http_error(
@ -110,6 +119,34 @@ def test_get_plans_propagates_http_error(
client.get_plans()
def test_get_plans_multi_page_fetches_all_pages(
client: MagicPlanClient, mock_session: MagicMock
) -> None:
# Arrange
page1_plan = _load_fixture("magicplan_api_plans_response_example.json")["data"][
"plans"
][0]
page2_plan = {**page1_plan, "id": "page-2-plan-id"}
page1_response = MagicMock()
page1_response.json.return_value = {
"data": {"paging": {"page": 1, "next_page": True, "count": 2}, "plans": [page1_plan]}
}
page2_response = MagicMock()
page2_response.json.return_value = {
"data": {"paging": {"page": 2, "next_page": False, "count": 2}, "plans": [page2_plan]}
}
mock_session.get.side_effect = [page1_response, page2_response]
# Act
result = client.get_plans()
# Assert
assert mock_session.get.call_count == 2
mock_session.get.assert_any_call(f"{BASE_URL}/workgroups/plans", params={"page": 1})
mock_session.get.assert_any_call(f"{BASE_URL}/workgroups/plans", params={"page": 2})
assert len(result) == 2
assert result[0].id == page1_plan["id"]
assert result[1].id == "page-2-plan-id"
# --- get_plan ---
@ -126,9 +163,7 @@ def test_get_plan_calls_correct_url(
# Act
client.get_plan(plan_id)
# Assert
mock_session.get.assert_called_once_with(
f"{BASE_URL}/plans/{plan_id}", params={"key": API_KEY}
)
mock_session.get.assert_called_once_with(f"{BASE_URL}/plans/get/{plan_id}")
def test_get_plan_calls_raise_for_status(
@ -198,9 +233,7 @@ def test_get_plan_raw_calls_correct_url(
# Act
client.get_plan_raw(plan_id)
# Assert
mock_session.get.assert_called_once_with(
f"{BASE_URL}/plans/{plan_id}", params={"key": API_KEY}
)
mock_session.get.assert_called_once_with(f"{BASE_URL}/plans/get/{plan_id}")
def test_get_plan_raw_calls_raise_for_status(

View file

@ -91,7 +91,7 @@ def test_run_fetches_plan_with_matched_id(
domain_plan: Plan,
) -> None:
# Arrange
mock_client.get_plans.return_value.plans = [plan_summary]
mock_client.get_plans.return_value = [plan_summary]
mock_client.get_plan.return_value = api_magic_plan
service = _make_service(mock_client)
with patch(
@ -114,7 +114,7 @@ def test_run_returns_mapped_plan(
domain_plan: Plan,
) -> None:
# Arrange
mock_client.get_plans.return_value.plans = [plan_summary]
mock_client.get_plans.return_value = [plan_summary]
mock_client.get_plan.return_value = api_magic_plan
service = _make_service(mock_client)
with patch(
@ -137,7 +137,7 @@ def test_run_calls_save_plan_with_mapped_plan(
plan_summary: PlanSummary,
) -> None:
# Arrange
mock_client.get_plans.return_value.plans = [plan_summary]
mock_client.get_plans.return_value = [plan_summary]
mock_client.get_plan.return_value = api_magic_plan
service = _make_service(mock_client)
with patch(
@ -161,7 +161,7 @@ def test_run_accepts_uprn_without_error(
plan_summary: PlanSummary,
) -> None:
# Arrange
mock_client.get_plans.return_value.plans = [plan_summary]
mock_client.get_plans.return_value = [plan_summary]
mock_client.get_plan.return_value = api_magic_plan
service = _make_service(mock_client)
with patch(
@ -184,7 +184,7 @@ def test_run_uploads_to_s3_with_uprn_key(
plan_summary: PlanSummary,
) -> None:
# Arrange
mock_client.get_plans.return_value.plans = [plan_summary]
mock_client.get_plans.return_value = [plan_summary]
request = _make_request(uprn="100023336956")
service = MagicPlanService(client=mock_client, s3_bucket=S3_BUCKET)
with patch(
@ -211,7 +211,7 @@ def test_run_uploads_to_s3_with_deal_id_key_when_uprn_absent(
plan_summary: PlanSummary,
) -> None:
# Arrange
mock_client.get_plans.return_value.plans = [plan_summary]
mock_client.get_plans.return_value = [plan_summary]
mock_client.get_plan.return_value = api_magic_plan
request = _make_request(hubspot_deal_id="deal-456", uprn=None)
service = MagicPlanService(client=mock_client, s3_bucket=S3_BUCKET)
@ -242,7 +242,7 @@ def test_run_creates_uploaded_file_record(
plan_summary: PlanSummary,
) -> None:
# Arrange
mock_client.get_plans.return_value.plans = [plan_summary]
mock_client.get_plans.return_value = [plan_summary]
mock_client.get_plan.return_value = api_magic_plan
request = _make_request(hubspot_deal_id="deal-789", uprn="100023336956")
service = MagicPlanService(client=mock_client, s3_bucket=S3_BUCKET)

View file

@ -5,6 +5,7 @@ from backend.pashub_fetcher.pashub_client import PashubClient, UnauthorizedError
from backend.pashub_fetcher.pashub_service import PashubService
from backend.pashub_fetcher.pashub_to_ara_trigger_request import PashubToAraTriggerRequest
from backend.pashub_fetcher.token_getter import get_token_from_local_storage
from backend.app.db.models.tasks import SourceEnum
from backend.utils.subtasks import task_handler
from utils.logger import setup_logger
from utils.sharepoint.domna_sharepoint_client import DomnaSharepointClient
@ -21,7 +22,7 @@ def get_pashub_client(email: str, password: str) -> PashubClient:
return PashubClient(token=token)
@task_handler()
@task_handler(task_source="pashub_fetcher", source=SourceEnum.HUBSPOT_DEAL)
def handler(body: Dict[str, Any], context: Any) -> List[str]:
logger.info("Received message")

View file

@ -0,0 +1,30 @@
import os
from typing import Optional
from utils.logger import setup_logger
logger = setup_logger()
def build_cloudwatch_log_url(start_ms: Optional[int]) -> str:
"""
Build a CloudWatch Logs URL for the current Lambda invocation, including a
timestamp window starting at start_ms. Requires AWS_REGION,
AWS_LAMBDA_LOG_GROUP_NAME, and AWS_LAMBDA_LOG_STREAM_NAME to be set in the
environment i.e. only safe to call inside a Lambda runtime.
"""
logger.info("Building cloudwatch logs URL")
region = os.environ["AWS_REGION"]
log_group = os.environ["AWS_LAMBDA_LOG_GROUP_NAME"]
log_stream = os.environ["AWS_LAMBDA_LOG_STREAM_NAME"]
encoded_group = log_group.replace("/", "$252F")
encoded_stream = log_stream.replace("/", "$252F")
return (
f"https://console.aws.amazon.com/cloudwatch/home?"
f"region={region}"
f"#logsV2:log-groups/log-group/{encoded_group}"
f"/log-events/{encoded_stream}"
f"$3Fstart={start_ms}"
)

View file

@ -1,75 +1,72 @@
# decorators/subtask_handler.py
from functools import wraps
from typing import Callable, Any
from uuid import UUID
import json
import os
import time
from functools import wraps
from typing import Any, Callable, Optional, cast
from uuid import UUID
from backend.app.db.functions.tasks.Tasks import SubTaskInterface, TasksInterface
from backend.app.db.models.tasks import SourceEnum
from backend.utils.cloudwatch import build_cloudwatch_log_url
from utils.logger import setup_logger
def subtask_handler():
"""
Decorator that wraps your existing handler and automatically:
def _try_build_cloud_logs_url(start_ms: int) -> Optional[str]:
# Returns None outside a Lambda runtime so local/non-Lambda runs don't crash.
required = ("AWS_REGION", "AWS_LAMBDA_LOG_GROUP_NAME", "AWS_LAMBDA_LOG_STREAM_NAME")
if not all(k in os.environ for k in required):
return None
return build_cloudwatch_log_url(start_ms)
- Extracts task_id + sub_task_id from event
- Marks subtask as in progress
- Executes handler logic
- Marks subtask complete on success
- Marks failed on exception
def subtask_handler() -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""
Decorator for Lambdas that operate on an already-existing SubTask. Extracts
task_id + sub_task_id from each record, records the CloudWatch logs URL,
marks the SubTask in progress, then complete on success / failed on raise.
"""
def decorator(func: Callable[..., Any]):
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def wrapper(event: dict[str, Any], context: Any, *args, **kwargs):
def wrapper(event: dict[str, Any], context: Any, *args: Any, **kwargs: Any) -> None:
start_ms = int(time.time() * 1000)
cloud_logs_url = _try_build_cloud_logs_url(start_ms)
records = event.get("Records", [event])
interface = SubTaskInterface()
for record in records:
# -------------------------------
# Parse body safely
# -------------------------------
body = {}
if isinstance(record.get("body"), str):
raw_body = record.get("body")
body: dict[str, Any]
if isinstance(raw_body, str):
try:
body = json.loads(record["body"])
body = json.loads(raw_body)
except Exception:
body = {}
elif isinstance(raw_body, dict):
body = cast(dict[str, Any], raw_body)
else:
body = record.get("body", {}) or {}
body = {}
task_id_raw = body.get("task_id")
subtask_id_raw = body.get("sub_task_id")
task_id = UUID(task_id_raw) if isinstance(task_id_raw, str) else None
subtask_id = (
UUID(subtask_id_raw) if isinstance(subtask_id_raw, str) else None
)
subtask_id = UUID(subtask_id_raw) if isinstance(subtask_id_raw, str) else None
if not task_id or not subtask_id:
raise RuntimeError("task_id or sub_task_id missing")
# -------------------------------
# Mark in progress
# -------------------------------
interface.update_subtask_status(
subtask_id=subtask_id,
status="in progress",
cloud_logs_url=cloud_logs_url,
)
try:
# Pass the parsed body into your function
result = func(body, context, *args, **kwargs)
# -------------------------------
# Success → mark complete
# -------------------------------
interface.update_subtask_status(
subtask_id=subtask_id,
status="complete",
@ -77,75 +74,79 @@ def subtask_handler():
)
except Exception as e:
# -------------------------------
# Failure → mark failed
# -------------------------------
interface.update_subtask_status(
subtask_id=subtask_id,
status="failed",
outputs={"error": str(e)},
)
raise
return None
return wrapper
return decorator
def task_handler():
def task_handler(
task_source: str,
source: SourceEnum,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""
Decorator that wraps a Lambda handler and automatically:
- Parses body from the first SQS record (or uses the event dict directly)
- Creates a fresh Task + SubTask in the database
- Marks the subtask as in progress
- Executes the handler, passing the parsed body
- Marks complete on success, failed on exception (and re-raises)
Decorator for Lambdas that are themselves the entry point of a pipeline (no
router in front). For each record the decorator creates a fresh Task +
SubTask with the given task_source and source. source_id is read from
body[source.value] (silent None if absent) see ADR-0001. Records the
CloudWatch logs URL, marks the SubTask in progress, then complete on
success / failed on raise.
"""
def decorator(func: Callable[..., Any]):
task_source = f"{func.__module__}.{func.__qualname__}"
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def wrapper(event: dict[str, Any], context: Any, *args, **kwargs):
def wrapper(event: dict[str, Any], context: Any, *args: Any, **kwargs: Any) -> Any:
logger = setup_logger()
start_ms = int(time.time() * 1000)
cloud_logs_url = _try_build_cloud_logs_url(start_ms)
records = event.get("Records", [event]) # fallback for non-SQS
results = []
failures = []
records = event.get("Records", [event])
results: list[Any] = []
failures: list[dict[str, Any]] = []
interface = SubTaskInterface()
for record in records:
# Parse body
raw_body = record.get("body", record)
body: dict[str, Any]
if isinstance(raw_body, str):
try:
body = json.loads(raw_body)
except Exception:
body = {}
elif isinstance(raw_body, dict):
body = cast(dict[str, Any], raw_body)
else:
body = raw_body or {}
body = {}
raw_source_id = body.get(source.value)
source_id: Optional[str] = (
str(raw_source_id) if raw_source_id is not None else None
)
# Create task per message
logger.info("Creating task for source: %s", task_source)
task_id, subtask_id = TasksInterface.create_task(
task_source=task_source,
inputs=body,
source=source,
source_id=source_id,
)
logger.info("Created task_id=%s subtask_id=%s", task_id, subtask_id)
if subtask_id is None:
raise RuntimeError("create_task did not return a subtask_id")
interface = SubTaskInterface()
logger.info("Created task_id=%s subtask_id=%s", task_id, subtask_id)
interface.update_subtask_status(
subtask_id=subtask_id,
status="in progress",
cloud_logs_url=cloud_logs_url,
)
try:
@ -172,13 +173,11 @@ def task_handler():
if "Records" in event:
failures.append({"itemIdentifier": record["messageId"]})
else:
# Handle non-SQS events
raise
if "Records" in event:
return {"batchItemFailures": failures}
# Handle non-SQS events
return results
return wrapper

View file

@ -9,6 +9,7 @@ from etl.hubspot.hubspot_deal_differ import HubspotDealDiffer
from etl.hubspot.hubspot_trigger_orchestrator_trigger_request import (
HubspotTriggerOrchestratorTriggerRequest,
)
from backend.app.db.models.tasks import SourceEnum
from backend.utils.subtasks import task_handler
from backend.app.db.models.hubspot_deal_data import HubspotDealData
from utils.logger import setup_logger
@ -16,7 +17,7 @@ from utils.logger import setup_logger
logger = setup_logger()
@task_handler()
@task_handler(task_source="hubspot_scraper", source=SourceEnum.HUBSPOT_DEAL)
def handler(body: dict[str, Any], context: Any) -> None:
db_client = HubspotDataToDb()
hubspot_client = HubspotClient()

View file

@ -19,7 +19,7 @@ data "terraform_remote_state" "pashub_to_ara" {
data "terraform_remote_state" "magic_plan" {
backend = "s3"
config = {
bucket = "magic-plan-hubspot-trigger-terraform-state"
bucket = "magic-plan-client-terraform-state"
key = "env:/${var.stage}/terraform.tfstate"
region = "eu-west-2"
}

View file

@ -7,7 +7,7 @@ terraform {
}
backend "s3" {
bucket = "magic-plan-hubspot-trigger-terraform-state"
bucket = "magic-plan-client-terraform-state"
key = "terraform.tfstate"
region = "eu-west-2"
}