Merge pull request #1106 from Hestia-Homes/claude/Model-p3

Refactor postcode_splitter into the DDD layout (project #3)
This commit is contained in:
Jun-te Kim 2026-05-20 15:01:29 +01:00 committed by GitHub
commit 00f0cb5442
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
54 changed files with 2142 additions and 51 deletions

1
.gitignore vendored
View file

@ -121,6 +121,7 @@ celerybeat.pid
# Environments
.env
.env.local
.venv
env/
venv/

View file

@ -23,7 +23,7 @@ Invoke `/ubiquitous-language` in any session to extract new terms from the conve
|------|------------|------------------|
| **UPRN** | Unique Property Reference Number — the government-issued permanent identifier for a physical address in the UK. | "property ID", "address ID", "code" |
| **Postcode** | A UK postal code used to group nearby addresses; the primary search key for finding EPC records. | "zip code", "postal code" |
| **User Address** | A free-text address string provided by a user or imported from a customer dataset, before any normalisation or matching. | "user input", "raw address", "user_inputed_address" |
| **User Address** | A structured dataclass (`domain.addresses.user_address.UserAddress`) capturing a customer-supplied address: a free-text `user_address` line, a canonical `postcode` (sanitised on construction), and an optional `internal_reference`. The bare string sense -- the raw free-text address line as it arrives from upstream ingestion, before being wrapped -- remains valid when discussing CSV columns, API payloads, or other upstream contexts; in domain code, prefer the dataclass. | "user input", "raw address", "user_inputed_address" |
| **Dwelling** | A single residential unit that can hold an EPC — a house, flat, or maisonette. | "property", "unit", "home" |
## Address Matching
@ -72,7 +72,7 @@ Invoke `/ubiquitous-language` in any session to extract new terms from the conve
## Flagged ambiguities
- **"address"** appears as both the raw **User Address** (free-text from customer data) and a structured field on an **EPC Search Result** (normalised address lines). Always qualify: "user address" vs "EPC address" or "address line 1".
- **"address"** appears as both the raw **User Address** (free-text from customer data, or the structured `UserAddress` dataclass that wraps it) and a structured field on an **EPC Search Result** (normalised address lines). Always qualify: "user address" vs "EPC address" or "address line 1". Within `domain/`, **User Address** specifically means the `UserAddress` dataclass; in upstream ingestion contexts (CSV columns, SQS payloads) it can still mean the raw string sense.
- **"score"** is used for the `AddressMatch.score()` function output, the `lexiscore` DataFrame column, and informally in conversation. Prefer **Lexiscore** in domain discussions; reserve "score" for method-level code comments.
- **"user_inputed_address"** in `backend/address2UPRN/main.py` is a misspelling and a synonym for **User Address** — the canonical term. New code should use `user_address`.
- **"EPC"** is overloaded as both the document (an Energy Performance Certificate) and the rating band letter. Use **EPC** for the document and **EPC Band** for the letter.

0
applications/__init__.py Normal file
View file

View file

@ -0,0 +1,34 @@
FROM public.ecr.aws/lambda/python:3.11
# Postgres host/port/database are baked into the image at build time from
# the deploy workflow's --build-arg values (GitHub Actions DEV_DB_* secrets),
# mirroring backend/postcode_splitter/handler/Dockerfile. They map onto the
# POSTGRES_* names PostgresConfig.from_env reads. Username/password are NOT
# baked in -- Terraform injects those as Lambda env vars from Secrets Manager.
ARG DEV_DB_HOST
ARG DEV_DB_PORT
ARG DEV_DB_NAME
ENV POSTGRES_HOST=${DEV_DB_HOST}
ENV POSTGRES_PORT=${DEV_DB_PORT}
ENV POSTGRES_DATABASE=${DEV_DB_NAME}
WORKDIR /var/task
COPY applications/postcode_splitter/requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy the layered source the handler imports from. The new splitter pulls
# only DDD-shaped packages — no pandas, no legacy backend/.
COPY domain/ domain/
COPY infrastructure/ infrastructure/
COPY orchestration/ orchestration/
COPY repositories/ repositories/
COPY utilities/ utilities/
COPY applications/ applications/
# Place the handler at the Lambda task root so the runtime can resolve
# ``main.handler`` without an extra package prefix.
COPY applications/postcode_splitter/handler.py /var/task/main.py
CMD ["main.handler"]

View file

@ -0,0 +1,52 @@
from __future__ import annotations
import os
from typing import Any
import boto3
from applications.postcode_splitter.postcode_splitter_trigger_body import (
PostcodeSplitterTriggerBody,
)
from infrastructure.address2uprn_queue_client import Address2UprnQueueClient
from infrastructure.csv_s3_client import CsvS3Client
from orchestration.postcode_splitter_orchestrator import PostcodeSplitterOrchestrator
from orchestration.task_orchestrator import TaskOrchestrator
from repositories.user_address.user_address_csv_s3_repository import (
UserAddressCsvS3Repository,
)
from utilities.aws_lambda.subtask_handler import subtask_handler
@subtask_handler()
def handler(
body: dict[str, Any], context: Any, task_orchestrator: TaskOrchestrator
) -> dict[str, list[str]]:
trigger = PostcodeSplitterTriggerBody.model_validate(body)
bucket = os.environ["S3_BUCKET_NAME"]
queue_url = os.environ["ADDRESS2UPRN_QUEUE_URL"]
# boto3.client is overloaded per-service in the installed stubs; cast
# to Any so the strict-mode checker treats it as opaque.
boto3_client: Any = boto3.client # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
boto_s3: Any = boto3_client("s3")
boto_sqs: Any = boto3_client("sqs")
csv_client = CsvS3Client(boto_s3, bucket)
user_address_repo = UserAddressCsvS3Repository(csv_client, bucket)
queue_client = Address2UprnQueueClient(boto_sqs, queue_url)
splitter = PostcodeSplitterOrchestrator(
task_orchestrator=task_orchestrator,
user_address_repo=user_address_repo,
queue_client=queue_client,
)
child_ids = splitter.split_and_dispatch(
parent_task_id=trigger.task_id,
parent_subtask_id=trigger.sub_task_id,
input_s3_uri=trigger.s3_uri,
)
return {"child_subtask_ids": [str(cid) for cid in child_ids]}

View file

@ -0,0 +1,34 @@
# Local-test environment for the postcode_splitter Lambda.
#
# cp .env.local.example .env.local then fill in the values below.
#
# .env.local is gitignored. The container hits REAL AWS and a REAL Postgres,
# so every value here points at infrastructure that actually exists.
#
# NOTE: the new DDD code uses different env var names than the repo root
# .env. The mapping (root .env name -> var here) is given per section.
# Keep comments on their own lines — docker-compose's env_file parser folds a
# trailing "# ..." into the value.
# --- Postgres (orchestration/default_orchestrator -> PostgresConfig.from_env) ---
# POSTGRES_HOST <- DB_HOST, PORT <- DB_PORT, USERNAME <- DB_USERNAME,
# PASSWORD <- DB_PASSWORD, DATABASE <- DB_NAME.
POSTGRES_HOST=
POSTGRES_PORT=5432
POSTGRES_USERNAME=
POSTGRES_PASSWORD=
POSTGRES_DATABASE=
# POSTGRES_DRIVER=psycopg2 (optional; defaults to psycopg2)
# --- Handler config (applications/postcode_splitter/handler.py) ---
# S3_BUCKET_NAME: bucket holding the input address CSV (root .env: DATA_BUCKET).
# ADDRESS2UPRN_QUEUE_URL: SQS queue the splitter fans batches out to; not in
# the root .env (Terraform sets it in prod).
S3_BUCKET_NAME=
ADDRESS2UPRN_QUEUE_URL=
# --- AWS credentials for boto3 (S3 + SQS clients) ---
AWS_ACCESS_KEY_ID=
AWS_SECRET_ACCESS_KEY=
AWS_DEFAULT_REGION=eu-west-2
# AWS_SESSION_TOKEN= (only if using temporary/SSO credentials)

View file

@ -0,0 +1,9 @@
services:
postcode-splitter:
build:
context: ../../../
dockerfile: applications/postcode_splitter/Dockerfile
ports:
- "9001:8080"
env_file:
- .env.local

View file

@ -0,0 +1,28 @@
#!/usr/bin/env python3
import json
import requests
HOST = "localhost"
PORT = "9001"
LAMBDA_URL = f"http://{HOST}:{PORT}/2015-03-31/functions/function/invocations"
payload = {
"Records": [
{
"body": json.dumps(
{
"task_id": "f4b3332f-c0cc-481f-96a5-d39860a647cf",
"sub_task_id": "14c042de-40c4-473b-8cd8-72c983a94a8d",
"s3_uri": "s3://retrofit-data-dev/ara_raw_inputs/calico/Calico Homes Full list EPC Properties(Sheet2) (1) (1).csv",
}
)
}
]
}
response = requests.post(LAMBDA_URL, json=payload)
print("Status code:", response.status_code)
print("Response:")
print(response.text)

View file

@ -0,0 +1,12 @@
#!/usr/bin/env bash
set -euo pipefail
cd "$(dirname "$0")"
if [ ! -f .env.local ]; then
cp .env.local.example .env.local
echo "Created .env.local from the template — fill it in, then re-run." >&2
exit 1
fi
docker compose build --no-cache
docker compose up --force-recreate

View file

@ -0,0 +1,11 @@
from uuid import UUID
from pydantic import BaseModel, ConfigDict
class PostcodeSplitterTriggerBody(BaseModel):
model_config = ConfigDict(extra="allow")
task_id: UUID
sub_task_id: UUID
s3_uri: str

View file

@ -0,0 +1,4 @@
boto3
pydantic
sqlmodel
psycopg2-binary

View file

@ -8,4 +8,5 @@ boto3==1.35.44
sqlmodel
sqlalchemy==2.0.36
psycopg2-binary==2.9.10
pydantic-settings==2.6.0
pydantic-settings==2.6.0
httpx

View file

@ -40,20 +40,6 @@ module "lambda" {
LOG_LEVEL = "info"
DB_USERNAME = local.db_credentials.db_assessment_model_username
DB_PASSWORD = local.db_credentials.db_assessment_model_password
GOOGLE_SOLAR_API_KEY = "test"
SAP_PREDICTIONS_BUCKET = "test"
CARBON_PREDICTIONS_BUCKET = "test"
HEAT_PREDICTIONS_BUCKET = "test"
HEATING_KWH_PREDICTIONS_BUCKET = "test"
HOTWATER_KWH_PREDICTIONS_BUCKET = "test"
API_KEY = "test"
ENVIRONMENT = "test"
SECRET_KEY = "test"
PLAN_TRIGGER_BUCKET = "test"
DATA_BUCKET = "test"
EPC_AUTH_TOKEN = "test"
ENGINE_SQS_URL = "test"
ENERGY_ASSESSMENTS_BUCKET = "test"
ADDRESS2UPRN_QUEUE_URL = data.terraform_remote_state.address2uprn.outputs.address2uprn_queue_url
S3_BUCKET_NAME = data.terraform_remote_state.shared.outputs.retrofit_sap_data_bucket_name
},

View file

View file

@ -0,0 +1,51 @@
from __future__ import annotations
from collections.abc import Iterable, Iterator
from domain.addresses.user_address import UserAddress
from domain.postcode import Postcode
def iter_postcode_grouped_batches(
addresses: Iterable[UserAddress],
*,
max_batch_size: int = 500,
) -> Iterator[list[UserAddress]]:
if max_batch_size < 1:
raise ValueError("max_batch_size must be >= 1")
groups = _group_by_postcode_in_order(addresses)
buffer: list[UserAddress] = []
for group in groups.values():
group_len = len(group)
# Oversize single-Postcode group: flush buffer first, then dispatch
# the group as its own batch. Mirrors the legacy
# ``if group_len >= batch_size`` branch.
if group_len >= max_batch_size:
if buffer:
yield buffer
buffer = []
yield group
continue
# Adding this group would overflow: flush buffer before appending.
if len(buffer) + group_len > max_batch_size:
yield buffer
buffer = []
buffer.extend(group)
# Final flush.
if buffer:
yield buffer
def _group_by_postcode_in_order(
addresses: Iterable[UserAddress],
) -> dict[Postcode, list[UserAddress]]:
groups: dict[Postcode, list[UserAddress]] = {}
for address in addresses:
groups.setdefault(address.postcode, []).append(address)
return groups

View file

@ -0,0 +1,18 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional
from domain.postcode import Postcode
def _empty_source_row() -> dict[str, str]:
return {}
@dataclass(frozen=True)
class UserAddress:
user_address: str
postcode: Postcode
internal_reference: Optional[str] = None
source_row: dict[str, str] = field(default_factory=_empty_source_row, compare=False)

15
domain/postcode.py Normal file
View file

@ -0,0 +1,15 @@
from __future__ import annotations
from dataclasses import dataclass
@dataclass(frozen=True)
class Postcode:
value: str
def __post_init__(self) -> None:
# Frozen dataclass: bypass the descriptor with object.__setattr__.
object.__setattr__(self, "value", "".join(self.value.split()).upper())
def __str__(self) -> str:
return self.value

View file

@ -0,0 +1,20 @@
from uuid import UUID
from infrastructure.sqs_client import SqsClient
class Address2UprnQueueClient(SqsClient):
def publish(
self,
*,
parent_task_id: UUID,
child_subtask_id: UUID,
s3_uri: str,
) -> str:
return self.send(
{
"task_id": str(parent_task_id),
"sub_task_id": str(child_subtask_id),
"s3_uri": s3_uri,
}
)

View file

@ -0,0 +1,28 @@
import csv
from io import StringIO
from infrastructure.s3_client import S3Client
from infrastructure.s3_uri import parse_s3_uri
class CsvS3Client(S3Client):
def read_rows(self, s3_uri: str) -> list[dict[str, str]]:
bucket, key = parse_s3_uri(s3_uri)
if bucket != self.bucket:
raise ValueError(
f"s3_uri bucket {bucket!r} does not match client bucket {self.bucket!r}"
)
raw = self.get_object(key)
text = raw.decode("utf-8-sig")
reader = csv.DictReader(StringIO(text))
return [dict(row) for row in reader]
def save_rows(self, rows: list[dict[str, str]], key: str) -> str:
if not rows:
raise ValueError("Cannot save an empty rows list: header is unknown")
buffer = StringIO()
fieldnames = list(rows[0].keys())
writer = csv.DictWriter(buffer, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(rows)
return self.put_object(key, buffer.getvalue().encode("utf-8"))

View file

@ -0,0 +1,22 @@
from typing import Any
class S3Client:
def __init__(self, boto_s3_client: Any, bucket: str) -> None:
self._client = boto_s3_client
self._bucket = bucket
@property
def bucket(self) -> str:
return self._bucket
def get_object(self, key: str) -> bytes:
response: dict[str, Any] = self._client.get_object(
Bucket=self._bucket, Key=key
)
body: bytes = response["Body"].read()
return body
def put_object(self, key: str, body: bytes) -> str:
self._client.put_object(Bucket=self._bucket, Key=key, Body=body)
return f"s3://{self._bucket}/{key}"

25
infrastructure/s3_uri.py Normal file
View file

@ -0,0 +1,25 @@
from urllib.parse import unquote
def parse_s3_uri(s3_uri: str) -> tuple[str, str]:
if s3_uri.startswith("s3://"):
parts = s3_uri[len("s3://") :].split("/", 1)
if len(parts) < 2 or not parts[0] or not parts[1]:
raise ValueError("S3 URI must include both a bucket and a key")
return parts[0], parts[1]
if "?" not in s3_uri:
raise ValueError(f"Not an s3:// URI and has no query string: {s3_uri!r}")
base, query = s3_uri.split("?", 1)
if "/s3/object/" not in base:
raise ValueError(f"Console URL has no '/s3/object/' segment: {s3_uri!r}")
bucket = base.split("/s3/object/", 1)[1]
params: dict[str, str] = {}
for item in query.split("&"):
if "=" in item:
name, value = item.split("=", 1)
params[name] = value
key = unquote(params.get("prefix", ""))
return bucket, key

View file

@ -0,0 +1,20 @@
import json
from typing import Any
class SqsClient:
def __init__(self, boto_sqs_client: Any, queue_url: str) -> None:
self._client = boto_sqs_client
self._queue_url = queue_url
@property
def queue_url(self) -> str:
return self._queue_url
def send(self, body: dict[str, Any]) -> str:
response: dict[str, Any] = self._client.send_message(
QueueUrl=self._queue_url,
MessageBody=json.dumps(body),
)
message_id: str = response["MessageId"]
return message_id

View file

@ -0,0 +1,55 @@
from __future__ import annotations
from uuid import UUID
from infrastructure.address2uprn_queue_client import Address2UprnQueueClient
from orchestration.task_orchestrator import TaskOrchestrator
from domain.addresses.postcode_batching import iter_postcode_grouped_batches
from repositories.user_address.user_address_repository import UserAddressRepository
class PostcodeSplitterOrchestrator:
def __init__(
self,
task_orchestrator: TaskOrchestrator,
user_address_repo: UserAddressRepository,
queue_client: Address2UprnQueueClient,
max_batch_size: int = 500,
) -> None:
self._task_orchestrator = task_orchestrator
self._user_address_repo = user_address_repo
self._queue_client = queue_client
self._max_batch_size = max_batch_size
def split_and_dispatch(
self,
*,
parent_task_id: UUID,
parent_subtask_id: UUID,
input_s3_uri: str,
) -> list[UUID]:
addresses = self._user_address_repo.load_batch(input_s3_uri)
path_prefix = (
f"ara_postcode_splitter_batches/{parent_task_id}/{parent_subtask_id}"
)
child_ids: list[UUID] = []
for batch in iter_postcode_grouped_batches(
addresses, max_batch_size=self._max_batch_size
):
batch_uri = self._user_address_repo.save_batch(batch, path_prefix)
child = self._task_orchestrator.create_child_subtask(
parent_task_id,
inputs={
"task_id": str(parent_task_id),
"s3_uri": batch_uri,
},
)
self._queue_client.publish(
parent_task_id=parent_task_id,
child_subtask_id=child.id,
s3_uri=batch_uri,
)
child_ids.append(child.id)
return child_ids

View file

@ -48,6 +48,16 @@ class TaskOrchestrator:
self._subtasks.create(subtask)
return task, subtask
def create_child_subtask(
self,
parent_task_id: UUID,
*,
inputs: Optional[dict[str, Any]] = None,
) -> SubTask:
subtask = SubTask.create(task_id=parent_task_id, inputs=inputs)
self._subtasks.create(subtask)
return subtask
def start_subtask(
self, subtask_id: UUID, cloud_logs_url: Optional[str] = None
) -> SubTask:

View file

View file

@ -0,0 +1,63 @@
from __future__ import annotations
import uuid
from datetime import datetime, timezone
from typing import Optional
from domain.addresses.user_address import UserAddress
from domain.postcode import Postcode
from infrastructure.csv_s3_client import CsvS3Client
from repositories.user_address.user_address_repository import UserAddressRepository
_ADDRESS_COLUMNS: tuple[str, str, str] = ("Address 1", "Address 2", "Address 3")
_POSTCODE_COLUMN: str = "postcode"
_INTERNAL_REFERENCE_COLUMN: str = "Internal Reference"
_POSTCODE_CLEAN_COLUMN: str = "postcode_clean"
class UserAddressCsvS3Repository(UserAddressRepository):
def __init__(self, csv_client: CsvS3Client, bucket: str) -> None:
self._csv_client = csv_client
self._bucket = bucket
def load_batch(self, s3_uri: str) -> list[UserAddress]:
rows = self._csv_client.read_rows(s3_uri)
if rows and _POSTCODE_COLUMN not in rows[0]:
raise ValueError(
f"Input CSV {s3_uri} has no {_POSTCODE_COLUMN!r} column; "
f"columns present: {sorted(rows[0])}"
)
addresses: list[UserAddress] = []
for row in rows:
parts = [
row[col].strip()
for col in _ADDRESS_COLUMNS
if col in row and row[col].strip()
]
user_address = ", ".join(parts)
postcode = row.get(_POSTCODE_COLUMN, "")
raw_ref = row.get(_INTERNAL_REFERENCE_COLUMN, "").strip()
internal_reference: Optional[str] = raw_ref or None
addresses.append(
UserAddress(
user_address=user_address,
postcode=Postcode(postcode),
internal_reference=internal_reference,
source_row=row,
)
)
return addresses
def save_batch(self, addresses: list[UserAddress], path_prefix: str) -> str:
rows: list[dict[str, str]] = [
{**addr.source_row, _POSTCODE_CLEAN_COLUMN: str(addr.postcode)}
for addr in addresses
]
# TODO: [New Starter Task] file_name generation can be standardised
# and also easier to read, test for future implementation. Buiild that!
filename = (
f"{datetime.now(timezone.utc).isoformat()}_{uuid.uuid4().hex[:8]}.csv"
)
key = f"{path_prefix.rstrip('/')}/{filename}"
return self._csv_client.save_rows(rows, key)

View file

@ -0,0 +1,13 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from domain.addresses.user_address import UserAddress
class UserAddressRepository(ABC):
@abstractmethod
def load_batch(self, s3_uri: str) -> list[UserAddress]: ...
@abstractmethod
def save_batch(self, addresses: list[UserAddress], path_prefix: str) -> str: ...

View file

@ -9,4 +9,5 @@ hubspot-api-client
fuzzywuzzy
pymupdf
playwright==1.58.0
msal
msal
moto[s3,sqs]

48
tests/conftest.py Normal file
View file

@ -0,0 +1,48 @@
"""Shared pytest fixtures for the ``tests/`` tree.
Provides an ephemeral PostgreSQL engine for tests that exercise SQLModel
repositories. PostgreSQL has no true in-memory mode; ``pytest-postgresql``
starts a real, throwaway server in a temp directory (the process is started
once per session and a fresh database is created/dropped per test). That is
the closest equivalent to "in-memory" and matches production behaviour far
better than SQLite (enums, JSONB, constraint semantics, etc.).
"""
from __future__ import annotations
import glob
from collections.abc import Iterator
from typing import Any
import pytest
from psycopg import Connection
from pytest_postgresql import factories
from sqlalchemy import Engine
from sqlmodel import SQLModel, create_engine
# Importing the SQLModel row modules registers their tables on
# SQLModel.metadata so ``create_all`` builds the full schema. Imports look
# unused; they aren't.
# pg_ctl ships under a versioned path and is not on PATH in the dev container.
_PG_CTL = next(iter(sorted(glob.glob("/usr/lib/postgresql/*/bin/pg_ctl"))), "pg_ctl")
postgresql_proc = factories.postgresql_proc(
executable=_PG_CTL
) # pyright: ignore[reportUnknownMemberType]
postgresql = factories.postgresql("postgresql_proc")
@pytest.fixture
def db_engine(postgresql: Connection[Any]) -> Iterator[Engine]:
"""A SQLModel engine bound to a fresh, ephemeral PostgreSQL database."""
info = postgresql.info
url = f"postgresql+psycopg://{info.user}:@{info.host}:{info.port}/{info.dbname}"
engine = create_engine(url)
SQLModel.metadata.create_all(engine)
try:
yield engine
finally:
SQLModel.metadata.drop_all(engine)
engine.dispose()

View file

View file

@ -0,0 +1,118 @@
import pytest
from domain.addresses.postcode_batching import iter_postcode_grouped_batches
from domain.addresses.user_address import UserAddress
from domain.postcode import Postcode
def _addrs(postcode: str, n: int) -> list[UserAddress]:
return [
UserAddress(
user_address=f"{i} {postcode} Street", postcode=Postcode(postcode)
)
for i in range(n)
]
def test_empty_input_yields_no_batches() -> None:
# act / assert
assert list(iter_postcode_grouped_batches([])) == []
def test_single_batch_under_cap() -> None:
# arrange
addrs = _addrs("AA1 1AA", 3) + _addrs("BB2 2BB", 2)
# act
batches = list(iter_postcode_grouped_batches(addrs, max_batch_size=500))
# assert
assert len(batches) == 1
assert batches[0] == addrs
def test_multiple_postcodes_packed_into_one_batch_up_to_cap() -> None:
# Two groups whose total exactly equals the cap pack into a single
# batch -- no premature flush.
# arrange
addrs = _addrs("AA1 1AA", 3) + _addrs("BB2 2BB", 2)
# act
batches = list(iter_postcode_grouped_batches(addrs, max_batch_size=5))
# assert
assert len(batches) == 1
assert len(batches[0]) == 5
def test_flush_on_overflow_before_adding_next_postcode() -> None:
# Cap is 5. First group fills 3 slots; second group of 3 would overflow,
# so the buffer is flushed first and the next group starts a fresh batch.
# arrange
addrs = _addrs("AA1 1AA", 3) + _addrs("BB2 2BB", 3)
# act
batches = list(iter_postcode_grouped_batches(addrs, max_batch_size=5))
# assert
assert len(batches) == 2
assert [str(a.postcode) for a in batches[0]] == ["AA11AA"] * 3
assert [str(a.postcode) for a in batches[1]] == ["BB22BB"] * 3
def test_single_postcode_group_exceeding_cap_is_dispatched_whole() -> None:
# An oversize single-postcode group goes out as one batch larger than
# the cap -- the cap never splits a postcode.
# arrange
addrs = _addrs("AA1 1AA", 7)
# act
batches = list(iter_postcode_grouped_batches(addrs, max_batch_size=5))
# assert
assert len(batches) == 1
assert len(batches[0]) == 7
def test_oversize_group_flushes_existing_buffer_first() -> None:
# Mirrors the legacy ``if buffer: flush`` branch when an oversize group
# is encountered: buffered work must not be lost or interleaved.
# arrange
small = _addrs("AA1 1AA", 2)
big = _addrs("BB2 2BB", 7)
tail = _addrs("CC3 3CC", 1)
# act
batches = list(
iter_postcode_grouped_batches(small + big + tail, max_batch_size=5)
)
# assert
assert len(batches) == 3
assert [str(a.postcode) for a in batches[0]] == ["AA11AA", "AA11AA"]
assert [str(a.postcode) for a in batches[1]] == ["BB22BB"] * 7
assert [str(a.postcode) for a in batches[2]] == ["CC33CC"]
def test_final_flush_yields_remaining_buffer() -> None:
# No overflow ever happens, but the trailing buffer must still come out.
# arrange
addrs = _addrs("AA1 1AA", 2) + _addrs("BB2 2BB", 2)
# act
batches = list(iter_postcode_grouped_batches(addrs, max_batch_size=500))
# assert
assert batches == [addrs]
def test_postcode_grouping_preserves_first_seen_order() -> None:
# Interleaved input must still group by postcode and emit in first-seen
# order -- never alphabetical.
# arrange
a1, a2 = _addrs("ZZ9 9ZZ", 2)
b1, b2 = _addrs("AA1 1AA", 2)
# act
batches = list(iter_postcode_grouped_batches([a1, b1, a2, b2]))
# assert
assert len(batches) == 1
assert [str(a.postcode) for a in batches[0]] == [
"ZZ99ZZ",
"ZZ99ZZ",
"AA11AA",
"AA11AA",
]
def test_invalid_max_batch_size_raises() -> None:
# act / assert
with pytest.raises(ValueError, match="max_batch_size"):
list(iter_postcode_grouped_batches([], max_batch_size=0))

View file

@ -0,0 +1,98 @@
import dataclasses
import pytest
from domain.addresses.user_address import UserAddress
from domain.postcode import Postcode
def test_user_address_holds_postcode_value_object() -> None:
# act
addr = UserAddress(user_address="1 The Street", postcode=Postcode("sw1a 1aa"))
# assert
assert addr.postcode == Postcode("SW1A1AA")
def test_user_address_preserves_user_address_verbatim() -> None:
# The free-text user_address string is intentionally NOT normalised --
# only the postcode is canonicalised, and that happens inside Postcode.
# act
addr = UserAddress(
user_address=" 1 The Street ", postcode=Postcode("SW1A1AA")
)
# assert
assert addr.user_address == " 1 The Street "
def test_user_address_internal_reference_defaults_to_none() -> None:
# act
addr = UserAddress(user_address="1 The Street", postcode=Postcode("SW1A1AA"))
# assert
assert addr.internal_reference is None
def test_user_address_internal_reference_accepted() -> None:
# act
addr = UserAddress(
user_address="1 The Street",
postcode=Postcode("SW1A1AA"),
internal_reference="cust-42",
)
# assert
assert addr.internal_reference == "cust-42"
def test_user_address_is_frozen() -> None:
# arrange
addr = UserAddress(user_address="1 The Street", postcode=Postcode("SW1A1AA"))
# act / assert
with pytest.raises(dataclasses.FrozenInstanceError):
addr.postcode = Postcode("OTHER") # type: ignore[misc]
def test_user_address_equality_uses_canonical_postcode() -> None:
# Postcode sanitises eagerly, so addresses built from different surface
# forms of the same postcode compare equal.
# arrange
a = UserAddress(user_address="1 The Street", postcode=Postcode("sw1a 1aa"))
b = UserAddress(user_address="1 The Street", postcode=Postcode("SW1A1AA"))
# act / assert
assert a == b
def test_user_address_source_row_defaults_to_empty_dict() -> None:
# act
addr = UserAddress(user_address="1 The Street", postcode=Postcode("SW1A1AA"))
# assert
assert addr.source_row == {}
def test_user_address_carries_source_row() -> None:
# arrange
row = {"Address 1": "1 The Street", "postcode": "SW1A 1AA", "SAP Score": "72"}
# act
addr = UserAddress(
user_address="1 The Street",
postcode=Postcode("SW1A 1AA"),
source_row=row,
)
# assert
assert addr.source_row == row
def test_user_address_equality_ignores_source_row() -> None:
# source_row is excluded from equality (and hashing): identity stays
# defined by the parsed fields.
# arrange
a = UserAddress(
user_address="1 The Street",
postcode=Postcode("SW1A1AA"),
source_row={"x": "1"},
)
b = UserAddress(
user_address="1 The Street",
postcode=Postcode("SW1A1AA"),
source_row={"y": "2"},
)
# act / assert
assert a == b

View file

@ -6,10 +6,13 @@ from domain.tasks.subtasks import SubTask, SubTaskStatus
def test_create_subtask_starts_waiting() -> None:
# arrange
task_id = uuid4()
# act
st = SubTask.create(task_id=task_id, inputs={"foo": "bar"})
# assert
assert st.task_id == task_id
assert st.status is SubTaskStatus.WAITING
assert st.inputs == {"foo": "bar"}
@ -19,57 +22,74 @@ def test_create_subtask_starts_waiting() -> None:
def test_start_transitions_to_in_progress_and_sets_cloud_logs_url() -> None:
# arrange
st = SubTask.create(task_id=uuid4())
# act
st.start(cloud_logs_url="https://example/log")
# assert
assert st.status is SubTaskStatus.IN_PROGRESS
assert st.cloud_logs_url == "https://example/log"
assert st.job_started is not None
def test_start_is_idempotent_from_in_progress() -> None:
# arrange
st = SubTask.create(task_id=uuid4())
st.start()
first_start = st.job_started
# act
st.start(cloud_logs_url="https://other")
# assert
assert st.status is SubTaskStatus.IN_PROGRESS
assert st.job_started == first_start # not overwritten
assert st.cloud_logs_url == "https://other"
def test_start_rejects_from_terminal_status() -> None:
# arrange
st = SubTask.create(task_id=uuid4())
st.complete()
# act / assert
with pytest.raises(ValueError):
st.start()
def test_complete_marks_outputs_and_job_completed() -> None:
# arrange
st = SubTask.create(task_id=uuid4())
st.start()
# act
st.complete({"uprn": "123"})
# assert
assert st.status is SubTaskStatus.COMPLETE
assert st.outputs == {"result": {"uprn": "123"}}
assert st.job_completed is not None
def test_complete_without_result_leaves_outputs_unset() -> None:
# arrange
st = SubTask.create(task_id=uuid4())
# act
st.complete()
# assert
assert st.outputs is None
def test_fail_records_error_in_outputs() -> None:
# arrange
st = SubTask.create(task_id=uuid4())
err = RuntimeError("boom")
# act
st.fail(err)
# assert
assert st.status is SubTaskStatus.FAILED
assert st.outputs == {"error": "boom"}
assert st.job_completed is not None

View file

@ -5,12 +5,12 @@ from domain.tasks.tasks import Source, Task, TaskStatus
def test_create_task_starts_waiting() -> None:
# Arrange / Act
# arrange / act
t = Task.create(
task_source="manual:test", source=Source.PORTFOLIO, source_id="abc-123"
)
# Assert
# assert
assert t.status is TaskStatus.WAITING
assert t.source is Source.PORTFOLIO
assert t.source_id == "abc-123"
@ -19,86 +19,113 @@ def test_create_task_starts_waiting() -> None:
def test_create_task_rejects_blank_task_source() -> None:
# act / assert
with pytest.raises(ValueError, match="task_source"):
Task.create(task_source=" ")
def test_start_transitions_to_in_progress() -> None:
# arrange
t = Task.create(task_source="manual:test")
# act
t.start()
# assert
assert t.status is TaskStatus.IN_PROGRESS
def test_complete_marks_job_completed() -> None:
# arrange
t = Task.create(task_source="manual:test")
t.start()
# act
t.complete()
# assert
assert t.status is TaskStatus.COMPLETE
assert t.job_completed is not None
def test_fail_marks_job_completed() -> None:
# arrange
t = Task.create(task_source="manual:test")
# act
t.fail()
# assert
assert t.status is TaskStatus.FAILED
assert t.job_completed is not None
def test_start_rejects_from_terminal_status() -> None:
# arrange
t = Task.create(task_source="manual:test")
t.complete()
# act / assert
with pytest.raises(ValueError):
t.start()
def test_recalculate_with_empty_statuses_is_noop() -> None:
# arrange
t = Task.create(task_source="manual:test")
original_status = t.status
original_completed = t.job_completed
# act
t.recalculate_from_subtasks([])
# assert
assert t.status is original_status
assert t.job_completed is original_completed
def test_recalculate_all_waiting_keeps_waiting() -> None:
# arrange
t = Task.create(task_source="manual:test")
t.start() # task moved to IN_PROGRESS earlier
t.complete() # then COMPLETE, with job_completed set
# act
t.recalculate_from_subtasks([SubTaskStatus.WAITING, SubTaskStatus.WAITING])
# assert
assert t.status is TaskStatus.WAITING
assert t.job_completed is None
def test_recalculate_any_in_progress_marks_in_progress() -> None:
# arrange
t = Task.create(task_source="manual:test")
# act
t.recalculate_from_subtasks(
[SubTaskStatus.WAITING, SubTaskStatus.IN_PROGRESS, SubTaskStatus.COMPLETE]
)
# assert
assert t.status is TaskStatus.IN_PROGRESS
assert t.job_completed is None
def test_recalculate_all_complete_marks_complete() -> None:
# arrange
t = Task.create(task_source="manual:test")
# act
t.recalculate_from_subtasks([SubTaskStatus.COMPLETE, SubTaskStatus.COMPLETE])
# assert
assert t.status is TaskStatus.COMPLETE
assert t.job_completed is not None
def test_recalculate_any_failed_marks_failed_even_with_others() -> None:
# arrange
t = Task.create(task_source="manual:test")
# act
t.recalculate_from_subtasks(
[SubTaskStatus.IN_PROGRESS, SubTaskStatus.COMPLETE, SubTaskStatus.FAILED]
)
# assert
assert t.status is TaskStatus.FAILED
assert t.job_completed is not None

View file

@ -0,0 +1,59 @@
import dataclasses
import pytest
from domain.postcode import Postcode
def test_postcode_uppercases() -> None:
# act / assert
assert Postcode("sw1a1aa").value == "SW1A1AA"
def test_postcode_strips_internal_spaces() -> None:
# act / assert
assert Postcode("sw1a 1aa").value == "SW1A1AA"
def test_postcode_strips_leading_and_trailing_whitespace() -> None:
# act / assert
assert Postcode(" sw1a 1aa ").value == "SW1A1AA"
def test_postcode_strips_tabs_and_newlines() -> None:
# CSV ingestion occasionally introduces stray whitespace characters; the
# canonical form must absorb them just like literal spaces.
# act / assert
assert Postcode("sw1a\t1aa\n").value == "SW1A1AA"
def test_postcode_construction_is_idempotent() -> None:
# arrange
once = Postcode("sw1a 1aa")
# act / assert
assert Postcode(once.value).value == "SW1A1AA"
def test_postcode_empty_string() -> None:
# act / assert
assert Postcode("").value == ""
def test_postcode_str_returns_canonical_value() -> None:
# act / assert
assert str(Postcode("sw1a 1aa")) == "SW1A1AA"
def test_postcode_equality_ignores_surface_form() -> None:
# Differing case / whitespace sanitise to the same canonical value, so
# the value objects compare equal.
# act / assert
assert Postcode("sw1a 1aa") == Postcode("SW1A1AA")
def test_postcode_is_frozen() -> None:
# arrange
postcode = Postcode("SW1A1AA")
# act / assert
with pytest.raises(dataclasses.FrozenInstanceError):
postcode.value = "OTHER" # type: ignore[misc]

View file

@ -0,0 +1,10 @@
from typing import Any
import boto3
REGION = "us-east-1"
def make_boto_client(service_name: str) -> Any:
factory: Any = boto3.client # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
return factory(service_name, region_name=REGION)

View file

@ -0,0 +1,28 @@
import os
from collections.abc import Iterator
from typing import Optional
import pytest
@pytest.fixture(autouse=True)
def _aws_creds() -> Iterator[None]: # pyright: ignore[reportUnusedFunction]
keys = (
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_SESSION_TOKEN",
"AWS_DEFAULT_REGION",
)
prev: dict[str, Optional[str]] = {k: os.environ.get(k) for k in keys}
os.environ["AWS_ACCESS_KEY_ID"] = "testing"
os.environ["AWS_SECRET_ACCESS_KEY"] = "testing"
os.environ["AWS_SESSION_TOKEN"] = "testing"
os.environ["AWS_DEFAULT_REGION"] = "us-east-1"
try:
yield
finally:
for k, v in prev.items():
if v is None:
os.environ.pop(k, None)
else:
os.environ[k] = v

View file

@ -0,0 +1,71 @@
import json
from collections.abc import Iterator
from typing import Any, cast
from uuid import uuid4
import pytest
from moto import mock_aws
from infrastructure.address2uprn_queue_client import Address2UprnQueueClient
from tests.infrastructure import make_boto_client
@pytest.fixture
def queue_setup() -> Iterator[tuple[Address2UprnQueueClient, Any, str]]:
with mock_aws():
boto_client = make_boto_client("sqs")
queue: dict[str, Any] = boto_client.create_queue(
QueueName="address2uprn-queue"
)
queue_url = cast(str, queue["QueueUrl"])
yield (
Address2UprnQueueClient(boto_client, queue_url),
boto_client,
queue_url,
)
def test_publish_returns_message_id(
queue_setup: tuple[Address2UprnQueueClient, Any, str],
) -> None:
# arrange
client, _boto, _url = queue_setup
# act
message_id = client.publish(
parent_task_id=uuid4(),
child_subtask_id=uuid4(),
s3_uri="s3://my-bucket/path/to/chunk.csv",
)
# assert
assert isinstance(message_id, str)
assert message_id
def test_publish_body_uses_typed_shape(
queue_setup: tuple[Address2UprnQueueClient, Any, str],
) -> None:
# arrange
client, boto_client, queue_url = queue_setup
parent_id = uuid4()
child_id = uuid4()
s3_uri = "s3://my-bucket/path/to/chunk.csv"
# act
client.publish(
parent_task_id=parent_id,
child_subtask_id=child_id,
s3_uri=s3_uri,
)
# assert
received: dict[str, Any] = boto_client.receive_message(
QueueUrl=queue_url, MaxNumberOfMessages=1
)
messages: list[dict[str, Any]] = received["Messages"]
assert len(messages) == 1
body = json.loads(messages[0]["Body"])
assert body == {
"task_id": str(parent_id),
"sub_task_id": str(child_id),
"s3_uri": s3_uri,
}

View file

@ -0,0 +1,51 @@
from collections.abc import Iterator
import pytest
from moto import mock_aws
from infrastructure.csv_s3_client import CsvS3Client
from tests.infrastructure import make_boto_client
BUCKET = "csv-bucket"
@pytest.fixture
def csv_client() -> Iterator[CsvS3Client]:
with mock_aws():
boto_client = make_boto_client("s3")
boto_client.create_bucket(Bucket=BUCKET)
yield CsvS3Client(boto_client, BUCKET)
def test_save_rows_returns_s3_uri(csv_client: CsvS3Client) -> None:
# arrange
rows = [{"address": "1 High St", "postcode": "AB1 2CD"}]
# act
uri = csv_client.save_rows(rows, "uploads/addresses.csv")
# assert
assert uri == f"s3://{BUCKET}/uploads/addresses.csv"
def test_round_trip_preserves_rows(csv_client: CsvS3Client) -> None:
# arrange
rows = [
{"address": "1 High St", "postcode": "AB1 2CD"},
{"address": "2 Low St", "postcode": "XY9 8ZW"},
]
# act
uri = csv_client.save_rows(rows, "uploads/addresses.csv")
fetched = csv_client.read_rows(uri)
# assert
assert fetched == rows
def test_save_rows_rejects_empty_list(csv_client: CsvS3Client) -> None:
# act / assert
with pytest.raises(ValueError, match="empty"):
csv_client.save_rows([], "uploads/empty.csv")
def test_read_rows_rejects_wrong_bucket(csv_client: CsvS3Client) -> None:
# act / assert
with pytest.raises(ValueError, match="does not match client bucket"):
csv_client.read_rows("s3://other-bucket/uploads/addresses.csv")

View file

@ -0,0 +1,36 @@
from collections.abc import Iterator
import pytest
from moto import mock_aws
from infrastructure.s3_client import S3Client
from tests.infrastructure import make_boto_client
BUCKET = "test-bucket"
@pytest.fixture
def s3_client() -> Iterator[S3Client]:
with mock_aws():
boto_client = make_boto_client("s3")
boto_client.create_bucket(Bucket=BUCKET)
yield S3Client(boto_client, BUCKET)
def test_put_object_returns_s3_uri(s3_client: S3Client) -> None:
# act
uri = s3_client.put_object("folder/data.bin", b"payload")
# assert
assert uri == f"s3://{BUCKET}/folder/data.bin"
def test_get_object_returns_bytes_written_by_put_object(s3_client: S3Client) -> None:
# arrange
s3_client.put_object("round/trip.bin", b"hello world")
# act / assert
assert s3_client.get_object("round/trip.bin") == b"hello world"
def test_bucket_property_exposes_configured_bucket(s3_client: S3Client) -> None:
# act / assert
assert s3_client.bucket == BUCKET

View file

@ -0,0 +1,40 @@
import pytest
from infrastructure.s3_uri import parse_s3_uri
def test_parses_simple_s3_uri() -> None:
# act / assert
assert parse_s3_uri("s3://my-bucket/file.csv") == ("my-bucket", "file.csv")
def test_parses_s3_uri_with_nested_key() -> None:
# act
bucket, key = parse_s3_uri("s3://my-bucket/nested/path/to/file.csv")
# assert
assert (bucket, key) == ("my-bucket", "nested/path/to/file.csv")
def test_rejects_s3_uri_without_key() -> None:
# act / assert
with pytest.raises(ValueError, match="bucket and a key"):
parse_s3_uri("s3://my-bucket")
def test_rejects_s3_uri_with_empty_key() -> None:
# act / assert
with pytest.raises(ValueError, match="bucket and a key"):
parse_s3_uri("s3://my-bucket/")
def test_parses_console_url_prefix() -> None:
# arrange
url = "https://eu-west-2.console.aws.amazon.com/s3/object/my-bucket?prefix=nested%2Ffile.csv"
# act / assert
assert parse_s3_uri(url) == ("my-bucket", "nested/file.csv")
def test_rejects_unparseable_string() -> None:
# act / assert
with pytest.raises(ValueError):
parse_s3_uri("not-a-uri-at-all")

View file

@ -0,0 +1,44 @@
import json
from collections.abc import Iterator
from typing import Any, cast
import pytest
from moto import mock_aws
from infrastructure.sqs_client import SqsClient
from tests.infrastructure import make_boto_client
@pytest.fixture
def sqs_setup() -> Iterator[tuple[SqsClient, Any, str]]:
with mock_aws():
boto_client = make_boto_client("sqs")
queue: dict[str, Any] = boto_client.create_queue(QueueName="test-queue")
queue_url = cast(str, queue["QueueUrl"])
yield SqsClient(boto_client, queue_url), boto_client, queue_url
def test_send_returns_message_id(sqs_setup: tuple[SqsClient, Any, str]) -> None:
# arrange
client, _boto, _url = sqs_setup
# act
message_id = client.send({"hello": "world"})
# assert
assert isinstance(message_id, str)
assert message_id
def test_send_json_serialises_body(sqs_setup: tuple[SqsClient, Any, str]) -> None:
# arrange
client, boto_client, queue_url = sqs_setup
body = {"hello": "world", "count": 3}
# act
client.send(body)
# assert
received: dict[str, Any] = boto_client.receive_message(
QueueUrl=queue_url, MaxNumberOfMessages=1
)
messages: list[dict[str, Any]] = received["Messages"]
assert len(messages) == 1
assert json.loads(messages[0]["Body"]) == body

View file

@ -0,0 +1,299 @@
from __future__ import annotations
import json
import os
from collections.abc import Iterator
from dataclasses import dataclass
from typing import Any, cast
import boto3
import pytest
from moto import mock_aws
from sqlalchemy import Engine
from sqlmodel import Session
from infrastructure.address2uprn_queue_client import Address2UprnQueueClient
from infrastructure.csv_s3_client import CsvS3Client
from orchestration.postcode_splitter_orchestrator import PostcodeSplitterOrchestrator
from orchestration.task_orchestrator import TaskOrchestrator
from repositories.tasks.subtask_postgres_repository import SubTaskPostgresRepository
from repositories.tasks.task_postgres_repository import TaskPostgresRepository
from repositories.user_address.user_address_csv_s3_repository import (
UserAddressCsvS3Repository,
)
BUCKET = "splitter-bucket"
REGION = "us-east-1"
def _make_boto_client(service_name: str) -> Any:
factory: Any = boto3.client # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
return factory(service_name, region_name=REGION)
@pytest.fixture(autouse=True)
def _aws_creds() -> Iterator[None]: # pyright: ignore[reportUnusedFunction]
keys = (
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_SESSION_TOKEN",
"AWS_DEFAULT_REGION",
)
prev: dict[str, Any] = {k: os.environ.get(k) for k in keys}
os.environ["AWS_ACCESS_KEY_ID"] = "testing"
os.environ["AWS_SECRET_ACCESS_KEY"] = "testing"
os.environ["AWS_SESSION_TOKEN"] = "testing"
os.environ["AWS_DEFAULT_REGION"] = REGION
try:
yield
finally:
for k, v in prev.items():
if v is None:
os.environ.pop(k, None)
else:
os.environ[k] = v
@dataclass
class Harness:
splitter: PostcodeSplitterOrchestrator
task_orchestrator: TaskOrchestrator
subtasks: SubTaskPostgresRepository
csv_client: CsvS3Client
boto_sqs: Any
queue_url: str
repo: UserAddressCsvS3Repository
@pytest.fixture
def harness(db_engine: Engine) -> Iterator[Harness]:
with mock_aws():
# Infra: S3 + SQS
boto_s3 = _make_boto_client("s3")
boto_s3.create_bucket(Bucket=BUCKET)
boto_sqs = _make_boto_client("sqs")
queue: dict[str, Any] = boto_sqs.create_queue(QueueName="address2uprn-queue")
queue_url = cast(str, queue["QueueUrl"])
csv_client = CsvS3Client(boto_s3, BUCKET)
repo = UserAddressCsvS3Repository(csv_client, BUCKET)
queue_client = Address2UprnQueueClient(boto_sqs, queue_url)
# DB: ephemeral PostgreSQL TaskOrchestrator
with Session(db_engine) as session:
task_repo = TaskPostgresRepository(session=session)
subtask_repo = SubTaskPostgresRepository(session=session)
task_orchestrator = TaskOrchestrator(
task_repo=task_repo, subtask_repo=subtask_repo
)
splitter = PostcodeSplitterOrchestrator(
task_orchestrator=task_orchestrator,
user_address_repo=repo,
queue_client=queue_client,
max_batch_size=3,
)
yield Harness(
splitter=splitter,
task_orchestrator=task_orchestrator,
subtasks=subtask_repo,
csv_client=csv_client,
boto_sqs=boto_sqs,
queue_url=queue_url,
repo=repo,
)
def _upload_fixture_csv(csv_client: CsvS3Client) -> str:
# Three postcode groups:
# AA1 1AA × 2 (within cap)
# BB2 2BB × 4 (oversize: > max_batch_size=3)
# CC3 3CC × 1 (final flush)
# Expected batching with cap=3 and the algorithm in
# ``iter_postcode_grouped_batches``:
# batch 1: [AA1 1AA × 2] (flushed because oversize follows)
# batch 2: [BB2 2BB × 4] (oversize own batch)
# batch 3: [CC3 3CC × 1] (final flush)
rows: list[dict[str, str]] = []
rows.extend(
{
"Address 1": f"{i} High St",
"Address 2": "",
"Address 3": "",
"postcode": "AA1 1AA",
"Internal Reference": f"AA-{i}",
}
for i in range(1, 3)
)
rows.extend(
{
"Address 1": f"{i} Long Road",
"Address 2": "",
"Address 3": "",
"postcode": "BB2 2BB",
"Internal Reference": f"BB-{i}",
}
for i in range(1, 5)
)
rows.append(
{
"Address 1": "1 Final Way",
"Address 2": "",
"Address 3": "",
"postcode": "CC3 3CC",
"Internal Reference": "CC-1",
}
)
return csv_client.save_rows(rows, "uploads/input.csv")
def _drain_queue(boto_sqs: Any, queue_url: str) -> list[dict[str, Any]]:
bodies: list[dict[str, Any]] = []
while True:
received: dict[str, Any] = boto_sqs.receive_message(
QueueUrl=queue_url, MaxNumberOfMessages=10, WaitTimeSeconds=0
)
messages = cast(list[dict[str, Any]], received.get("Messages", []))
if not messages:
break
for message in messages:
bodies.append(cast(dict[str, Any], json.loads(message["Body"])))
boto_sqs.delete_message(
QueueUrl=queue_url, ReceiptHandle=message["ReceiptHandle"]
)
return bodies
def test_split_and_dispatch_creates_three_children_for_fixture(
harness: Harness,
) -> None:
# arrange
parent_task, parent_subtask = (
harness.task_orchestrator.create_task_with_subtask(
task_source="manual:postcode-splitter-int"
)
)
input_uri = _upload_fixture_csv(harness.csv_client)
# act
child_ids = harness.splitter.split_and_dispatch(
parent_task_id=parent_task.id,
parent_subtask_id=parent_subtask.id,
input_s3_uri=input_uri,
)
# assert
assert len(child_ids) == 3
# All child ids are unique and persisted as WAITING children of the
# parent task.
assert len(set(child_ids)) == 3
for cid in child_ids:
child = harness.subtasks.get(cid)
assert child.task_id == parent_task.id
def test_split_and_dispatch_persists_child_inputs_with_task_id_and_s3_uri(
harness: Harness,
) -> None:
# arrange
parent_task, parent_subtask = (
harness.task_orchestrator.create_task_with_subtask(
task_source="manual:postcode-splitter-int"
)
)
input_uri = _upload_fixture_csv(harness.csv_client)
# act
child_ids = harness.splitter.split_and_dispatch(
parent_task_id=parent_task.id,
parent_subtask_id=parent_subtask.id,
input_s3_uri=input_uri,
)
# assert
for cid in child_ids:
child = harness.subtasks.get(cid)
assert child.inputs is not None
assert child.inputs["task_id"] == str(parent_task.id)
batch_uri = child.inputs["s3_uri"]
assert isinstance(batch_uri, str)
prefix = (
f"s3://{BUCKET}/ara_postcode_splitter_batches/"
f"{parent_task.id}/{parent_subtask.id}/"
)
assert batch_uri.startswith(prefix)
assert batch_uri.endswith(".csv")
def test_split_and_dispatch_publishes_one_message_per_child_with_matching_ids(
harness: Harness,
) -> None:
# arrange
parent_task, parent_subtask = (
harness.task_orchestrator.create_task_with_subtask(
task_source="manual:postcode-splitter-int"
)
)
input_uri = _upload_fixture_csv(harness.csv_client)
# act
child_ids = harness.splitter.split_and_dispatch(
parent_task_id=parent_task.id,
parent_subtask_id=parent_subtask.id,
input_s3_uri=input_uri,
)
# assert
bodies = _drain_queue(harness.boto_sqs, harness.queue_url)
assert len(bodies) == len(child_ids)
# Match queue messages against persisted child inputs by child_subtask_id;
# the message body's task_id/s3_uri must agree with the SubTask inputs.
bodies_by_child = {body["sub_task_id"]: body for body in bodies}
assert set(bodies_by_child.keys()) == {str(cid) for cid in child_ids}
for cid in child_ids:
child = harness.subtasks.get(cid)
body = bodies_by_child[str(cid)]
assert child.inputs is not None
assert body == {
"task_id": str(parent_task.id),
"sub_task_id": str(cid),
"s3_uri": child.inputs["s3_uri"],
}
def test_split_and_dispatch_returns_child_ids_in_dispatch_order(
harness: Harness,
) -> None:
# arrange
parent_task, parent_subtask = (
harness.task_orchestrator.create_task_with_subtask(
task_source="manual:postcode-splitter-int"
)
)
input_uri = _upload_fixture_csv(harness.csv_client)
# act
child_ids = harness.splitter.split_and_dispatch(
parent_task_id=parent_task.id,
parent_subtask_id=parent_subtask.id,
input_s3_uri=input_uri,
)
# assert
# Re-load each child's saved batch and inspect the postcode_clean column
# to confirm the dispatch order matches the postcode-batching algorithm:
# AA-batch first, BB oversize batch second, CC final-flush third.
postcodes_per_batch: list[set[str]] = []
for cid in child_ids:
child = harness.subtasks.get(cid)
assert child.inputs is not None
rows = harness.csv_client.read_rows(child.inputs["s3_uri"])
postcodes_per_batch.append({row["postcode_clean"] for row in rows})
assert postcodes_per_batch == [
{"AA11AA"},
{"BB22BB"},
{"CC33CC"},
]

View file

@ -2,7 +2,8 @@ from collections.abc import Iterator
from dataclasses import dataclass
import pytest
from sqlmodel import Session, SQLModel, create_engine
from sqlalchemy import Engine
from sqlmodel import Session
from domain.tasks.subtasks import SubTask, SubTaskStatus
from domain.tasks.tasks import Source, TaskStatus
@ -19,10 +20,8 @@ class Harness:
@pytest.fixture
def harness() -> Iterator[Harness]:
engine = create_engine("sqlite://")
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
def harness(db_engine: Engine) -> Iterator[Harness]:
with Session(db_engine) as session:
tasks = TaskPostgresRepository(session=session)
subtasks = SubTaskPostgresRepository(session=session)
yield Harness(
@ -35,6 +34,7 @@ def harness() -> Iterator[Harness]:
def test_create_task_with_subtask_creates_both_in_waiting(
harness: Harness,
) -> None:
# act
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test",
inputs={"foo": "bar"},
@ -42,6 +42,7 @@ def test_create_task_with_subtask_creates_both_in_waiting(
source_id="abc",
)
# assert
assert task.status is TaskStatus.WAITING
assert subtask.status is SubTaskStatus.WAITING
assert subtask.task_id == task.id
@ -49,27 +50,33 @@ def test_create_task_with_subtask_creates_both_in_waiting(
def test_start_subtask_cascades_to_in_progress(harness: Harness) -> None:
# arrange
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
# act
started = harness.orchestrator.start_subtask(
subtask.id, cloud_logs_url="https://example/log"
)
# assert
assert started.status is SubTaskStatus.IN_PROGRESS
assert started.cloud_logs_url == "https://example/log"
assert harness.tasks.get(task.id).status is TaskStatus.IN_PROGRESS
def test_complete_subtask_cascades_to_complete(harness: Harness) -> None:
# arrange
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
harness.orchestrator.start_subtask(subtask.id)
# act
harness.orchestrator.complete_subtask(subtask.id, {"value": 42})
# assert
done_subtask = harness.subtasks.get(subtask.id)
done_task = harness.tasks.get(task.id)
assert done_subtask.outputs == {"result": {"value": 42}}
@ -78,12 +85,15 @@ def test_complete_subtask_cascades_to_complete(harness: Harness) -> None:
def test_fail_subtask_cascades_to_failed(harness: Harness) -> None:
# arrange
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
# act
harness.orchestrator.fail_subtask(subtask.id, RuntimeError("boom"))
# assert
failed_subtask = harness.subtasks.get(subtask.id)
failed_task = harness.tasks.get(task.id)
assert failed_subtask.outputs == {"error": "boom"}
@ -93,50 +103,85 @@ def test_fail_subtask_cascades_to_failed(harness: Harness) -> None:
def test_failed_subtask_locks_task_failed_even_with_others_complete(
harness: Harness,
) -> None:
# arrange
task, first = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
second = SubTask.create(task_id=task.id)
harness.subtasks.create(second)
# act
harness.orchestrator.complete_subtask(first.id)
harness.orchestrator.fail_subtask(second.id, RuntimeError("nope"))
# assert
assert harness.tasks.get(task.id).status is TaskStatus.FAILED
def test_mixed_complete_and_in_progress_keeps_task_in_progress(
harness: Harness,
) -> None:
# arrange
task, first = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
second = SubTask.create(task_id=task.id)
harness.subtasks.create(second)
# act
harness.orchestrator.complete_subtask(first.id)
harness.orchestrator.start_subtask(second.id)
# assert
assert harness.tasks.get(task.id).status is TaskStatus.IN_PROGRESS
def test_run_subtask_happy_path_returns_result_and_cascades_complete(
harness: Harness,
) -> None:
# arrange
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
# act
result = harness.orchestrator.run_subtask(subtask.id, work=lambda: {"answer": 42})
# assert
assert result == {"answer": 42}
assert harness.subtasks.get(subtask.id).status is SubTaskStatus.COMPLETE
assert harness.tasks.get(task.id).status is TaskStatus.COMPLETE
def test_create_child_subtask_adds_waiting_child_without_changing_parent_status(
harness: Harness,
) -> None:
# arrange
task, first = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
harness.orchestrator.start_subtask(first.id)
assert harness.tasks.get(task.id).status is TaskStatus.IN_PROGRESS
# act
child = harness.orchestrator.create_child_subtask(
task.id, inputs={"split": "a"}
)
# assert
persisted_child = harness.subtasks.get(child.id)
assert persisted_child.task_id == task.id
assert persisted_child.status is SubTaskStatus.WAITING
assert persisted_child.inputs == {"split": "a"}
assert persisted_child.id != first.id
# Cascade is a no-op: parent stays IN_PROGRESS.
assert harness.tasks.get(task.id).status is TaskStatus.IN_PROGRESS
def test_run_subtask_failing_work_marks_failed_and_reraises(
harness: Harness,
) -> None:
# arrange
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
@ -144,6 +189,7 @@ def test_run_subtask_failing_work_marks_failed_and_reraises(
def boom() -> None:
raise RuntimeError("boom")
# act / assert
with pytest.raises(RuntimeError, match="boom"):
harness.orchestrator.run_subtask(subtask.id, work=boom)

View file

@ -1,33 +1,40 @@
from collections.abc import Iterator
from uuid import uuid4
from uuid import UUID, uuid4
import pytest
from sqlmodel import Session, SQLModel, create_engine
from sqlalchemy import Engine
from sqlmodel import Session
# Importing the SQLModel row modules registers their tables in
# SQLModel.metadata so create_all builds both. Imports look unused; they aren't.
import infrastructure.postgres.subtask_table # noqa: F401 # pyright: ignore[reportUnusedImport]
import infrastructure.postgres.task_table # noqa: F401 # pyright: ignore[reportUnusedImport]
from domain.tasks.subtasks import SubTask, SubTaskStatus
from domain.tasks.tasks import Task
from repositories.tasks.subtask_postgres_repository import SubTaskPostgresRepository
from repositories.tasks.task_postgres_repository import TaskPostgresRepository
@pytest.fixture
def session() -> Iterator[Session]:
engine = create_engine("sqlite://")
SQLModel.metadata.create_all(engine)
with Session(engine) as s:
def session(db_engine: Engine) -> Iterator[Session]:
with Session(db_engine) as s:
yield s
def _persisted_task_id(session: Session) -> UUID:
"""Create a parent Task row so SubTask FK constraints are satisfied."""
task = Task.create(task_source="manual:test")
TaskPostgresRepository(session=session).create(task)
return task.id
def test_create_and_get_round_trip_preserves_inputs(session: Session) -> None:
# arrange
repo = SubTaskPostgresRepository(session=session)
task_id = uuid4()
task_id = _persisted_task_id(session)
st = SubTask.create(task_id=task_id, inputs={"address": "68 Glendon Way"})
# act
repo.create(st)
fetched = repo.get(st.id)
# assert
assert fetched.id == st.id
assert fetched.task_id == task_id
assert fetched.status is SubTaskStatus.WAITING
@ -36,16 +43,21 @@ def test_create_and_get_round_trip_preserves_inputs(session: Session) -> None:
def test_save_persists_status_and_outputs(session: Session) -> None:
# arrange
repo = SubTaskPostgresRepository(session=session)
st = SubTask.create(task_id=uuid4())
st = SubTask.create(task_id=_persisted_task_id(session))
repo.create(st)
# act
st.start(cloud_logs_url="https://example/log")
repo.save(st)
# assert
assert repo.get(st.id).status is SubTaskStatus.IN_PROGRESS
# act
st.complete({"uprn": "123"})
repo.save(st)
# assert
done = repo.get(st.id)
assert done.status is SubTaskStatus.COMPLETE
assert done.outputs == {"result": {"uprn": "123"}}
@ -54,16 +66,19 @@ def test_save_persists_status_and_outputs(session: Session) -> None:
def test_list_by_task_filters_by_task_id(session: Session) -> None:
# arrange
repo = SubTaskPostgresRepository(session=session)
task_a = uuid4()
task_b = uuid4()
task_a = _persisted_task_id(session)
task_b = _persisted_task_id(session)
repo.create(SubTask.create(task_id=task_a))
repo.create(SubTask.create(task_id=task_a))
repo.create(SubTask.create(task_id=task_b))
# act
a_results = repo.list_by_task(task_a)
b_results = repo.list_by_task(task_b)
# assert
assert len(a_results) == 2
assert len(b_results) == 1
assert all(s.task_id == task_a for s in a_results)
@ -71,11 +86,15 @@ def test_list_by_task_filters_by_task_id(session: Session) -> None:
def test_list_by_task_returns_empty_for_unknown_task(session: Session) -> None:
# arrange
repo = SubTaskPostgresRepository(session=session)
# act / assert
assert repo.list_by_task(uuid4()) == []
def test_get_missing_raises(session: Session) -> None:
# arrange
repo = SubTaskPostgresRepository(session=session)
# act / assert
with pytest.raises(ValueError, match="not found"):
repo.get(uuid4())

View file

@ -2,7 +2,8 @@ from collections.abc import Iterator
from uuid import uuid4
import pytest
from sqlmodel import Session, SQLModel, create_engine
from sqlalchemy import Engine
from sqlmodel import Session
from domain.tasks.tasks import Source, Task, TaskStatus
from infrastructure.postgres.task_table import TaskRow
@ -10,25 +11,23 @@ from repositories.tasks.task_postgres_repository import TaskPostgresRepository
@pytest.fixture
def session() -> Iterator[Session]:
engine = create_engine("sqlite://")
SQLModel.metadata.create_all(engine)
with Session(engine) as s:
def session(db_engine: Engine) -> Iterator[Session]:
with Session(db_engine) as s:
yield s
def test_create_and_get_round_trip(session: Session) -> None:
# Arrange
# arrange
repo = TaskPostgresRepository(session=session)
t = Task.create(
task_source="manual:test", source=Source.PORTFOLIO, source_id="abc-123"
)
# Act
# act
repo.create(t)
fetched = repo.get(t.id)
# Assert
# assert
assert fetched.id == t.id
assert fetched.status is TaskStatus.WAITING
assert fetched.source is Source.PORTFOLIO
@ -36,33 +35,43 @@ def test_create_and_get_round_trip(session: Session) -> None:
def test_save_persists_status_transition(session: Session) -> None:
# arrange
repo = TaskPostgresRepository(session=session)
t = Task.create(task_source="manual:test")
repo.create(t)
# act
t.start()
repo.save(t)
# assert
assert repo.get(t.id).status is TaskStatus.IN_PROGRESS
# act
t.complete()
repo.save(t)
# assert
done = repo.get(t.id)
assert done.status is TaskStatus.COMPLETE
assert done.job_completed is not None
def test_get_missing_raises(session: Session) -> None:
# arrange
repo = TaskPostgresRepository(session=session)
# act / assert
with pytest.raises(ValueError, match="not found"):
repo.get(uuid4())
def test_get_normalises_legacy_capitalised_status(session: Session) -> None:
# Existing rows written by backend code use "In Progress" (capitalised).
# arrange
repo = TaskPostgresRepository(session=session)
row = TaskRow(task_source="manual:test", status="In Progress")
session.add(row)
session.commit()
# act
fetched = repo.get(row.id)
# assert
assert fetched.status is TaskStatus.IN_PROGRESS

View file

@ -0,0 +1,28 @@
import os
from collections.abc import Iterator
from typing import Optional
import pytest
@pytest.fixture(autouse=True)
def _aws_creds() -> Iterator[None]: # pyright: ignore[reportUnusedFunction]
keys = (
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_SESSION_TOKEN",
"AWS_DEFAULT_REGION",
)
prev: dict[str, Optional[str]] = {k: os.environ.get(k) for k in keys}
os.environ["AWS_ACCESS_KEY_ID"] = "testing"
os.environ["AWS_SECRET_ACCESS_KEY"] = "testing"
os.environ["AWS_SESSION_TOKEN"] = "testing"
os.environ["AWS_DEFAULT_REGION"] = "us-east-1"
try:
yield
finally:
for k, v in prev.items():
if v is None:
os.environ.pop(k, None)
else:
os.environ[k] = v

View file

@ -0,0 +1,237 @@
from collections.abc import Iterator
import pytest
from moto import mock_aws
from domain.addresses.user_address import UserAddress
from domain.postcode import Postcode
from infrastructure.csv_s3_client import CsvS3Client
from repositories.user_address.user_address_csv_s3_repository import (
UserAddressCsvS3Repository,
)
from tests.infrastructure import make_boto_client
BUCKET = "user-address-bucket"
@pytest.fixture
def repo() -> Iterator[UserAddressCsvS3Repository]:
with mock_aws():
boto_client = make_boto_client("s3")
boto_client.create_bucket(Bucket=BUCKET)
csv_client = CsvS3Client(boto_client, BUCKET)
yield UserAddressCsvS3Repository(csv_client, BUCKET)
def _upload_csv(
repo: UserAddressCsvS3Repository, rows: list[dict[str, str]], key: str
) -> str:
return repo._csv_client.save_rows(rows, key) # pyright: ignore[reportPrivateUsage]
def test_load_batch_parses_address_postcode_and_reference(
repo: UserAddressCsvS3Repository,
) -> None:
# arrange
rows = [
{
"Address 1": "1 High Street",
"Address 2": "Flat 2",
"Address 3": "Townville",
"postcode": "sw1a 1aa",
"Internal Reference": "REF-001",
}
]
uri = _upload_csv(repo, rows, "uploads/full.csv")
# act
addresses = repo.load_batch(uri)
# assert
assert len(addresses) == 1
address = addresses[0]
assert address.user_address == "1 High Street, Flat 2, Townville"
assert address.postcode == Postcode("SW1A1AA")
assert address.internal_reference == "REF-001"
def test_load_batch_uses_only_address_1_when_others_missing(
repo: UserAddressCsvS3Repository,
) -> None:
# arrange
rows = [
{
"Address 1": "10 Cardiff Road",
"Address 2": "",
"Address 3": "",
"postcode": "CF10 1AA",
"Internal Reference": "REF-002",
}
]
uri = _upload_csv(repo, rows, "uploads/address1-only.csv")
# act
addresses = repo.load_batch(uri)
# assert
assert len(addresses) == 1
assert addresses[0].user_address == "10 Cardiff Road"
assert addresses[0].postcode == Postcode("CF101AA")
assert addresses[0].internal_reference == "REF-002"
def test_load_batch_handles_missing_internal_reference(
repo: UserAddressCsvS3Repository,
) -> None:
# arrange
rows = [
{
"Address 1": "5 Park Lane",
"Address 2": "",
"Address 3": "",
"postcode": "M1 1AA",
"Internal Reference": "",
}
]
uri = _upload_csv(repo, rows, "uploads/no-ref.csv")
# act
addresses = repo.load_batch(uri)
# assert
assert len(addresses) == 1
assert addresses[0].user_address == "5 Park Lane"
assert addresses[0].postcode == Postcode("M11AA")
assert addresses[0].internal_reference is None
def test_load_batch_captures_full_source_row(
repo: UserAddressCsvS3Repository,
) -> None:
# A raw EPC-export-shaped row: the splitter must preserve every column,
# not just the ones it parses into UserAddress fields.
# arrange
row = {
"Asset Reference": "511",
"Address 1": "9 Abingdon Road Padiham Lancashire BB12 7BX",
"postcode": "BB12 7BX",
"Property Type": "House: End Terrace",
"SAP Score": "69",
}
uri = _upload_csv(repo, [row], "uploads/epc.csv")
# act
addresses = repo.load_batch(uri)
# assert
assert addresses[0].source_row == row
def test_load_batch_raises_when_postcode_column_absent(
repo: UserAddressCsvS3Repository,
) -> None:
# arrange
rows = [{"Address 1": "1 High Street", "Property Type": "Flat"}]
uri = _upload_csv(repo, rows, "uploads/no-postcode.csv")
# act / assert
with pytest.raises(ValueError, match="no 'postcode' column"):
repo.load_batch(uri)
def test_save_batch_passes_through_all_columns_and_appends_postcode_clean(
repo: UserAddressCsvS3Repository,
) -> None:
# arrange
row = {
"Asset Reference": "511",
"Address 1": "9 Abingdon Road Padiham Lancashire BB12 7BX",
"postcode": " BB12 7BX",
"Property Type": "House: End Terrace",
}
uri = _upload_csv(repo, [row], "uploads/epc.csv")
addresses = repo.load_batch(uri)
# act
saved_uri = repo.save_batch(addresses, "tasks/passthrough")
saved_rows = repo._csv_client.read_rows(saved_uri) # pyright: ignore[reportPrivateUsage]
# assert
assert len(saved_rows) == 1
saved = saved_rows[0]
# Every original column survives, byte-for-byte.
for column, value in row.items():
assert saved[column] == value
# Plus the one appended column the downstream address2uprn stage groups on.
assert saved["postcode_clean"] == "BB127BX"
def test_save_batch_returns_uri_under_path_prefix(
repo: UserAddressCsvS3Repository,
) -> None:
# arrange
addresses = [
UserAddress(
user_address="1 High Street",
postcode=Postcode("SW1A 1AA"),
source_row={"Address 1": "1 High Street", "postcode": "SW1A 1AA"},
),
]
# act
uri = repo.save_batch(addresses, "tasks/abc/batches")
# assert
assert uri.startswith(f"s3://{BUCKET}/tasks/abc/batches/")
assert uri.endswith(".csv")
def test_save_then_reload_round_trip_preserves_columns(
repo: UserAddressCsvS3Repository,
) -> None:
# arrange
rows = [
{
"Address 1": "1 High Street",
"postcode": "SW1A 1AA",
"Internal Reference": "REF-001",
},
{
"Address 1": "2 Low Street",
"postcode": "XY9 8ZW",
"Internal Reference": "",
},
]
uri = _upload_csv(repo, rows, "uploads/round-trip.csv")
addresses = repo.load_batch(uri)
# act
saved_uri = repo.save_batch(addresses, "tasks/round-trip")
saved_rows = repo._csv_client.read_rows(saved_uri) # pyright: ignore[reportPrivateUsage]
# assert
# Original columns come back verbatim; postcode_clean is the only addition.
assert [
{k: v for k, v in r.items() if k != "postcode_clean"} for r in saved_rows
] == rows
assert [r["postcode_clean"] for r in saved_rows] == ["SW1A1AA", "XY98ZW"]
def test_save_batch_uses_unique_filename_per_call(
repo: UserAddressCsvS3Repository,
) -> None:
# arrange
addresses = [
UserAddress(
user_address="1 High Street",
postcode=Postcode("SW1A 1AA"),
source_row={"Address 1": "1 High Street", "postcode": "SW1A 1AA"},
),
]
# act
uri_1 = repo.save_batch(addresses, "tasks/uniqueness")
uri_2 = repo.save_batch(addresses, "tasks/uniqueness")
# assert
assert uri_1 != uri_2

View file

View file

View file

@ -0,0 +1,255 @@
import logging
from collections.abc import Generator, Iterator
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any
from uuid import UUID
import pytest
from sqlalchemy import Engine
from sqlmodel import Session
from domain.tasks.subtasks import SubTaskStatus
from domain.tasks.tasks import TaskStatus
from orchestration.task_orchestrator import TaskOrchestrator
from repositories.tasks.subtask_postgres_repository import SubTaskPostgresRepository
from repositories.tasks.task_postgres_repository import TaskPostgresRepository
from utilities.aws_lambda.subtask_handler import subtask_handler
_LOGGER_NAME = "utilities.aws_lambda.subtask_handler"
@dataclass
class Harness:
orchestrator: TaskOrchestrator
tasks: TaskPostgresRepository
subtasks: SubTaskPostgresRepository
@contextmanager
def factory(self) -> Generator[TaskOrchestrator, None, None]:
yield self.orchestrator
@pytest.fixture
def harness(db_engine: Engine) -> Iterator[Harness]:
with Session(db_engine) as session:
tasks = TaskPostgresRepository(session=session)
subtasks = SubTaskPostgresRepository(session=session)
yield Harness(
orchestrator=TaskOrchestrator(task_repo=tasks, subtask_repo=subtasks),
tasks=tasks,
subtasks=subtasks,
)
def _direct_event(task_id: UUID, subtask_id: UUID) -> dict[str, Any]:
return {"task_id": str(task_id), "sub_task_id": str(subtask_id)}
def test_subtask_handler_injects_orchestrator_as_third_positional_argument(
harness: Harness,
) -> None:
# arrange
_, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
received: dict[str, Any] = {}
@subtask_handler(orchestrator_cm=harness.factory)
def handler(
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
) -> None:
received["body"] = body
received["context"] = context
received["orchestrator"] = orchestrator
# act
handler(_direct_event(subtask.task_id, subtask.id), context="ctx-sentinel")
# assert
assert received["orchestrator"] is harness.orchestrator
assert received["context"] == "ctx-sentinel"
assert received["body"]["sub_task_id"] == str(subtask.id)
def test_subtask_handler_completes_parent_subtask_on_success(
harness: Harness,
) -> None:
# arrange
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
@subtask_handler(orchestrator_cm=harness.factory)
def handler(
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
) -> None:
return None
# act
handler(_direct_event(task.id, subtask.id), context=None)
# assert
assert harness.subtasks.get(subtask.id).status is SubTaskStatus.COMPLETE
assert harness.tasks.get(task.id).status is TaskStatus.COMPLETE
def test_subtask_handler_marks_parent_failed_and_reraises_on_error(
harness: Harness,
) -> None:
# arrange
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
@subtask_handler(orchestrator_cm=harness.factory)
def handler(
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
) -> None:
raise RuntimeError("boom")
# act / assert
with pytest.raises(RuntimeError, match="boom"):
handler(_direct_event(task.id, subtask.id), context=None)
assert harness.subtasks.get(subtask.id).status is SubTaskStatus.FAILED
assert harness.tasks.get(task.id).status is TaskStatus.FAILED
def test_subtask_handler_injected_orchestrator_can_create_child_subtask(
harness: Harness,
) -> None:
# arrange
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
child_ids: list[UUID] = []
@subtask_handler(orchestrator_cm=harness.factory)
def handler(
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
) -> None:
child = orchestrator.create_child_subtask(task.id, inputs={"split": 1})
child_ids.append(child.id)
# act
handler(_direct_event(task.id, subtask.id), context=None)
# assert
assert len(child_ids) == 1
persisted_child = harness.subtasks.get(child_ids[0])
assert persisted_child.task_id == task.id
assert persisted_child.status is SubTaskStatus.WAITING
def test_subtask_handler_logs_subtask_lifecycle_on_success(
harness: Harness, caplog: pytest.LogCaptureFixture
) -> None:
# arrange
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
@subtask_handler(orchestrator_cm=harness.factory)
def handler(
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
) -> None:
return None
# act
with caplog.at_level(logging.INFO, logger=_LOGGER_NAME):
handler(_direct_event(task.id, subtask.id), context=None)
# assert
assert f"Running subtask {subtask.id}" in caplog.text
assert f"Subtask {subtask.id} completed" in caplog.text
def test_subtask_handler_logs_exception_on_failure(
harness: Harness, caplog: pytest.LogCaptureFixture
) -> None:
# arrange
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
@subtask_handler(orchestrator_cm=harness.factory)
def handler(
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
) -> None:
raise RuntimeError("boom")
# act / assert
with caplog.at_level(logging.INFO, logger=_LOGGER_NAME):
with pytest.raises(RuntimeError, match="boom"):
handler(_direct_event(task.id, subtask.id), context=None)
failures = [r for r in caplog.records if r.levelno == logging.ERROR]
assert any(
f"Subtask {subtask.id} failed" in r.getMessage() for r in failures
)
assert any(r.exc_info is not None for r in failures)
def test_subtask_handler_records_cloudwatch_url_on_subtask(
harness: Harness, monkeypatch: pytest.MonkeyPatch
) -> None:
# arrange
monkeypatch.setenv("AWS_REGION", "eu-west-2")
monkeypatch.setenv(
"AWS_LAMBDA_LOG_GROUP_NAME", "/aws/lambda/postcode-splitter"
)
monkeypatch.setenv(
"AWS_LAMBDA_LOG_STREAM_NAME", "2026/05/20/[$LATEST]abc123"
)
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
@subtask_handler(orchestrator_cm=harness.factory)
def handler(
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
) -> None:
return None
# act
handler(_direct_event(task.id, subtask.id), context=None)
# assert
saved_url = harness.subtasks.get(subtask.id).cloud_logs_url
assert saved_url is not None
assert saved_url.startswith(
"https://eu-west-2.console.aws.amazon.com/cloudwatch/home"
)
# Log group / stream are console-encoded ("/" -> "$252F").
assert "$252Faws$252Flambda$252Fpostcode-splitter" in saved_url
assert "$255B$2524LATEST$255D" in saved_url
def test_subtask_handler_leaves_cloudwatch_url_unset_outside_lambda(
harness: Harness, monkeypatch: pytest.MonkeyPatch
) -> None:
# arrange
for var in (
"AWS_REGION",
"AWS_LAMBDA_LOG_GROUP_NAME",
"AWS_LAMBDA_LOG_STREAM_NAME",
):
monkeypatch.delenv(var, raising=False)
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
@subtask_handler(orchestrator_cm=harness.factory)
def handler(
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
) -> None:
return None
# act
handler(_direct_event(task.id, subtask.id), context=None)
# assert
assert harness.subtasks.get(subtask.id).cloud_logs_url is None

View file

@ -5,14 +5,20 @@ TaskOrchestrator.run_subtask(...) calls.
"""
import json
import logging
import os
from contextlib import AbstractContextManager
from functools import wraps
from typing import Any, Callable, Optional, cast
from urllib.parse import quote
from utilities.aws_lambda.default_orchestrator import default_orchestrator
from utilities.aws_lambda.subtask_trigger_body import SubtaskTriggerBody
from orchestration.task_orchestrator import TaskOrchestrator
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
OrchestratorCM = Callable[[], AbstractContextManager[TaskOrchestrator]]
@ -33,14 +39,26 @@ def subtask_handler(
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def wrapper(event: dict[str, Any], context: Any) -> None:
cloud_logs_url = _cloudwatch_url()
with factory() as orchestrator:
for record in _records(event):
body = _parse_body(record)
trigger = SubtaskTriggerBody.model_validate(body)
orchestrator.run_subtask(
trigger.sub_task_id,
work=lambda body=body: func(body, context),
)
logger.info("Running subtask %s", trigger.sub_task_id)
try:
orchestrator.run_subtask(
trigger.sub_task_id,
work=lambda body=body, o=orchestrator: func(
body, context, o
),
cloud_logs_url=cloud_logs_url,
)
except Exception:
logger.exception(
"Subtask %s failed", trigger.sub_task_id
)
raise
logger.info("Subtask %s completed", trigger.sub_task_id)
return wrapper
@ -65,3 +83,20 @@ def _records(event: dict[str, Any]) -> list[dict[str, Any]]:
if isinstance(raw_records, list):
return [r for r in cast(list[Any], raw_records) if isinstance(r, dict)]
return [event]
def _console_encode(value: str) -> str:
return quote(value, safe="").replace("%", "$25")
def _cloudwatch_url() -> Optional[str]:
region = os.environ.get("AWS_REGION")
log_group = os.environ.get("AWS_LAMBDA_LOG_GROUP_NAME")
log_stream = os.environ.get("AWS_LAMBDA_LOG_STREAM_NAME")
if not (region and log_group and log_stream):
return None
return (
f"https://{region}.console.aws.amazon.com/cloudwatch/home"
f"?region={region}#logsV2:log-groups/log-group/"
f"{_console_encode(log_group)}/log-events/{_console_encode(log_stream)}"
)