fix(task_handler): persist cloud_logs_url for modelling_e2e

@task_handler never built or passed cloud_logs_url, so every app using
it (incl. modelling_e2e) ran run_subtask with the None default and the
CloudWatch deep-link was never saved onto the SubTask. @subtask_handler
did this correctly.

Extract the URL builder into a shared utilities/aws_lambda/cloud_logs.py
(public cloudwatch_url()), use it from both handlers, and pass the URL
into run_subtask from @task_handler. Add regression tests.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
Jun-te Kim 2026-06-23 15:35:47 +00:00
parent 4e3eb52a37
commit 119ff3740c
4 changed files with 136 additions and 20 deletions

View file

@ -0,0 +1,104 @@
from collections.abc import Generator, Iterator
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any
from uuid import UUID
import pytest
from sqlalchemy import Engine
from sqlmodel import Session
from domain.tasks.tasks import Source
from orchestration.task_orchestrator import TaskOrchestrator
from repositories.tasks.subtask_postgres_repository import SubTaskPostgresRepository
from repositories.tasks.task_postgres_repository import TaskPostgresRepository
from utilities.aws_lambda.task_handler import task_handler
@dataclass
class Harness:
orchestrator: TaskOrchestrator
tasks: TaskPostgresRepository
subtasks: SubTaskPostgresRepository
@contextmanager
def factory(self) -> Generator[TaskOrchestrator, None, None]:
yield self.orchestrator
@pytest.fixture
def harness(db_engine: Engine) -> Iterator[Harness]:
with Session(db_engine) as session:
tasks = TaskPostgresRepository(session=session)
subtasks = SubTaskPostgresRepository(session=session)
yield Harness(
orchestrator=TaskOrchestrator(task_repo=tasks, subtask_repo=subtasks),
tasks=tasks,
subtasks=subtasks,
)
def _direct_event(property_id: str) -> dict[str, Any]:
return {"property_id": property_id}
def test_task_handler_records_cloudwatch_url_on_subtask(
harness: Harness, monkeypatch: pytest.MonkeyPatch
) -> None:
# arrange
monkeypatch.setenv("AWS_REGION", "eu-west-2")
monkeypatch.setenv(
"AWS_LAMBDA_LOG_GROUP_NAME", "/aws/lambda/modelling-e2e"
)
monkeypatch.setenv(
"AWS_LAMBDA_LOG_STREAM_NAME", "2026/05/20/[$LATEST]abc123"
)
@task_handler(
task_source="modelling_e2e",
source=Source.PROPERTY,
orchestrator_cm=harness.factory,
)
def handler(body: dict[str, Any], context: Any) -> None:
return None
# act
result = handler(_direct_event("prop-1"), context=None)
# assert
subtask_id = result[0]["subtask_id"]
saved_url = harness.subtasks.get(UUID(subtask_id)).cloud_logs_url
assert saved_url is not None
assert saved_url.startswith(
"https://eu-west-2.console.aws.amazon.com/cloudwatch/home"
)
# Log group / stream are console-encoded ("/" -> "$252F").
assert "$252Faws$252Flambda$252Fmodelling-e2e" in saved_url
assert "$255B$2524LATEST$255D" in saved_url
def test_task_handler_leaves_cloudwatch_url_unset_outside_lambda(
harness: Harness, monkeypatch: pytest.MonkeyPatch
) -> None:
# arrange
for var in (
"AWS_REGION",
"AWS_LAMBDA_LOG_GROUP_NAME",
"AWS_LAMBDA_LOG_STREAM_NAME",
):
monkeypatch.delenv(var, raising=False)
@task_handler(
task_source="modelling_e2e",
source=Source.PROPERTY,
orchestrator_cm=harness.factory,
)
def handler(body: dict[str, Any], context: Any) -> None:
return None
# act
result = handler(_direct_event("prop-1"), context=None)
# assert
subtask_id = result[0]["subtask_id"]
assert harness.subtasks.get(UUID(subtask_id)).cloud_logs_url is None

View file

@ -0,0 +1,27 @@
"""Build a CloudWatch console deep-link for the running Lambda invocation.
Shared by @task_handler and @subtask_handler so both persist the same
`cloud_logs_url` onto the SubTask they run.
"""
import os
from typing import Optional
from urllib.parse import quote
def _console_encode(value: str) -> str:
return quote(value, safe="").replace("%", "$25")
def cloudwatch_url() -> Optional[str]:
"""Deep-link to this invocation's log stream, or None outside Lambda."""
region = os.environ.get("AWS_REGION")
log_group = os.environ.get("AWS_LAMBDA_LOG_GROUP_NAME")
log_stream = os.environ.get("AWS_LAMBDA_LOG_STREAM_NAME")
if not (region and log_group and log_stream):
return None
return (
f"https://{region}.console.aws.amazon.com/cloudwatch/home"
f"?region={region}#logsV2:log-groups/log-group/"
f"{_console_encode(log_group)}/log-events/{_console_encode(log_stream)}"
)

View file

@ -6,12 +6,11 @@ TaskOrchestrator.run_subtask(...) calls.
import json
import logging
import os
from contextlib import AbstractContextManager
from functools import wraps
from typing import Any, Callable, Optional, cast
from urllib.parse import quote
from utilities.aws_lambda.cloud_logs import cloudwatch_url
from utilities.aws_lambda.default_orchestrator import default_orchestrator
from utilities.aws_lambda.subtask_trigger_body import SubtaskTriggerBody
from orchestration.task_orchestrator import TaskOrchestrator
@ -42,7 +41,7 @@ def subtask_handler(
@wraps(func)
def wrapper(event: dict[str, Any], context: Any) -> None:
cloud_logs_url = _cloudwatch_url()
cloud_logs_url = cloudwatch_url()
with factory() as orchestrator:
for record in _records(event):
body = _parse_body(record)
@ -95,20 +94,3 @@ def _records(event: dict[str, Any]) -> list[dict[str, Any]]:
if isinstance(raw_records, list):
return [r for r in cast(list[Any], raw_records) if isinstance(r, dict)]
return [event]
def _console_encode(value: str) -> str:
return quote(value, safe="").replace("%", "$25")
def _cloudwatch_url() -> Optional[str]:
region = os.environ.get("AWS_REGION")
log_group = os.environ.get("AWS_LAMBDA_LOG_GROUP_NAME")
log_stream = os.environ.get("AWS_LAMBDA_LOG_STREAM_NAME")
if not (region and log_group and log_stream):
return None
return (
f"https://{region}.console.aws.amazon.com/cloudwatch/home"
f"?region={region}#logsV2:log-groups/log-group/"
f"{_console_encode(log_group)}/log-events/{_console_encode(log_stream)}"
)

View file

@ -10,6 +10,7 @@ from contextlib import AbstractContextManager
from functools import wraps
from typing import Any, Callable, Optional, cast
from utilities.aws_lambda.cloud_logs import cloudwatch_url
from utilities.aws_lambda.default_orchestrator import default_orchestrator
from domain.tasks.tasks import Source
from orchestration.task_orchestrator import TaskOrchestrator
@ -41,6 +42,7 @@ def task_handler(
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def wrapper(event: dict[str, Any], context: Any) -> Any:
cloud_logs_url = cloudwatch_url()
with factory() as orchestrator:
task_ids: list[dict[str, str]] = []
failures: list[dict[str, Any]] = []
@ -66,6 +68,7 @@ def task_handler(
orchestrator.run_subtask(
subtask.id,
work=lambda body=body: func(body, context),
cloud_logs_url=cloud_logs_url,
)
except Exception:
logger.exception(