postcode splliter working e2e

This commit is contained in:
Jun-te Kim 2026-05-20 11:07:40 +00:00
parent 0a04448217
commit 914a8ed51e
18 changed files with 523 additions and 93 deletions

1
.gitignore vendored
View file

@ -121,6 +121,7 @@ celerybeat.pid
# Environments
.env
.env.local
.venv
env/
venv/

View file

@ -1,5 +1,18 @@
FROM public.ecr.aws/lambda/python:3.11
# Postgres host/port/database are baked into the image at build time from
# the deploy workflow's --build-arg values (GitHub Actions DEV_DB_* secrets),
# mirroring backend/postcode_splitter/handler/Dockerfile. They map onto the
# POSTGRES_* names PostgresConfig.from_env reads. Username/password are NOT
# baked in -- Terraform injects those as Lambda env vars from Secrets Manager.
ARG DEV_DB_HOST
ARG DEV_DB_PORT
ARG DEV_DB_NAME
ENV POSTGRES_HOST=${DEV_DB_HOST}
ENV POSTGRES_PORT=${DEV_DB_PORT}
ENV POSTGRES_DATABASE=${DEV_DB_NAME}
WORKDIR /var/task
COPY applications/postcode_splitter/requirements.txt .

View file

@ -0,0 +1,34 @@
# Local-test environment for the postcode_splitter Lambda.
#
# cp .env.local.example .env.local then fill in the values below.
#
# .env.local is gitignored. The container hits REAL AWS and a REAL Postgres,
# so every value here points at infrastructure that actually exists.
#
# NOTE: the new DDD code uses different env var names than the repo root
# .env. The mapping (root .env name -> var here) is given per section.
# Keep comments on their own lines — docker-compose's env_file parser folds a
# trailing "# ..." into the value.
# --- Postgres (orchestration/default_orchestrator -> PostgresConfig.from_env) ---
# POSTGRES_HOST <- DB_HOST, PORT <- DB_PORT, USERNAME <- DB_USERNAME,
# PASSWORD <- DB_PASSWORD, DATABASE <- DB_NAME.
POSTGRES_HOST=
POSTGRES_PORT=5432
POSTGRES_USERNAME=
POSTGRES_PASSWORD=
POSTGRES_DATABASE=
# POSTGRES_DRIVER=psycopg2 (optional; defaults to psycopg2)
# --- Handler config (applications/postcode_splitter/handler.py) ---
# S3_BUCKET_NAME: bucket holding the input address CSV (root .env: DATA_BUCKET).
# ADDRESS2UPRN_QUEUE_URL: SQS queue the splitter fans batches out to; not in
# the root .env (Terraform sets it in prod).
S3_BUCKET_NAME=
ADDRESS2UPRN_QUEUE_URL=
# --- AWS credentials for boto3 (S3 + SQS clients) ---
AWS_ACCESS_KEY_ID=
AWS_SECRET_ACCESS_KEY=
AWS_DEFAULT_REGION=eu-west-2
# AWS_SESSION_TOKEN= (only if using temporary/SSO credentials)

View file

@ -0,0 +1,9 @@
services:
postcode-splitter:
build:
context: ../../../
dockerfile: applications/postcode_splitter/Dockerfile
ports:
- "9001:8080"
env_file:
- .env.local

View file

@ -0,0 +1,37 @@
#!/usr/bin/env python3
"""POST a single SQS-shaped event at the locally-running splitter Lambda.
The container built by docker-compose runs the AWS Lambda Runtime Interface
Emulator, which accepts invocations on the URL below. Replace the three
placeholder values with a real parent Task id, the splitter's own SubTask id
(both must already exist in the Postgres pointed at by .env.local), and the
s3://... URI of an uploaded address CSV.
"""
import json
import requests
HOST = "localhost"
PORT = "9001"
LAMBDA_URL = f"http://{HOST}:{PORT}/2015-03-31/functions/function/invocations"
payload = {
"Records": [
{
"body": json.dumps(
{
"task_id": "f4b3332f-c0cc-481f-96a5-d39860a647cf",
"sub_task_id": "14c042de-40c4-473b-8cd8-72c983a94a8d",
"s3_uri": "s3://retrofit-data-dev/ara_raw_inputs/calico/Calico Homes Full list EPC Properties(Sheet2) (1) (1).csv",
}
)
}
]
}
response = requests.post(LAMBDA_URL, json=payload)
print("Status code:", response.status_code)
print("Response:")
print(response.text)

