diff --git a/.gitignore b/.gitignore index 888d527a..9e5df0c7 100644 --- a/.gitignore +++ b/.gitignore @@ -121,6 +121,7 @@ celerybeat.pid # Environments .env +.env.local .venv env/ venv/ diff --git a/applications/postcode_splitter/Dockerfile b/applications/postcode_splitter/Dockerfile index 578ee7a7..aea1f914 100644 --- a/applications/postcode_splitter/Dockerfile +++ b/applications/postcode_splitter/Dockerfile @@ -1,5 +1,18 @@ FROM public.ecr.aws/lambda/python:3.11 +# Postgres host/port/database are baked into the image at build time from +# the deploy workflow's --build-arg values (GitHub Actions DEV_DB_* secrets), +# mirroring backend/postcode_splitter/handler/Dockerfile. They map onto the +# POSTGRES_* names PostgresConfig.from_env reads. Username/password are NOT +# baked in -- Terraform injects those as Lambda env vars from Secrets Manager. +ARG DEV_DB_HOST +ARG DEV_DB_PORT +ARG DEV_DB_NAME + +ENV POSTGRES_HOST=${DEV_DB_HOST} +ENV POSTGRES_PORT=${DEV_DB_PORT} +ENV POSTGRES_DATABASE=${DEV_DB_NAME} + WORKDIR /var/task COPY applications/postcode_splitter/requirements.txt . diff --git a/applications/postcode_splitter/local_handler/.env.local.example b/applications/postcode_splitter/local_handler/.env.local.example new file mode 100644 index 00000000..28fa8390 --- /dev/null +++ b/applications/postcode_splitter/local_handler/.env.local.example @@ -0,0 +1,34 @@ +# Local-test environment for the postcode_splitter Lambda. +# +# cp .env.local.example .env.local then fill in the values below. +# +# .env.local is gitignored. The container hits REAL AWS and a REAL Postgres, +# so every value here points at infrastructure that actually exists. +# +# NOTE: the new DDD code uses different env var names than the repo root +# .env. The mapping (root .env name -> var here) is given per section. +# Keep comments on their own lines — docker-compose's env_file parser folds a +# trailing "# ..." into the value. + +# --- Postgres (orchestration/default_orchestrator -> PostgresConfig.from_env) --- +# POSTGRES_HOST <- DB_HOST, PORT <- DB_PORT, USERNAME <- DB_USERNAME, +# PASSWORD <- DB_PASSWORD, DATABASE <- DB_NAME. +POSTGRES_HOST= +POSTGRES_PORT=5432 +POSTGRES_USERNAME= +POSTGRES_PASSWORD= +POSTGRES_DATABASE= +# POSTGRES_DRIVER=psycopg2 (optional; defaults to psycopg2) + +# --- Handler config (applications/postcode_splitter/handler.py) --- +# S3_BUCKET_NAME: bucket holding the input address CSV (root .env: DATA_BUCKET). +# ADDRESS2UPRN_QUEUE_URL: SQS queue the splitter fans batches out to; not in +# the root .env (Terraform sets it in prod). +S3_BUCKET_NAME= +ADDRESS2UPRN_QUEUE_URL= + +# --- AWS credentials for boto3 (S3 + SQS clients) --- +AWS_ACCESS_KEY_ID= +AWS_SECRET_ACCESS_KEY= +AWS_DEFAULT_REGION=eu-west-2 +# AWS_SESSION_TOKEN= (only if using temporary/SSO credentials) diff --git a/applications/postcode_splitter/local_handler/docker-compose.yml b/applications/postcode_splitter/local_handler/docker-compose.yml new file mode 100644 index 00000000..68af1c40 --- /dev/null +++ b/applications/postcode_splitter/local_handler/docker-compose.yml @@ -0,0 +1,9 @@ +services: + postcode-splitter: + build: + context: ../../../ + dockerfile: applications/postcode_splitter/Dockerfile + ports: + - "9001:8080" + env_file: + - .env.local diff --git a/applications/postcode_splitter/local_handler/invoke_local_lambda.py b/applications/postcode_splitter/local_handler/invoke_local_lambda.py new file mode 100755 index 00000000..c0ca89ec --- /dev/null +++ b/applications/postcode_splitter/local_handler/invoke_local_lambda.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +"""POST a single SQS-shaped event at the locally-running splitter Lambda. + +The container built by docker-compose runs the AWS Lambda Runtime Interface +Emulator, which accepts invocations on the URL below. Replace the three +placeholder values with a real parent Task id, the splitter's own SubTask id +(both must already exist in the Postgres pointed at by .env.local), and the +s3://... URI of an uploaded address CSV. +""" + +import json +import requests + +HOST = "localhost" +PORT = "9001" + +LAMBDA_URL = f"http://{HOST}:{PORT}/2015-03-31/functions/function/invocations" + +payload = { + "Records": [ + { + "body": json.dumps( + { + "task_id": "f4b3332f-c0cc-481f-96a5-d39860a647cf", + "sub_task_id": "14c042de-40c4-473b-8cd8-72c983a94a8d", + "s3_uri": "s3://retrofit-data-dev/ara_raw_inputs/calico/Calico Homes Full list EPC Properties(Sheet2) (1) (1).csv", + } + ) + } + ] +} + +response = requests.post(LAMBDA_URL, json=payload) + +print("Status code:", response.status_code) +print("Response:") +print(response.text) diff --git a/applications/postcode_splitter/local_handler/run_local.sh b/applications/postcode_splitter/local_handler/run_local.sh new file mode 100755 index 00000000..345b60ee --- /dev/null +++ b/applications/postcode_splitter/local_handler/run_local.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash +set -euo pipefail +cd "$(dirname "$0")" + +if [ ! -f .env.local ]; then + cp .env.local.example .env.local + echo "Created .env.local from the template — fill it in, then re-run." >&2 + exit 1 +fi + +docker compose build --no-cache +docker compose up --force-recreate diff --git a/backend/address2UPRN/handler/requirements.txt b/backend/address2UPRN/handler/requirements.txt index 6ef41b2d..02aaefba 100644 --- a/backend/address2UPRN/handler/requirements.txt +++ b/backend/address2UPRN/handler/requirements.txt @@ -8,4 +8,5 @@ boto3==1.35.44 sqlmodel sqlalchemy==2.0.36 psycopg2-binary==2.9.10 -pydantic-settings==2.6.0 \ No newline at end of file +pydantic-settings==2.6.0 +httpx \ No newline at end of file diff --git a/deployment/terraform/lambda/postcodeSplitter/main.tf b/deployment/terraform/lambda/postcodeSplitter/main.tf index 94c5cd4e..325f7dc7 100644 --- a/deployment/terraform/lambda/postcodeSplitter/main.tf +++ b/deployment/terraform/lambda/postcodeSplitter/main.tf @@ -40,20 +40,6 @@ module "lambda" { LOG_LEVEL = "info" DB_USERNAME = local.db_credentials.db_assessment_model_username DB_PASSWORD = local.db_credentials.db_assessment_model_password - GOOGLE_SOLAR_API_KEY = "test" - SAP_PREDICTIONS_BUCKET = "test" - CARBON_PREDICTIONS_BUCKET = "test" - HEAT_PREDICTIONS_BUCKET = "test" - HEATING_KWH_PREDICTIONS_BUCKET = "test" - HOTWATER_KWH_PREDICTIONS_BUCKET = "test" - API_KEY = "test" - ENVIRONMENT = "test" - SECRET_KEY = "test" - PLAN_TRIGGER_BUCKET = "test" - DATA_BUCKET = "test" - EPC_AUTH_TOKEN = "test" - ENGINE_SQS_URL = "test" - ENERGY_ASSESSMENTS_BUCKET = "test" ADDRESS2UPRN_QUEUE_URL = data.terraform_remote_state.address2uprn.outputs.address2uprn_queue_url S3_BUCKET_NAME = data.terraform_remote_state.shared.outputs.retrofit_sap_data_bucket_name }, diff --git a/domain/addresses/user_address.py b/domain/addresses/user_address.py index e48dfdec..120a3659 100644 --- a/domain/addresses/user_address.py +++ b/domain/addresses/user_address.py @@ -8,12 +8,17 @@ caller can construct an instance with an un-normalised postcode. from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Optional from domain.postcodes.sanitise import sanitise_postcode +def _empty_source_row() -> dict[str, str]: + """Typed default factory for :attr:`UserAddress.source_row`.""" + return {} + + @dataclass(frozen=True) class UserAddress: """A user-supplied address paired with its canonical postcode. @@ -25,11 +30,20 @@ class UserAddress: :meth:`__post_init__`. internal_reference: Optional customer-side identifier preserved for traceability through the matching pipeline. + source_row: The complete original CSV row this address was parsed + from, column name -> cell value. The splitter is a pass-through + router: it groups rows by postcode but must not drop the other + columns the downstream address2uprn stage relies on, so the raw + row travels alongside the parsed fields. Excluded from equality + and hashing -- identity stays defined by the parsed fields above. """ user_address: str postcode: str internal_reference: Optional[str] = None + source_row: dict[str, str] = field( + default_factory=_empty_source_row, compare=False + ) def __post_init__(self) -> None: # Frozen dataclass: bypass the descriptor with object.__setattr__. diff --git a/infrastructure/csv_s3_client.py b/infrastructure/csv_s3_client.py index 5163705b..0a576b81 100644 --- a/infrastructure/csv_s3_client.py +++ b/infrastructure/csv_s3_client.py @@ -2,7 +2,7 @@ import csv from io import StringIO from infrastructure.s3_client import S3Client -from utils.s3 import parse_s3_uri +from infrastructure.s3_uri import parse_s3_uri class CsvS3Client(S3Client): diff --git a/infrastructure/s3_uri.py b/infrastructure/s3_uri.py new file mode 100644 index 00000000..bf97100e --- /dev/null +++ b/infrastructure/s3_uri.py @@ -0,0 +1,43 @@ +"""Parse S3 URIs into ``(bucket, key)`` pairs. + +A pure-stdlib helper for the infrastructure layer. It deliberately pulls in +neither pandas, boto3, nor the legacy ``utils`` package, so slim Lambda images +that only need URI parsing do not drag the wider data stack along. + +Two input shapes are supported: + +* canonical S3 URIs --- ``s3://bucket/key`` +* AWS S3 console URLs --- ``https://.../s3/object/bucket?prefix=key`` +""" + +from urllib.parse import unquote + + +def parse_s3_uri(s3_uri: str) -> tuple[str, str]: + """Return the ``(bucket, key)`` pair addressed by ``s3_uri``. + + Raises: + ValueError: if ``s3_uri`` is neither a well-formed ``s3://`` URI nor + an AWS console URL carrying a ``prefix`` query parameter. + """ + if s3_uri.startswith("s3://"): + parts = s3_uri[len("s3://") :].split("/", 1) + if len(parts) < 2 or not parts[0] or not parts[1]: + raise ValueError("S3 URI must include both a bucket and a key") + return parts[0], parts[1] + + if "?" not in s3_uri: + raise ValueError(f"Not an s3:// URI and has no query string: {s3_uri!r}") + base, query = s3_uri.split("?", 1) + + if "/s3/object/" not in base: + raise ValueError(f"Console URL has no '/s3/object/' segment: {s3_uri!r}") + bucket = base.split("/s3/object/", 1)[1] + + params: dict[str, str] = {} + for item in query.split("&"): + if "=" in item: + name, value = item.split("=", 1) + params[name] = value + key = unquote(params.get("prefix", "")) + return bucket, key diff --git a/repositories/user_address/user_address_csv_s3_repository.py b/repositories/user_address/user_address_csv_s3_repository.py index be2baa13..7cd10bac 100644 --- a/repositories/user_address/user_address_csv_s3_repository.py +++ b/repositories/user_address/user_address_csv_s3_repository.py @@ -1,12 +1,16 @@ """CSV-on-S3 adapter for :class:`UserAddressRepository`. -Reads canonical upload CSVs (``Address 1``, ``Address 2``, ``Address 3``, -``Postcode``, ``Internal Reference``) and writes the splitter's compact -3-column form (``user_address``, ``postcode``, ``internal_reference``). +Reads upload CSVs that carry a ``postcode`` column (plus optional +``Address 1``/``Address 2``/``Address 3`` and ``Internal Reference``), and +writes batch CSVs that pass *every* original column through unchanged with +one column appended -- ``postcode_clean`` (uppercase, whitespace-stripped) -- +which the downstream address2uprn stage groups on. -The frontend pre-applies the user's column mapping at upload time, so this -adapter does NOT consult any ``BulkAddressUpload.column_mapping``: it always -expects the canonical column names listed above. +The splitter is a pass-through router: it must not reshape or drop columns, +because address2uprn has not been migrated and still consumes the legacy +splitter's full-row output. The frontend pre-applies the user's column +mapping at upload time, so this adapter does NOT consult any +``BulkAddressUpload.column_mapping``. """ from __future__ import annotations @@ -20,8 +24,9 @@ from infrastructure.csv_s3_client import CsvS3Client from repositories.user_address.user_address_repository import UserAddressRepository _ADDRESS_COLUMNS: tuple[str, str, str] = ("Address 1", "Address 2", "Address 3") -_POSTCODE_COLUMN: str = "Postcode" +_POSTCODE_COLUMN: str = "postcode" _INTERNAL_REFERENCE_COLUMN: str = "Internal Reference" +_POSTCODE_CLEAN_COLUMN: str = "postcode_clean" class UserAddressCsvS3Repository(UserAddressRepository): @@ -37,15 +42,27 @@ class UserAddressCsvS3Repository(UserAddressRepository): self._bucket = bucket def load_batch(self, s3_uri: str) -> list[UserAddress]: - """Load canonical upload CSV rows into :class:`UserAddress` objects. + """Load upload CSV rows into :class:`UserAddress` objects. - Concatenates ``Address 1``/``Address 2``/``Address 3`` with ``", "``, - skipping missing or empty parts, into ``user_address``. Falls back to - just ``Address 1`` when 2 and 3 are absent. Passes ``Internal Reference`` - through to :attr:`UserAddress.internal_reference` (``None`` when the - column is missing or empty). + Each row's complete column set is preserved on + :attr:`UserAddress.source_row` so :meth:`save_batch` can pass it + through untouched. The parsed convenience fields are also populated: + ``Address 1``/``Address 2``/``Address 3`` are concatenated with + ``", "`` (skipping missing/empty parts) into ``user_address``, and + ``Internal Reference`` is threaded to + :attr:`UserAddress.internal_reference` (``None`` when missing/empty). + + Raises: + ValueError: if the CSV has rows but no ``postcode`` column -- + without it the splitter cannot group, and silently emitting + empty postcodes would corrupt every downstream batch. """ rows = self._csv_client.read_rows(s3_uri) + if rows and _POSTCODE_COLUMN not in rows[0]: + raise ValueError( + f"Input CSV {s3_uri} has no {_POSTCODE_COLUMN!r} column; " + f"columns present: {sorted(rows[0])}" + ) addresses: list[UserAddress] = [] for row in rows: parts = [ @@ -62,22 +79,24 @@ class UserAddressCsvS3Repository(UserAddressRepository): user_address=user_address, postcode=postcode, internal_reference=internal_reference, + source_row=row, ) ) return addresses def save_batch(self, addresses: list[UserAddress], path_prefix: str) -> str: - """Write a 3-column CSV under a unique key beneath ``path_prefix``. + """Write a pass-through batch CSV under a unique key. + + Each output row is the address's original ``source_row`` with a + ``postcode_clean`` column appended (the canonical postcode the + downstream address2uprn stage groups on). No original column is + dropped or reshaped. The key is ``{path_prefix}/{ISO-8601 datetime}_{8-char uuid}.csv``. Returns the full ``s3://bucket/key`` URI. """ rows: list[dict[str, str]] = [ - { - "user_address": addr.user_address, - "postcode": addr.postcode, - "internal_reference": addr.internal_reference or "", - } + {**addr.source_row, _POSTCODE_CLEAN_COLUMN: addr.postcode} for addr in addresses ] filename = ( diff --git a/tests/domain/addresses/test_user_address.py b/tests/domain/addresses/test_user_address.py index e722077d..4d8322da 100644 --- a/tests/domain/addresses/test_user_address.py +++ b/tests/domain/addresses/test_user_address.py @@ -43,3 +43,29 @@ def test_user_address_equality_uses_sanitised_postcode() -> None: a = UserAddress(user_address="1 The Street", postcode="sw1a 1aa") b = UserAddress(user_address="1 The Street", postcode="SW1A1AA") assert a == b + + +def test_user_address_source_row_defaults_to_empty_dict() -> None: + addr = UserAddress(user_address="1 The Street", postcode="SW1A1AA") + assert addr.source_row == {} + + +def test_user_address_carries_source_row() -> None: + row = {"Address 1": "1 The Street", "postcode": "SW1A 1AA", "SAP Score": "72"} + addr = UserAddress( + user_address="1 The Street", postcode="SW1A 1AA", source_row=row + ) + assert addr.source_row == row + + +def test_user_address_equality_ignores_source_row() -> None: + # source_row is excluded from equality (and hashing): identity stays + # defined by the parsed fields, so two addresses parsed from rows with + # different incidental columns still compare equal. + a = UserAddress( + user_address="1 The Street", postcode="SW1A1AA", source_row={"x": "1"} + ) + b = UserAddress( + user_address="1 The Street", postcode="SW1A1AA", source_row={"y": "2"} + ) + assert a == b diff --git a/tests/infrastructure/test_s3_uri.py b/tests/infrastructure/test_s3_uri.py new file mode 100644 index 00000000..896c5959 --- /dev/null +++ b/tests/infrastructure/test_s3_uri.py @@ -0,0 +1,32 @@ +import pytest + +from infrastructure.s3_uri import parse_s3_uri + + +def test_parses_simple_s3_uri() -> None: + assert parse_s3_uri("s3://my-bucket/file.csv") == ("my-bucket", "file.csv") + + +def test_parses_s3_uri_with_nested_key() -> None: + bucket, key = parse_s3_uri("s3://my-bucket/nested/path/to/file.csv") + assert (bucket, key) == ("my-bucket", "nested/path/to/file.csv") + + +def test_rejects_s3_uri_without_key() -> None: + with pytest.raises(ValueError, match="bucket and a key"): + parse_s3_uri("s3://my-bucket") + + +def test_rejects_s3_uri_with_empty_key() -> None: + with pytest.raises(ValueError, match="bucket and a key"): + parse_s3_uri("s3://my-bucket/") + + +def test_parses_console_url_prefix() -> None: + url = "https://eu-west-2.console.aws.amazon.com/s3/object/my-bucket?prefix=nested%2Ffile.csv" + assert parse_s3_uri(url) == ("my-bucket", "nested/file.csv") + + +def test_rejects_unparseable_string() -> None: + with pytest.raises(ValueError): + parse_s3_uri("not-a-uri-at-all") diff --git a/tests/orchestration/test_postcode_splitter_orchestrator.py b/tests/orchestration/test_postcode_splitter_orchestrator.py index 57bd2133..79c60974 100644 --- a/tests/orchestration/test_postcode_splitter_orchestrator.py +++ b/tests/orchestration/test_postcode_splitter_orchestrator.py @@ -132,7 +132,7 @@ def _upload_fixture_csv(csv_client: CsvS3Client) -> str: "Address 1": f"{i} High St", "Address 2": "", "Address 3": "", - "Postcode": "AA1 1AA", + "postcode": "AA1 1AA", "Internal Reference": f"AA-{i}", } for i in range(1, 3) @@ -142,7 +142,7 @@ def _upload_fixture_csv(csv_client: CsvS3Client) -> str: "Address 1": f"{i} Long Road", "Address 2": "", "Address 3": "", - "Postcode": "BB2 2BB", + "postcode": "BB2 2BB", "Internal Reference": f"BB-{i}", } for i in range(1, 5) @@ -152,7 +152,7 @@ def _upload_fixture_csv(csv_client: CsvS3Client) -> str: "Address 1": "1 Final Way", "Address 2": "", "Address 3": "", - "Postcode": "CC3 3CC", + "postcode": "CC3 3CC", "Internal Reference": "CC-1", } ) @@ -281,15 +281,15 @@ def test_split_and_dispatch_returns_child_ids_in_dispatch_order( input_s3_uri=input_uri, ) - # Re-load each child's saved batch and inspect the postcode column to - # confirm the dispatch order matches the postcode-batching algorithm: + # Re-load each child's saved batch and inspect the postcode_clean column + # to confirm the dispatch order matches the postcode-batching algorithm: # AA-batch first, BB oversize batch second, CC final-flush third. postcodes_per_batch: list[set[str]] = [] for cid in child_ids: child = harness.subtasks.get(cid) assert child.inputs is not None rows = harness.csv_client.read_rows(child.inputs["s3_uri"]) - postcodes_per_batch.append({row["postcode"] for row in rows}) + postcodes_per_batch.append({row["postcode_clean"] for row in rows}) assert postcodes_per_batch == [ {"AA11AA"}, diff --git a/tests/repositories/user_address/test_user_address_csv_s3_repository.py b/tests/repositories/user_address/test_user_address_csv_s3_repository.py index ca9e8a57..48733b55 100644 --- a/tests/repositories/user_address/test_user_address_csv_s3_repository.py +++ b/tests/repositories/user_address/test_user_address_csv_s3_repository.py @@ -3,6 +3,7 @@ from collections.abc import Iterator import pytest from moto import mock_aws +from domain.addresses.user_address import UserAddress from infrastructure.csv_s3_client import CsvS3Client from repositories.user_address.user_address_csv_s3_repository import ( UserAddressCsvS3Repository, @@ -27,7 +28,7 @@ def _upload_csv( return repo._csv_client.save_rows(rows, key) # pyright: ignore[reportPrivateUsage] -def test_load_batch_concatenates_three_address_lines( +def test_load_batch_parses_address_postcode_and_reference( repo: UserAddressCsvS3Repository, ) -> None: rows = [ @@ -35,7 +36,7 @@ def test_load_batch_concatenates_three_address_lines( "Address 1": "1 High Street", "Address 2": "Flat 2", "Address 3": "Townville", - "Postcode": "sw1a 1aa", + "postcode": "sw1a 1aa", "Internal Reference": "REF-001", } ] @@ -58,7 +59,7 @@ def test_load_batch_uses_only_address_1_when_others_missing( "Address 1": "10 Cardiff Road", "Address 2": "", "Address 3": "", - "Postcode": "CF10 1AA", + "postcode": "CF10 1AA", "Internal Reference": "REF-002", } ] @@ -80,7 +81,7 @@ def test_load_batch_handles_missing_internal_reference( "Address 1": "5 Park Lane", "Address 2": "", "Address 3": "", - "Postcode": "M1 1AA", + "postcode": "M1 1AA", "Internal Reference": "", } ] @@ -94,16 +95,67 @@ def test_load_batch_handles_missing_internal_reference( assert addresses[0].internal_reference is None +def test_load_batch_captures_full_source_row( + repo: UserAddressCsvS3Repository, +) -> None: + # A raw EPC-export-shaped row: the splitter must preserve every column, + # not just the ones it parses into UserAddress fields. + row = { + "Asset Reference": "511", + "Address 1": "9 Abingdon Road Padiham Lancashire BB12 7BX", + "postcode": "BB12 7BX", + "Property Type": "House: End Terrace", + "SAP Score": "69", + } + uri = _upload_csv(repo, [row], "uploads/epc.csv") + + addresses = repo.load_batch(uri) + + assert addresses[0].source_row == row + + +def test_load_batch_raises_when_postcode_column_absent( + repo: UserAddressCsvS3Repository, +) -> None: + rows = [{"Address 1": "1 High Street", "Property Type": "Flat"}] + uri = _upload_csv(repo, rows, "uploads/no-postcode.csv") + + with pytest.raises(ValueError, match="no 'postcode' column"): + repo.load_batch(uri) + + +def test_save_batch_passes_through_all_columns_and_appends_postcode_clean( + repo: UserAddressCsvS3Repository, +) -> None: + row = { + "Asset Reference": "511", + "Address 1": "9 Abingdon Road Padiham Lancashire BB12 7BX", + "postcode": " BB12 7BX", + "Property Type": "House: End Terrace", + } + uri = _upload_csv(repo, [row], "uploads/epc.csv") + addresses = repo.load_batch(uri) + + saved_uri = repo.save_batch(addresses, "tasks/passthrough") + saved_rows = repo._csv_client.read_rows(saved_uri) # pyright: ignore[reportPrivateUsage] + + assert len(saved_rows) == 1 + saved = saved_rows[0] + # Every original column survives, byte-for-byte. + for column, value in row.items(): + assert saved[column] == value + # Plus the one appended column the downstream address2uprn stage groups on. + assert saved["postcode_clean"] == "BB127BX" + + def test_save_batch_returns_uri_under_path_prefix( repo: UserAddressCsvS3Repository, ) -> None: - from domain.addresses.user_address import UserAddress - addresses = [ UserAddress( - user_address="1 High Street, Flat 2, Townville", + user_address="1 High Street", postcode="SW1A 1AA", - internal_reference="REF-001", + source_row={"Address 1": "1 High Street", "postcode": "SW1A 1AA"}, ), ] @@ -113,59 +165,42 @@ def test_save_batch_returns_uri_under_path_prefix( assert uri.endswith(".csv") -def test_save_then_reload_round_trip_preserves_values( +def test_save_then_reload_round_trip_preserves_columns( repo: UserAddressCsvS3Repository, ) -> None: - from domain.addresses.user_address import UserAddress - - # save_batch writes the splitter's compact schema - # (user_address/postcode/internal_reference); load_batch reads the - # canonical upload schema. To round-trip through the repo we re-upload - # the saved CSV under the upload schema's column names. - original = [ - UserAddress( - user_address="1 High Street", - postcode="SW1A 1AA", - internal_reference="REF-001", - ), - UserAddress( - user_address="2 Low Street", - postcode="XY9 8ZW", - internal_reference=None, - ), - ] - - saved_uri = repo.save_batch(original, "tasks/round-trip") - - # Re-shape the saved CSV into the canonical upload schema for reload. - saved_rows = repo._csv_client.read_rows(saved_uri) # pyright: ignore[reportPrivateUsage] - upload_rows: list[dict[str, str]] = [ + rows = [ { - "Address 1": row["user_address"], - "Address 2": "", - "Address 3": "", - "Postcode": row["postcode"], - "Internal Reference": row["internal_reference"], - } - for row in saved_rows + "Address 1": "1 High Street", + "postcode": "SW1A 1AA", + "Internal Reference": "REF-001", + }, + { + "Address 1": "2 Low Street", + "postcode": "XY9 8ZW", + "Internal Reference": "", + }, ] - upload_uri = _upload_csv(repo, upload_rows, "uploads/round-trip.csv") + uri = _upload_csv(repo, rows, "uploads/round-trip.csv") + addresses = repo.load_batch(uri) - reloaded = repo.load_batch(upload_uri) + saved_uri = repo.save_batch(addresses, "tasks/round-trip") + saved_rows = repo._csv_client.read_rows(saved_uri) # pyright: ignore[reportPrivateUsage] - assert reloaded == original + # Original columns come back verbatim; postcode_clean is the only addition. + assert [ + {k: v for k, v in r.items() if k != "postcode_clean"} for r in saved_rows + ] == rows + assert [r["postcode_clean"] for r in saved_rows] == ["SW1A1AA", "XY98ZW"] def test_save_batch_uses_unique_filename_per_call( repo: UserAddressCsvS3Repository, ) -> None: - from domain.addresses.user_address import UserAddress - addresses = [ UserAddress( user_address="1 High Street", postcode="SW1A 1AA", - internal_reference="REF-001", + source_row={"Address 1": "1 High Street", "postcode": "SW1A 1AA"}, ), ] diff --git a/tests/utilities/aws_lambda/test_subtask_handler.py b/tests/utilities/aws_lambda/test_subtask_handler.py index 426b250f..771a49f8 100644 --- a/tests/utilities/aws_lambda/test_subtask_handler.py +++ b/tests/utilities/aws_lambda/test_subtask_handler.py @@ -6,6 +6,7 @@ to the wrapped function — so the handler can compose its own use-case orchestrator that shares the session. """ +import logging from collections.abc import Generator, Iterator from contextlib import contextmanager from dataclasses import dataclass @@ -13,6 +14,8 @@ from typing import Any from uuid import UUID import pytest + +_LOGGER_NAME = "utilities.aws_lambda.subtask_handler" from sqlmodel import Session, SQLModel, create_engine from domain.tasks.subtasks import SubTaskStatus @@ -142,3 +145,111 @@ def test_subtask_handler_injected_orchestrator_can_create_child_subtask( persisted_child = harness.subtasks.get(child_ids[0]) assert persisted_child.task_id == task.id assert persisted_child.status is SubTaskStatus.WAITING + + +def test_subtask_handler_logs_subtask_lifecycle_on_success( + harness: Harness, caplog: pytest.LogCaptureFixture +) -> None: + """Start and completion are logged at INFO so a successful invocation + leaves a CloudWatch breadcrumb (not just the Lambda runtime lines).""" + task, subtask = harness.orchestrator.create_task_with_subtask( + task_source="manual:test" + ) + + @subtask_handler(orchestrator_cm=harness.factory) + def handler( + body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator + ) -> None: + return None + + with caplog.at_level(logging.INFO, logger=_LOGGER_NAME): + handler(_direct_event(task.id, subtask.id), context=None) + + assert f"Running subtask {subtask.id}" in caplog.text + assert f"Subtask {subtask.id} completed" in caplog.text + + +def test_subtask_handler_logs_exception_on_failure( + harness: Harness, caplog: pytest.LogCaptureFixture +) -> None: + """A failing subtask is logged at ERROR with the traceback attached, + before the exception propagates for the Lambda runtime to surface.""" + task, subtask = harness.orchestrator.create_task_with_subtask( + task_source="manual:test" + ) + + @subtask_handler(orchestrator_cm=harness.factory) + def handler( + body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator + ) -> None: + raise RuntimeError("boom") + + with caplog.at_level(logging.INFO, logger=_LOGGER_NAME): + with pytest.raises(RuntimeError, match="boom"): + handler(_direct_event(task.id, subtask.id), context=None) + + failures = [r for r in caplog.records if r.levelno == logging.ERROR] + assert any( + f"Subtask {subtask.id} failed" in r.getMessage() for r in failures + ) + assert any(r.exc_info is not None for r in failures) + + +def test_subtask_handler_records_cloudwatch_url_on_subtask( + harness: Harness, monkeypatch: pytest.MonkeyPatch +) -> None: + """With the AWS Lambda runtime's log env vars present, a CloudWatch deep + link is built and persisted on the SubTask.""" + monkeypatch.setenv("AWS_REGION", "eu-west-2") + monkeypatch.setenv( + "AWS_LAMBDA_LOG_GROUP_NAME", "/aws/lambda/postcode-splitter" + ) + monkeypatch.setenv( + "AWS_LAMBDA_LOG_STREAM_NAME", "2026/05/20/[$LATEST]abc123" + ) + task, subtask = harness.orchestrator.create_task_with_subtask( + task_source="manual:test" + ) + + @subtask_handler(orchestrator_cm=harness.factory) + def handler( + body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator + ) -> None: + return None + + handler(_direct_event(task.id, subtask.id), context=None) + + saved_url = harness.subtasks.get(subtask.id).cloud_logs_url + assert saved_url is not None + assert saved_url.startswith( + "https://eu-west-2.console.aws.amazon.com/cloudwatch/home" + ) + # Log group / stream are console-encoded ("/" -> "$252F"). + assert "$252Faws$252Flambda$252Fpostcode-splitter" in saved_url + assert "$255B$2524LATEST$255D" in saved_url + + +def test_subtask_handler_leaves_cloudwatch_url_unset_outside_lambda( + harness: Harness, monkeypatch: pytest.MonkeyPatch +) -> None: + """Outside a real Lambda (e.g. the local RIE) the runtime log env vars + are absent, so cloud_logs_url is left unset rather than storing junk.""" + for var in ( + "AWS_REGION", + "AWS_LAMBDA_LOG_GROUP_NAME", + "AWS_LAMBDA_LOG_STREAM_NAME", + ): + monkeypatch.delenv(var, raising=False) + task, subtask = harness.orchestrator.create_task_with_subtask( + task_source="manual:test" + ) + + @subtask_handler(orchestrator_cm=harness.factory) + def handler( + body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator + ) -> None: + return None + + handler(_direct_event(task.id, subtask.id), context=None) + + assert harness.subtasks.get(subtask.id).cloud_logs_url is None diff --git a/utilities/aws_lambda/subtask_handler.py b/utilities/aws_lambda/subtask_handler.py index 5ad5f6e1..40f116ad 100644 --- a/utilities/aws_lambda/subtask_handler.py +++ b/utilities/aws_lambda/subtask_handler.py @@ -1,18 +1,32 @@ """@subtask_handler decorator for Lambdas that operate on existing SubTasks. Translates an AWS Lambda invocation (SQS-shaped or direct) into -TaskOrchestrator.run_subtask(...) calls. +TaskOrchestrator.run_subtask(...) calls, emitting an INFO log line for each +subtask's start and completion and a logged exception on failure. Those lines +land in CloudWatch via the Lambda runtime's stdout/stderr capture. + +Each subtask also records ``cloud_logs_url`` -- a deep link to this +invocation's CloudWatch log stream -- so an operator can jump from a SubTask +row straight to its logs. It is built from the environment variables the AWS +Lambda runtime sets, so it is populated only on real Lambda invocations and +left unset under the local RIE (which does not export them). """ import json +import logging +import os from contextlib import AbstractContextManager from functools import wraps from typing import Any, Callable, Optional, cast +from urllib.parse import quote from utilities.aws_lambda.default_orchestrator import default_orchestrator from utilities.aws_lambda.subtask_trigger_body import SubtaskTriggerBody from orchestration.task_orchestrator import TaskOrchestrator +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + OrchestratorCM = Callable[[], AbstractContextManager[TaskOrchestrator]] @@ -33,14 +47,26 @@ def subtask_handler( def decorator(func: Callable[..., Any]) -> Callable[..., Any]: @wraps(func) def wrapper(event: dict[str, Any], context: Any) -> None: + cloud_logs_url = _cloudwatch_url() with factory() as orchestrator: for record in _records(event): body = _parse_body(record) trigger = SubtaskTriggerBody.model_validate(body) - orchestrator.run_subtask( - trigger.sub_task_id, - work=lambda body=body, o=orchestrator: func(body, context, o), - ) + logger.info("Running subtask %s", trigger.sub_task_id) + try: + orchestrator.run_subtask( + trigger.sub_task_id, + work=lambda body=body, o=orchestrator: func( + body, context, o + ), + cloud_logs_url=cloud_logs_url, + ) + except Exception: + logger.exception( + "Subtask %s failed", trigger.sub_task_id + ) + raise + logger.info("Subtask %s completed", trigger.sub_task_id) return wrapper @@ -65,3 +91,34 @@ def _records(event: dict[str, Any]) -> list[dict[str, Any]]: if isinstance(raw_records, list): return [r for r in cast(list[Any], raw_records) if isinstance(r, dict)] return [event] + + +def _console_encode(value: str) -> str: + """Encode a value for a CloudWatch console deep link. + + The console expects URL-encoding with the percent signs themselves + re-encoded as ``$25`` -- e.g. ``/`` becomes ``%2F`` becomes ``$252F``. + """ + return quote(value, safe="").replace("%", "$25") + + +def _cloudwatch_url() -> Optional[str]: + """Build a CloudWatch console URL for this invocation's log stream. + + Sourced entirely from the environment variables the AWS Lambda runtime + sets -- ``AWS_REGION``, ``AWS_LAMBDA_LOG_GROUP_NAME`` and + ``AWS_LAMBDA_LOG_STREAM_NAME``. Returns None when any is absent, which is + the case outside a real Lambda (the local RIE does not export them) -- so + ``SubTask.cloud_logs_url`` is left unset rather than storing a link that + points nowhere. + """ + region = os.environ.get("AWS_REGION") + log_group = os.environ.get("AWS_LAMBDA_LOG_GROUP_NAME") + log_stream = os.environ.get("AWS_LAMBDA_LOG_STREAM_NAME") + if not (region and log_group and log_stream): + return None + return ( + f"https://{region}.console.aws.amazon.com/cloudwatch/home" + f"?region={region}#logsV2:log-groups/log-group/" + f"{_console_encode(log_group)}/log-events/{_console_encode(log_stream)}" + )