diff --git a/tests/utilities/aws_lambda/test_task_handler.py b/tests/utilities/aws_lambda/test_task_handler.py new file mode 100644 index 00000000..fae35de2 --- /dev/null +++ b/tests/utilities/aws_lambda/test_task_handler.py @@ -0,0 +1,104 @@ +from collections.abc import Generator, Iterator +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any +from uuid import UUID + +import pytest +from sqlalchemy import Engine +from sqlmodel import Session + +from domain.tasks.tasks import Source +from orchestration.task_orchestrator import TaskOrchestrator +from repositories.tasks.subtask_postgres_repository import SubTaskPostgresRepository +from repositories.tasks.task_postgres_repository import TaskPostgresRepository +from utilities.aws_lambda.task_handler import task_handler + + +@dataclass +class Harness: + orchestrator: TaskOrchestrator + tasks: TaskPostgresRepository + subtasks: SubTaskPostgresRepository + + @contextmanager + def factory(self) -> Generator[TaskOrchestrator, None, None]: + yield self.orchestrator + + +@pytest.fixture +def harness(db_engine: Engine) -> Iterator[Harness]: + with Session(db_engine) as session: + tasks = TaskPostgresRepository(session=session) + subtasks = SubTaskPostgresRepository(session=session) + yield Harness( + orchestrator=TaskOrchestrator(task_repo=tasks, subtask_repo=subtasks), + tasks=tasks, + subtasks=subtasks, + ) + + +def _direct_event(property_id: str) -> dict[str, Any]: + return {"property_id": property_id} + + +def test_task_handler_records_cloudwatch_url_on_subtask( + harness: Harness, monkeypatch: pytest.MonkeyPatch +) -> None: + # arrange + monkeypatch.setenv("AWS_REGION", "eu-west-2") + monkeypatch.setenv( + "AWS_LAMBDA_LOG_GROUP_NAME", "/aws/lambda/modelling-e2e" + ) + monkeypatch.setenv( + "AWS_LAMBDA_LOG_STREAM_NAME", "2026/05/20/[$LATEST]abc123" + ) + + @task_handler( + task_source="modelling_e2e", + source=Source.PROPERTY, + orchestrator_cm=harness.factory, + ) + def handler(body: dict[str, Any], context: Any) -> None: + return None + + # act + result = handler(_direct_event("prop-1"), context=None) + + # assert + subtask_id = result[0]["subtask_id"] + saved_url = harness.subtasks.get(UUID(subtask_id)).cloud_logs_url + assert saved_url is not None + assert saved_url.startswith( + "https://eu-west-2.console.aws.amazon.com/cloudwatch/home" + ) + # Log group / stream are console-encoded ("/" -> "$252F"). + assert "$252Faws$252Flambda$252Fmodelling-e2e" in saved_url + assert "$255B$2524LATEST$255D" in saved_url + + +def test_task_handler_leaves_cloudwatch_url_unset_outside_lambda( + harness: Harness, monkeypatch: pytest.MonkeyPatch +) -> None: + # arrange + for var in ( + "AWS_REGION", + "AWS_LAMBDA_LOG_GROUP_NAME", + "AWS_LAMBDA_LOG_STREAM_NAME", + ): + monkeypatch.delenv(var, raising=False) + + @task_handler( + task_source="modelling_e2e", + source=Source.PROPERTY, + orchestrator_cm=harness.factory, + ) + def handler(body: dict[str, Any], context: Any) -> None: + return None + + # act + result = handler(_direct_event("prop-1"), context=None) + + # assert + subtask_id = result[0]["subtask_id"] + assert harness.subtasks.get(UUID(subtask_id)).cloud_logs_url is None diff --git a/utilities/aws_lambda/cloud_logs.py b/utilities/aws_lambda/cloud_logs.py new file mode 100644 index 00000000..9a8da920 --- /dev/null +++ b/utilities/aws_lambda/cloud_logs.py @@ -0,0 +1,27 @@ +"""Build a CloudWatch console deep-link for the running Lambda invocation. + +Shared by @task_handler and @subtask_handler so both persist the same +`cloud_logs_url` onto the SubTask they run. +""" + +import os +from typing import Optional +from urllib.parse import quote + + +def _console_encode(value: str) -> str: + return quote(value, safe="").replace("%", "$25") + + +def cloudwatch_url() -> Optional[str]: + """Deep-link to this invocation's log stream, or None outside Lambda.""" + region = os.environ.get("AWS_REGION") + log_group = os.environ.get("AWS_LAMBDA_LOG_GROUP_NAME") + log_stream = os.environ.get("AWS_LAMBDA_LOG_STREAM_NAME") + if not (region and log_group and log_stream): + return None + return ( + f"https://{region}.console.aws.amazon.com/cloudwatch/home" + f"?region={region}#logsV2:log-groups/log-group/" + f"{_console_encode(log_group)}/log-events/{_console_encode(log_stream)}" + ) diff --git a/utilities/aws_lambda/subtask_handler.py b/utilities/aws_lambda/subtask_handler.py index e5ac086a..6b513ba4 100644 --- a/utilities/aws_lambda/subtask_handler.py +++ b/utilities/aws_lambda/subtask_handler.py @@ -6,12 +6,11 @@ TaskOrchestrator.run_subtask(...) calls. import json import logging -import os from contextlib import AbstractContextManager from functools import wraps from typing import Any, Callable, Optional, cast -from urllib.parse import quote +from utilities.aws_lambda.cloud_logs import cloudwatch_url from utilities.aws_lambda.default_orchestrator import default_orchestrator from utilities.aws_lambda.subtask_trigger_body import SubtaskTriggerBody from orchestration.task_orchestrator import TaskOrchestrator @@ -42,7 +41,7 @@ def subtask_handler( @wraps(func) def wrapper(event: dict[str, Any], context: Any) -> None: - cloud_logs_url = _cloudwatch_url() + cloud_logs_url = cloudwatch_url() with factory() as orchestrator: for record in _records(event): body = _parse_body(record) @@ -95,20 +94,3 @@ def _records(event: dict[str, Any]) -> list[dict[str, Any]]: if isinstance(raw_records, list): return [r for r in cast(list[Any], raw_records) if isinstance(r, dict)] return [event] - - -def _console_encode(value: str) -> str: - return quote(value, safe="").replace("%", "$25") - - -def _cloudwatch_url() -> Optional[str]: - region = os.environ.get("AWS_REGION") - log_group = os.environ.get("AWS_LAMBDA_LOG_GROUP_NAME") - log_stream = os.environ.get("AWS_LAMBDA_LOG_STREAM_NAME") - if not (region and log_group and log_stream): - return None - return ( - f"https://{region}.console.aws.amazon.com/cloudwatch/home" - f"?region={region}#logsV2:log-groups/log-group/" - f"{_console_encode(log_group)}/log-events/{_console_encode(log_stream)}" - ) diff --git a/utilities/aws_lambda/task_handler.py b/utilities/aws_lambda/task_handler.py index 34811515..43699aee 100644 --- a/utilities/aws_lambda/task_handler.py +++ b/utilities/aws_lambda/task_handler.py @@ -10,6 +10,7 @@ from contextlib import AbstractContextManager from functools import wraps from typing import Any, Callable, Optional, cast +from utilities.aws_lambda.cloud_logs import cloudwatch_url from utilities.aws_lambda.default_orchestrator import default_orchestrator from domain.tasks.tasks import Source from orchestration.task_orchestrator import TaskOrchestrator @@ -41,6 +42,7 @@ def task_handler( def decorator(func: Callable[..., Any]) -> Callable[..., Any]: @wraps(func) def wrapper(event: dict[str, Any], context: Any) -> Any: + cloud_logs_url = cloudwatch_url() with factory() as orchestrator: task_ids: list[dict[str, str]] = [] failures: list[dict[str, Any]] = [] @@ -66,6 +68,7 @@ def task_handler( orchestrator.run_subtask( subtask.id, work=lambda body=body: func(body, context), + cloud_logs_url=cloud_logs_url, ) except Exception: logger.exception(