View file

@ -0,0 +1,12 @@
#!/usr/bin/env bash
set -euo pipefail
cd "$(dirname "$0")"
if [ ! -f .env.local ]; then
cp .env.local.example .env.local
echo "Created .env.local from the template — fill it in, then re-run." >&2
exit 1
fi
docker compose build --no-cache
docker compose up --force-recreate

View file

@ -8,4 +8,5 @@ boto3==1.35.44
sqlmodel
sqlalchemy==2.0.36
psycopg2-binary==2.9.10
pydantic-settings==2.6.0
pydantic-settings==2.6.0
httpx

View file

@ -40,20 +40,6 @@ module "lambda" {
LOG_LEVEL = "info"
DB_USERNAME = local.db_credentials.db_assessment_model_username
DB_PASSWORD = local.db_credentials.db_assessment_model_password
GOOGLE_SOLAR_API_KEY = "test"
SAP_PREDICTIONS_BUCKET = "test"
CARBON_PREDICTIONS_BUCKET = "test"
HEAT_PREDICTIONS_BUCKET = "test"
HEATING_KWH_PREDICTIONS_BUCKET = "test"
HOTWATER_KWH_PREDICTIONS_BUCKET = "test"
API_KEY = "test"
ENVIRONMENT = "test"
SECRET_KEY = "test"
PLAN_TRIGGER_BUCKET = "test"
DATA_BUCKET = "test"
EPC_AUTH_TOKEN = "test"
ENGINE_SQS_URL = "test"
ENERGY_ASSESSMENTS_BUCKET = "test"
ADDRESS2UPRN_QUEUE_URL = data.terraform_remote_state.address2uprn.outputs.address2uprn_queue_url
S3_BUCKET_NAME = data.terraform_remote_state.shared.outputs.retrofit_sap_data_bucket_name
},

View file

@ -8,12 +8,17 @@ caller can construct an instance with an un-normalised postcode.
from __future__ import annotations
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional
from domain.postcodes.sanitise import sanitise_postcode
def _empty_source_row() -> dict[str, str]:
"""Typed default factory for :attr:`UserAddress.source_row`."""
return {}
@dataclass(frozen=True)
class UserAddress:
"""A user-supplied address paired with its canonical postcode.
@ -25,11 +30,20 @@ class UserAddress:
:meth:`__post_init__`.
internal_reference: Optional customer-side identifier preserved for
traceability through the matching pipeline.
source_row: The complete original CSV row this address was parsed
from, column name -> cell value. The splitter is a pass-through
router: it groups rows by postcode but must not drop the other
columns the downstream address2uprn stage relies on, so the raw
row travels alongside the parsed fields. Excluded from equality
and hashing -- identity stays defined by the parsed fields above.
"""
user_address: str
postcode: str
internal_reference: Optional[str] = None
source_row: dict[str, str] = field(
default_factory=_empty_source_row, compare=False
)
def __post_init__(self) -> None:
# Frozen dataclass: bypass the descriptor with object.__setattr__.

View file

@ -2,7 +2,7 @@ import csv
from io import StringIO
from infrastructure.s3_client import S3Client
from utils.s3 import parse_s3_uri
from infrastructure.s3_uri import parse_s3_uri
class CsvS3Client(S3Client):

43
infrastructure/s3_uri.py Normal file
View file

@ -0,0 +1,43 @@
"""Parse S3 URIs into ``(bucket, key)`` pairs.
A pure-stdlib helper for the infrastructure layer. It deliberately pulls in
neither pandas, boto3, nor the legacy ``utils`` package, so slim Lambda images
that only need URI parsing do not drag the wider data stack along.
Two input shapes are supported:
* canonical S3 URIs --- ``s3://bucket/key``
* AWS S3 console URLs --- ``https://.../s3/object/bucket?prefix=key``
"""
from urllib.parse import unquote
def parse_s3_uri(s3_uri: str) -> tuple[str, str]:
"""Return the ``(bucket, key)`` pair addressed by ``s3_uri``.
Raises:
ValueError: if ``s3_uri`` is neither a well-formed ``s3://`` URI nor
an AWS console URL carrying a ``prefix`` query parameter.
"""
if s3_uri.startswith("s3://"):
parts = s3_uri[len("s3://") :].split("/", 1)
if len(parts) < 2 or not parts[0] or not parts[1]:
raise ValueError("S3 URI must include both a bucket and a key")
return parts[0], parts[1]
if "?" not in s3_uri:
raise ValueError(f"Not an s3:// URI and has no query string: {s3_uri!r}")
base, query = s3_uri.split("?", 1)
if "/s3/object/" not in base:
raise ValueError(f"Console URL has no '/s3/object/' segment: {s3_uri!r}")
bucket = base.split("/s3/object/", 1)[1]
params: dict[str, str] = {}
for item in query.split("&"):
if "=" in item:
name, value = item.split("=", 1)
params[name] = value
key = unquote(params.get("prefix", ""))
return bucket, key

View file

@ -1,12 +1,16 @@
"""CSV-on-S3 adapter for :class:`UserAddressRepository`.
Reads canonical upload CSVs (``Address 1``, ``Address 2``, ``Address 3``,
``Postcode``, ``Internal Reference``) and writes the splitter's compact
3-column form (``user_address``, ``postcode``, ``internal_reference``).
Reads upload CSVs that carry a ``postcode`` column (plus optional
``Address 1``/``Address 2``/``Address 3`` and ``Internal Reference``), and
writes batch CSVs that pass *every* original column through unchanged with
one column appended -- ``postcode_clean`` (uppercase, whitespace-stripped) --
which the downstream address2uprn stage groups on.
The frontend pre-applies the user's column mapping at upload time, so this
adapter does NOT consult any ``BulkAddressUpload.column_mapping``: it always
expects the canonical column names listed above.
The splitter is a pass-through router: it must not reshape or drop columns,
because address2uprn has not been migrated and still consumes the legacy
splitter's full-row output. The frontend pre-applies the user's column
mapping at upload time, so this adapter does NOT consult any
``BulkAddressUpload.column_mapping``.
"""
from __future__ import annotations
@ -20,8 +24,9 @@ from infrastructure.csv_s3_client import CsvS3Client
from repositories.user_address.user_address_repository import UserAddressRepository
_ADDRESS_COLUMNS: tuple[str, str, str] = ("Address 1", "Address 2", "Address 3")
_POSTCODE_COLUMN: str = "Postcode"
_POSTCODE_COLUMN: str = "postcode"
_INTERNAL_REFERENCE_COLUMN: str = "Internal Reference"
_POSTCODE_CLEAN_COLUMN: str = "postcode_clean"
class UserAddressCsvS3Repository(UserAddressRepository):
@ -37,15 +42,27 @@ class UserAddressCsvS3Repository(UserAddressRepository):
self._bucket = bucket
def load_batch(self, s3_uri: str) -> list[UserAddress]:
"""Load canonical upload CSV rows into :class:`UserAddress` objects.
"""Load upload CSV rows into :class:`UserAddress` objects.
Concatenates ``Address 1``/``Address 2``/``Address 3`` with ``", "``,
skipping missing or empty parts, into ``user_address``. Falls back to
just ``Address 1`` when 2 and 3 are absent. Passes ``Internal Reference``
through to :attr:`UserAddress.internal_reference` (``None`` when the
column is missing or empty).
Each row's complete column set is preserved on
:attr:`UserAddress.source_row` so :meth:`save_batch` can pass it
through untouched. The parsed convenience fields are also populated:
``Address 1``/``Address 2``/``Address 3`` are concatenated with
``", "`` (skipping missing/empty parts) into ``user_address``, and
``Internal Reference`` is threaded to
:attr:`UserAddress.internal_reference` (``None`` when missing/empty).
Raises:
ValueError: if the CSV has rows but no ``postcode`` column --
without it the splitter cannot group, and silently emitting
empty postcodes would corrupt every downstream batch.
"""
rows = self._csv_client.read_rows(s3_uri)
if rows and _POSTCODE_COLUMN not in rows[0]:
raise ValueError(
f"Input CSV {s3_uri} has no {_POSTCODE_COLUMN!r} column; "
f"columns present: {sorted(rows[0])}"
)
addresses: list[UserAddress] = []
for row in rows:
parts = [
@ -62,22 +79,24 @@ class UserAddressCsvS3Repository(UserAddressRepository):
user_address=user_address,
postcode=postcode,
internal_reference=internal_reference,
source_row=row,
)
)
return addresses
def save_batch(self, addresses: list[UserAddress], path_prefix: str) -> str:
"""Write a 3-column CSV under a unique key beneath ``path_prefix``.
"""Write a pass-through batch CSV under a unique key.
Each output row is the address's original ``source_row`` with a
``postcode_clean`` column appended (the canonical postcode the
downstream address2uprn stage groups on). No original column is
dropped or reshaped.
The key is ``{path_prefix}/{ISO-8601 datetime}_{8-char uuid}.csv``.
Returns the full ``s3://bucket/key`` URI.
"""
rows: list[dict[str, str]] = [
{
"user_address": addr.user_address,
"postcode": addr.postcode,
"internal_reference": addr.internal_reference or "",
}
{**addr.source_row, _POSTCODE_CLEAN_COLUMN: addr.postcode}
for addr in addresses
]
filename = (

View file

@ -43,3 +43,29 @@ def test_user_address_equality_uses_sanitised_postcode() -> None:
a = UserAddress(user_address="1 The Street", postcode="sw1a 1aa")
b = UserAddress(user_address="1 The Street", postcode="SW1A1AA")
assert a == b
def test_user_address_source_row_defaults_to_empty_dict() -> None:
addr = UserAddress(user_address="1 The Street", postcode="SW1A1AA")
assert addr.source_row == {}
def test_user_address_carries_source_row() -> None:
row = {"Address 1": "1 The Street", "postcode": "SW1A 1AA", "SAP Score": "72"}
addr = UserAddress(
user_address="1 The Street", postcode="SW1A 1AA", source_row=row
)
assert addr.source_row == row
def test_user_address_equality_ignores_source_row() -> None:
# source_row is excluded from equality (and hashing): identity stays
# defined by the parsed fields, so two addresses parsed from rows with
# different incidental columns still compare equal.
a = UserAddress(
user_address="1 The Street", postcode="SW1A1AA", source_row={"x": "1"}
)
b = UserAddress(
user_address="1 The Street", postcode="SW1A1AA", source_row={"y": "2"}
)
assert a == b

View file

@ -0,0 +1,32 @@
import pytest
from infrastructure.s3_uri import parse_s3_uri
def test_parses_simple_s3_uri() -> None:
assert parse_s3_uri("s3://my-bucket/file.csv") == ("my-bucket", "file.csv")
def test_parses_s3_uri_with_nested_key() -> None:
bucket, key = parse_s3_uri("s3://my-bucket/nested/path/to/file.csv")
assert (bucket, key) == ("my-bucket", "nested/path/to/file.csv")
def test_rejects_s3_uri_without_key() -> None:
with pytest.raises(ValueError, match="bucket and a key"):
parse_s3_uri("s3://my-bucket")
def test_rejects_s3_uri_with_empty_key() -> None:
with pytest.raises(ValueError, match="bucket and a key"):
parse_s3_uri("s3://my-bucket/")
def test_parses_console_url_prefix() -> None:
url = "https://eu-west-2.console.aws.amazon.com/s3/object/my-bucket?prefix=nested%2Ffile.csv"
assert parse_s3_uri(url) == ("my-bucket", "nested/file.csv")
def test_rejects_unparseable_string() -> None:
with pytest.raises(ValueError):
parse_s3_uri("not-a-uri-at-all")

View file

@ -132,7 +132,7 @@ def _upload_fixture_csv(csv_client: CsvS3Client) -> str:
"Address 1": f"{i} High St",
"Address 2": "",
"Address 3": "",
"Postcode": "AA1 1AA",
"postcode": "AA1 1AA",
"Internal Reference": f"AA-{i}",
}
for i in range(1, 3)
@ -142,7 +142,7 @@ def _upload_fixture_csv(csv_client: CsvS3Client) -> str:
"Address 1": f"{i} Long Road",
"Address 2": "",
"Address 3": "",
"Postcode": "BB2 2BB",
"postcode": "BB2 2BB",
"Internal Reference": f"BB-{i}",
}
for i in range(1, 5)
@ -152,7 +152,7 @@ def _upload_fixture_csv(csv_client: CsvS3Client) -> str:
"Address 1": "1 Final Way",
"Address 2": "",
"Address 3": "",
"Postcode": "CC3 3CC",
"postcode": "CC3 3CC",
"Internal Reference": "CC-1",
}
)
@ -281,15 +281,15 @@ def test_split_and_dispatch_returns_child_ids_in_dispatch_order(
input_s3_uri=input_uri,
)
# Re-load each child's saved batch and inspect the postcode column to
# confirm the dispatch order matches the postcode-batching algorithm:
# Re-load each child's saved batch and inspect the postcode_clean column
# to confirm the dispatch order matches the postcode-batching algorithm:
# AA-batch first, BB oversize batch second, CC final-flush third.
postcodes_per_batch: list[set[str]] = []
for cid in child_ids:
child = harness.subtasks.get(cid)
assert child.inputs is not None
rows = harness.csv_client.read_rows(child.inputs["s3_uri"])
postcodes_per_batch.append({row["postcode"] for row in rows})
postcodes_per_batch.append({row["postcode_clean"] for row in rows})
assert postcodes_per_batch == [
{"AA11AA"},

View file

@ -3,6 +3,7 @@ from collections.abc import Iterator
import pytest
from moto import mock_aws
from domain.addresses.user_address import UserAddress
from infrastructure.csv_s3_client import CsvS3Client
from repositories.user_address.user_address_csv_s3_repository import (
UserAddressCsvS3Repository,
@ -27,7 +28,7 @@ def _upload_csv(
return repo._csv_client.save_rows(rows, key) # pyright: ignore[reportPrivateUsage]
def test_load_batch_concatenates_three_address_lines(
def test_load_batch_parses_address_postcode_and_reference(
repo: UserAddressCsvS3Repository,
) -> None:
rows = [
@ -35,7 +36,7 @@ def test_load_batch_concatenates_three_address_lines(
"Address 1": "1 High Street",
"Address 2": "Flat 2",
"Address 3": "Townville",
"Postcode": "sw1a 1aa",
"postcode": "sw1a 1aa",
"Internal Reference": "REF-001",
}
]
@ -58,7 +59,7 @@ def test_load_batch_uses_only_address_1_when_others_missing(
"Address 1": "10 Cardiff Road",
"Address 2": "",
"Address 3": "",
"Postcode": "CF10 1AA",
"postcode": "CF10 1AA",
"Internal Reference": "REF-002",
}
]
@ -80,7 +81,7 @@ def test_load_batch_handles_missing_internal_reference(
"Address 1": "5 Park Lane",
"Address 2": "",
"Address 3": "",
"Postcode": "M1 1AA",
"postcode": "M1 1AA",
"Internal Reference": "",
}
]
@ -94,16 +95,67 @@ def test_load_batch_handles_missing_internal_reference(
assert addresses[0].internal_reference is None
def test_load_batch_captures_full_source_row(
repo: UserAddressCsvS3Repository,
) -> None:
# A raw EPC-export-shaped row: the splitter must preserve every column,
# not just the ones it parses into UserAddress fields.
row = {
"Asset Reference": "511",
"Address 1": "9 Abingdon Road Padiham Lancashire BB12 7BX",
"postcode": "BB12 7BX",
"Property Type": "House: End Terrace",
"SAP Score": "69",
}
uri = _upload_csv(repo, [row], "uploads/epc.csv")
addresses = repo.load_batch(uri)
assert addresses[0].source_row == row
def test_load_batch_raises_when_postcode_column_absent(
repo: UserAddressCsvS3Repository,
) -> None:
rows = [{"Address 1": "1 High Street", "Property Type": "Flat"}]
uri = _upload_csv(repo, rows, "uploads/no-postcode.csv")
with pytest.raises(ValueError, match="no 'postcode' column"):
repo.load_batch(uri)
def test_save_batch_passes_through_all_columns_and_appends_postcode_clean(
repo: UserAddressCsvS3Repository,
) -> None:
row = {
"Asset Reference": "511",
"Address 1": "9 Abingdon Road Padiham Lancashire BB12 7BX",
"postcode": " BB12 7BX",
"Property Type": "House: End Terrace",
}
uri = _upload_csv(repo, [row], "uploads/epc.csv")
addresses = repo.load_batch(uri)
saved_uri = repo.save_batch(addresses, "tasks/passthrough")
saved_rows = repo._csv_client.read_rows(saved_uri) # pyright: ignore[reportPrivateUsage]
assert len(saved_rows) == 1
saved = saved_rows[0]
# Every original column survives, byte-for-byte.
for column, value in row.items():
assert saved[column] == value
# Plus the one appended column the downstream address2uprn stage groups on.
assert saved["postcode_clean"] == "BB127BX"
def test_save_batch_returns_uri_under_path_prefix(
repo: UserAddressCsvS3Repository,
) -> None:
from domain.addresses.user_address import UserAddress
addresses = [
UserAddress(
user_address="1 High Street, Flat 2, Townville",
user_address="1 High Street",
postcode="SW1A 1AA",
internal_reference="REF-001",
source_row={"Address 1": "1 High Street", "postcode": "SW1A 1AA"},
),
]
@ -113,59 +165,42 @@ def test_save_batch_returns_uri_under_path_prefix(
assert uri.endswith(".csv")
def test_save_then_reload_round_trip_preserves_values(
def test_save_then_reload_round_trip_preserves_columns(
repo: UserAddressCsvS3Repository,
) -> None:
from domain.addresses.user_address import UserAddress
# save_batch writes the splitter's compact schema
# (user_address/postcode/internal_reference); load_batch reads the
# canonical upload schema. To round-trip through the repo we re-upload
# the saved CSV under the upload schema's column names.
original = [
UserAddress(
user_address="1 High Street",
postcode="SW1A 1AA",
internal_reference="REF-001",
),
UserAddress(
user_address="2 Low Street",
postcode="XY9 8ZW",
internal_reference=None,
),
]
saved_uri = repo.save_batch(original, "tasks/round-trip")
# Re-shape the saved CSV into the canonical upload schema for reload.
saved_rows = repo._csv_client.read_rows(saved_uri) # pyright: ignore[reportPrivateUsage]
upload_rows: list[dict[str, str]] = [
rows = [
{
"Address 1": row["user_address"],
"Address 2": "",
"Address 3": "",
"Postcode": row["postcode"],
"Internal Reference": row["internal_reference"],
}
for row in saved_rows
"Address 1": "1 High Street",
"postcode": "SW1A 1AA",
"Internal Reference": "REF-001",
},
{
"Address 1": "2 Low Street",
"postcode": "XY9 8ZW",
"Internal Reference": "",
},
]
upload_uri = _upload_csv(repo, upload_rows, "uploads/round-trip.csv")
uri = _upload_csv(repo, rows, "uploads/round-trip.csv")
addresses = repo.load_batch(uri)
reloaded = repo.load_batch(upload_uri)
saved_uri = repo.save_batch(addresses, "tasks/round-trip")
saved_rows = repo._csv_client.read_rows(saved_uri) # pyright: ignore[reportPrivateUsage]
assert reloaded == original
# Original columns come back verbatim; postcode_clean is the only addition.
assert [
{k: v for k, v in r.items() if k != "postcode_clean"} for r in saved_rows
] == rows
assert [r["postcode_clean"] for r in saved_rows] == ["SW1A1AA", "XY98ZW"]
def test_save_batch_uses_unique_filename_per_call(
repo: UserAddressCsvS3Repository,
) -> None:
from domain.addresses.user_address import UserAddress
addresses = [
UserAddress(
user_address="1 High Street",
postcode="SW1A 1AA",
internal_reference="REF-001",
source_row={"Address 1": "1 High Street", "postcode": "SW1A 1AA"},
),
]

View file

@ -6,6 +6,7 @@ to the wrapped function — so the handler can compose its own use-case
orchestrator that shares the session.
"""
import logging
from collections.abc import Generator, Iterator
from contextlib import contextmanager
from dataclasses import dataclass
@ -13,6 +14,8 @@ from typing import Any
from uuid import UUID
import pytest
_LOGGER_NAME = "utilities.aws_lambda.subtask_handler"
from sqlmodel import Session, SQLModel, create_engine
from domain.tasks.subtasks import SubTaskStatus
@ -142,3 +145,111 @@ def test_subtask_handler_injected_orchestrator_can_create_child_subtask(
persisted_child = harness.subtasks.get(child_ids[0])
assert persisted_child.task_id == task.id
assert persisted_child.status is SubTaskStatus.WAITING
def test_subtask_handler_logs_subtask_lifecycle_on_success(
harness: Harness, caplog: pytest.LogCaptureFixture
) -> None:
"""Start and completion are logged at INFO so a successful invocation
leaves a CloudWatch breadcrumb (not just the Lambda runtime lines)."""
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
@subtask_handler(orchestrator_cm=harness.factory)
def handler(
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
) -> None:
return None
with caplog.at_level(logging.INFO, logger=_LOGGER_NAME):
handler(_direct_event(task.id, subtask.id), context=None)
assert f"Running subtask {subtask.id}" in caplog.text
assert f"Subtask {subtask.id} completed" in caplog.text
def test_subtask_handler_logs_exception_on_failure(
harness: Harness, caplog: pytest.LogCaptureFixture
) -> None:
"""A failing subtask is logged at ERROR with the traceback attached,
before the exception propagates for the Lambda runtime to surface."""
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
@subtask_handler(orchestrator_cm=harness.factory)
def handler(
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
) -> None:
raise RuntimeError("boom")
with caplog.at_level(logging.INFO, logger=_LOGGER_NAME):
with pytest.raises(RuntimeError, match="boom"):
handler(_direct_event(task.id, subtask.id), context=None)
failures = [r for r in caplog.records if r.levelno == logging.ERROR]
assert any(
f"Subtask {subtask.id} failed" in r.getMessage() for r in failures
)
assert any(r.exc_info is not None for r in failures)
def test_subtask_handler_records_cloudwatch_url_on_subtask(
harness: Harness, monkeypatch: pytest.MonkeyPatch
) -> None:
"""With the AWS Lambda runtime's log env vars present, a CloudWatch deep
link is built and persisted on the SubTask."""
monkeypatch.setenv("AWS_REGION", "eu-west-2")
monkeypatch.setenv(
"AWS_LAMBDA_LOG_GROUP_NAME", "/aws/lambda/postcode-splitter"
)
monkeypatch.setenv(
"AWS_LAMBDA_LOG_STREAM_NAME", "2026/05/20/[$LATEST]abc123"
)
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
@subtask_handler(orchestrator_cm=harness.factory)
def handler(
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
) -> None:
return None
handler(_direct_event(task.id, subtask.id), context=None)
saved_url = harness.subtasks.get(subtask.id).cloud_logs_url
assert saved_url is not None
assert saved_url.startswith(
"https://eu-west-2.console.aws.amazon.com/cloudwatch/home"
)
# Log group / stream are console-encoded ("/" -> "$252F").
assert "$252Faws$252Flambda$252Fpostcode-splitter" in saved_url
assert "$255B$2524LATEST$255D" in saved_url
def test_subtask_handler_leaves_cloudwatch_url_unset_outside_lambda(
harness: Harness, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Outside a real Lambda (e.g. the local RIE) the runtime log env vars
are absent, so cloud_logs_url is left unset rather than storing junk."""
for var in (
"AWS_REGION",
"AWS_LAMBDA_LOG_GROUP_NAME",
"AWS_LAMBDA_LOG_STREAM_NAME",
):
monkeypatch.delenv(var, raising=False)
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
@subtask_handler(orchestrator_cm=harness.factory)
def handler(
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
) -> None:
return None
handler(_direct_event(task.id, subtask.id), context=None)
assert harness.subtasks.get(subtask.id).cloud_logs_url is None

View file

@ -1,18 +1,32 @@
"""@subtask_handler decorator for Lambdas that operate on existing SubTasks.
Translates an AWS Lambda invocation (SQS-shaped or direct) into
TaskOrchestrator.run_subtask(...) calls.
TaskOrchestrator.run_subtask(...) calls, emitting an INFO log line for each
subtask's start and completion and a logged exception on failure. Those lines
land in CloudWatch via the Lambda runtime's stdout/stderr capture.
Each subtask also records ``cloud_logs_url`` -- a deep link to this
invocation's CloudWatch log stream -- so an operator can jump from a SubTask
row straight to its logs. It is built from the environment variables the AWS
Lambda runtime sets, so it is populated only on real Lambda invocations and
left unset under the local RIE (which does not export them).
"""
import json
import logging
import os
from contextlib import AbstractContextManager
from functools import wraps
from typing import Any, Callable, Optional, cast
from urllib.parse import quote
from utilities.aws_lambda.default_orchestrator import default_orchestrator
from utilities.aws_lambda.subtask_trigger_body import SubtaskTriggerBody
from orchestration.task_orchestrator import TaskOrchestrator
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
OrchestratorCM = Callable[[], AbstractContextManager[TaskOrchestrator]]
@ -33,14 +47,26 @@ def subtask_handler(
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def wrapper(event: dict[str, Any], context: Any) -> None:
cloud_logs_url = _cloudwatch_url()
with factory() as orchestrator:
for record in _records(event):
body = _parse_body(record)
trigger = SubtaskTriggerBody.model_validate(body)
orchestrator.run_subtask(
trigger.sub_task_id,
work=lambda body=body, o=orchestrator: func(body, context, o),
)
logger.info("Running subtask %s", trigger.sub_task_id)
try:
orchestrator.run_subtask(
trigger.sub_task_id,
work=lambda body=body, o=orchestrator: func(
body, context, o
),
cloud_logs_url=cloud_logs_url,
)
except Exception:
logger.exception(
"Subtask %s failed", trigger.sub_task_id
)
raise
logger.info("Subtask %s completed", trigger.sub_task_id)
return wrapper
@ -65,3 +91,34 @@ def _records(event: dict[str, Any]) -> list[dict[str, Any]]:
if isinstance(raw_records, list):
return [r for r in cast(list[Any], raw_records) if isinstance(r, dict)]
return [event]
def _console_encode(value: str) -> str:
"""Encode a value for a CloudWatch console deep link.
The console expects URL-encoding with the percent signs themselves
re-encoded as ``$25`` -- e.g. ``/`` becomes ``%2F`` becomes ``$252F``.
"""
return quote(value, safe="").replace("%", "$25")
def _cloudwatch_url() -> Optional[str]:
"""Build a CloudWatch console URL for this invocation's log stream.
Sourced entirely from the environment variables the AWS Lambda runtime
sets -- ``AWS_REGION``, ``AWS_LAMBDA_LOG_GROUP_NAME`` and
``AWS_LAMBDA_LOG_STREAM_NAME``. Returns None when any is absent, which is
the case outside a real Lambda (the local RIE does not export them) -- so
``SubTask.cloud_logs_url`` is left unset rather than storing a link that
points nowhere.
"""
region = os.environ.get("AWS_REGION")
log_group = os.environ.get("AWS_LAMBDA_LOG_GROUP_NAME")
log_stream = os.environ.get("AWS_LAMBDA_LOG_STREAM_NAME")
if not (region and log_group and log_stream):
return None
return (
f"https://{region}.console.aws.amazon.com/cloudwatch/home"
f"?region={region}#logsV2:log-groups/log-group/"
f"{_console_encode(log_group)}/log-events/{_console_encode(log_stream)}"
)