mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
postcode splliter working e2e
This commit is contained in:
parent
0a04448217
commit
914a8ed51e
18 changed files with 523 additions and 93 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -121,6 +121,7 @@ celerybeat.pid
|
|||
|
||||
# Environments
|
||||
.env
|
||||
.env.local
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
|
|
|
|||
|
|
@ -1,5 +1,18 @@
|
|||
FROM public.ecr.aws/lambda/python:3.11
|
||||
|
||||
# Postgres host/port/database are baked into the image at build time from
|
||||
# the deploy workflow's --build-arg values (GitHub Actions DEV_DB_* secrets),
|
||||
# mirroring backend/postcode_splitter/handler/Dockerfile. They map onto the
|
||||
# POSTGRES_* names PostgresConfig.from_env reads. Username/password are NOT
|
||||
# baked in -- Terraform injects those as Lambda env vars from Secrets Manager.
|
||||
ARG DEV_DB_HOST
|
||||
ARG DEV_DB_PORT
|
||||
ARG DEV_DB_NAME
|
||||
|
||||
ENV POSTGRES_HOST=${DEV_DB_HOST}
|
||||
ENV POSTGRES_PORT=${DEV_DB_PORT}
|
||||
ENV POSTGRES_DATABASE=${DEV_DB_NAME}
|
||||
|
||||
WORKDIR /var/task
|
||||
|
||||
COPY applications/postcode_splitter/requirements.txt .
|
||||
|
|
|
|||
|
|
@ -0,0 +1,34 @@
|
|||
# Local-test environment for the postcode_splitter Lambda.
|
||||
#
|
||||
# cp .env.local.example .env.local then fill in the values below.
|
||||
#
|
||||
# .env.local is gitignored. The container hits REAL AWS and a REAL Postgres,
|
||||
# so every value here points at infrastructure that actually exists.
|
||||
#
|
||||
# NOTE: the new DDD code uses different env var names than the repo root
|
||||
# .env. The mapping (root .env name -> var here) is given per section.
|
||||
# Keep comments on their own lines — docker-compose's env_file parser folds a
|
||||
# trailing "# ..." into the value.
|
||||
|
||||
# --- Postgres (orchestration/default_orchestrator -> PostgresConfig.from_env) ---
|
||||
# POSTGRES_HOST <- DB_HOST, PORT <- DB_PORT, USERNAME <- DB_USERNAME,
|
||||
# PASSWORD <- DB_PASSWORD, DATABASE <- DB_NAME.
|
||||
POSTGRES_HOST=
|
||||
POSTGRES_PORT=5432
|
||||
POSTGRES_USERNAME=
|
||||
POSTGRES_PASSWORD=
|
||||
POSTGRES_DATABASE=
|
||||
# POSTGRES_DRIVER=psycopg2 (optional; defaults to psycopg2)
|
||||
|
||||
# --- Handler config (applications/postcode_splitter/handler.py) ---
|
||||
# S3_BUCKET_NAME: bucket holding the input address CSV (root .env: DATA_BUCKET).
|
||||
# ADDRESS2UPRN_QUEUE_URL: SQS queue the splitter fans batches out to; not in
|
||||
# the root .env (Terraform sets it in prod).
|
||||
S3_BUCKET_NAME=
|
||||
ADDRESS2UPRN_QUEUE_URL=
|
||||
|
||||
# --- AWS credentials for boto3 (S3 + SQS clients) ---
|
||||
AWS_ACCESS_KEY_ID=
|
||||
AWS_SECRET_ACCESS_KEY=
|
||||
AWS_DEFAULT_REGION=eu-west-2
|
||||
# AWS_SESSION_TOKEN= (only if using temporary/SSO credentials)
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
services:
|
||||
postcode-splitter:
|
||||
build:
|
||||
context: ../../../
|
||||
dockerfile: applications/postcode_splitter/Dockerfile
|
||||
ports:
|
||||
- "9001:8080"
|
||||
env_file:
|
||||
- .env.local
|
||||
37
applications/postcode_splitter/local_handler/invoke_local_lambda.py
Executable file
37
applications/postcode_splitter/local_handler/invoke_local_lambda.py
Executable file
|
|
@ -0,0 +1,37 @@
|
|||
#!/usr/bin/env python3
|
||||
"""POST a single SQS-shaped event at the locally-running splitter Lambda.
|
||||
|
||||
The container built by docker-compose runs the AWS Lambda Runtime Interface
|
||||
Emulator, which accepts invocations on the URL below. Replace the three
|
||||
placeholder values with a real parent Task id, the splitter's own SubTask id
|
||||
(both must already exist in the Postgres pointed at by .env.local), and the
|
||||
s3://... URI of an uploaded address CSV.
|
||||
"""
|
||||
|
||||
import json
|
||||
import requests
|
||||
|
||||
HOST = "localhost"
|
||||
PORT = "9001"
|
||||
|
||||
LAMBDA_URL = f"http://{HOST}:{PORT}/2015-03-31/functions/function/invocations"
|
||||
|
||||
payload = {
|
||||
"Records": [
|
||||
{
|
||||
"body": json.dumps(
|
||||
{
|
||||
"task_id": "f4b3332f-c0cc-481f-96a5-d39860a647cf",
|
||||
"sub_task_id": "14c042de-40c4-473b-8cd8-72c983a94a8d",
|
||||
"s3_uri": "s3://retrofit-data-dev/ara_raw_inputs/calico/Calico Homes Full list EPC Properties(Sheet2) (1) (1).csv",
|
||||
}
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = requests.post(LAMBDA_URL, json=payload)
|
||||
|
||||
print("Status code:", response.status_code)
|
||||
print("Response:")
|
||||
print(response.text)
|
||||
12
applications/postcode_splitter/local_handler/run_local.sh
Executable file
12
applications/postcode_splitter/local_handler/run_local.sh
Executable file
|
|
@ -0,0 +1,12 @@
|
|||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
if [ ! -f .env.local ]; then
|
||||
cp .env.local.example .env.local
|
||||
echo "Created .env.local from the template — fill it in, then re-run." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
docker compose build --no-cache
|
||||
docker compose up --force-recreate
|
||||
|
|
@ -8,4 +8,5 @@ boto3==1.35.44
|
|||
sqlmodel
|
||||
sqlalchemy==2.0.36
|
||||
psycopg2-binary==2.9.10
|
||||
pydantic-settings==2.6.0
|
||||
pydantic-settings==2.6.0
|
||||
httpx
|
||||
|
|
@ -40,20 +40,6 @@ module "lambda" {
|
|||
LOG_LEVEL = "info"
|
||||
DB_USERNAME = local.db_credentials.db_assessment_model_username
|
||||
DB_PASSWORD = local.db_credentials.db_assessment_model_password
|
||||
GOOGLE_SOLAR_API_KEY = "test"
|
||||
SAP_PREDICTIONS_BUCKET = "test"
|
||||
CARBON_PREDICTIONS_BUCKET = "test"
|
||||
HEAT_PREDICTIONS_BUCKET = "test"
|
||||
HEATING_KWH_PREDICTIONS_BUCKET = "test"
|
||||
HOTWATER_KWH_PREDICTIONS_BUCKET = "test"
|
||||
API_KEY = "test"
|
||||
ENVIRONMENT = "test"
|
||||
SECRET_KEY = "test"
|
||||
PLAN_TRIGGER_BUCKET = "test"
|
||||
DATA_BUCKET = "test"
|
||||
EPC_AUTH_TOKEN = "test"
|
||||
ENGINE_SQS_URL = "test"
|
||||
ENERGY_ASSESSMENTS_BUCKET = "test"
|
||||
ADDRESS2UPRN_QUEUE_URL = data.terraform_remote_state.address2uprn.outputs.address2uprn_queue_url
|
||||
S3_BUCKET_NAME = data.terraform_remote_state.shared.outputs.retrofit_sap_data_bucket_name
|
||||
},
|
||||
|
|
|
|||
|
|
@ -8,12 +8,17 @@ caller can construct an instance with an un-normalised postcode.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from domain.postcodes.sanitise import sanitise_postcode
|
||||
|
||||
|
||||
def _empty_source_row() -> dict[str, str]:
|
||||
"""Typed default factory for :attr:`UserAddress.source_row`."""
|
||||
return {}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UserAddress:
|
||||
"""A user-supplied address paired with its canonical postcode.
|
||||
|
|
@ -25,11 +30,20 @@ class UserAddress:
|
|||
:meth:`__post_init__`.
|
||||
internal_reference: Optional customer-side identifier preserved for
|
||||
traceability through the matching pipeline.
|
||||
source_row: The complete original CSV row this address was parsed
|
||||
from, column name -> cell value. The splitter is a pass-through
|
||||
router: it groups rows by postcode but must not drop the other
|
||||
columns the downstream address2uprn stage relies on, so the raw
|
||||
row travels alongside the parsed fields. Excluded from equality
|
||||
and hashing -- identity stays defined by the parsed fields above.
|
||||
"""
|
||||
|
||||
user_address: str
|
||||
postcode: str
|
||||
internal_reference: Optional[str] = None
|
||||
source_row: dict[str, str] = field(
|
||||
default_factory=_empty_source_row, compare=False
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Frozen dataclass: bypass the descriptor with object.__setattr__.
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import csv
|
|||
from io import StringIO
|
||||
|
||||
from infrastructure.s3_client import S3Client
|
||||
from utils.s3 import parse_s3_uri
|
||||
from infrastructure.s3_uri import parse_s3_uri
|
||||
|
||||
|
||||
class CsvS3Client(S3Client):
|
||||
|
|
|
|||
43
infrastructure/s3_uri.py
Normal file
43
infrastructure/s3_uri.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
"""Parse S3 URIs into ``(bucket, key)`` pairs.
|
||||
|
||||
A pure-stdlib helper for the infrastructure layer. It deliberately pulls in
|
||||
neither pandas, boto3, nor the legacy ``utils`` package, so slim Lambda images
|
||||
that only need URI parsing do not drag the wider data stack along.
|
||||
|
||||
Two input shapes are supported:
|
||||
|
||||
* canonical S3 URIs --- ``s3://bucket/key``
|
||||
* AWS S3 console URLs --- ``https://.../s3/object/bucket?prefix=key``
|
||||
"""
|
||||
|
||||
from urllib.parse import unquote
|
||||
|
||||
|
||||
def parse_s3_uri(s3_uri: str) -> tuple[str, str]:
|
||||
"""Return the ``(bucket, key)`` pair addressed by ``s3_uri``.
|
||||
|
||||
Raises:
|
||||
ValueError: if ``s3_uri`` is neither a well-formed ``s3://`` URI nor
|
||||
an AWS console URL carrying a ``prefix`` query parameter.
|
||||
"""
|
||||
if s3_uri.startswith("s3://"):
|
||||
parts = s3_uri[len("s3://") :].split("/", 1)
|
||||
if len(parts) < 2 or not parts[0] or not parts[1]:
|
||||
raise ValueError("S3 URI must include both a bucket and a key")
|
||||
return parts[0], parts[1]
|
||||
|
||||
if "?" not in s3_uri:
|
||||
raise ValueError(f"Not an s3:// URI and has no query string: {s3_uri!r}")
|
||||
base, query = s3_uri.split("?", 1)
|
||||
|
||||
if "/s3/object/" not in base:
|
||||
raise ValueError(f"Console URL has no '/s3/object/' segment: {s3_uri!r}")
|
||||
bucket = base.split("/s3/object/", 1)[1]
|
||||
|
||||
params: dict[str, str] = {}
|
||||
for item in query.split("&"):
|
||||
if "=" in item:
|
||||
name, value = item.split("=", 1)
|
||||
params[name] = value
|
||||
key = unquote(params.get("prefix", ""))
|
||||
return bucket, key
|
||||
|
|
@ -1,12 +1,16 @@
|
|||
"""CSV-on-S3 adapter for :class:`UserAddressRepository`.
|
||||
|
||||
Reads canonical upload CSVs (``Address 1``, ``Address 2``, ``Address 3``,
|
||||
``Postcode``, ``Internal Reference``) and writes the splitter's compact
|
||||
3-column form (``user_address``, ``postcode``, ``internal_reference``).
|
||||
Reads upload CSVs that carry a ``postcode`` column (plus optional
|
||||
``Address 1``/``Address 2``/``Address 3`` and ``Internal Reference``), and
|
||||
writes batch CSVs that pass *every* original column through unchanged with
|
||||
one column appended -- ``postcode_clean`` (uppercase, whitespace-stripped) --
|
||||
which the downstream address2uprn stage groups on.
|
||||
|
||||
The frontend pre-applies the user's column mapping at upload time, so this
|
||||
adapter does NOT consult any ``BulkAddressUpload.column_mapping``: it always
|
||||
expects the canonical column names listed above.
|
||||
The splitter is a pass-through router: it must not reshape or drop columns,
|
||||
because address2uprn has not been migrated and still consumes the legacy
|
||||
splitter's full-row output. The frontend pre-applies the user's column
|
||||
mapping at upload time, so this adapter does NOT consult any
|
||||
``BulkAddressUpload.column_mapping``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -20,8 +24,9 @@ from infrastructure.csv_s3_client import CsvS3Client
|
|||
from repositories.user_address.user_address_repository import UserAddressRepository
|
||||
|
||||
_ADDRESS_COLUMNS: tuple[str, str, str] = ("Address 1", "Address 2", "Address 3")
|
||||
_POSTCODE_COLUMN: str = "Postcode"
|
||||
_POSTCODE_COLUMN: str = "postcode"
|
||||
_INTERNAL_REFERENCE_COLUMN: str = "Internal Reference"
|
||||
_POSTCODE_CLEAN_COLUMN: str = "postcode_clean"
|
||||
|
||||
|
||||
class UserAddressCsvS3Repository(UserAddressRepository):
|
||||
|
|
@ -37,15 +42,27 @@ class UserAddressCsvS3Repository(UserAddressRepository):
|
|||
self._bucket = bucket
|
||||
|
||||
def load_batch(self, s3_uri: str) -> list[UserAddress]:
|
||||
"""Load canonical upload CSV rows into :class:`UserAddress` objects.
|
||||
"""Load upload CSV rows into :class:`UserAddress` objects.
|
||||
|
||||
Concatenates ``Address 1``/``Address 2``/``Address 3`` with ``", "``,
|
||||
skipping missing or empty parts, into ``user_address``. Falls back to
|
||||
just ``Address 1`` when 2 and 3 are absent. Passes ``Internal Reference``
|
||||
through to :attr:`UserAddress.internal_reference` (``None`` when the
|
||||
column is missing or empty).
|
||||
Each row's complete column set is preserved on
|
||||
:attr:`UserAddress.source_row` so :meth:`save_batch` can pass it
|
||||
through untouched. The parsed convenience fields are also populated:
|
||||
``Address 1``/``Address 2``/``Address 3`` are concatenated with
|
||||
``", "`` (skipping missing/empty parts) into ``user_address``, and
|
||||
``Internal Reference`` is threaded to
|
||||
:attr:`UserAddress.internal_reference` (``None`` when missing/empty).
|
||||
|
||||
Raises:
|
||||
ValueError: if the CSV has rows but no ``postcode`` column --
|
||||
without it the splitter cannot group, and silently emitting
|
||||
empty postcodes would corrupt every downstream batch.
|
||||
"""
|
||||
rows = self._csv_client.read_rows(s3_uri)
|
||||
if rows and _POSTCODE_COLUMN not in rows[0]:
|
||||
raise ValueError(
|
||||
f"Input CSV {s3_uri} has no {_POSTCODE_COLUMN!r} column; "
|
||||
f"columns present: {sorted(rows[0])}"
|
||||
)
|
||||
addresses: list[UserAddress] = []
|
||||
for row in rows:
|
||||
parts = [
|
||||
|
|
@ -62,22 +79,24 @@ class UserAddressCsvS3Repository(UserAddressRepository):
|
|||
user_address=user_address,
|
||||
postcode=postcode,
|
||||
internal_reference=internal_reference,
|
||||
source_row=row,
|
||||
)
|
||||
)
|
||||
return addresses
|
||||
|
||||
def save_batch(self, addresses: list[UserAddress], path_prefix: str) -> str:
|
||||
"""Write a 3-column CSV under a unique key beneath ``path_prefix``.
|
||||
"""Write a pass-through batch CSV under a unique key.
|
||||
|
||||
Each output row is the address's original ``source_row`` with a
|
||||
``postcode_clean`` column appended (the canonical postcode the
|
||||
downstream address2uprn stage groups on). No original column is
|
||||
dropped or reshaped.
|
||||
|
||||
The key is ``{path_prefix}/{ISO-8601 datetime}_{8-char uuid}.csv``.
|
||||
Returns the full ``s3://bucket/key`` URI.
|
||||
"""
|
||||
rows: list[dict[str, str]] = [
|
||||
{
|
||||
"user_address": addr.user_address,
|
||||
"postcode": addr.postcode,
|
||||
"internal_reference": addr.internal_reference or "",
|
||||
}
|
||||
{**addr.source_row, _POSTCODE_CLEAN_COLUMN: addr.postcode}
|
||||
for addr in addresses
|
||||
]
|
||||
filename = (
|
||||
|
|
|
|||
|
|
@ -43,3 +43,29 @@ def test_user_address_equality_uses_sanitised_postcode() -> None:
|
|||
a = UserAddress(user_address="1 The Street", postcode="sw1a 1aa")
|
||||
b = UserAddress(user_address="1 The Street", postcode="SW1A1AA")
|
||||
assert a == b
|
||||
|
||||
|
||||
def test_user_address_source_row_defaults_to_empty_dict() -> None:
|
||||
addr = UserAddress(user_address="1 The Street", postcode="SW1A1AA")
|
||||
assert addr.source_row == {}
|
||||
|
||||
|
||||
def test_user_address_carries_source_row() -> None:
|
||||
row = {"Address 1": "1 The Street", "postcode": "SW1A 1AA", "SAP Score": "72"}
|
||||
addr = UserAddress(
|
||||
user_address="1 The Street", postcode="SW1A 1AA", source_row=row
|
||||
)
|
||||
assert addr.source_row == row
|
||||
|
||||
|
||||
def test_user_address_equality_ignores_source_row() -> None:
|
||||
# source_row is excluded from equality (and hashing): identity stays
|
||||
# defined by the parsed fields, so two addresses parsed from rows with
|
||||
# different incidental columns still compare equal.
|
||||
a = UserAddress(
|
||||
user_address="1 The Street", postcode="SW1A1AA", source_row={"x": "1"}
|
||||
)
|
||||
b = UserAddress(
|
||||
user_address="1 The Street", postcode="SW1A1AA", source_row={"y": "2"}
|
||||
)
|
||||
assert a == b
|
||||
|
|
|
|||
32
tests/infrastructure/test_s3_uri.py
Normal file
32
tests/infrastructure/test_s3_uri.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
import pytest
|
||||
|
||||
from infrastructure.s3_uri import parse_s3_uri
|
||||
|
||||
|
||||
def test_parses_simple_s3_uri() -> None:
|
||||
assert parse_s3_uri("s3://my-bucket/file.csv") == ("my-bucket", "file.csv")
|
||||
|
||||
|
||||
def test_parses_s3_uri_with_nested_key() -> None:
|
||||
bucket, key = parse_s3_uri("s3://my-bucket/nested/path/to/file.csv")
|
||||
assert (bucket, key) == ("my-bucket", "nested/path/to/file.csv")
|
||||
|
||||
|
||||
def test_rejects_s3_uri_without_key() -> None:
|
||||
with pytest.raises(ValueError, match="bucket and a key"):
|
||||
parse_s3_uri("s3://my-bucket")
|
||||
|
||||
|
||||
def test_rejects_s3_uri_with_empty_key() -> None:
|
||||
with pytest.raises(ValueError, match="bucket and a key"):
|
||||
parse_s3_uri("s3://my-bucket/")
|
||||
|
||||
|
||||
def test_parses_console_url_prefix() -> None:
|
||||
url = "https://eu-west-2.console.aws.amazon.com/s3/object/my-bucket?prefix=nested%2Ffile.csv"
|
||||
assert parse_s3_uri(url) == ("my-bucket", "nested/file.csv")
|
||||
|
||||
|
||||
def test_rejects_unparseable_string() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
parse_s3_uri("not-a-uri-at-all")
|
||||
|
|
@ -132,7 +132,7 @@ def _upload_fixture_csv(csv_client: CsvS3Client) -> str:
|
|||
"Address 1": f"{i} High St",
|
||||
"Address 2": "",
|
||||
"Address 3": "",
|
||||
"Postcode": "AA1 1AA",
|
||||
"postcode": "AA1 1AA",
|
||||
"Internal Reference": f"AA-{i}",
|
||||
}
|
||||
for i in range(1, 3)
|
||||
|
|
@ -142,7 +142,7 @@ def _upload_fixture_csv(csv_client: CsvS3Client) -> str:
|
|||
"Address 1": f"{i} Long Road",
|
||||
"Address 2": "",
|
||||
"Address 3": "",
|
||||
"Postcode": "BB2 2BB",
|
||||
"postcode": "BB2 2BB",
|
||||
"Internal Reference": f"BB-{i}",
|
||||
}
|
||||
for i in range(1, 5)
|
||||
|
|
@ -152,7 +152,7 @@ def _upload_fixture_csv(csv_client: CsvS3Client) -> str:
|
|||
"Address 1": "1 Final Way",
|
||||
"Address 2": "",
|
||||
"Address 3": "",
|
||||
"Postcode": "CC3 3CC",
|
||||
"postcode": "CC3 3CC",
|
||||
"Internal Reference": "CC-1",
|
||||
}
|
||||
)
|
||||
|
|
@ -281,15 +281,15 @@ def test_split_and_dispatch_returns_child_ids_in_dispatch_order(
|
|||
input_s3_uri=input_uri,
|
||||
)
|
||||
|
||||
# Re-load each child's saved batch and inspect the postcode column to
|
||||
# confirm the dispatch order matches the postcode-batching algorithm:
|
||||
# Re-load each child's saved batch and inspect the postcode_clean column
|
||||
# to confirm the dispatch order matches the postcode-batching algorithm:
|
||||
# AA-batch first, BB oversize batch second, CC final-flush third.
|
||||
postcodes_per_batch: list[set[str]] = []
|
||||
for cid in child_ids:
|
||||
child = harness.subtasks.get(cid)
|
||||
assert child.inputs is not None
|
||||
rows = harness.csv_client.read_rows(child.inputs["s3_uri"])
|
||||
postcodes_per_batch.append({row["postcode"] for row in rows})
|
||||
postcodes_per_batch.append({row["postcode_clean"] for row in rows})
|
||||
|
||||
assert postcodes_per_batch == [
|
||||
{"AA11AA"},
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from collections.abc import Iterator
|
|||
import pytest
|
||||
from moto import mock_aws
|
||||
|
||||
from domain.addresses.user_address import UserAddress
|
||||
from infrastructure.csv_s3_client import CsvS3Client
|
||||
from repositories.user_address.user_address_csv_s3_repository import (
|
||||
UserAddressCsvS3Repository,
|
||||
|
|
@ -27,7 +28,7 @@ def _upload_csv(
|
|||
return repo._csv_client.save_rows(rows, key) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
def test_load_batch_concatenates_three_address_lines(
|
||||
def test_load_batch_parses_address_postcode_and_reference(
|
||||
repo: UserAddressCsvS3Repository,
|
||||
) -> None:
|
||||
rows = [
|
||||
|
|
@ -35,7 +36,7 @@ def test_load_batch_concatenates_three_address_lines(
|
|||
"Address 1": "1 High Street",
|
||||
"Address 2": "Flat 2",
|
||||
"Address 3": "Townville",
|
||||
"Postcode": "sw1a 1aa",
|
||||
"postcode": "sw1a 1aa",
|
||||
"Internal Reference": "REF-001",
|
||||
}
|
||||
]
|
||||
|
|
@ -58,7 +59,7 @@ def test_load_batch_uses_only_address_1_when_others_missing(
|
|||
"Address 1": "10 Cardiff Road",
|
||||
"Address 2": "",
|
||||
"Address 3": "",
|
||||
"Postcode": "CF10 1AA",
|
||||
"postcode": "CF10 1AA",
|
||||
"Internal Reference": "REF-002",
|
||||
}
|
||||
]
|
||||
|
|
@ -80,7 +81,7 @@ def test_load_batch_handles_missing_internal_reference(
|
|||
"Address 1": "5 Park Lane",
|
||||
"Address 2": "",
|
||||
"Address 3": "",
|
||||
"Postcode": "M1 1AA",
|
||||
"postcode": "M1 1AA",
|
||||
"Internal Reference": "",
|
||||
}
|
||||
]
|
||||
|
|
@ -94,16 +95,67 @@ def test_load_batch_handles_missing_internal_reference(
|
|||
assert addresses[0].internal_reference is None
|
||||
|
||||
|
||||
def test_load_batch_captures_full_source_row(
|
||||
repo: UserAddressCsvS3Repository,
|
||||
) -> None:
|
||||
# A raw EPC-export-shaped row: the splitter must preserve every column,
|
||||
# not just the ones it parses into UserAddress fields.
|
||||
row = {
|
||||
"Asset Reference": "511",
|
||||
"Address 1": "9 Abingdon Road Padiham Lancashire BB12 7BX",
|
||||
"postcode": "BB12 7BX",
|
||||
"Property Type": "House: End Terrace",
|
||||
"SAP Score": "69",
|
||||
}
|
||||
uri = _upload_csv(repo, [row], "uploads/epc.csv")
|
||||
|
||||
addresses = repo.load_batch(uri)
|
||||
|
||||
assert addresses[0].source_row == row
|
||||
|
||||
|
||||
def test_load_batch_raises_when_postcode_column_absent(
|
||||
repo: UserAddressCsvS3Repository,
|
||||
) -> None:
|
||||
rows = [{"Address 1": "1 High Street", "Property Type": "Flat"}]
|
||||
uri = _upload_csv(repo, rows, "uploads/no-postcode.csv")
|
||||
|
||||
with pytest.raises(ValueError, match="no 'postcode' column"):
|
||||
repo.load_batch(uri)
|
||||
|
||||
|
||||
def test_save_batch_passes_through_all_columns_and_appends_postcode_clean(
|
||||
repo: UserAddressCsvS3Repository,
|
||||
) -> None:
|
||||
row = {
|
||||
"Asset Reference": "511",
|
||||
"Address 1": "9 Abingdon Road Padiham Lancashire BB12 7BX",
|
||||
"postcode": " BB12 7BX",
|
||||
"Property Type": "House: End Terrace",
|
||||
}
|
||||
uri = _upload_csv(repo, [row], "uploads/epc.csv")
|
||||
addresses = repo.load_batch(uri)
|
||||
|
||||
saved_uri = repo.save_batch(addresses, "tasks/passthrough")
|
||||
saved_rows = repo._csv_client.read_rows(saved_uri) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
assert len(saved_rows) == 1
|
||||
saved = saved_rows[0]
|
||||
# Every original column survives, byte-for-byte.
|
||||
for column, value in row.items():
|
||||
assert saved[column] == value
|
||||
# Plus the one appended column the downstream address2uprn stage groups on.
|
||||
assert saved["postcode_clean"] == "BB127BX"
|
||||
|
||||
|
||||
def test_save_batch_returns_uri_under_path_prefix(
|
||||
repo: UserAddressCsvS3Repository,
|
||||
) -> None:
|
||||
from domain.addresses.user_address import UserAddress
|
||||
|
||||
addresses = [
|
||||
UserAddress(
|
||||
user_address="1 High Street, Flat 2, Townville",
|
||||
user_address="1 High Street",
|
||||
postcode="SW1A 1AA",
|
||||
internal_reference="REF-001",
|
||||
source_row={"Address 1": "1 High Street", "postcode": "SW1A 1AA"},
|
||||
),
|
||||
]
|
||||
|
||||
|
|
@ -113,59 +165,42 @@ def test_save_batch_returns_uri_under_path_prefix(
|
|||
assert uri.endswith(".csv")
|
||||
|
||||
|
||||
def test_save_then_reload_round_trip_preserves_values(
|
||||
def test_save_then_reload_round_trip_preserves_columns(
|
||||
repo: UserAddressCsvS3Repository,
|
||||
) -> None:
|
||||
from domain.addresses.user_address import UserAddress
|
||||
|
||||
# save_batch writes the splitter's compact schema
|
||||
# (user_address/postcode/internal_reference); load_batch reads the
|
||||
# canonical upload schema. To round-trip through the repo we re-upload
|
||||
# the saved CSV under the upload schema's column names.
|
||||
original = [
|
||||
UserAddress(
|
||||
user_address="1 High Street",
|
||||
postcode="SW1A 1AA",
|
||||
internal_reference="REF-001",
|
||||
),
|
||||
UserAddress(
|
||||
user_address="2 Low Street",
|
||||
postcode="XY9 8ZW",
|
||||
internal_reference=None,
|
||||
),
|
||||
]
|
||||
|
||||
saved_uri = repo.save_batch(original, "tasks/round-trip")
|
||||
|
||||
# Re-shape the saved CSV into the canonical upload schema for reload.
|
||||
saved_rows = repo._csv_client.read_rows(saved_uri) # pyright: ignore[reportPrivateUsage]
|
||||
upload_rows: list[dict[str, str]] = [
|
||||
rows = [
|
||||
{
|
||||
"Address 1": row["user_address"],
|
||||
"Address 2": "",
|
||||
"Address 3": "",
|
||||
"Postcode": row["postcode"],
|
||||
"Internal Reference": row["internal_reference"],
|
||||
}
|
||||
for row in saved_rows
|
||||
"Address 1": "1 High Street",
|
||||
"postcode": "SW1A 1AA",
|
||||
"Internal Reference": "REF-001",
|
||||
},
|
||||
{
|
||||
"Address 1": "2 Low Street",
|
||||
"postcode": "XY9 8ZW",
|
||||
"Internal Reference": "",
|
||||
},
|
||||
]
|
||||
upload_uri = _upload_csv(repo, upload_rows, "uploads/round-trip.csv")
|
||||
uri = _upload_csv(repo, rows, "uploads/round-trip.csv")
|
||||
addresses = repo.load_batch(uri)
|
||||
|
||||
reloaded = repo.load_batch(upload_uri)
|
||||
saved_uri = repo.save_batch(addresses, "tasks/round-trip")
|
||||
saved_rows = repo._csv_client.read_rows(saved_uri) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
assert reloaded == original
|
||||
# Original columns come back verbatim; postcode_clean is the only addition.
|
||||
assert [
|
||||
{k: v for k, v in r.items() if k != "postcode_clean"} for r in saved_rows
|
||||
] == rows
|
||||
assert [r["postcode_clean"] for r in saved_rows] == ["SW1A1AA", "XY98ZW"]
|
||||
|
||||
|
||||
def test_save_batch_uses_unique_filename_per_call(
|
||||
repo: UserAddressCsvS3Repository,
|
||||
) -> None:
|
||||
from domain.addresses.user_address import UserAddress
|
||||
|
||||
addresses = [
|
||||
UserAddress(
|
||||
user_address="1 High Street",
|
||||
postcode="SW1A 1AA",
|
||||
internal_reference="REF-001",
|
||||
source_row={"Address 1": "1 High Street", "postcode": "SW1A 1AA"},
|
||||
),
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ to the wrapped function — so the handler can compose its own use-case
|
|||
orchestrator that shares the session.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Generator, Iterator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
|
|
@ -13,6 +14,8 @@ from typing import Any
|
|||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
||||
_LOGGER_NAME = "utilities.aws_lambda.subtask_handler"
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
from domain.tasks.subtasks import SubTaskStatus
|
||||
|
|
@ -142,3 +145,111 @@ def test_subtask_handler_injected_orchestrator_can_create_child_subtask(
|
|||
persisted_child = harness.subtasks.get(child_ids[0])
|
||||
assert persisted_child.task_id == task.id
|
||||
assert persisted_child.status is SubTaskStatus.WAITING
|
||||
|
||||
|
||||
def test_subtask_handler_logs_subtask_lifecycle_on_success(
|
||||
harness: Harness, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Start and completion are logged at INFO so a successful invocation
|
||||
leaves a CloudWatch breadcrumb (not just the Lambda runtime lines)."""
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
||||
@subtask_handler(orchestrator_cm=harness.factory)
|
||||
def handler(
|
||||
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
with caplog.at_level(logging.INFO, logger=_LOGGER_NAME):
|
||||
handler(_direct_event(task.id, subtask.id), context=None)
|
||||
|
||||
assert f"Running subtask {subtask.id}" in caplog.text
|
||||
assert f"Subtask {subtask.id} completed" in caplog.text
|
||||
|
||||
|
||||
def test_subtask_handler_logs_exception_on_failure(
|
||||
harness: Harness, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""A failing subtask is logged at ERROR with the traceback attached,
|
||||
before the exception propagates for the Lambda runtime to surface."""
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
||||
@subtask_handler(orchestrator_cm=harness.factory)
|
||||
def handler(
|
||||
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
|
||||
) -> None:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
with caplog.at_level(logging.INFO, logger=_LOGGER_NAME):
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
handler(_direct_event(task.id, subtask.id), context=None)
|
||||
|
||||
failures = [r for r in caplog.records if r.levelno == logging.ERROR]
|
||||
assert any(
|
||||
f"Subtask {subtask.id} failed" in r.getMessage() for r in failures
|
||||
)
|
||||
assert any(r.exc_info is not None for r in failures)
|
||||
|
||||
|
||||
def test_subtask_handler_records_cloudwatch_url_on_subtask(
|
||||
harness: Harness, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""With the AWS Lambda runtime's log env vars present, a CloudWatch deep
|
||||
link is built and persisted on the SubTask."""
|
||||
monkeypatch.setenv("AWS_REGION", "eu-west-2")
|
||||
monkeypatch.setenv(
|
||||
"AWS_LAMBDA_LOG_GROUP_NAME", "/aws/lambda/postcode-splitter"
|
||||
)
|
||||
monkeypatch.setenv(
|
||||
"AWS_LAMBDA_LOG_STREAM_NAME", "2026/05/20/[$LATEST]abc123"
|
||||
)
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
||||
@subtask_handler(orchestrator_cm=harness.factory)
|
||||
def handler(
|
||||
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
handler(_direct_event(task.id, subtask.id), context=None)
|
||||
|
||||
saved_url = harness.subtasks.get(subtask.id).cloud_logs_url
|
||||
assert saved_url is not None
|
||||
assert saved_url.startswith(
|
||||
"https://eu-west-2.console.aws.amazon.com/cloudwatch/home"
|
||||
)
|
||||
# Log group / stream are console-encoded ("/" -> "$252F").
|
||||
assert "$252Faws$252Flambda$252Fpostcode-splitter" in saved_url
|
||||
assert "$255B$2524LATEST$255D" in saved_url
|
||||
|
||||
|
||||
def test_subtask_handler_leaves_cloudwatch_url_unset_outside_lambda(
|
||||
harness: Harness, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Outside a real Lambda (e.g. the local RIE) the runtime log env vars
|
||||
are absent, so cloud_logs_url is left unset rather than storing junk."""
|
||||
for var in (
|
||||
"AWS_REGION",
|
||||
"AWS_LAMBDA_LOG_GROUP_NAME",
|
||||
"AWS_LAMBDA_LOG_STREAM_NAME",
|
||||
):
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
||||
@subtask_handler(orchestrator_cm=harness.factory)
|
||||
def handler(
|
||||
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
handler(_direct_event(task.id, subtask.id), context=None)
|
||||
|
||||
assert harness.subtasks.get(subtask.id).cloud_logs_url is None
|
||||
|
|
|
|||
|
|
@ -1,18 +1,32 @@
|
|||
"""@subtask_handler decorator for Lambdas that operate on existing SubTasks.
|
||||
|
||||
Translates an AWS Lambda invocation (SQS-shaped or direct) into
|
||||
TaskOrchestrator.run_subtask(...) calls.
|
||||
TaskOrchestrator.run_subtask(...) calls, emitting an INFO log line for each
|
||||
subtask's start and completion and a logged exception on failure. Those lines
|
||||
land in CloudWatch via the Lambda runtime's stdout/stderr capture.
|
||||
|
||||
Each subtask also records ``cloud_logs_url`` -- a deep link to this
|
||||
invocation's CloudWatch log stream -- so an operator can jump from a SubTask
|
||||
row straight to its logs. It is built from the environment variables the AWS
|
||||
Lambda runtime sets, so it is populated only on real Lambda invocations and
|
||||
left unset under the local RIE (which does not export them).
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from contextlib import AbstractContextManager
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Optional, cast
|
||||
from urllib.parse import quote
|
||||
|
||||
from utilities.aws_lambda.default_orchestrator import default_orchestrator
|
||||
from utilities.aws_lambda.subtask_trigger_body import SubtaskTriggerBody
|
||||
from orchestration.task_orchestrator import TaskOrchestrator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
OrchestratorCM = Callable[[], AbstractContextManager[TaskOrchestrator]]
|
||||
|
||||
|
||||
|
|
@ -33,14 +47,26 @@ def subtask_handler(
|
|||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@wraps(func)
|
||||
def wrapper(event: dict[str, Any], context: Any) -> None:
|
||||
cloud_logs_url = _cloudwatch_url()
|
||||
with factory() as orchestrator:
|
||||
for record in _records(event):
|
||||
body = _parse_body(record)
|
||||
trigger = SubtaskTriggerBody.model_validate(body)
|
||||
orchestrator.run_subtask(
|
||||
trigger.sub_task_id,
|
||||
work=lambda body=body, o=orchestrator: func(body, context, o),
|
||||
)
|
||||
logger.info("Running subtask %s", trigger.sub_task_id)
|
||||
try:
|
||||
orchestrator.run_subtask(
|
||||
trigger.sub_task_id,
|
||||
work=lambda body=body, o=orchestrator: func(
|
||||
body, context, o
|
||||
),
|
||||
cloud_logs_url=cloud_logs_url,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Subtask %s failed", trigger.sub_task_id
|
||||
)
|
||||
raise
|
||||
logger.info("Subtask %s completed", trigger.sub_task_id)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
|
@ -65,3 +91,34 @@ def _records(event: dict[str, Any]) -> list[dict[str, Any]]:
|
|||
if isinstance(raw_records, list):
|
||||
return [r for r in cast(list[Any], raw_records) if isinstance(r, dict)]
|
||||
return [event]
|
||||
|
||||
|
||||
def _console_encode(value: str) -> str:
|
||||
"""Encode a value for a CloudWatch console deep link.
|
||||
|
||||
The console expects URL-encoding with the percent signs themselves
|
||||
re-encoded as ``$25`` -- e.g. ``/`` becomes ``%2F`` becomes ``$252F``.
|
||||
"""
|
||||
return quote(value, safe="").replace("%", "$25")
|
||||
|
||||
|
||||
def _cloudwatch_url() -> Optional[str]:
|
||||
"""Build a CloudWatch console URL for this invocation's log stream.
|
||||
|
||||
Sourced entirely from the environment variables the AWS Lambda runtime
|
||||
sets -- ``AWS_REGION``, ``AWS_LAMBDA_LOG_GROUP_NAME`` and
|
||||
``AWS_LAMBDA_LOG_STREAM_NAME``. Returns None when any is absent, which is
|
||||
the case outside a real Lambda (the local RIE does not export them) -- so
|
||||
``SubTask.cloud_logs_url`` is left unset rather than storing a link that
|
||||
points nowhere.
|
||||
"""
|
||||
region = os.environ.get("AWS_REGION")
|
||||
log_group = os.environ.get("AWS_LAMBDA_LOG_GROUP_NAME")
|
||||
log_stream = os.environ.get("AWS_LAMBDA_LOG_STREAM_NAME")
|
||||
if not (region and log_group and log_stream):
|
||||
return None
|
||||
return (
|
||||
f"https://{region}.console.aws.amazon.com/cloudwatch/home"
|
||||
f"?region={region}#logsV2:log-groups/log-group/"
|
||||
f"{_console_encode(log_group)}/log-events/{_console_encode(log_stream)}"
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue