mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
Merge branch 'main' into feature/integrate_new_epc_with_historical_epc
This commit is contained in:
commit
3bcb94f9e5
20 changed files with 239 additions and 148 deletions
|
|
@ -9,6 +9,7 @@ from sqlmodel import SQLModel, Field, Relationship
|
|||
|
||||
class SourceEnum(enum.Enum): # TODO: move to domain?
|
||||
PORTFOLIO = "portfolio_id"
|
||||
HUBSPOT_DEAL = "hubspot_deal_id"
|
||||
|
||||
|
||||
class Task(SQLModel, table=True):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import ast
|
||||
import os
|
||||
from typing import Optional
|
||||
import msgpack
|
||||
from uuid import UUID
|
||||
|
|
@ -8,6 +7,7 @@ from backend.addresses.Address import Address
|
|||
from backend.app.config import get_settings
|
||||
from backend.app.plan.data_classes import PropertyRequestData
|
||||
from backend.app.db.functions.tasks.Tasks import SubTaskInterface
|
||||
from backend.utils.cloudwatch import build_cloudwatch_log_url
|
||||
from starlette.responses import Response
|
||||
from utils.logger import setup_logger
|
||||
|
||||
|
|
@ -241,33 +241,6 @@ def parse_eco_packages(
|
|||
return measures, mapped["target_sap"], mapped["plan_type"], already_installed
|
||||
|
||||
|
||||
def build_cloudwatch_log_url(start_ms: Optional[int]) -> str:
|
||||
"""
|
||||
Build a CloudWatch Logs URL for the current Lambda invocation,
|
||||
including timestamp window from start_ms to end_ms (epoch ms).
|
||||
"""
|
||||
logger.info("Building cloudwatch logs URL")
|
||||
region = os.environ["AWS_REGION"]
|
||||
logger.info("Building cloudwatch logs URL: Got AWS region")
|
||||
log_group = os.environ["AWS_LAMBDA_LOG_GROUP_NAME"]
|
||||
logger.info("Building cloudwatch logs URL: Got lambda log group name")
|
||||
log_stream = os.environ["AWS_LAMBDA_LOG_STREAM_NAME"]
|
||||
logger.info("Building cloudwatch logs URL: Got lambda log stream name")
|
||||
|
||||
# CloudWatch console requires / encoded as $252F
|
||||
encoded_group = log_group.replace("/", "$252F")
|
||||
encoded_stream = log_stream.replace("/", "$252F")
|
||||
|
||||
# Return the full URL with time range
|
||||
return (
|
||||
f"https://console.aws.amazon.com/cloudwatch/home?"
|
||||
f"region={region}"
|
||||
f"#logsV2:log-groups/log-group/{encoded_group}"
|
||||
f"/log-events/{encoded_stream}"
|
||||
f"$3Fstart={start_ms}"
|
||||
)
|
||||
|
||||
|
||||
def handle_error(
|
||||
msg: str,
|
||||
exception: Exception,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import time
|
|||
from typing import Any, Mapping
|
||||
|
||||
from backend.app.db.functions.tasks.Tasks import SubTaskInterface
|
||||
from backend.app.plan.utils import build_cloudwatch_log_url
|
||||
from backend.utils.cloudwatch import build_cloudwatch_log_url
|
||||
from backend.categorisation.categorisation_trigger_request import (
|
||||
CategorisationTriggerRequest,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ from backend.app.db.functions.tasks.Tasks import SubTaskInterface
|
|||
from backend.app.db.models.recommendations import PlanModel, ScenarioModel
|
||||
from backend.app.domain.classes.plan import Plan
|
||||
from backend.app.domain.classes.scenario import Scenario
|
||||
from backend.app.plan.utils import build_cloudwatch_log_url, handle_error
|
||||
from backend.app.plan.utils import handle_error
|
||||
from backend.utils.cloudwatch import build_cloudwatch_log_url
|
||||
from backend.categorisation.categorisation_trigger_request import (
|
||||
CategorisationTriggerRequest,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -23,8 +23,9 @@ from backend.app.db.functions.tasks.Tasks import SubTaskInterface
|
|||
|
||||
from backend.app.plan.schemas import PlanTriggerRequest
|
||||
from backend.app.plan.utils import (
|
||||
get_cleaned, patch_epc, extract_property_request_data, handle_error, build_cloudwatch_log_url
|
||||
get_cleaned, patch_epc, extract_property_request_data, handle_error
|
||||
)
|
||||
from backend.utils.cloudwatch import build_cloudwatch_log_url
|
||||
from backend.app.utils import sap_to_epc
|
||||
import backend.app.assumptions as assumptions
|
||||
|
||||
|
|
|
|||
|
|
@ -5,13 +5,14 @@ from backend.magic_plan.magic_plan_client import MagicPlanClient
|
|||
from backend.magic_plan.magic_plan_service import MagicPlanService
|
||||
from backend.magic_plan.magic_plan_trigger_request import MagicPlanTriggerRequest
|
||||
from datatypes.magicplan.domain.models import Plan
|
||||
from backend.app.db.models.tasks import SourceEnum
|
||||
from backend.utils.subtasks import task_handler
|
||||
from utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@task_handler()
|
||||
@task_handler(task_source="magic_plan", source=SourceEnum.HUBSPOT_DEAL)
|
||||
def handler(body: dict[str, Any], context: Any) -> str:
|
||||
settings = get_settings()
|
||||
payload = MagicPlanTriggerRequest.model_validate(body)
|
||||
|
|
@ -20,7 +21,9 @@ def handler(body: dict[str, Any], context: Any) -> str:
|
|||
api_key=settings.MAGICPLAN_API_KEY,
|
||||
)
|
||||
# TODO: read s3_bucket from env var so staging/prod use the correct bucket
|
||||
plan: Plan = MagicPlanService(client, s3_bucket="retrofit-energy-assessments-dev").run(payload)
|
||||
plan: Plan = MagicPlanService(
|
||||
client, s3_bucket="retrofit-energy-assessments-dev"
|
||||
).run(payload)
|
||||
logger.info("Saved MagicPlan plan uid=%s", plan.uid)
|
||||
return plan.uid
|
||||
|
||||
|
|
@ -30,7 +33,6 @@ if __name__ == "__main__":
|
|||
"Records": [
|
||||
{
|
||||
"body": '{"address": "2 Laburnum Way Bromley BR2 8BZ", "hubspot_deal_id": "local-test-deal"}',
|
||||
"messageId": "local-test",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,3 +5,7 @@ sqlmodel
|
|||
psycopg2-binary==2.9.10
|
||||
pydantic-settings==2.6.0
|
||||
boto3==1.35.44
|
||||
|
||||
pytz==2024.2
|
||||
pandas==2.2.2
|
||||
numpy==2.1.2
|
||||
|
|
|
|||
11
backend/magic_plan/local_handler/docker-compose.yml
Normal file
11
backend/magic_plan/local_handler/docker-compose.yml
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
version: "3.9"
|
||||
|
||||
services:
|
||||
ecmk-fetcher-lambda:
|
||||
build:
|
||||
context: ../../../
|
||||
dockerfile: backend/magic_plan/handler/Dockerfile
|
||||
ports:
|
||||
- "9000:8080"
|
||||
env_file:
|
||||
- ../../../.env
|
||||
29
backend/magic_plan/local_handler/invoke_local_lambda.py
Normal file
29
backend/magic_plan/local_handler/invoke_local_lambda.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
#!/usr/bin/env python3
|
||||
import json
|
||||
import requests
|
||||
|
||||
HOST = "localhost"
|
||||
PORT = "9000"
|
||||
|
||||
LAMBDA_URL = f"http://{HOST}:{PORT}/2015-03-31/functions/function/invocations"
|
||||
|
||||
payload = {
|
||||
"Records": [
|
||||
{
|
||||
"messageId": "test-message-id",
|
||||
"body": json.dumps(
|
||||
# {
|
||||
# "address": "2 Laburnum Way, Rombley, BR2 8BZ | Retrofit Assessment",
|
||||
# "hubspot_deal_id": "500262906061",
|
||||
# }
|
||||
{"address": "33 Wallaby Way, Sydney", "hubspot_deal_id": "123456789"}
|
||||
),
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = requests.post(LAMBDA_URL, json=payload)
|
||||
|
||||
print("Status code:", response.status_code)
|
||||
print("Response:")
|
||||
print(response.text)
|
||||
|
|
@ -1,20 +1,27 @@
|
|||
import requests
|
||||
|
||||
from datatypes.magicplan.api.response import MagicPlanPlan, PlansListResponse
|
||||
from datatypes.magicplan.api.response import MagicPlanPlan, PlanSummary, PlansListResponse
|
||||
|
||||
_BASE_URL = "https://cloud.magicplan.app/api/v2"
|
||||
|
||||
|
||||
class MagicPlanClient:
|
||||
def __init__(self, customer_id: str, api_key: str) -> None:
|
||||
self._api_key = api_key
|
||||
self._session = requests.Session()
|
||||
self._session.headers.update({"customer": customer_id})
|
||||
self._session.headers.update({"customer": customer_id, "key": api_key})
|
||||
|
||||
def get_plans(self) -> PlansListResponse:
|
||||
r = self._session.get(f"{_BASE_URL}/plans", params={"key": self._api_key})
|
||||
r.raise_for_status()
|
||||
return PlansListResponse.model_validate(r.json()["data"])
|
||||
def get_plans(self) -> list[PlanSummary]:
|
||||
all_plans: list[PlanSummary] = []
|
||||
page = 1
|
||||
while True:
|
||||
r = self._session.get(f"{_BASE_URL}/workgroups/plans", params={"page": page})
|
||||
r.raise_for_status()
|
||||
response = PlansListResponse.model_validate(r.json()["data"])
|
||||
all_plans.extend(response.plans)
|
||||
if not response.paging.next_page:
|
||||
break
|
||||
page += 1
|
||||
return all_plans
|
||||
|
||||
def get_plan(self, plan_id: str) -> MagicPlanPlan:
|
||||
return MagicPlanPlan.model_validate(self._fetch_plan(plan_id).json()["data"])
|
||||
|
|
@ -23,8 +30,6 @@ class MagicPlanClient:
|
|||
return self._fetch_plan(plan_id).content
|
||||
|
||||
def _fetch_plan(self, plan_id: str) -> requests.Response:
|
||||
r = self._session.get(
|
||||
f"{_BASE_URL}/plans/{plan_id}", params={"key": self._api_key}
|
||||
)
|
||||
r = self._session.get(f"{_BASE_URL}/plans/get/{plan_id}")
|
||||
r.raise_for_status()
|
||||
return r
|
||||
|
|
|
|||
|
|
@ -3,11 +3,7 @@ import json
|
|||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from datatypes.magicplan.api.response import (
|
||||
MagicPlanPlan,
|
||||
PlanSummary,
|
||||
PlansListResponse,
|
||||
)
|
||||
from datatypes.magicplan.api.response import MagicPlanPlan, PlanSummary
|
||||
from datatypes.magicplan.domain.mapper import map_plan
|
||||
from datatypes.magicplan.domain.models import Plan
|
||||
|
||||
|
|
@ -39,10 +35,8 @@ class MagicPlanService:
|
|||
if uprn is not None:
|
||||
logger.info("MagicPlanService.run uprn=%s", uprn)
|
||||
|
||||
plans_response: PlansListResponse = self._client.get_plans()
|
||||
matched: Optional[PlanSummary] = find_matching_plan(
|
||||
plans_response.plans, address
|
||||
)
|
||||
plans: list[PlanSummary] = self._client.get_plans()
|
||||
matched: Optional[PlanSummary] = find_matching_plan(plans, address)
|
||||
|
||||
if matched is None:
|
||||
raise ValueError(f"No MagicPlan found for address: {address!r}")
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ def test_handler_raises_on_missing_address(mock_plan: MagicMock) -> None:
|
|||
|
||||
def test_handler_constructs_client_from_settings(mock_service: MagicMock) -> None:
|
||||
# Arrange
|
||||
body = {"address": ADDRESS}
|
||||
body = {"address": ADDRESS, "hubspot_deal_id": "deal-123"}
|
||||
with patch("backend.magic_plan.handler.get_settings", return_value=_make_settings(customer_id="cust-xyz", api_key="key-xyz")), \
|
||||
patch("backend.magic_plan.handler.MagicPlanClient") as MockClient, \
|
||||
patch("backend.magic_plan.handler.MagicPlanService", return_value=mock_service):
|
||||
|
|
@ -69,31 +69,37 @@ def test_handler_constructs_client_from_settings(mock_service: MagicMock) -> Non
|
|||
|
||||
def test_handler_calls_service_run_with_address(mock_service: MagicMock) -> None:
|
||||
# Arrange
|
||||
body = {"address": ADDRESS}
|
||||
body = {"address": ADDRESS, "hubspot_deal_id": "deal-123"}
|
||||
with patch("backend.magic_plan.handler.get_settings", return_value=_make_settings()), \
|
||||
patch("backend.magic_plan.handler.MagicPlanClient"), \
|
||||
patch("backend.magic_plan.handler.MagicPlanService", return_value=mock_service):
|
||||
# Act
|
||||
_call_handler(body)
|
||||
# Assert
|
||||
mock_service.run.assert_called_once_with(ADDRESS, None)
|
||||
mock_service.run.assert_called_once()
|
||||
request = mock_service.run.call_args.args[0]
|
||||
assert request.address == ADDRESS
|
||||
assert request.uprn is None
|
||||
|
||||
|
||||
def test_handler_passes_uprn_to_service(mock_service: MagicMock) -> None:
|
||||
# Arrange
|
||||
body = {"address": ADDRESS, "uprn": "100023336956"}
|
||||
body = {"address": ADDRESS, "uprn": "100023336956", "hubspot_deal_id": "deal-123"}
|
||||
with patch("backend.magic_plan.handler.get_settings", return_value=_make_settings()), \
|
||||
patch("backend.magic_plan.handler.MagicPlanClient"), \
|
||||
patch("backend.magic_plan.handler.MagicPlanService", return_value=mock_service):
|
||||
# Act
|
||||
_call_handler(body)
|
||||
# Assert
|
||||
mock_service.run.assert_called_once_with(ADDRESS, "100023336956")
|
||||
mock_service.run.assert_called_once()
|
||||
request = mock_service.run.call_args.args[0]
|
||||
assert request.address == ADDRESS
|
||||
assert request.uprn == "100023336956"
|
||||
|
||||
|
||||
def test_handler_returns_plan_uid(mock_service: MagicMock) -> None:
|
||||
# Arrange
|
||||
body = {"address": ADDRESS}
|
||||
body = {"address": ADDRESS, "hubspot_deal_id": "deal-123"}
|
||||
with patch("backend.magic_plan.handler.get_settings", return_value=_make_settings()), \
|
||||
patch("backend.magic_plan.handler.MagicPlanClient"), \
|
||||
patch("backend.magic_plan.handler.MagicPlanService", return_value=mock_service):
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import pytest
|
|||
import requests
|
||||
|
||||
from backend.magic_plan.magic_plan_client import MagicPlanClient
|
||||
from datatypes.magicplan.api.response import MagicPlanPlan, PlansListResponse
|
||||
from datatypes.magicplan.api.response import MagicPlanPlan, PlanSummary
|
||||
|
||||
FIXTURE_DIR = Path(__file__).parents[2] / "magic_plan"
|
||||
BASE_URL = "https://cloud.magicplan.app/api/v2"
|
||||
|
|
@ -20,6 +20,7 @@ def _load_fixture(name: str) -> dict[str, Any]:
|
|||
|
||||
|
||||
def _make_client(mock_session: MagicMock) -> MagicPlanClient:
|
||||
mock_session.headers = {}
|
||||
with patch(
|
||||
"backend.magic_plan.magic_plan_client.requests.Session",
|
||||
return_value=mock_session,
|
||||
|
|
@ -44,7 +45,14 @@ def test_customer_header_set_on_session(mock_session: MagicMock) -> None:
|
|||
# Act
|
||||
_make_client(mock_session)
|
||||
# Assert
|
||||
mock_session.headers.update.assert_called_once_with({"customer": CUSTOMER_ID})
|
||||
assert mock_session.headers["customer"] == CUSTOMER_ID
|
||||
|
||||
|
||||
def test_api_key_header_set_on_session(mock_session: MagicMock) -> None:
|
||||
# Act
|
||||
_make_client(mock_session)
|
||||
# Assert
|
||||
assert mock_session.headers["key"] == API_KEY
|
||||
|
||||
|
||||
# --- get_plans ---
|
||||
|
|
@ -63,7 +71,7 @@ def test_get_plans_calls_correct_url(
|
|||
client.get_plans()
|
||||
# Assert
|
||||
mock_session.get.assert_called_once_with(
|
||||
f"{BASE_URL}/plans", params={"key": API_KEY}
|
||||
f"{BASE_URL}/workgroups/plans", params={"page": 1}
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -82,7 +90,7 @@ def test_get_plans_calls_raise_for_status(
|
|||
mock_session.get.return_value.raise_for_status.assert_called_once()
|
||||
|
||||
|
||||
def test_get_plans_returns_plans_list_response(
|
||||
def test_get_plans_returns_list_of_plan_summaries(
|
||||
client: MagicPlanClient, mock_session: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
|
|
@ -94,8 +102,9 @@ def test_get_plans_returns_plans_list_response(
|
|||
# Act
|
||||
result = client.get_plans()
|
||||
# Assert
|
||||
assert isinstance(result, PlansListResponse)
|
||||
assert len(result.plans) == 1
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], PlanSummary)
|
||||
|
||||
|
||||
def test_get_plans_propagates_http_error(
|
||||
|
|
@ -110,6 +119,34 @@ def test_get_plans_propagates_http_error(
|
|||
client.get_plans()
|
||||
|
||||
|
||||
def test_get_plans_multi_page_fetches_all_pages(
|
||||
client: MagicPlanClient, mock_session: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
page1_plan = _load_fixture("magicplan_api_plans_response_example.json")["data"][
|
||||
"plans"
|
||||
][0]
|
||||
page2_plan = {**page1_plan, "id": "page-2-plan-id"}
|
||||
page1_response = MagicMock()
|
||||
page1_response.json.return_value = {
|
||||
"data": {"paging": {"page": 1, "next_page": True, "count": 2}, "plans": [page1_plan]}
|
||||
}
|
||||
page2_response = MagicMock()
|
||||
page2_response.json.return_value = {
|
||||
"data": {"paging": {"page": 2, "next_page": False, "count": 2}, "plans": [page2_plan]}
|
||||
}
|
||||
mock_session.get.side_effect = [page1_response, page2_response]
|
||||
# Act
|
||||
result = client.get_plans()
|
||||
# Assert
|
||||
assert mock_session.get.call_count == 2
|
||||
mock_session.get.assert_any_call(f"{BASE_URL}/workgroups/plans", params={"page": 1})
|
||||
mock_session.get.assert_any_call(f"{BASE_URL}/workgroups/plans", params={"page": 2})
|
||||
assert len(result) == 2
|
||||
assert result[0].id == page1_plan["id"]
|
||||
assert result[1].id == "page-2-plan-id"
|
||||
|
||||
|
||||
# --- get_plan ---
|
||||
|
||||
|
||||
|
|
@ -126,9 +163,7 @@ def test_get_plan_calls_correct_url(
|
|||
# Act
|
||||
client.get_plan(plan_id)
|
||||
# Assert
|
||||
mock_session.get.assert_called_once_with(
|
||||
f"{BASE_URL}/plans/{plan_id}", params={"key": API_KEY}
|
||||
)
|
||||
mock_session.get.assert_called_once_with(f"{BASE_URL}/plans/get/{plan_id}")
|
||||
|
||||
|
||||
def test_get_plan_calls_raise_for_status(
|
||||
|
|
@ -198,9 +233,7 @@ def test_get_plan_raw_calls_correct_url(
|
|||
# Act
|
||||
client.get_plan_raw(plan_id)
|
||||
# Assert
|
||||
mock_session.get.assert_called_once_with(
|
||||
f"{BASE_URL}/plans/{plan_id}", params={"key": API_KEY}
|
||||
)
|
||||
mock_session.get.assert_called_once_with(f"{BASE_URL}/plans/get/{plan_id}")
|
||||
|
||||
|
||||
def test_get_plan_raw_calls_raise_for_status(
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ def test_run_fetches_plan_with_matched_id(
|
|||
domain_plan: Plan,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_client.get_plans.return_value.plans = [plan_summary]
|
||||
mock_client.get_plans.return_value = [plan_summary]
|
||||
mock_client.get_plan.return_value = api_magic_plan
|
||||
service = _make_service(mock_client)
|
||||
with patch(
|
||||
|
|
@ -114,7 +114,7 @@ def test_run_returns_mapped_plan(
|
|||
domain_plan: Plan,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_client.get_plans.return_value.plans = [plan_summary]
|
||||
mock_client.get_plans.return_value = [plan_summary]
|
||||
mock_client.get_plan.return_value = api_magic_plan
|
||||
service = _make_service(mock_client)
|
||||
with patch(
|
||||
|
|
@ -137,7 +137,7 @@ def test_run_calls_save_plan_with_mapped_plan(
|
|||
plan_summary: PlanSummary,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_client.get_plans.return_value.plans = [plan_summary]
|
||||
mock_client.get_plans.return_value = [plan_summary]
|
||||
mock_client.get_plan.return_value = api_magic_plan
|
||||
service = _make_service(mock_client)
|
||||
with patch(
|
||||
|
|
@ -161,7 +161,7 @@ def test_run_accepts_uprn_without_error(
|
|||
plan_summary: PlanSummary,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_client.get_plans.return_value.plans = [plan_summary]
|
||||
mock_client.get_plans.return_value = [plan_summary]
|
||||
mock_client.get_plan.return_value = api_magic_plan
|
||||
service = _make_service(mock_client)
|
||||
with patch(
|
||||
|
|
@ -184,7 +184,7 @@ def test_run_uploads_to_s3_with_uprn_key(
|
|||
plan_summary: PlanSummary,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_client.get_plans.return_value.plans = [plan_summary]
|
||||
mock_client.get_plans.return_value = [plan_summary]
|
||||
request = _make_request(uprn="100023336956")
|
||||
service = MagicPlanService(client=mock_client, s3_bucket=S3_BUCKET)
|
||||
with patch(
|
||||
|
|
@ -211,7 +211,7 @@ def test_run_uploads_to_s3_with_deal_id_key_when_uprn_absent(
|
|||
plan_summary: PlanSummary,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_client.get_plans.return_value.plans = [plan_summary]
|
||||
mock_client.get_plans.return_value = [plan_summary]
|
||||
mock_client.get_plan.return_value = api_magic_plan
|
||||
request = _make_request(hubspot_deal_id="deal-456", uprn=None)
|
||||
service = MagicPlanService(client=mock_client, s3_bucket=S3_BUCKET)
|
||||
|
|
@ -242,7 +242,7 @@ def test_run_creates_uploaded_file_record(
|
|||
plan_summary: PlanSummary,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_client.get_plans.return_value.plans = [plan_summary]
|
||||
mock_client.get_plans.return_value = [plan_summary]
|
||||
mock_client.get_plan.return_value = api_magic_plan
|
||||
request = _make_request(hubspot_deal_id="deal-789", uprn="100023336956")
|
||||
service = MagicPlanService(client=mock_client, s3_bucket=S3_BUCKET)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from backend.pashub_fetcher.pashub_client import PashubClient, UnauthorizedError
|
|||
from backend.pashub_fetcher.pashub_service import PashubService
|
||||
from backend.pashub_fetcher.pashub_to_ara_trigger_request import PashubToAraTriggerRequest
|
||||
from backend.pashub_fetcher.token_getter import get_token_from_local_storage
|
||||
from backend.app.db.models.tasks import SourceEnum
|
||||
from backend.utils.subtasks import task_handler
|
||||
from utils.logger import setup_logger
|
||||
from utils.sharepoint.domna_sharepoint_client import DomnaSharepointClient
|
||||
|
|
@ -21,7 +22,7 @@ def get_pashub_client(email: str, password: str) -> PashubClient:
|
|||
return PashubClient(token=token)
|
||||
|
||||
|
||||
@task_handler()
|
||||
@task_handler(task_source="pashub_fetcher", source=SourceEnum.HUBSPOT_DEAL)
|
||||
def handler(body: Dict[str, Any], context: Any) -> List[str]:
|
||||
logger.info("Received message")
|
||||
|
||||
|
|
|
|||
30
backend/utils/cloudwatch.py
Normal file
30
backend/utils/cloudwatch.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
import os
|
||||
from typing import Optional
|
||||
|
||||
from utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_cloudwatch_log_url(start_ms: Optional[int]) -> str:
|
||||
"""
|
||||
Build a CloudWatch Logs URL for the current Lambda invocation, including a
|
||||
timestamp window starting at start_ms. Requires AWS_REGION,
|
||||
AWS_LAMBDA_LOG_GROUP_NAME, and AWS_LAMBDA_LOG_STREAM_NAME to be set in the
|
||||
environment — i.e. only safe to call inside a Lambda runtime.
|
||||
"""
|
||||
logger.info("Building cloudwatch logs URL")
|
||||
region = os.environ["AWS_REGION"]
|
||||
log_group = os.environ["AWS_LAMBDA_LOG_GROUP_NAME"]
|
||||
log_stream = os.environ["AWS_LAMBDA_LOG_STREAM_NAME"]
|
||||
|
||||
encoded_group = log_group.replace("/", "$252F")
|
||||
encoded_stream = log_stream.replace("/", "$252F")
|
||||
|
||||
return (
|
||||
f"https://console.aws.amazon.com/cloudwatch/home?"
|
||||
f"region={region}"
|
||||
f"#logsV2:log-groups/log-group/{encoded_group}"
|
||||
f"/log-events/{encoded_stream}"
|
||||
f"$3Fstart={start_ms}"
|
||||
)
|
||||
|
|
@ -1,75 +1,72 @@
|
|||
# decorators/subtask_handler.py
|
||||
|
||||
from functools import wraps
|
||||
from typing import Callable, Any
|
||||
from uuid import UUID
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Optional, cast
|
||||
from uuid import UUID
|
||||
|
||||
from backend.app.db.functions.tasks.Tasks import SubTaskInterface, TasksInterface
|
||||
from backend.app.db.models.tasks import SourceEnum
|
||||
from backend.utils.cloudwatch import build_cloudwatch_log_url
|
||||
from utils.logger import setup_logger
|
||||
|
||||
|
||||
def subtask_handler():
|
||||
"""
|
||||
Decorator that wraps your existing handler and automatically:
|
||||
def _try_build_cloud_logs_url(start_ms: int) -> Optional[str]:
|
||||
# Returns None outside a Lambda runtime so local/non-Lambda runs don't crash.
|
||||
required = ("AWS_REGION", "AWS_LAMBDA_LOG_GROUP_NAME", "AWS_LAMBDA_LOG_STREAM_NAME")
|
||||
if not all(k in os.environ for k in required):
|
||||
return None
|
||||
return build_cloudwatch_log_url(start_ms)
|
||||
|
||||
- Extracts task_id + sub_task_id from event
|
||||
- Marks subtask as in progress
|
||||
- Executes handler logic
|
||||
- Marks subtask complete on success
|
||||
- Marks failed on exception
|
||||
|
||||
def subtask_handler() -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
"""
|
||||
Decorator for Lambdas that operate on an already-existing SubTask. Extracts
|
||||
task_id + sub_task_id from each record, records the CloudWatch logs URL,
|
||||
marks the SubTask in progress, then complete on success / failed on raise.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., Any]):
|
||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(event: dict[str, Any], context: Any, *args, **kwargs):
|
||||
def wrapper(event: dict[str, Any], context: Any, *args: Any, **kwargs: Any) -> None:
|
||||
start_ms = int(time.time() * 1000)
|
||||
cloud_logs_url = _try_build_cloud_logs_url(start_ms)
|
||||
|
||||
records = event.get("Records", [event])
|
||||
|
||||
interface = SubTaskInterface()
|
||||
|
||||
for record in records:
|
||||
|
||||
# -------------------------------
|
||||
# Parse body safely
|
||||
# -------------------------------
|
||||
body = {}
|
||||
|
||||
if isinstance(record.get("body"), str):
|
||||
raw_body = record.get("body")
|
||||
body: dict[str, Any]
|
||||
if isinstance(raw_body, str):
|
||||
try:
|
||||
body = json.loads(record["body"])
|
||||
body = json.loads(raw_body)
|
||||
except Exception:
|
||||
body = {}
|
||||
elif isinstance(raw_body, dict):
|
||||
body = cast(dict[str, Any], raw_body)
|
||||
else:
|
||||
body = record.get("body", {}) or {}
|
||||
body = {}
|
||||
|
||||
task_id_raw = body.get("task_id")
|
||||
subtask_id_raw = body.get("sub_task_id")
|
||||
|
||||
task_id = UUID(task_id_raw) if isinstance(task_id_raw, str) else None
|
||||
subtask_id = (
|
||||
UUID(subtask_id_raw) if isinstance(subtask_id_raw, str) else None
|
||||
)
|
||||
subtask_id = UUID(subtask_id_raw) if isinstance(subtask_id_raw, str) else None
|
||||
|
||||
if not task_id or not subtask_id:
|
||||
raise RuntimeError("task_id or sub_task_id missing")
|
||||
|
||||
# -------------------------------
|
||||
# Mark in progress
|
||||
# -------------------------------
|
||||
interface.update_subtask_status(
|
||||
subtask_id=subtask_id,
|
||||
status="in progress",
|
||||
cloud_logs_url=cloud_logs_url,
|
||||
)
|
||||
|
||||
try:
|
||||
# Pass the parsed body into your function
|
||||
result = func(body, context, *args, **kwargs)
|
||||
|
||||
# -------------------------------
|
||||
# Success → mark complete
|
||||
# -------------------------------
|
||||
interface.update_subtask_status(
|
||||
subtask_id=subtask_id,
|
||||
status="complete",
|
||||
|
|
@ -77,75 +74,79 @@ def subtask_handler():
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
# -------------------------------
|
||||
# Failure → mark failed
|
||||
# -------------------------------
|
||||
interface.update_subtask_status(
|
||||
subtask_id=subtask_id,
|
||||
status="failed",
|
||||
outputs={"error": str(e)},
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def task_handler():
|
||||
def task_handler(
|
||||
task_source: str,
|
||||
source: SourceEnum,
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
"""
|
||||
Decorator that wraps a Lambda handler and automatically:
|
||||
|
||||
- Parses body from the first SQS record (or uses the event dict directly)
|
||||
- Creates a fresh Task + SubTask in the database
|
||||
- Marks the subtask as in progress
|
||||
- Executes the handler, passing the parsed body
|
||||
- Marks complete on success, failed on exception (and re-raises)
|
||||
Decorator for Lambdas that are themselves the entry point of a pipeline (no
|
||||
router in front). For each record the decorator creates a fresh Task +
|
||||
SubTask with the given task_source and source. source_id is read from
|
||||
body[source.value] (silent None if absent) — see ADR-0001. Records the
|
||||
CloudWatch logs URL, marks the SubTask in progress, then complete on
|
||||
success / failed on raise.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., Any]):
|
||||
|
||||
task_source = f"{func.__module__}.{func.__qualname__}"
|
||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(event: dict[str, Any], context: Any, *args, **kwargs):
|
||||
def wrapper(event: dict[str, Any], context: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
logger = setup_logger()
|
||||
start_ms = int(time.time() * 1000)
|
||||
cloud_logs_url = _try_build_cloud_logs_url(start_ms)
|
||||
|
||||
records = event.get("Records", [event]) # fallback for non-SQS
|
||||
|
||||
results = []
|
||||
failures = []
|
||||
records = event.get("Records", [event])
|
||||
results: list[Any] = []
|
||||
failures: list[dict[str, Any]] = []
|
||||
interface = SubTaskInterface()
|
||||
|
||||
for record in records:
|
||||
# Parse body
|
||||
raw_body = record.get("body", record)
|
||||
|
||||
body: dict[str, Any]
|
||||
if isinstance(raw_body, str):
|
||||
try:
|
||||
body = json.loads(raw_body)
|
||||
except Exception:
|
||||
body = {}
|
||||
elif isinstance(raw_body, dict):
|
||||
body = cast(dict[str, Any], raw_body)
|
||||
else:
|
||||
body = raw_body or {}
|
||||
body = {}
|
||||
|
||||
raw_source_id = body.get(source.value)
|
||||
source_id: Optional[str] = (
|
||||
str(raw_source_id) if raw_source_id is not None else None
|
||||
)
|
||||
|
||||
# Create task per message
|
||||
logger.info("Creating task for source: %s", task_source)
|
||||
task_id, subtask_id = TasksInterface.create_task(
|
||||
task_source=task_source,
|
||||
inputs=body,
|
||||
source=source,
|
||||
source_id=source_id,
|
||||
)
|
||||
|
||||
logger.info("Created task_id=%s subtask_id=%s", task_id, subtask_id)
|
||||
if subtask_id is None:
|
||||
raise RuntimeError("create_task did not return a subtask_id")
|
||||
|
||||
interface = SubTaskInterface()
|
||||
logger.info("Created task_id=%s subtask_id=%s", task_id, subtask_id)
|
||||
|
||||
interface.update_subtask_status(
|
||||
subtask_id=subtask_id,
|
||||
status="in progress",
|
||||
cloud_logs_url=cloud_logs_url,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -172,13 +173,11 @@ def task_handler():
|
|||
if "Records" in event:
|
||||
failures.append({"itemIdentifier": record["messageId"]})
|
||||
else:
|
||||
# Handle non-SQS events
|
||||
raise
|
||||
|
||||
if "Records" in event:
|
||||
return {"batchItemFailures": failures}
|
||||
|
||||
# Handle non-SQS events
|
||||
return results
|
||||
|
||||
return wrapper
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from etl.hubspot.hubspot_deal_differ import HubspotDealDiffer
|
|||
from etl.hubspot.hubspot_trigger_orchestrator_trigger_request import (
|
||||
HubspotTriggerOrchestratorTriggerRequest,
|
||||
)
|
||||
from backend.app.db.models.tasks import SourceEnum
|
||||
from backend.utils.subtasks import task_handler
|
||||
from backend.app.db.models.hubspot_deal_data import HubspotDealData
|
||||
from utils.logger import setup_logger
|
||||
|
|
@ -16,7 +17,7 @@ from utils.logger import setup_logger
|
|||
logger = setup_logger()
|
||||
|
||||
|
||||
@task_handler()
|
||||
@task_handler(task_source="hubspot_scraper", source=SourceEnum.HUBSPOT_DEAL)
|
||||
def handler(body: dict[str, Any], context: Any) -> None:
|
||||
db_client = HubspotDataToDb()
|
||||
hubspot_client = HubspotClient()
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ data "terraform_remote_state" "pashub_to_ara" {
|
|||
data "terraform_remote_state" "magic_plan" {
|
||||
backend = "s3"
|
||||
config = {
|
||||
bucket = "magic-plan-hubspot-trigger-terraform-state"
|
||||
bucket = "magic-plan-client-terraform-state"
|
||||
key = "env:/${var.stage}/terraform.tfstate"
|
||||
region = "eu-west-2"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ terraform {
|
|||
}
|
||||
|
||||
backend "s3" {
|
||||
bucket = "magic-plan-hubspot-trigger-terraform-state"
|
||||
bucket = "magic-plan-client-terraform-state"
|
||||
key = "terraform.tfstate"
|
||||
region = "eu-west-2"
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue