mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
Merge pull request #1106 from Hestia-Homes/claude/Model-p3
Refactor postcode_splitter into the DDD layout (project #3)
This commit is contained in:
commit
00f0cb5442
54 changed files with 2142 additions and 51 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -121,6 +121,7 @@ celerybeat.pid
|
|||
|
||||
# Environments
|
||||
.env
|
||||
.env.local
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ Invoke `/ubiquitous-language` in any session to extract new terms from the conve
|
|||
|------|------------|------------------|
|
||||
| **UPRN** | Unique Property Reference Number — the government-issued permanent identifier for a physical address in the UK. | "property ID", "address ID", "code" |
|
||||
| **Postcode** | A UK postal code used to group nearby addresses; the primary search key for finding EPC records. | "zip code", "postal code" |
|
||||
| **User Address** | A free-text address string provided by a user or imported from a customer dataset, before any normalisation or matching. | "user input", "raw address", "user_inputed_address" |
|
||||
| **User Address** | A structured dataclass (`domain.addresses.user_address.UserAddress`) capturing a customer-supplied address: a free-text `user_address` line, a canonical `postcode` (sanitised on construction), and an optional `internal_reference`. The bare string sense -- the raw free-text address line as it arrives from upstream ingestion, before being wrapped -- remains valid when discussing CSV columns, API payloads, or other upstream contexts; in domain code, prefer the dataclass. | "user input", "raw address", "user_inputed_address" |
|
||||
| **Dwelling** | A single residential unit that can hold an EPC — a house, flat, or maisonette. | "property", "unit", "home" |
|
||||
|
||||
## Address Matching
|
||||
|
|
@ -72,7 +72,7 @@ Invoke `/ubiquitous-language` in any session to extract new terms from the conve
|
|||
|
||||
## Flagged ambiguities
|
||||
|
||||
- **"address"** appears as both the raw **User Address** (free-text from customer data) and a structured field on an **EPC Search Result** (normalised address lines). Always qualify: "user address" vs "EPC address" or "address line 1".
|
||||
- **"address"** appears as both the raw **User Address** (free-text from customer data, or the structured `UserAddress` dataclass that wraps it) and a structured field on an **EPC Search Result** (normalised address lines). Always qualify: "user address" vs "EPC address" or "address line 1". Within `domain/`, **User Address** specifically means the `UserAddress` dataclass; in upstream ingestion contexts (CSV columns, SQS payloads) it can still mean the raw string sense.
|
||||
- **"score"** is used for the `AddressMatch.score()` function output, the `lexiscore` DataFrame column, and informally in conversation. Prefer **Lexiscore** in domain discussions; reserve "score" for method-level code comments.
|
||||
- **"user_inputed_address"** in `backend/address2UPRN/main.py` is a misspelling and a synonym for **User Address** — the canonical term. New code should use `user_address`.
|
||||
- **"EPC"** is overloaded as both the document (an Energy Performance Certificate) and the rating band letter. Use **EPC** for the document and **EPC Band** for the letter.
|
||||
|
|
|
|||
0
applications/__init__.py
Normal file
0
applications/__init__.py
Normal file
34
applications/postcode_splitter/Dockerfile
Normal file
34
applications/postcode_splitter/Dockerfile
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
FROM public.ecr.aws/lambda/python:3.11
|
||||
|
||||
# Postgres host/port/database are baked into the image at build time from
|
||||
# the deploy workflow's --build-arg values (GitHub Actions DEV_DB_* secrets),
|
||||
# mirroring backend/postcode_splitter/handler/Dockerfile. They map onto the
|
||||
# POSTGRES_* names PostgresConfig.from_env reads. Username/password are NOT
|
||||
# baked in -- Terraform injects those as Lambda env vars from Secrets Manager.
|
||||
ARG DEV_DB_HOST
|
||||
ARG DEV_DB_PORT
|
||||
ARG DEV_DB_NAME
|
||||
|
||||
ENV POSTGRES_HOST=${DEV_DB_HOST}
|
||||
ENV POSTGRES_PORT=${DEV_DB_PORT}
|
||||
ENV POSTGRES_DATABASE=${DEV_DB_NAME}
|
||||
|
||||
WORKDIR /var/task
|
||||
|
||||
COPY applications/postcode_splitter/requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy the layered source the handler imports from. The new splitter pulls
|
||||
# only DDD-shaped packages — no pandas, no legacy backend/.
|
||||
COPY domain/ domain/
|
||||
COPY infrastructure/ infrastructure/
|
||||
COPY orchestration/ orchestration/
|
||||
COPY repositories/ repositories/
|
||||
COPY utilities/ utilities/
|
||||
COPY applications/ applications/
|
||||
|
||||
# Place the handler at the Lambda task root so the runtime can resolve
|
||||
# ``main.handler`` without an extra package prefix.
|
||||
COPY applications/postcode_splitter/handler.py /var/task/main.py
|
||||
|
||||
CMD ["main.handler"]
|
||||
0
applications/postcode_splitter/__init__.py
Normal file
0
applications/postcode_splitter/__init__.py
Normal file
52
applications/postcode_splitter/handler.py
Normal file
52
applications/postcode_splitter/handler.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import boto3
|
||||
|
||||
from applications.postcode_splitter.postcode_splitter_trigger_body import (
|
||||
PostcodeSplitterTriggerBody,
|
||||
)
|
||||
from infrastructure.address2uprn_queue_client import Address2UprnQueueClient
|
||||
from infrastructure.csv_s3_client import CsvS3Client
|
||||
from orchestration.postcode_splitter_orchestrator import PostcodeSplitterOrchestrator
|
||||
from orchestration.task_orchestrator import TaskOrchestrator
|
||||
from repositories.user_address.user_address_csv_s3_repository import (
|
||||
UserAddressCsvS3Repository,
|
||||
)
|
||||
from utilities.aws_lambda.subtask_handler import subtask_handler
|
||||
|
||||
|
||||
@subtask_handler()
|
||||
def handler(
|
||||
body: dict[str, Any], context: Any, task_orchestrator: TaskOrchestrator
|
||||
) -> dict[str, list[str]]:
|
||||
trigger = PostcodeSplitterTriggerBody.model_validate(body)
|
||||
|
||||
bucket = os.environ["S3_BUCKET_NAME"]
|
||||
queue_url = os.environ["ADDRESS2UPRN_QUEUE_URL"]
|
||||
|
||||
# boto3.client is overloaded per-service in the installed stubs; cast
|
||||
# to Any so the strict-mode checker treats it as opaque.
|
||||
boto3_client: Any = boto3.client # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
||||
boto_s3: Any = boto3_client("s3")
|
||||
boto_sqs: Any = boto3_client("sqs")
|
||||
|
||||
csv_client = CsvS3Client(boto_s3, bucket)
|
||||
user_address_repo = UserAddressCsvS3Repository(csv_client, bucket)
|
||||
queue_client = Address2UprnQueueClient(boto_sqs, queue_url)
|
||||
|
||||
splitter = PostcodeSplitterOrchestrator(
|
||||
task_orchestrator=task_orchestrator,
|
||||
user_address_repo=user_address_repo,
|
||||
queue_client=queue_client,
|
||||
)
|
||||
|
||||
child_ids = splitter.split_and_dispatch(
|
||||
parent_task_id=trigger.task_id,
|
||||
parent_subtask_id=trigger.sub_task_id,
|
||||
input_s3_uri=trigger.s3_uri,
|
||||
)
|
||||
|
||||
return {"child_subtask_ids": [str(cid) for cid in child_ids]}
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
# Local-test environment for the postcode_splitter Lambda.
|
||||
#
|
||||
# cp .env.local.example .env.local then fill in the values below.
|
||||
#
|
||||
# .env.local is gitignored. The container hits REAL AWS and a REAL Postgres,
|
||||
# so every value here points at infrastructure that actually exists.
|
||||
#
|
||||
# NOTE: the new DDD code uses different env var names than the repo root
|
||||
# .env. The mapping (root .env name -> var here) is given per section.
|
||||
# Keep comments on their own lines — docker-compose's env_file parser folds a
|
||||
# trailing "# ..." into the value.
|
||||
|
||||
# --- Postgres (orchestration/default_orchestrator -> PostgresConfig.from_env) ---
|
||||
# POSTGRES_HOST <- DB_HOST, PORT <- DB_PORT, USERNAME <- DB_USERNAME,
|
||||
# PASSWORD <- DB_PASSWORD, DATABASE <- DB_NAME.
|
||||
POSTGRES_HOST=
|
||||
POSTGRES_PORT=5432
|
||||
POSTGRES_USERNAME=
|
||||
POSTGRES_PASSWORD=
|
||||
POSTGRES_DATABASE=
|
||||
# POSTGRES_DRIVER=psycopg2 (optional; defaults to psycopg2)
|
||||
|
||||
# --- Handler config (applications/postcode_splitter/handler.py) ---
|
||||
# S3_BUCKET_NAME: bucket holding the input address CSV (root .env: DATA_BUCKET).
|
||||
# ADDRESS2UPRN_QUEUE_URL: SQS queue the splitter fans batches out to; not in
|
||||
# the root .env (Terraform sets it in prod).
|
||||
S3_BUCKET_NAME=
|
||||
ADDRESS2UPRN_QUEUE_URL=
|
||||
|
||||
# --- AWS credentials for boto3 (S3 + SQS clients) ---
|
||||
AWS_ACCESS_KEY_ID=
|
||||
AWS_SECRET_ACCESS_KEY=
|
||||
AWS_DEFAULT_REGION=eu-west-2
|
||||
# AWS_SESSION_TOKEN= (only if using temporary/SSO credentials)
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
services:
|
||||
postcode-splitter:
|
||||
build:
|
||||
context: ../../../
|
||||
dockerfile: applications/postcode_splitter/Dockerfile
|
||||
ports:
|
||||
- "9001:8080"
|
||||
env_file:
|
||||
- .env.local
|
||||
28
applications/postcode_splitter/local_handler/invoke_local_lambda.py
Executable file
28
applications/postcode_splitter/local_handler/invoke_local_lambda.py
Executable file
|
|
@ -0,0 +1,28 @@
|
|||
#!/usr/bin/env python3
|
||||
import json
|
||||
import requests
|
||||
|
||||
HOST = "localhost"
|
||||
PORT = "9001"
|
||||
|
||||
LAMBDA_URL = f"http://{HOST}:{PORT}/2015-03-31/functions/function/invocations"
|
||||
|
||||
payload = {
|
||||
"Records": [
|
||||
{
|
||||
"body": json.dumps(
|
||||
{
|
||||
"task_id": "f4b3332f-c0cc-481f-96a5-d39860a647cf",
|
||||
"sub_task_id": "14c042de-40c4-473b-8cd8-72c983a94a8d",
|
||||
"s3_uri": "s3://retrofit-data-dev/ara_raw_inputs/calico/Calico Homes Full list EPC Properties(Sheet2) (1) (1).csv",
|
||||
}
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = requests.post(LAMBDA_URL, json=payload)
|
||||
|
||||
print("Status code:", response.status_code)
|
||||
print("Response:")
|
||||
print(response.text)
|
||||
12
applications/postcode_splitter/local_handler/run_local.sh
Executable file
12
applications/postcode_splitter/local_handler/run_local.sh
Executable file
|
|
@ -0,0 +1,12 @@
|
|||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
if [ ! -f .env.local ]; then
|
||||
cp .env.local.example .env.local
|
||||
echo "Created .env.local from the template — fill it in, then re-run." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
docker compose build --no-cache
|
||||
docker compose up --force-recreate
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class PostcodeSplitterTriggerBody(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
task_id: UUID
|
||||
sub_task_id: UUID
|
||||
s3_uri: str
|
||||
4
applications/postcode_splitter/requirements.txt
Normal file
4
applications/postcode_splitter/requirements.txt
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
boto3
|
||||
pydantic
|
||||
sqlmodel
|
||||
psycopg2-binary
|
||||
|
|
@ -8,4 +8,5 @@ boto3==1.35.44
|
|||
sqlmodel
|
||||
sqlalchemy==2.0.36
|
||||
psycopg2-binary==2.9.10
|
||||
pydantic-settings==2.6.0
|
||||
pydantic-settings==2.6.0
|
||||
httpx
|
||||
|
|
@ -40,20 +40,6 @@ module "lambda" {
|
|||
LOG_LEVEL = "info"
|
||||
DB_USERNAME = local.db_credentials.db_assessment_model_username
|
||||
DB_PASSWORD = local.db_credentials.db_assessment_model_password
|
||||
GOOGLE_SOLAR_API_KEY = "test"
|
||||
SAP_PREDICTIONS_BUCKET = "test"
|
||||
CARBON_PREDICTIONS_BUCKET = "test"
|
||||
HEAT_PREDICTIONS_BUCKET = "test"
|
||||
HEATING_KWH_PREDICTIONS_BUCKET = "test"
|
||||
HOTWATER_KWH_PREDICTIONS_BUCKET = "test"
|
||||
API_KEY = "test"
|
||||
ENVIRONMENT = "test"
|
||||
SECRET_KEY = "test"
|
||||
PLAN_TRIGGER_BUCKET = "test"
|
||||
DATA_BUCKET = "test"
|
||||
EPC_AUTH_TOKEN = "test"
|
||||
ENGINE_SQS_URL = "test"
|
||||
ENERGY_ASSESSMENTS_BUCKET = "test"
|
||||
ADDRESS2UPRN_QUEUE_URL = data.terraform_remote_state.address2uprn.outputs.address2uprn_queue_url
|
||||
S3_BUCKET_NAME = data.terraform_remote_state.shared.outputs.retrofit_sap_data_bucket_name
|
||||
},
|
||||
|
|
|
|||
0
domain/addresses/__init__.py
Normal file
0
domain/addresses/__init__.py
Normal file
51
domain/addresses/postcode_batching.py
Normal file
51
domain/addresses/postcode_batching.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Iterator
|
||||
|
||||
from domain.addresses.user_address import UserAddress
|
||||
from domain.postcode import Postcode
|
||||
|
||||
|
||||
def iter_postcode_grouped_batches(
|
||||
addresses: Iterable[UserAddress],
|
||||
*,
|
||||
max_batch_size: int = 500,
|
||||
) -> Iterator[list[UserAddress]]:
|
||||
if max_batch_size < 1:
|
||||
raise ValueError("max_batch_size must be >= 1")
|
||||
|
||||
groups = _group_by_postcode_in_order(addresses)
|
||||
|
||||
buffer: list[UserAddress] = []
|
||||
for group in groups.values():
|
||||
group_len = len(group)
|
||||
|
||||
# Oversize single-Postcode group: flush buffer first, then dispatch
|
||||
# the group as its own batch. Mirrors the legacy
|
||||
# ``if group_len >= batch_size`` branch.
|
||||
if group_len >= max_batch_size:
|
||||
if buffer:
|
||||
yield buffer
|
||||
buffer = []
|
||||
yield group
|
||||
continue
|
||||
|
||||
# Adding this group would overflow: flush buffer before appending.
|
||||
if len(buffer) + group_len > max_batch_size:
|
||||
yield buffer
|
||||
buffer = []
|
||||
|
||||
buffer.extend(group)
|
||||
|
||||
# Final flush.
|
||||
if buffer:
|
||||
yield buffer
|
||||
|
||||
|
||||
def _group_by_postcode_in_order(
|
||||
addresses: Iterable[UserAddress],
|
||||
) -> dict[Postcode, list[UserAddress]]:
|
||||
groups: dict[Postcode, list[UserAddress]] = {}
|
||||
for address in addresses:
|
||||
groups.setdefault(address.postcode, []).append(address)
|
||||
return groups
|
||||
18
domain/addresses/user_address.py
Normal file
18
domain/addresses/user_address.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from domain.postcode import Postcode
|
||||
|
||||
|
||||
def _empty_source_row() -> dict[str, str]:
|
||||
return {}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UserAddress:
|
||||
user_address: str
|
||||
postcode: Postcode
|
||||
internal_reference: Optional[str] = None
|
||||
source_row: dict[str, str] = field(default_factory=_empty_source_row, compare=False)
|
||||
15
domain/postcode.py
Normal file
15
domain/postcode.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Postcode:
|
||||
value: str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Frozen dataclass: bypass the descriptor with object.__setattr__.
|
||||
object.__setattr__(self, "value", "".join(self.value.split()).upper())
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
20
infrastructure/address2uprn_queue_client.py
Normal file
20
infrastructure/address2uprn_queue_client.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
from uuid import UUID
|
||||
|
||||
from infrastructure.sqs_client import SqsClient
|
||||
|
||||
|
||||
class Address2UprnQueueClient(SqsClient):
|
||||
def publish(
|
||||
self,
|
||||
*,
|
||||
parent_task_id: UUID,
|
||||
child_subtask_id: UUID,
|
||||
s3_uri: str,
|
||||
) -> str:
|
||||
return self.send(
|
||||
{
|
||||
"task_id": str(parent_task_id),
|
||||
"sub_task_id": str(child_subtask_id),
|
||||
"s3_uri": s3_uri,
|
||||
}
|
||||
)
|
||||
28
infrastructure/csv_s3_client.py
Normal file
28
infrastructure/csv_s3_client.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
import csv
|
||||
from io import StringIO
|
||||
|
||||
from infrastructure.s3_client import S3Client
|
||||
from infrastructure.s3_uri import parse_s3_uri
|
||||
|
||||
|
||||
class CsvS3Client(S3Client):
|
||||
def read_rows(self, s3_uri: str) -> list[dict[str, str]]:
|
||||
bucket, key = parse_s3_uri(s3_uri)
|
||||
if bucket != self.bucket:
|
||||
raise ValueError(
|
||||
f"s3_uri bucket {bucket!r} does not match client bucket {self.bucket!r}"
|
||||
)
|
||||
raw = self.get_object(key)
|
||||
text = raw.decode("utf-8-sig")
|
||||
reader = csv.DictReader(StringIO(text))
|
||||
return [dict(row) for row in reader]
|
||||
|
||||
def save_rows(self, rows: list[dict[str, str]], key: str) -> str:
|
||||
if not rows:
|
||||
raise ValueError("Cannot save an empty rows list: header is unknown")
|
||||
buffer = StringIO()
|
||||
fieldnames = list(rows[0].keys())
|
||||
writer = csv.DictWriter(buffer, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(rows)
|
||||
return self.put_object(key, buffer.getvalue().encode("utf-8"))
|
||||
22
infrastructure/s3_client.py
Normal file
22
infrastructure/s3_client.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from typing import Any
|
||||
|
||||
|
||||
class S3Client:
|
||||
def __init__(self, boto_s3_client: Any, bucket: str) -> None:
|
||||
self._client = boto_s3_client
|
||||
self._bucket = bucket
|
||||
|
||||
@property
|
||||
def bucket(self) -> str:
|
||||
return self._bucket
|
||||
|
||||
def get_object(self, key: str) -> bytes:
|
||||
response: dict[str, Any] = self._client.get_object(
|
||||
Bucket=self._bucket, Key=key
|
||||
)
|
||||
body: bytes = response["Body"].read()
|
||||
return body
|
||||
|
||||
def put_object(self, key: str, body: bytes) -> str:
|
||||
self._client.put_object(Bucket=self._bucket, Key=key, Body=body)
|
||||
return f"s3://{self._bucket}/{key}"
|
||||
25
infrastructure/s3_uri.py
Normal file
25
infrastructure/s3_uri.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
from urllib.parse import unquote
|
||||
|
||||
|
||||
def parse_s3_uri(s3_uri: str) -> tuple[str, str]:
|
||||
if s3_uri.startswith("s3://"):
|
||||
parts = s3_uri[len("s3://") :].split("/", 1)
|
||||
if len(parts) < 2 or not parts[0] or not parts[1]:
|
||||
raise ValueError("S3 URI must include both a bucket and a key")
|
||||
return parts[0], parts[1]
|
||||
|
||||
if "?" not in s3_uri:
|
||||
raise ValueError(f"Not an s3:// URI and has no query string: {s3_uri!r}")
|
||||
base, query = s3_uri.split("?", 1)
|
||||
|
||||
if "/s3/object/" not in base:
|
||||
raise ValueError(f"Console URL has no '/s3/object/' segment: {s3_uri!r}")
|
||||
bucket = base.split("/s3/object/", 1)[1]
|
||||
|
||||
params: dict[str, str] = {}
|
||||
for item in query.split("&"):
|
||||
if "=" in item:
|
||||
name, value = item.split("=", 1)
|
||||
params[name] = value
|
||||
key = unquote(params.get("prefix", ""))
|
||||
return bucket, key
|
||||
20
infrastructure/sqs_client.py
Normal file
20
infrastructure/sqs_client.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
import json
|
||||
from typing import Any
|
||||
|
||||
|
||||
class SqsClient:
|
||||
def __init__(self, boto_sqs_client: Any, queue_url: str) -> None:
|
||||
self._client = boto_sqs_client
|
||||
self._queue_url = queue_url
|
||||
|
||||
@property
|
||||
def queue_url(self) -> str:
|
||||
return self._queue_url
|
||||
|
||||
def send(self, body: dict[str, Any]) -> str:
|
||||
response: dict[str, Any] = self._client.send_message(
|
||||
QueueUrl=self._queue_url,
|
||||
MessageBody=json.dumps(body),
|
||||
)
|
||||
message_id: str = response["MessageId"]
|
||||
return message_id
|
||||
55
orchestration/postcode_splitter_orchestrator.py
Normal file
55
orchestration/postcode_splitter_orchestrator.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from infrastructure.address2uprn_queue_client import Address2UprnQueueClient
|
||||
from orchestration.task_orchestrator import TaskOrchestrator
|
||||
from domain.addresses.postcode_batching import iter_postcode_grouped_batches
|
||||
from repositories.user_address.user_address_repository import UserAddressRepository
|
||||
|
||||
|
||||
class PostcodeSplitterOrchestrator:
|
||||
def __init__(
|
||||
self,
|
||||
task_orchestrator: TaskOrchestrator,
|
||||
user_address_repo: UserAddressRepository,
|
||||
queue_client: Address2UprnQueueClient,
|
||||
max_batch_size: int = 500,
|
||||
) -> None:
|
||||
self._task_orchestrator = task_orchestrator
|
||||
self._user_address_repo = user_address_repo
|
||||
self._queue_client = queue_client
|
||||
self._max_batch_size = max_batch_size
|
||||
|
||||
def split_and_dispatch(
|
||||
self,
|
||||
*,
|
||||
parent_task_id: UUID,
|
||||
parent_subtask_id: UUID,
|
||||
input_s3_uri: str,
|
||||
) -> list[UUID]:
|
||||
addresses = self._user_address_repo.load_batch(input_s3_uri)
|
||||
path_prefix = (
|
||||
f"ara_postcode_splitter_batches/{parent_task_id}/{parent_subtask_id}"
|
||||
)
|
||||
|
||||
child_ids: list[UUID] = []
|
||||
for batch in iter_postcode_grouped_batches(
|
||||
addresses, max_batch_size=self._max_batch_size
|
||||
):
|
||||
batch_uri = self._user_address_repo.save_batch(batch, path_prefix)
|
||||
child = self._task_orchestrator.create_child_subtask(
|
||||
parent_task_id,
|
||||
inputs={
|
||||
"task_id": str(parent_task_id),
|
||||
"s3_uri": batch_uri,
|
||||
},
|
||||
)
|
||||
self._queue_client.publish(
|
||||
parent_task_id=parent_task_id,
|
||||
child_subtask_id=child.id,
|
||||
s3_uri=batch_uri,
|
||||
)
|
||||
child_ids.append(child.id)
|
||||
|
||||
return child_ids
|
||||
|
|
@ -48,6 +48,16 @@ class TaskOrchestrator:
|
|||
self._subtasks.create(subtask)
|
||||
return task, subtask
|
||||
|
||||
def create_child_subtask(
|
||||
self,
|
||||
parent_task_id: UUID,
|
||||
*,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
) -> SubTask:
|
||||
subtask = SubTask.create(task_id=parent_task_id, inputs=inputs)
|
||||
self._subtasks.create(subtask)
|
||||
return subtask
|
||||
|
||||
def start_subtask(
|
||||
self, subtask_id: UUID, cloud_logs_url: Optional[str] = None
|
||||
) -> SubTask:
|
||||
|
|
|
|||
0
repositories/user_address/__init__.py
Normal file
0
repositories/user_address/__init__.py
Normal file
63
repositories/user_address/user_address_csv_s3_repository.py
Normal file
63
repositories/user_address/user_address_csv_s3_repository.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from domain.addresses.user_address import UserAddress
|
||||
from domain.postcode import Postcode
|
||||
from infrastructure.csv_s3_client import CsvS3Client
|
||||
from repositories.user_address.user_address_repository import UserAddressRepository
|
||||
|
||||
_ADDRESS_COLUMNS: tuple[str, str, str] = ("Address 1", "Address 2", "Address 3")
|
||||
_POSTCODE_COLUMN: str = "postcode"
|
||||
_INTERNAL_REFERENCE_COLUMN: str = "Internal Reference"
|
||||
_POSTCODE_CLEAN_COLUMN: str = "postcode_clean"
|
||||
|
||||
|
||||
class UserAddressCsvS3Repository(UserAddressRepository):
|
||||
def __init__(self, csv_client: CsvS3Client, bucket: str) -> None:
|
||||
self._csv_client = csv_client
|
||||
self._bucket = bucket
|
||||
|
||||
def load_batch(self, s3_uri: str) -> list[UserAddress]:
|
||||
rows = self._csv_client.read_rows(s3_uri)
|
||||
if rows and _POSTCODE_COLUMN not in rows[0]:
|
||||
raise ValueError(
|
||||
f"Input CSV {s3_uri} has no {_POSTCODE_COLUMN!r} column; "
|
||||
f"columns present: {sorted(rows[0])}"
|
||||
)
|
||||
addresses: list[UserAddress] = []
|
||||
for row in rows:
|
||||
parts = [
|
||||
row[col].strip()
|
||||
for col in _ADDRESS_COLUMNS
|
||||
if col in row and row[col].strip()
|
||||
]
|
||||
user_address = ", ".join(parts)
|
||||
postcode = row.get(_POSTCODE_COLUMN, "")
|
||||
raw_ref = row.get(_INTERNAL_REFERENCE_COLUMN, "").strip()
|
||||
internal_reference: Optional[str] = raw_ref or None
|
||||
addresses.append(
|
||||
UserAddress(
|
||||
user_address=user_address,
|
||||
postcode=Postcode(postcode),
|
||||
internal_reference=internal_reference,
|
||||
source_row=row,
|
||||
)
|
||||
)
|
||||
return addresses
|
||||
|
||||
def save_batch(self, addresses: list[UserAddress], path_prefix: str) -> str:
|
||||
rows: list[dict[str, str]] = [
|
||||
{**addr.source_row, _POSTCODE_CLEAN_COLUMN: str(addr.postcode)}
|
||||
for addr in addresses
|
||||
]
|
||||
|
||||
# TODO: [New Starter Task] file_name generation can be standardised
|
||||
# and also easier to read, test for future implementation. Buiild that!
|
||||
filename = (
|
||||
f"{datetime.now(timezone.utc).isoformat()}_{uuid.uuid4().hex[:8]}.csv"
|
||||
)
|
||||
key = f"{path_prefix.rstrip('/')}/{filename}"
|
||||
return self._csv_client.save_rows(rows, key)
|
||||
13
repositories/user_address/user_address_repository.py
Normal file
13
repositories/user_address/user_address_repository.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from domain.addresses.user_address import UserAddress
|
||||
|
||||
|
||||
class UserAddressRepository(ABC):
|
||||
@abstractmethod
|
||||
def load_batch(self, s3_uri: str) -> list[UserAddress]: ...
|
||||
|
||||
@abstractmethod
|
||||
def save_batch(self, addresses: list[UserAddress], path_prefix: str) -> str: ...
|
||||
|
|
@ -9,4 +9,5 @@ hubspot-api-client
|
|||
fuzzywuzzy
|
||||
pymupdf
|
||||
playwright==1.58.0
|
||||
msal
|
||||
msal
|
||||
moto[s3,sqs]
|
||||
48
tests/conftest.py
Normal file
48
tests/conftest.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
"""Shared pytest fixtures for the ``tests/`` tree.
|
||||
|
||||
Provides an ephemeral PostgreSQL engine for tests that exercise SQLModel
|
||||
repositories. PostgreSQL has no true in-memory mode; ``pytest-postgresql``
|
||||
starts a real, throwaway server in a temp directory (the process is started
|
||||
once per session and a fresh database is created/dropped per test). That is
|
||||
the closest equivalent to "in-memory" and matches production behaviour far
|
||||
better than SQLite (enums, JSONB, constraint semantics, etc.).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from psycopg import Connection
|
||||
from pytest_postgresql import factories
|
||||
from sqlalchemy import Engine
|
||||
from sqlmodel import SQLModel, create_engine
|
||||
|
||||
# Importing the SQLModel row modules registers their tables on
|
||||
# SQLModel.metadata so ``create_all`` builds the full schema. Imports look
|
||||
# unused; they aren't.
|
||||
|
||||
|
||||
# pg_ctl ships under a versioned path and is not on PATH in the dev container.
|
||||
_PG_CTL = next(iter(sorted(glob.glob("/usr/lib/postgresql/*/bin/pg_ctl"))), "pg_ctl")
|
||||
|
||||
postgresql_proc = factories.postgresql_proc(
|
||||
executable=_PG_CTL
|
||||
) # pyright: ignore[reportUnknownMemberType]
|
||||
postgresql = factories.postgresql("postgresql_proc")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_engine(postgresql: Connection[Any]) -> Iterator[Engine]:
|
||||
"""A SQLModel engine bound to a fresh, ephemeral PostgreSQL database."""
|
||||
info = postgresql.info
|
||||
url = f"postgresql+psycopg://{info.user}:@{info.host}:{info.port}/{info.dbname}"
|
||||
engine = create_engine(url)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
try:
|
||||
yield engine
|
||||
finally:
|
||||
SQLModel.metadata.drop_all(engine)
|
||||
engine.dispose()
|
||||
0
tests/domain/addresses/__init__.py
Normal file
0
tests/domain/addresses/__init__.py
Normal file
118
tests/domain/addresses/test_postcode_batching.py
Normal file
118
tests/domain/addresses/test_postcode_batching.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
import pytest
|
||||
|
||||
from domain.addresses.postcode_batching import iter_postcode_grouped_batches
|
||||
from domain.addresses.user_address import UserAddress
|
||||
from domain.postcode import Postcode
|
||||
|
||||
|
||||
def _addrs(postcode: str, n: int) -> list[UserAddress]:
|
||||
return [
|
||||
UserAddress(
|
||||
user_address=f"{i} {postcode} Street", postcode=Postcode(postcode)
|
||||
)
|
||||
for i in range(n)
|
||||
]
|
||||
|
||||
|
||||
def test_empty_input_yields_no_batches() -> None:
|
||||
# act / assert
|
||||
assert list(iter_postcode_grouped_batches([])) == []
|
||||
|
||||
|
||||
def test_single_batch_under_cap() -> None:
|
||||
# arrange
|
||||
addrs = _addrs("AA1 1AA", 3) + _addrs("BB2 2BB", 2)
|
||||
# act
|
||||
batches = list(iter_postcode_grouped_batches(addrs, max_batch_size=500))
|
||||
# assert
|
||||
assert len(batches) == 1
|
||||
assert batches[0] == addrs
|
||||
|
||||
|
||||
def test_multiple_postcodes_packed_into_one_batch_up_to_cap() -> None:
|
||||
# Two groups whose total exactly equals the cap pack into a single
|
||||
# batch -- no premature flush.
|
||||
# arrange
|
||||
addrs = _addrs("AA1 1AA", 3) + _addrs("BB2 2BB", 2)
|
||||
# act
|
||||
batches = list(iter_postcode_grouped_batches(addrs, max_batch_size=5))
|
||||
# assert
|
||||
assert len(batches) == 1
|
||||
assert len(batches[0]) == 5
|
||||
|
||||
|
||||
def test_flush_on_overflow_before_adding_next_postcode() -> None:
|
||||
# Cap is 5. First group fills 3 slots; second group of 3 would overflow,
|
||||
# so the buffer is flushed first and the next group starts a fresh batch.
|
||||
# arrange
|
||||
addrs = _addrs("AA1 1AA", 3) + _addrs("BB2 2BB", 3)
|
||||
# act
|
||||
batches = list(iter_postcode_grouped_batches(addrs, max_batch_size=5))
|
||||
# assert
|
||||
assert len(batches) == 2
|
||||
assert [str(a.postcode) for a in batches[0]] == ["AA11AA"] * 3
|
||||
assert [str(a.postcode) for a in batches[1]] == ["BB22BB"] * 3
|
||||
|
||||
|
||||
def test_single_postcode_group_exceeding_cap_is_dispatched_whole() -> None:
|
||||
# An oversize single-postcode group goes out as one batch larger than
|
||||
# the cap -- the cap never splits a postcode.
|
||||
# arrange
|
||||
addrs = _addrs("AA1 1AA", 7)
|
||||
# act
|
||||
batches = list(iter_postcode_grouped_batches(addrs, max_batch_size=5))
|
||||
# assert
|
||||
assert len(batches) == 1
|
||||
assert len(batches[0]) == 7
|
||||
|
||||
|
||||
def test_oversize_group_flushes_existing_buffer_first() -> None:
|
||||
# Mirrors the legacy ``if buffer: flush`` branch when an oversize group
|
||||
# is encountered: buffered work must not be lost or interleaved.
|
||||
# arrange
|
||||
small = _addrs("AA1 1AA", 2)
|
||||
big = _addrs("BB2 2BB", 7)
|
||||
tail = _addrs("CC3 3CC", 1)
|
||||
# act
|
||||
batches = list(
|
||||
iter_postcode_grouped_batches(small + big + tail, max_batch_size=5)
|
||||
)
|
||||
# assert
|
||||
assert len(batches) == 3
|
||||
assert [str(a.postcode) for a in batches[0]] == ["AA11AA", "AA11AA"]
|
||||
assert [str(a.postcode) for a in batches[1]] == ["BB22BB"] * 7
|
||||
assert [str(a.postcode) for a in batches[2]] == ["CC33CC"]
|
||||
|
||||
|
||||
def test_final_flush_yields_remaining_buffer() -> None:
|
||||
# No overflow ever happens, but the trailing buffer must still come out.
|
||||
# arrange
|
||||
addrs = _addrs("AA1 1AA", 2) + _addrs("BB2 2BB", 2)
|
||||
# act
|
||||
batches = list(iter_postcode_grouped_batches(addrs, max_batch_size=500))
|
||||
# assert
|
||||
assert batches == [addrs]
|
||||
|
||||
|
||||
def test_postcode_grouping_preserves_first_seen_order() -> None:
|
||||
# Interleaved input must still group by postcode and emit in first-seen
|
||||
# order -- never alphabetical.
|
||||
# arrange
|
||||
a1, a2 = _addrs("ZZ9 9ZZ", 2)
|
||||
b1, b2 = _addrs("AA1 1AA", 2)
|
||||
# act
|
||||
batches = list(iter_postcode_grouped_batches([a1, b1, a2, b2]))
|
||||
# assert
|
||||
assert len(batches) == 1
|
||||
assert [str(a.postcode) for a in batches[0]] == [
|
||||
"ZZ99ZZ",
|
||||
"ZZ99ZZ",
|
||||
"AA11AA",
|
||||
"AA11AA",
|
||||
]
|
||||
|
||||
|
||||
def test_invalid_max_batch_size_raises() -> None:
|
||||
# act / assert
|
||||
with pytest.raises(ValueError, match="max_batch_size"):
|
||||
list(iter_postcode_grouped_batches([], max_batch_size=0))
|
||||
98
tests/domain/addresses/test_user_address.py
Normal file
98
tests/domain/addresses/test_user_address.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
import dataclasses
|
||||
|
||||
import pytest
|
||||
|
||||
from domain.addresses.user_address import UserAddress
|
||||
from domain.postcode import Postcode
|
||||
|
||||
|
||||
def test_user_address_holds_postcode_value_object() -> None:
|
||||
# act
|
||||
addr = UserAddress(user_address="1 The Street", postcode=Postcode("sw1a 1aa"))
|
||||
# assert
|
||||
assert addr.postcode == Postcode("SW1A1AA")
|
||||
|
||||
|
||||
def test_user_address_preserves_user_address_verbatim() -> None:
|
||||
# The free-text user_address string is intentionally NOT normalised --
|
||||
# only the postcode is canonicalised, and that happens inside Postcode.
|
||||
# act
|
||||
addr = UserAddress(
|
||||
user_address=" 1 The Street ", postcode=Postcode("SW1A1AA")
|
||||
)
|
||||
# assert
|
||||
assert addr.user_address == " 1 The Street "
|
||||
|
||||
|
||||
def test_user_address_internal_reference_defaults_to_none() -> None:
|
||||
# act
|
||||
addr = UserAddress(user_address="1 The Street", postcode=Postcode("SW1A1AA"))
|
||||
# assert
|
||||
assert addr.internal_reference is None
|
||||
|
||||
|
||||
def test_user_address_internal_reference_accepted() -> None:
|
||||
# act
|
||||
addr = UserAddress(
|
||||
user_address="1 The Street",
|
||||
postcode=Postcode("SW1A1AA"),
|
||||
internal_reference="cust-42",
|
||||
)
|
||||
# assert
|
||||
assert addr.internal_reference == "cust-42"
|
||||
|
||||
|
||||
def test_user_address_is_frozen() -> None:
|
||||
# arrange
|
||||
addr = UserAddress(user_address="1 The Street", postcode=Postcode("SW1A1AA"))
|
||||
# act / assert
|
||||
with pytest.raises(dataclasses.FrozenInstanceError):
|
||||
addr.postcode = Postcode("OTHER") # type: ignore[misc]
|
||||
|
||||
|
||||
def test_user_address_equality_uses_canonical_postcode() -> None:
|
||||
# Postcode sanitises eagerly, so addresses built from different surface
|
||||
# forms of the same postcode compare equal.
|
||||
# arrange
|
||||
a = UserAddress(user_address="1 The Street", postcode=Postcode("sw1a 1aa"))
|
||||
b = UserAddress(user_address="1 The Street", postcode=Postcode("SW1A1AA"))
|
||||
# act / assert
|
||||
assert a == b
|
||||
|
||||
|
||||
def test_user_address_source_row_defaults_to_empty_dict() -> None:
|
||||
# act
|
||||
addr = UserAddress(user_address="1 The Street", postcode=Postcode("SW1A1AA"))
|
||||
# assert
|
||||
assert addr.source_row == {}
|
||||
|
||||
|
||||
def test_user_address_carries_source_row() -> None:
|
||||
# arrange
|
||||
row = {"Address 1": "1 The Street", "postcode": "SW1A 1AA", "SAP Score": "72"}
|
||||
# act
|
||||
addr = UserAddress(
|
||||
user_address="1 The Street",
|
||||
postcode=Postcode("SW1A 1AA"),
|
||||
source_row=row,
|
||||
)
|
||||
# assert
|
||||
assert addr.source_row == row
|
||||
|
||||
|
||||
def test_user_address_equality_ignores_source_row() -> None:
|
||||
# source_row is excluded from equality (and hashing): identity stays
|
||||
# defined by the parsed fields.
|
||||
# arrange
|
||||
a = UserAddress(
|
||||
user_address="1 The Street",
|
||||
postcode=Postcode("SW1A1AA"),
|
||||
source_row={"x": "1"},
|
||||
)
|
||||
b = UserAddress(
|
||||
user_address="1 The Street",
|
||||
postcode=Postcode("SW1A1AA"),
|
||||
source_row={"y": "2"},
|
||||
)
|
||||
# act / assert
|
||||
assert a == b
|
||||
|
|
@ -6,10 +6,13 @@ from domain.tasks.subtasks import SubTask, SubTaskStatus
|
|||
|
||||
|
||||
def test_create_subtask_starts_waiting() -> None:
|
||||
# arrange
|
||||
task_id = uuid4()
|
||||
|
||||
# act
|
||||
st = SubTask.create(task_id=task_id, inputs={"foo": "bar"})
|
||||
|
||||
# assert
|
||||
assert st.task_id == task_id
|
||||
assert st.status is SubTaskStatus.WAITING
|
||||
assert st.inputs == {"foo": "bar"}
|
||||
|
|
@ -19,57 +22,74 @@ def test_create_subtask_starts_waiting() -> None:
|
|||
|
||||
|
||||
def test_start_transitions_to_in_progress_and_sets_cloud_logs_url() -> None:
|
||||
# arrange
|
||||
st = SubTask.create(task_id=uuid4())
|
||||
|
||||
# act
|
||||
st.start(cloud_logs_url="https://example/log")
|
||||
|
||||
# assert
|
||||
assert st.status is SubTaskStatus.IN_PROGRESS
|
||||
assert st.cloud_logs_url == "https://example/log"
|
||||
assert st.job_started is not None
|
||||
|
||||
|
||||
def test_start_is_idempotent_from_in_progress() -> None:
|
||||
# arrange
|
||||
st = SubTask.create(task_id=uuid4())
|
||||
st.start()
|
||||
first_start = st.job_started
|
||||
|
||||
# act
|
||||
st.start(cloud_logs_url="https://other")
|
||||
|
||||
# assert
|
||||
assert st.status is SubTaskStatus.IN_PROGRESS
|
||||
assert st.job_started == first_start # not overwritten
|
||||
assert st.cloud_logs_url == "https://other"
|
||||
|
||||
|
||||
def test_start_rejects_from_terminal_status() -> None:
|
||||
# arrange
|
||||
st = SubTask.create(task_id=uuid4())
|
||||
st.complete()
|
||||
# act / assert
|
||||
with pytest.raises(ValueError):
|
||||
st.start()
|
||||
|
||||
|
||||
def test_complete_marks_outputs_and_job_completed() -> None:
|
||||
# arrange
|
||||
st = SubTask.create(task_id=uuid4())
|
||||
st.start()
|
||||
|
||||
# act
|
||||
st.complete({"uprn": "123"})
|
||||
|
||||
# assert
|
||||
assert st.status is SubTaskStatus.COMPLETE
|
||||
assert st.outputs == {"result": {"uprn": "123"}}
|
||||
assert st.job_completed is not None
|
||||
|
||||
|
||||
def test_complete_without_result_leaves_outputs_unset() -> None:
|
||||
# arrange
|
||||
st = SubTask.create(task_id=uuid4())
|
||||
# act
|
||||
st.complete()
|
||||
# assert
|
||||
assert st.outputs is None
|
||||
|
||||
|
||||
def test_fail_records_error_in_outputs() -> None:
|
||||
# arrange
|
||||
st = SubTask.create(task_id=uuid4())
|
||||
err = RuntimeError("boom")
|
||||
|
||||
# act
|
||||
st.fail(err)
|
||||
|
||||
# assert
|
||||
assert st.status is SubTaskStatus.FAILED
|
||||
assert st.outputs == {"error": "boom"}
|
||||
assert st.job_completed is not None
|
||||
|
|
|
|||
|
|
@ -5,12 +5,12 @@ from domain.tasks.tasks import Source, Task, TaskStatus
|
|||
|
||||
|
||||
def test_create_task_starts_waiting() -> None:
|
||||
# Arrange / Act
|
||||
# arrange / act
|
||||
t = Task.create(
|
||||
task_source="manual:test", source=Source.PORTFOLIO, source_id="abc-123"
|
||||
)
|
||||
|
||||
# Assert
|
||||
# assert
|
||||
assert t.status is TaskStatus.WAITING
|
||||
assert t.source is Source.PORTFOLIO
|
||||
assert t.source_id == "abc-123"
|
||||
|
|
@ -19,86 +19,113 @@ def test_create_task_starts_waiting() -> None:
|
|||
|
||||
|
||||
def test_create_task_rejects_blank_task_source() -> None:
|
||||
# act / assert
|
||||
with pytest.raises(ValueError, match="task_source"):
|
||||
Task.create(task_source=" ")
|
||||
|
||||
|
||||
def test_start_transitions_to_in_progress() -> None:
|
||||
# arrange
|
||||
t = Task.create(task_source="manual:test")
|
||||
# act
|
||||
t.start()
|
||||
# assert
|
||||
assert t.status is TaskStatus.IN_PROGRESS
|
||||
|
||||
|
||||
def test_complete_marks_job_completed() -> None:
|
||||
# arrange
|
||||
t = Task.create(task_source="manual:test")
|
||||
t.start()
|
||||
# act
|
||||
t.complete()
|
||||
# assert
|
||||
assert t.status is TaskStatus.COMPLETE
|
||||
assert t.job_completed is not None
|
||||
|
||||
|
||||
def test_fail_marks_job_completed() -> None:
|
||||
# arrange
|
||||
t = Task.create(task_source="manual:test")
|
||||
# act
|
||||
t.fail()
|
||||
# assert
|
||||
assert t.status is TaskStatus.FAILED
|
||||
assert t.job_completed is not None
|
||||
|
||||
|
||||
def test_start_rejects_from_terminal_status() -> None:
|
||||
# arrange
|
||||
t = Task.create(task_source="manual:test")
|
||||
t.complete()
|
||||
# act / assert
|
||||
with pytest.raises(ValueError):
|
||||
t.start()
|
||||
|
||||
|
||||
def test_recalculate_with_empty_statuses_is_noop() -> None:
|
||||
# arrange
|
||||
t = Task.create(task_source="manual:test")
|
||||
original_status = t.status
|
||||
original_completed = t.job_completed
|
||||
|
||||
# act
|
||||
t.recalculate_from_subtasks([])
|
||||
|
||||
# assert
|
||||
assert t.status is original_status
|
||||
assert t.job_completed is original_completed
|
||||
|
||||
|
||||
def test_recalculate_all_waiting_keeps_waiting() -> None:
|
||||
# arrange
|
||||
t = Task.create(task_source="manual:test")
|
||||
t.start() # task moved to IN_PROGRESS earlier
|
||||
t.complete() # then COMPLETE, with job_completed set
|
||||
|
||||
# act
|
||||
t.recalculate_from_subtasks([SubTaskStatus.WAITING, SubTaskStatus.WAITING])
|
||||
|
||||
# assert
|
||||
assert t.status is TaskStatus.WAITING
|
||||
assert t.job_completed is None
|
||||
|
||||
|
||||
def test_recalculate_any_in_progress_marks_in_progress() -> None:
|
||||
# arrange
|
||||
t = Task.create(task_source="manual:test")
|
||||
|
||||
# act
|
||||
t.recalculate_from_subtasks(
|
||||
[SubTaskStatus.WAITING, SubTaskStatus.IN_PROGRESS, SubTaskStatus.COMPLETE]
|
||||
)
|
||||
|
||||
# assert
|
||||
assert t.status is TaskStatus.IN_PROGRESS
|
||||
assert t.job_completed is None
|
||||
|
||||
|
||||
def test_recalculate_all_complete_marks_complete() -> None:
|
||||
# arrange
|
||||
t = Task.create(task_source="manual:test")
|
||||
|
||||
# act
|
||||
t.recalculate_from_subtasks([SubTaskStatus.COMPLETE, SubTaskStatus.COMPLETE])
|
||||
|
||||
# assert
|
||||
assert t.status is TaskStatus.COMPLETE
|
||||
assert t.job_completed is not None
|
||||
|
||||
|
||||
def test_recalculate_any_failed_marks_failed_even_with_others() -> None:
|
||||
# arrange
|
||||
t = Task.create(task_source="manual:test")
|
||||
|
||||
# act
|
||||
t.recalculate_from_subtasks(
|
||||
[SubTaskStatus.IN_PROGRESS, SubTaskStatus.COMPLETE, SubTaskStatus.FAILED]
|
||||
)
|
||||
|
||||
# assert
|
||||
assert t.status is TaskStatus.FAILED
|
||||
assert t.job_completed is not None
|
||||
|
|
|
|||
59
tests/domain/test_postcode.py
Normal file
59
tests/domain/test_postcode.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
import dataclasses
|
||||
|
||||
import pytest
|
||||
|
||||
from domain.postcode import Postcode
|
||||
|
||||
|
||||
def test_postcode_uppercases() -> None:
|
||||
# act / assert
|
||||
assert Postcode("sw1a1aa").value == "SW1A1AA"
|
||||
|
||||
|
||||
def test_postcode_strips_internal_spaces() -> None:
|
||||
# act / assert
|
||||
assert Postcode("sw1a 1aa").value == "SW1A1AA"
|
||||
|
||||
|
||||
def test_postcode_strips_leading_and_trailing_whitespace() -> None:
|
||||
# act / assert
|
||||
assert Postcode(" sw1a 1aa ").value == "SW1A1AA"
|
||||
|
||||
|
||||
def test_postcode_strips_tabs_and_newlines() -> None:
|
||||
# CSV ingestion occasionally introduces stray whitespace characters; the
|
||||
# canonical form must absorb them just like literal spaces.
|
||||
# act / assert
|
||||
assert Postcode("sw1a\t1aa\n").value == "SW1A1AA"
|
||||
|
||||
|
||||
def test_postcode_construction_is_idempotent() -> None:
|
||||
# arrange
|
||||
once = Postcode("sw1a 1aa")
|
||||
# act / assert
|
||||
assert Postcode(once.value).value == "SW1A1AA"
|
||||
|
||||
|
||||
def test_postcode_empty_string() -> None:
|
||||
# act / assert
|
||||
assert Postcode("").value == ""
|
||||
|
||||
|
||||
def test_postcode_str_returns_canonical_value() -> None:
|
||||
# act / assert
|
||||
assert str(Postcode("sw1a 1aa")) == "SW1A1AA"
|
||||
|
||||
|
||||
def test_postcode_equality_ignores_surface_form() -> None:
|
||||
# Differing case / whitespace sanitise to the same canonical value, so
|
||||
# the value objects compare equal.
|
||||
# act / assert
|
||||
assert Postcode("sw1a 1aa") == Postcode("SW1A1AA")
|
||||
|
||||
|
||||
def test_postcode_is_frozen() -> None:
|
||||
# arrange
|
||||
postcode = Postcode("SW1A1AA")
|
||||
# act / assert
|
||||
with pytest.raises(dataclasses.FrozenInstanceError):
|
||||
postcode.value = "OTHER" # type: ignore[misc]
|
||||
10
tests/infrastructure/__init__.py
Normal file
10
tests/infrastructure/__init__.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
from typing import Any
|
||||
|
||||
import boto3
|
||||
|
||||
REGION = "us-east-1"
|
||||
|
||||
|
||||
def make_boto_client(service_name: str) -> Any:
|
||||
factory: Any = boto3.client # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
||||
return factory(service_name, region_name=REGION)
|
||||
28
tests/infrastructure/conftest.py
Normal file
28
tests/infrastructure/conftest.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
import os
|
||||
from collections.abc import Iterator
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _aws_creds() -> Iterator[None]: # pyright: ignore[reportUnusedFunction]
|
||||
keys = (
|
||||
"AWS_ACCESS_KEY_ID",
|
||||
"AWS_SECRET_ACCESS_KEY",
|
||||
"AWS_SESSION_TOKEN",
|
||||
"AWS_DEFAULT_REGION",
|
||||
)
|
||||
prev: dict[str, Optional[str]] = {k: os.environ.get(k) for k in keys}
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = "testing"
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = "testing"
|
||||
os.environ["AWS_SESSION_TOKEN"] = "testing"
|
||||
os.environ["AWS_DEFAULT_REGION"] = "us-east-1"
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for k, v in prev.items():
|
||||
if v is None:
|
||||
os.environ.pop(k, None)
|
||||
else:
|
||||
os.environ[k] = v
|
||||
71
tests/infrastructure/test_address2uprn_queue_client.py
Normal file
71
tests/infrastructure/test_address2uprn_queue_client.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
import json
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from moto import mock_aws
|
||||
|
||||
from infrastructure.address2uprn_queue_client import Address2UprnQueueClient
|
||||
from tests.infrastructure import make_boto_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def queue_setup() -> Iterator[tuple[Address2UprnQueueClient, Any, str]]:
|
||||
with mock_aws():
|
||||
boto_client = make_boto_client("sqs")
|
||||
queue: dict[str, Any] = boto_client.create_queue(
|
||||
QueueName="address2uprn-queue"
|
||||
)
|
||||
queue_url = cast(str, queue["QueueUrl"])
|
||||
yield (
|
||||
Address2UprnQueueClient(boto_client, queue_url),
|
||||
boto_client,
|
||||
queue_url,
|
||||
)
|
||||
|
||||
|
||||
def test_publish_returns_message_id(
|
||||
queue_setup: tuple[Address2UprnQueueClient, Any, str],
|
||||
) -> None:
|
||||
# arrange
|
||||
client, _boto, _url = queue_setup
|
||||
# act
|
||||
message_id = client.publish(
|
||||
parent_task_id=uuid4(),
|
||||
child_subtask_id=uuid4(),
|
||||
s3_uri="s3://my-bucket/path/to/chunk.csv",
|
||||
)
|
||||
# assert
|
||||
assert isinstance(message_id, str)
|
||||
assert message_id
|
||||
|
||||
|
||||
def test_publish_body_uses_typed_shape(
|
||||
queue_setup: tuple[Address2UprnQueueClient, Any, str],
|
||||
) -> None:
|
||||
# arrange
|
||||
client, boto_client, queue_url = queue_setup
|
||||
parent_id = uuid4()
|
||||
child_id = uuid4()
|
||||
s3_uri = "s3://my-bucket/path/to/chunk.csv"
|
||||
|
||||
# act
|
||||
client.publish(
|
||||
parent_task_id=parent_id,
|
||||
child_subtask_id=child_id,
|
||||
s3_uri=s3_uri,
|
||||
)
|
||||
|
||||
# assert
|
||||
received: dict[str, Any] = boto_client.receive_message(
|
||||
QueueUrl=queue_url, MaxNumberOfMessages=1
|
||||
)
|
||||
messages: list[dict[str, Any]] = received["Messages"]
|
||||
assert len(messages) == 1
|
||||
body = json.loads(messages[0]["Body"])
|
||||
assert body == {
|
||||
"task_id": str(parent_id),
|
||||
"sub_task_id": str(child_id),
|
||||
"s3_uri": s3_uri,
|
||||
}
|
||||
51
tests/infrastructure/test_csv_s3_client.py
Normal file
51
tests/infrastructure/test_csv_s3_client.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
from collections.abc import Iterator
|
||||
|
||||
import pytest
|
||||
from moto import mock_aws
|
||||
|
||||
from infrastructure.csv_s3_client import CsvS3Client
|
||||
from tests.infrastructure import make_boto_client
|
||||
|
||||
BUCKET = "csv-bucket"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def csv_client() -> Iterator[CsvS3Client]:
|
||||
with mock_aws():
|
||||
boto_client = make_boto_client("s3")
|
||||
boto_client.create_bucket(Bucket=BUCKET)
|
||||
yield CsvS3Client(boto_client, BUCKET)
|
||||
|
||||
|
||||
def test_save_rows_returns_s3_uri(csv_client: CsvS3Client) -> None:
|
||||
# arrange
|
||||
rows = [{"address": "1 High St", "postcode": "AB1 2CD"}]
|
||||
# act
|
||||
uri = csv_client.save_rows(rows, "uploads/addresses.csv")
|
||||
# assert
|
||||
assert uri == f"s3://{BUCKET}/uploads/addresses.csv"
|
||||
|
||||
|
||||
def test_round_trip_preserves_rows(csv_client: CsvS3Client) -> None:
|
||||
# arrange
|
||||
rows = [
|
||||
{"address": "1 High St", "postcode": "AB1 2CD"},
|
||||
{"address": "2 Low St", "postcode": "XY9 8ZW"},
|
||||
]
|
||||
# act
|
||||
uri = csv_client.save_rows(rows, "uploads/addresses.csv")
|
||||
fetched = csv_client.read_rows(uri)
|
||||
# assert
|
||||
assert fetched == rows
|
||||
|
||||
|
||||
def test_save_rows_rejects_empty_list(csv_client: CsvS3Client) -> None:
|
||||
# act / assert
|
||||
with pytest.raises(ValueError, match="empty"):
|
||||
csv_client.save_rows([], "uploads/empty.csv")
|
||||
|
||||
|
||||
def test_read_rows_rejects_wrong_bucket(csv_client: CsvS3Client) -> None:
|
||||
# act / assert
|
||||
with pytest.raises(ValueError, match="does not match client bucket"):
|
||||
csv_client.read_rows("s3://other-bucket/uploads/addresses.csv")
|
||||
36
tests/infrastructure/test_s3_client.py
Normal file
36
tests/infrastructure/test_s3_client.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
from collections.abc import Iterator
|
||||
|
||||
import pytest
|
||||
from moto import mock_aws
|
||||
|
||||
from infrastructure.s3_client import S3Client
|
||||
from tests.infrastructure import make_boto_client
|
||||
|
||||
BUCKET = "test-bucket"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_client() -> Iterator[S3Client]:
|
||||
with mock_aws():
|
||||
boto_client = make_boto_client("s3")
|
||||
boto_client.create_bucket(Bucket=BUCKET)
|
||||
yield S3Client(boto_client, BUCKET)
|
||||
|
||||
|
||||
def test_put_object_returns_s3_uri(s3_client: S3Client) -> None:
|
||||
# act
|
||||
uri = s3_client.put_object("folder/data.bin", b"payload")
|
||||
# assert
|
||||
assert uri == f"s3://{BUCKET}/folder/data.bin"
|
||||
|
||||
|
||||
def test_get_object_returns_bytes_written_by_put_object(s3_client: S3Client) -> None:
|
||||
# arrange
|
||||
s3_client.put_object("round/trip.bin", b"hello world")
|
||||
# act / assert
|
||||
assert s3_client.get_object("round/trip.bin") == b"hello world"
|
||||
|
||||
|
||||
def test_bucket_property_exposes_configured_bucket(s3_client: S3Client) -> None:
|
||||
# act / assert
|
||||
assert s3_client.bucket == BUCKET
|
||||
40
tests/infrastructure/test_s3_uri.py
Normal file
40
tests/infrastructure/test_s3_uri.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
import pytest
|
||||
|
||||
from infrastructure.s3_uri import parse_s3_uri
|
||||
|
||||
|
||||
def test_parses_simple_s3_uri() -> None:
|
||||
# act / assert
|
||||
assert parse_s3_uri("s3://my-bucket/file.csv") == ("my-bucket", "file.csv")
|
||||
|
||||
|
||||
def test_parses_s3_uri_with_nested_key() -> None:
|
||||
# act
|
||||
bucket, key = parse_s3_uri("s3://my-bucket/nested/path/to/file.csv")
|
||||
# assert
|
||||
assert (bucket, key) == ("my-bucket", "nested/path/to/file.csv")
|
||||
|
||||
|
||||
def test_rejects_s3_uri_without_key() -> None:
|
||||
# act / assert
|
||||
with pytest.raises(ValueError, match="bucket and a key"):
|
||||
parse_s3_uri("s3://my-bucket")
|
||||
|
||||
|
||||
def test_rejects_s3_uri_with_empty_key() -> None:
|
||||
# act / assert
|
||||
with pytest.raises(ValueError, match="bucket and a key"):
|
||||
parse_s3_uri("s3://my-bucket/")
|
||||
|
||||
|
||||
def test_parses_console_url_prefix() -> None:
|
||||
# arrange
|
||||
url = "https://eu-west-2.console.aws.amazon.com/s3/object/my-bucket?prefix=nested%2Ffile.csv"
|
||||
# act / assert
|
||||
assert parse_s3_uri(url) == ("my-bucket", "nested/file.csv")
|
||||
|
||||
|
||||
def test_rejects_unparseable_string() -> None:
|
||||
# act / assert
|
||||
with pytest.raises(ValueError):
|
||||
parse_s3_uri("not-a-uri-at-all")
|
||||
44
tests/infrastructure/test_sqs_client.py
Normal file
44
tests/infrastructure/test_sqs_client.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
import json
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
from moto import mock_aws
|
||||
|
||||
from infrastructure.sqs_client import SqsClient
|
||||
from tests.infrastructure import make_boto_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sqs_setup() -> Iterator[tuple[SqsClient, Any, str]]:
|
||||
with mock_aws():
|
||||
boto_client = make_boto_client("sqs")
|
||||
queue: dict[str, Any] = boto_client.create_queue(QueueName="test-queue")
|
||||
queue_url = cast(str, queue["QueueUrl"])
|
||||
yield SqsClient(boto_client, queue_url), boto_client, queue_url
|
||||
|
||||
|
||||
def test_send_returns_message_id(sqs_setup: tuple[SqsClient, Any, str]) -> None:
|
||||
# arrange
|
||||
client, _boto, _url = sqs_setup
|
||||
# act
|
||||
message_id = client.send({"hello": "world"})
|
||||
# assert
|
||||
assert isinstance(message_id, str)
|
||||
assert message_id
|
||||
|
||||
|
||||
def test_send_json_serialises_body(sqs_setup: tuple[SqsClient, Any, str]) -> None:
|
||||
# arrange
|
||||
client, boto_client, queue_url = sqs_setup
|
||||
body = {"hello": "world", "count": 3}
|
||||
# act
|
||||
client.send(body)
|
||||
|
||||
# assert
|
||||
received: dict[str, Any] = boto_client.receive_message(
|
||||
QueueUrl=queue_url, MaxNumberOfMessages=1
|
||||
)
|
||||
messages: list[dict[str, Any]] = received["Messages"]
|
||||
assert len(messages) == 1
|
||||
assert json.loads(messages[0]["Body"]) == body
|
||||
299
tests/orchestration/test_postcode_splitter_orchestrator.py
Normal file
299
tests/orchestration/test_postcode_splitter_orchestrator.py
Normal file
|
|
@ -0,0 +1,299 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
import boto3
|
||||
import pytest
|
||||
from moto import mock_aws
|
||||
from sqlalchemy import Engine
|
||||
from sqlmodel import Session
|
||||
|
||||
from infrastructure.address2uprn_queue_client import Address2UprnQueueClient
|
||||
from infrastructure.csv_s3_client import CsvS3Client
|
||||
from orchestration.postcode_splitter_orchestrator import PostcodeSplitterOrchestrator
|
||||
from orchestration.task_orchestrator import TaskOrchestrator
|
||||
from repositories.tasks.subtask_postgres_repository import SubTaskPostgresRepository
|
||||
from repositories.tasks.task_postgres_repository import TaskPostgresRepository
|
||||
from repositories.user_address.user_address_csv_s3_repository import (
|
||||
UserAddressCsvS3Repository,
|
||||
)
|
||||
|
||||
BUCKET = "splitter-bucket"
|
||||
REGION = "us-east-1"
|
||||
|
||||
|
||||
def _make_boto_client(service_name: str) -> Any:
|
||||
factory: Any = boto3.client # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
||||
return factory(service_name, region_name=REGION)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _aws_creds() -> Iterator[None]: # pyright: ignore[reportUnusedFunction]
|
||||
keys = (
|
||||
"AWS_ACCESS_KEY_ID",
|
||||
"AWS_SECRET_ACCESS_KEY",
|
||||
"AWS_SESSION_TOKEN",
|
||||
"AWS_DEFAULT_REGION",
|
||||
)
|
||||
prev: dict[str, Any] = {k: os.environ.get(k) for k in keys}
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = "testing"
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = "testing"
|
||||
os.environ["AWS_SESSION_TOKEN"] = "testing"
|
||||
os.environ["AWS_DEFAULT_REGION"] = REGION
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for k, v in prev.items():
|
||||
if v is None:
|
||||
os.environ.pop(k, None)
|
||||
else:
|
||||
os.environ[k] = v
|
||||
|
||||
|
||||
@dataclass
|
||||
class Harness:
|
||||
splitter: PostcodeSplitterOrchestrator
|
||||
task_orchestrator: TaskOrchestrator
|
||||
subtasks: SubTaskPostgresRepository
|
||||
csv_client: CsvS3Client
|
||||
boto_sqs: Any
|
||||
queue_url: str
|
||||
repo: UserAddressCsvS3Repository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def harness(db_engine: Engine) -> Iterator[Harness]:
|
||||
with mock_aws():
|
||||
# Infra: S3 + SQS
|
||||
boto_s3 = _make_boto_client("s3")
|
||||
boto_s3.create_bucket(Bucket=BUCKET)
|
||||
boto_sqs = _make_boto_client("sqs")
|
||||
queue: dict[str, Any] = boto_sqs.create_queue(QueueName="address2uprn-queue")
|
||||
queue_url = cast(str, queue["QueueUrl"])
|
||||
|
||||
csv_client = CsvS3Client(boto_s3, BUCKET)
|
||||
repo = UserAddressCsvS3Repository(csv_client, BUCKET)
|
||||
queue_client = Address2UprnQueueClient(boto_sqs, queue_url)
|
||||
|
||||
# DB: ephemeral PostgreSQL TaskOrchestrator
|
||||
with Session(db_engine) as session:
|
||||
task_repo = TaskPostgresRepository(session=session)
|
||||
subtask_repo = SubTaskPostgresRepository(session=session)
|
||||
task_orchestrator = TaskOrchestrator(
|
||||
task_repo=task_repo, subtask_repo=subtask_repo
|
||||
)
|
||||
|
||||
splitter = PostcodeSplitterOrchestrator(
|
||||
task_orchestrator=task_orchestrator,
|
||||
user_address_repo=repo,
|
||||
queue_client=queue_client,
|
||||
max_batch_size=3,
|
||||
)
|
||||
|
||||
yield Harness(
|
||||
splitter=splitter,
|
||||
task_orchestrator=task_orchestrator,
|
||||
subtasks=subtask_repo,
|
||||
csv_client=csv_client,
|
||||
boto_sqs=boto_sqs,
|
||||
queue_url=queue_url,
|
||||
repo=repo,
|
||||
)
|
||||
|
||||
|
||||
def _upload_fixture_csv(csv_client: CsvS3Client) -> str:
|
||||
# Three postcode groups:
|
||||
# AA1 1AA × 2 (within cap)
|
||||
# BB2 2BB × 4 (oversize: > max_batch_size=3)
|
||||
# CC3 3CC × 1 (final flush)
|
||||
# Expected batching with cap=3 and the algorithm in
|
||||
# ``iter_postcode_grouped_batches``:
|
||||
# batch 1: [AA1 1AA × 2] (flushed because oversize follows)
|
||||
# batch 2: [BB2 2BB × 4] (oversize own batch)
|
||||
# batch 3: [CC3 3CC × 1] (final flush)
|
||||
rows: list[dict[str, str]] = []
|
||||
rows.extend(
|
||||
{
|
||||
"Address 1": f"{i} High St",
|
||||
"Address 2": "",
|
||||
"Address 3": "",
|
||||
"postcode": "AA1 1AA",
|
||||
"Internal Reference": f"AA-{i}",
|
||||
}
|
||||
for i in range(1, 3)
|
||||
)
|
||||
rows.extend(
|
||||
{
|
||||
"Address 1": f"{i} Long Road",
|
||||
"Address 2": "",
|
||||
"Address 3": "",
|
||||
"postcode": "BB2 2BB",
|
||||
"Internal Reference": f"BB-{i}",
|
||||
}
|
||||
for i in range(1, 5)
|
||||
)
|
||||
rows.append(
|
||||
{
|
||||
"Address 1": "1 Final Way",
|
||||
"Address 2": "",
|
||||
"Address 3": "",
|
||||
"postcode": "CC3 3CC",
|
||||
"Internal Reference": "CC-1",
|
||||
}
|
||||
)
|
||||
return csv_client.save_rows(rows, "uploads/input.csv")
|
||||
|
||||
|
||||
def _drain_queue(boto_sqs: Any, queue_url: str) -> list[dict[str, Any]]:
|
||||
bodies: list[dict[str, Any]] = []
|
||||
while True:
|
||||
received: dict[str, Any] = boto_sqs.receive_message(
|
||||
QueueUrl=queue_url, MaxNumberOfMessages=10, WaitTimeSeconds=0
|
||||
)
|
||||
messages = cast(list[dict[str, Any]], received.get("Messages", []))
|
||||
if not messages:
|
||||
break
|
||||
for message in messages:
|
||||
bodies.append(cast(dict[str, Any], json.loads(message["Body"])))
|
||||
boto_sqs.delete_message(
|
||||
QueueUrl=queue_url, ReceiptHandle=message["ReceiptHandle"]
|
||||
)
|
||||
return bodies
|
||||
|
||||
|
||||
def test_split_and_dispatch_creates_three_children_for_fixture(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
parent_task, parent_subtask = (
|
||||
harness.task_orchestrator.create_task_with_subtask(
|
||||
task_source="manual:postcode-splitter-int"
|
||||
)
|
||||
)
|
||||
input_uri = _upload_fixture_csv(harness.csv_client)
|
||||
|
||||
# act
|
||||
child_ids = harness.splitter.split_and_dispatch(
|
||||
parent_task_id=parent_task.id,
|
||||
parent_subtask_id=parent_subtask.id,
|
||||
input_s3_uri=input_uri,
|
||||
)
|
||||
|
||||
# assert
|
||||
assert len(child_ids) == 3
|
||||
# All child ids are unique and persisted as WAITING children of the
|
||||
# parent task.
|
||||
assert len(set(child_ids)) == 3
|
||||
for cid in child_ids:
|
||||
child = harness.subtasks.get(cid)
|
||||
assert child.task_id == parent_task.id
|
||||
|
||||
|
||||
def test_split_and_dispatch_persists_child_inputs_with_task_id_and_s3_uri(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
parent_task, parent_subtask = (
|
||||
harness.task_orchestrator.create_task_with_subtask(
|
||||
task_source="manual:postcode-splitter-int"
|
||||
)
|
||||
)
|
||||
input_uri = _upload_fixture_csv(harness.csv_client)
|
||||
|
||||
# act
|
||||
child_ids = harness.splitter.split_and_dispatch(
|
||||
parent_task_id=parent_task.id,
|
||||
parent_subtask_id=parent_subtask.id,
|
||||
input_s3_uri=input_uri,
|
||||
)
|
||||
|
||||
# assert
|
||||
for cid in child_ids:
|
||||
child = harness.subtasks.get(cid)
|
||||
assert child.inputs is not None
|
||||
assert child.inputs["task_id"] == str(parent_task.id)
|
||||
batch_uri = child.inputs["s3_uri"]
|
||||
assert isinstance(batch_uri, str)
|
||||
prefix = (
|
||||
f"s3://{BUCKET}/ara_postcode_splitter_batches/"
|
||||
f"{parent_task.id}/{parent_subtask.id}/"
|
||||
)
|
||||
assert batch_uri.startswith(prefix)
|
||||
assert batch_uri.endswith(".csv")
|
||||
|
||||
|
||||
def test_split_and_dispatch_publishes_one_message_per_child_with_matching_ids(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
parent_task, parent_subtask = (
|
||||
harness.task_orchestrator.create_task_with_subtask(
|
||||
task_source="manual:postcode-splitter-int"
|
||||
)
|
||||
)
|
||||
input_uri = _upload_fixture_csv(harness.csv_client)
|
||||
|
||||
# act
|
||||
child_ids = harness.splitter.split_and_dispatch(
|
||||
parent_task_id=parent_task.id,
|
||||
parent_subtask_id=parent_subtask.id,
|
||||
input_s3_uri=input_uri,
|
||||
)
|
||||
|
||||
# assert
|
||||
bodies = _drain_queue(harness.boto_sqs, harness.queue_url)
|
||||
assert len(bodies) == len(child_ids)
|
||||
|
||||
# Match queue messages against persisted child inputs by child_subtask_id;
|
||||
# the message body's task_id/s3_uri must agree with the SubTask inputs.
|
||||
bodies_by_child = {body["sub_task_id"]: body for body in bodies}
|
||||
assert set(bodies_by_child.keys()) == {str(cid) for cid in child_ids}
|
||||
for cid in child_ids:
|
||||
child = harness.subtasks.get(cid)
|
||||
body = bodies_by_child[str(cid)]
|
||||
assert child.inputs is not None
|
||||
assert body == {
|
||||
"task_id": str(parent_task.id),
|
||||
"sub_task_id": str(cid),
|
||||
"s3_uri": child.inputs["s3_uri"],
|
||||
}
|
||||
|
||||
|
||||
def test_split_and_dispatch_returns_child_ids_in_dispatch_order(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
parent_task, parent_subtask = (
|
||||
harness.task_orchestrator.create_task_with_subtask(
|
||||
task_source="manual:postcode-splitter-int"
|
||||
)
|
||||
)
|
||||
input_uri = _upload_fixture_csv(harness.csv_client)
|
||||
|
||||
# act
|
||||
child_ids = harness.splitter.split_and_dispatch(
|
||||
parent_task_id=parent_task.id,
|
||||
parent_subtask_id=parent_subtask.id,
|
||||
input_s3_uri=input_uri,
|
||||
)
|
||||
|
||||
# assert
|
||||
# Re-load each child's saved batch and inspect the postcode_clean column
|
||||
# to confirm the dispatch order matches the postcode-batching algorithm:
|
||||
# AA-batch first, BB oversize batch second, CC final-flush third.
|
||||
postcodes_per_batch: list[set[str]] = []
|
||||
for cid in child_ids:
|
||||
child = harness.subtasks.get(cid)
|
||||
assert child.inputs is not None
|
||||
rows = harness.csv_client.read_rows(child.inputs["s3_uri"])
|
||||
postcodes_per_batch.append({row["postcode_clean"] for row in rows})
|
||||
|
||||
assert postcodes_per_batch == [
|
||||
{"AA11AA"},
|
||||
{"BB22BB"},
|
||||
{"CC33CC"},
|
||||
]
|
||||
|
|
@ -2,7 +2,8 @@ from collections.abc import Iterator
|
|||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
from sqlalchemy import Engine
|
||||
from sqlmodel import Session
|
||||
|
||||
from domain.tasks.subtasks import SubTask, SubTaskStatus
|
||||
from domain.tasks.tasks import Source, TaskStatus
|
||||
|
|
@ -19,10 +20,8 @@ class Harness:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def harness() -> Iterator[Harness]:
|
||||
engine = create_engine("sqlite://")
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as session:
|
||||
def harness(db_engine: Engine) -> Iterator[Harness]:
|
||||
with Session(db_engine) as session:
|
||||
tasks = TaskPostgresRepository(session=session)
|
||||
subtasks = SubTaskPostgresRepository(session=session)
|
||||
yield Harness(
|
||||
|
|
@ -35,6 +34,7 @@ def harness() -> Iterator[Harness]:
|
|||
def test_create_task_with_subtask_creates_both_in_waiting(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# act
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test",
|
||||
inputs={"foo": "bar"},
|
||||
|
|
@ -42,6 +42,7 @@ def test_create_task_with_subtask_creates_both_in_waiting(
|
|||
source_id="abc",
|
||||
)
|
||||
|
||||
# assert
|
||||
assert task.status is TaskStatus.WAITING
|
||||
assert subtask.status is SubTaskStatus.WAITING
|
||||
assert subtask.task_id == task.id
|
||||
|
|
@ -49,27 +50,33 @@ def test_create_task_with_subtask_creates_both_in_waiting(
|
|||
|
||||
|
||||
def test_start_subtask_cascades_to_in_progress(harness: Harness) -> None:
|
||||
# arrange
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
||||
# act
|
||||
started = harness.orchestrator.start_subtask(
|
||||
subtask.id, cloud_logs_url="https://example/log"
|
||||
)
|
||||
|
||||
# assert
|
||||
assert started.status is SubTaskStatus.IN_PROGRESS
|
||||
assert started.cloud_logs_url == "https://example/log"
|
||||
assert harness.tasks.get(task.id).status is TaskStatus.IN_PROGRESS
|
||||
|
||||
|
||||
def test_complete_subtask_cascades_to_complete(harness: Harness) -> None:
|
||||
# arrange
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
harness.orchestrator.start_subtask(subtask.id)
|
||||
|
||||
# act
|
||||
harness.orchestrator.complete_subtask(subtask.id, {"value": 42})
|
||||
|
||||
# assert
|
||||
done_subtask = harness.subtasks.get(subtask.id)
|
||||
done_task = harness.tasks.get(task.id)
|
||||
assert done_subtask.outputs == {"result": {"value": 42}}
|
||||
|
|
@ -78,12 +85,15 @@ def test_complete_subtask_cascades_to_complete(harness: Harness) -> None:
|
|||
|
||||
|
||||
def test_fail_subtask_cascades_to_failed(harness: Harness) -> None:
|
||||
# arrange
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
||||
# act
|
||||
harness.orchestrator.fail_subtask(subtask.id, RuntimeError("boom"))
|
||||
|
||||
# assert
|
||||
failed_subtask = harness.subtasks.get(subtask.id)
|
||||
failed_task = harness.tasks.get(task.id)
|
||||
assert failed_subtask.outputs == {"error": "boom"}
|
||||
|
|
@ -93,50 +103,85 @@ def test_fail_subtask_cascades_to_failed(harness: Harness) -> None:
|
|||
def test_failed_subtask_locks_task_failed_even_with_others_complete(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
task, first = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
second = SubTask.create(task_id=task.id)
|
||||
harness.subtasks.create(second)
|
||||
|
||||
# act
|
||||
harness.orchestrator.complete_subtask(first.id)
|
||||
harness.orchestrator.fail_subtask(second.id, RuntimeError("nope"))
|
||||
|
||||
# assert
|
||||
assert harness.tasks.get(task.id).status is TaskStatus.FAILED
|
||||
|
||||
|
||||
def test_mixed_complete_and_in_progress_keeps_task_in_progress(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
task, first = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
second = SubTask.create(task_id=task.id)
|
||||
harness.subtasks.create(second)
|
||||
|
||||
# act
|
||||
harness.orchestrator.complete_subtask(first.id)
|
||||
harness.orchestrator.start_subtask(second.id)
|
||||
|
||||
# assert
|
||||
assert harness.tasks.get(task.id).status is TaskStatus.IN_PROGRESS
|
||||
|
||||
|
||||
def test_run_subtask_happy_path_returns_result_and_cascades_complete(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
||||
# act
|
||||
result = harness.orchestrator.run_subtask(subtask.id, work=lambda: {"answer": 42})
|
||||
|
||||
# assert
|
||||
assert result == {"answer": 42}
|
||||
assert harness.subtasks.get(subtask.id).status is SubTaskStatus.COMPLETE
|
||||
assert harness.tasks.get(task.id).status is TaskStatus.COMPLETE
|
||||
|
||||
|
||||
def test_create_child_subtask_adds_waiting_child_without_changing_parent_status(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
task, first = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
harness.orchestrator.start_subtask(first.id)
|
||||
assert harness.tasks.get(task.id).status is TaskStatus.IN_PROGRESS
|
||||
|
||||
# act
|
||||
child = harness.orchestrator.create_child_subtask(
|
||||
task.id, inputs={"split": "a"}
|
||||
)
|
||||
|
||||
# assert
|
||||
persisted_child = harness.subtasks.get(child.id)
|
||||
assert persisted_child.task_id == task.id
|
||||
assert persisted_child.status is SubTaskStatus.WAITING
|
||||
assert persisted_child.inputs == {"split": "a"}
|
||||
assert persisted_child.id != first.id
|
||||
# Cascade is a no-op: parent stays IN_PROGRESS.
|
||||
assert harness.tasks.get(task.id).status is TaskStatus.IN_PROGRESS
|
||||
|
||||
|
||||
def test_run_subtask_failing_work_marks_failed_and_reraises(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
|
@ -144,6 +189,7 @@ def test_run_subtask_failing_work_marks_failed_and_reraises(
|
|||
def boom() -> None:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
# act / assert
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
harness.orchestrator.run_subtask(subtask.id, work=boom)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,33 +1,40 @@
|
|||
from collections.abc import Iterator
|
||||
from uuid import uuid4
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
from sqlalchemy import Engine
|
||||
from sqlmodel import Session
|
||||
|
||||
# Importing the SQLModel row modules registers their tables in
|
||||
# SQLModel.metadata so create_all builds both. Imports look unused; they aren't.
|
||||
import infrastructure.postgres.subtask_table # noqa: F401 # pyright: ignore[reportUnusedImport]
|
||||
import infrastructure.postgres.task_table # noqa: F401 # pyright: ignore[reportUnusedImport]
|
||||
from domain.tasks.subtasks import SubTask, SubTaskStatus
|
||||
from domain.tasks.tasks import Task
|
||||
from repositories.tasks.subtask_postgres_repository import SubTaskPostgresRepository
|
||||
from repositories.tasks.task_postgres_repository import TaskPostgresRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session() -> Iterator[Session]:
|
||||
engine = create_engine("sqlite://")
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as s:
|
||||
def session(db_engine: Engine) -> Iterator[Session]:
|
||||
with Session(db_engine) as s:
|
||||
yield s
|
||||
|
||||
|
||||
def _persisted_task_id(session: Session) -> UUID:
|
||||
"""Create a parent Task row so SubTask FK constraints are satisfied."""
|
||||
task = Task.create(task_source="manual:test")
|
||||
TaskPostgresRepository(session=session).create(task)
|
||||
return task.id
|
||||
|
||||
|
||||
def test_create_and_get_round_trip_preserves_inputs(session: Session) -> None:
|
||||
# arrange
|
||||
repo = SubTaskPostgresRepository(session=session)
|
||||
task_id = uuid4()
|
||||
task_id = _persisted_task_id(session)
|
||||
st = SubTask.create(task_id=task_id, inputs={"address": "68 Glendon Way"})
|
||||
|
||||
# act
|
||||
repo.create(st)
|
||||
fetched = repo.get(st.id)
|
||||
|
||||
# assert
|
||||
assert fetched.id == st.id
|
||||
assert fetched.task_id == task_id
|
||||
assert fetched.status is SubTaskStatus.WAITING
|
||||
|
|
@ -36,16 +43,21 @@ def test_create_and_get_round_trip_preserves_inputs(session: Session) -> None:
|
|||
|
||||
|
||||
def test_save_persists_status_and_outputs(session: Session) -> None:
|
||||
# arrange
|
||||
repo = SubTaskPostgresRepository(session=session)
|
||||
st = SubTask.create(task_id=uuid4())
|
||||
st = SubTask.create(task_id=_persisted_task_id(session))
|
||||
repo.create(st)
|
||||
|
||||
# act
|
||||
st.start(cloud_logs_url="https://example/log")
|
||||
repo.save(st)
|
||||
# assert
|
||||
assert repo.get(st.id).status is SubTaskStatus.IN_PROGRESS
|
||||
|
||||
# act
|
||||
st.complete({"uprn": "123"})
|
||||
repo.save(st)
|
||||
# assert
|
||||
done = repo.get(st.id)
|
||||
assert done.status is SubTaskStatus.COMPLETE
|
||||
assert done.outputs == {"result": {"uprn": "123"}}
|
||||
|
|
@ -54,16 +66,19 @@ def test_save_persists_status_and_outputs(session: Session) -> None:
|
|||
|
||||
|
||||
def test_list_by_task_filters_by_task_id(session: Session) -> None:
|
||||
# arrange
|
||||
repo = SubTaskPostgresRepository(session=session)
|
||||
task_a = uuid4()
|
||||
task_b = uuid4()
|
||||
task_a = _persisted_task_id(session)
|
||||
task_b = _persisted_task_id(session)
|
||||
repo.create(SubTask.create(task_id=task_a))
|
||||
repo.create(SubTask.create(task_id=task_a))
|
||||
repo.create(SubTask.create(task_id=task_b))
|
||||
|
||||
# act
|
||||
a_results = repo.list_by_task(task_a)
|
||||
b_results = repo.list_by_task(task_b)
|
||||
|
||||
# assert
|
||||
assert len(a_results) == 2
|
||||
assert len(b_results) == 1
|
||||
assert all(s.task_id == task_a for s in a_results)
|
||||
|
|
@ -71,11 +86,15 @@ def test_list_by_task_filters_by_task_id(session: Session) -> None:
|
|||
|
||||
|
||||
def test_list_by_task_returns_empty_for_unknown_task(session: Session) -> None:
|
||||
# arrange
|
||||
repo = SubTaskPostgresRepository(session=session)
|
||||
# act / assert
|
||||
assert repo.list_by_task(uuid4()) == []
|
||||
|
||||
|
||||
def test_get_missing_raises(session: Session) -> None:
|
||||
# arrange
|
||||
repo = SubTaskPostgresRepository(session=session)
|
||||
# act / assert
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
repo.get(uuid4())
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ from collections.abc import Iterator
|
|||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
from sqlalchemy import Engine
|
||||
from sqlmodel import Session
|
||||
|
||||
from domain.tasks.tasks import Source, Task, TaskStatus
|
||||
from infrastructure.postgres.task_table import TaskRow
|
||||
|
|
@ -10,25 +11,23 @@ from repositories.tasks.task_postgres_repository import TaskPostgresRepository
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def session() -> Iterator[Session]:
|
||||
engine = create_engine("sqlite://")
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as s:
|
||||
def session(db_engine: Engine) -> Iterator[Session]:
|
||||
with Session(db_engine) as s:
|
||||
yield s
|
||||
|
||||
|
||||
def test_create_and_get_round_trip(session: Session) -> None:
|
||||
# Arrange
|
||||
# arrange
|
||||
repo = TaskPostgresRepository(session=session)
|
||||
t = Task.create(
|
||||
task_source="manual:test", source=Source.PORTFOLIO, source_id="abc-123"
|
||||
)
|
||||
|
||||
# Act
|
||||
# act
|
||||
repo.create(t)
|
||||
fetched = repo.get(t.id)
|
||||
|
||||
# Assert
|
||||
# assert
|
||||
assert fetched.id == t.id
|
||||
assert fetched.status is TaskStatus.WAITING
|
||||
assert fetched.source is Source.PORTFOLIO
|
||||
|
|
@ -36,33 +35,43 @@ def test_create_and_get_round_trip(session: Session) -> None:
|
|||
|
||||
|
||||
def test_save_persists_status_transition(session: Session) -> None:
|
||||
# arrange
|
||||
repo = TaskPostgresRepository(session=session)
|
||||
t = Task.create(task_source="manual:test")
|
||||
repo.create(t)
|
||||
|
||||
# act
|
||||
t.start()
|
||||
repo.save(t)
|
||||
# assert
|
||||
assert repo.get(t.id).status is TaskStatus.IN_PROGRESS
|
||||
|
||||
# act
|
||||
t.complete()
|
||||
repo.save(t)
|
||||
# assert
|
||||
done = repo.get(t.id)
|
||||
assert done.status is TaskStatus.COMPLETE
|
||||
assert done.job_completed is not None
|
||||
|
||||
|
||||
def test_get_missing_raises(session: Session) -> None:
|
||||
# arrange
|
||||
repo = TaskPostgresRepository(session=session)
|
||||
# act / assert
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
repo.get(uuid4())
|
||||
|
||||
|
||||
def test_get_normalises_legacy_capitalised_status(session: Session) -> None:
|
||||
# Existing rows written by backend code use "In Progress" (capitalised).
|
||||
# arrange
|
||||
repo = TaskPostgresRepository(session=session)
|
||||
row = TaskRow(task_source="manual:test", status="In Progress")
|
||||
session.add(row)
|
||||
session.commit()
|
||||
|
||||
# act
|
||||
fetched = repo.get(row.id)
|
||||
# assert
|
||||
assert fetched.status is TaskStatus.IN_PROGRESS
|
||||
|
|
|
|||
0
tests/repositories/user_address/__init__.py
Normal file
0
tests/repositories/user_address/__init__.py
Normal file
28
tests/repositories/user_address/conftest.py
Normal file
28
tests/repositories/user_address/conftest.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
import os
|
||||
from collections.abc import Iterator
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _aws_creds() -> Iterator[None]: # pyright: ignore[reportUnusedFunction]
|
||||
keys = (
|
||||
"AWS_ACCESS_KEY_ID",
|
||||
"AWS_SECRET_ACCESS_KEY",
|
||||
"AWS_SESSION_TOKEN",
|
||||
"AWS_DEFAULT_REGION",
|
||||
)
|
||||
prev: dict[str, Optional[str]] = {k: os.environ.get(k) for k in keys}
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = "testing"
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = "testing"
|
||||
os.environ["AWS_SESSION_TOKEN"] = "testing"
|
||||
os.environ["AWS_DEFAULT_REGION"] = "us-east-1"
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for k, v in prev.items():
|
||||
if v is None:
|
||||
os.environ.pop(k, None)
|
||||
else:
|
||||
os.environ[k] = v
|
||||
|
|
@ -0,0 +1,237 @@
|
|||
from collections.abc import Iterator
|
||||
|
||||
import pytest
|
||||
from moto import mock_aws
|
||||
|
||||
from domain.addresses.user_address import UserAddress
|
||||
from domain.postcode import Postcode
|
||||
from infrastructure.csv_s3_client import CsvS3Client
|
||||
from repositories.user_address.user_address_csv_s3_repository import (
|
||||
UserAddressCsvS3Repository,
|
||||
)
|
||||
from tests.infrastructure import make_boto_client
|
||||
|
||||
BUCKET = "user-address-bucket"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repo() -> Iterator[UserAddressCsvS3Repository]:
|
||||
with mock_aws():
|
||||
boto_client = make_boto_client("s3")
|
||||
boto_client.create_bucket(Bucket=BUCKET)
|
||||
csv_client = CsvS3Client(boto_client, BUCKET)
|
||||
yield UserAddressCsvS3Repository(csv_client, BUCKET)
|
||||
|
||||
|
||||
def _upload_csv(
|
||||
repo: UserAddressCsvS3Repository, rows: list[dict[str, str]], key: str
|
||||
) -> str:
|
||||
return repo._csv_client.save_rows(rows, key) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
def test_load_batch_parses_address_postcode_and_reference(
|
||||
repo: UserAddressCsvS3Repository,
|
||||
) -> None:
|
||||
# arrange
|
||||
rows = [
|
||||
{
|
||||
"Address 1": "1 High Street",
|
||||
"Address 2": "Flat 2",
|
||||
"Address 3": "Townville",
|
||||
"postcode": "sw1a 1aa",
|
||||
"Internal Reference": "REF-001",
|
||||
}
|
||||
]
|
||||
uri = _upload_csv(repo, rows, "uploads/full.csv")
|
||||
|
||||
# act
|
||||
addresses = repo.load_batch(uri)
|
||||
|
||||
# assert
|
||||
assert len(addresses) == 1
|
||||
address = addresses[0]
|
||||
assert address.user_address == "1 High Street, Flat 2, Townville"
|
||||
assert address.postcode == Postcode("SW1A1AA")
|
||||
assert address.internal_reference == "REF-001"
|
||||
|
||||
|
||||
def test_load_batch_uses_only_address_1_when_others_missing(
|
||||
repo: UserAddressCsvS3Repository,
|
||||
) -> None:
|
||||
# arrange
|
||||
rows = [
|
||||
{
|
||||
"Address 1": "10 Cardiff Road",
|
||||
"Address 2": "",
|
||||
"Address 3": "",
|
||||
"postcode": "CF10 1AA",
|
||||
"Internal Reference": "REF-002",
|
||||
}
|
||||
]
|
||||
uri = _upload_csv(repo, rows, "uploads/address1-only.csv")
|
||||
|
||||
# act
|
||||
addresses = repo.load_batch(uri)
|
||||
|
||||
# assert
|
||||
assert len(addresses) == 1
|
||||
assert addresses[0].user_address == "10 Cardiff Road"
|
||||
assert addresses[0].postcode == Postcode("CF101AA")
|
||||
assert addresses[0].internal_reference == "REF-002"
|
||||
|
||||
|
||||
def test_load_batch_handles_missing_internal_reference(
|
||||
repo: UserAddressCsvS3Repository,
|
||||
) -> None:
|
||||
# arrange
|
||||
rows = [
|
||||
{
|
||||
"Address 1": "5 Park Lane",
|
||||
"Address 2": "",
|
||||
"Address 3": "",
|
||||
"postcode": "M1 1AA",
|
||||
"Internal Reference": "",
|
||||
}
|
||||
]
|
||||
uri = _upload_csv(repo, rows, "uploads/no-ref.csv")
|
||||
|
||||
# act
|
||||
addresses = repo.load_batch(uri)
|
||||
|
||||
# assert
|
||||
assert len(addresses) == 1
|
||||
assert addresses[0].user_address == "5 Park Lane"
|
||||
assert addresses[0].postcode == Postcode("M11AA")
|
||||
assert addresses[0].internal_reference is None
|
||||
|
||||
|
||||
def test_load_batch_captures_full_source_row(
|
||||
repo: UserAddressCsvS3Repository,
|
||||
) -> None:
|
||||
# A raw EPC-export-shaped row: the splitter must preserve every column,
|
||||
# not just the ones it parses into UserAddress fields.
|
||||
# arrange
|
||||
row = {
|
||||
"Asset Reference": "511",
|
||||
"Address 1": "9 Abingdon Road Padiham Lancashire BB12 7BX",
|
||||
"postcode": "BB12 7BX",
|
||||
"Property Type": "House: End Terrace",
|
||||
"SAP Score": "69",
|
||||
}
|
||||
uri = _upload_csv(repo, [row], "uploads/epc.csv")
|
||||
|
||||
# act
|
||||
addresses = repo.load_batch(uri)
|
||||
|
||||
# assert
|
||||
assert addresses[0].source_row == row
|
||||
|
||||
|
||||
def test_load_batch_raises_when_postcode_column_absent(
|
||||
repo: UserAddressCsvS3Repository,
|
||||
) -> None:
|
||||
# arrange
|
||||
rows = [{"Address 1": "1 High Street", "Property Type": "Flat"}]
|
||||
uri = _upload_csv(repo, rows, "uploads/no-postcode.csv")
|
||||
|
||||
# act / assert
|
||||
with pytest.raises(ValueError, match="no 'postcode' column"):
|
||||
repo.load_batch(uri)
|
||||
|
||||
|
||||
def test_save_batch_passes_through_all_columns_and_appends_postcode_clean(
|
||||
repo: UserAddressCsvS3Repository,
|
||||
) -> None:
|
||||
# arrange
|
||||
row = {
|
||||
"Asset Reference": "511",
|
||||
"Address 1": "9 Abingdon Road Padiham Lancashire BB12 7BX",
|
||||
"postcode": " BB12 7BX",
|
||||
"Property Type": "House: End Terrace",
|
||||
}
|
||||
uri = _upload_csv(repo, [row], "uploads/epc.csv")
|
||||
addresses = repo.load_batch(uri)
|
||||
|
||||
# act
|
||||
saved_uri = repo.save_batch(addresses, "tasks/passthrough")
|
||||
saved_rows = repo._csv_client.read_rows(saved_uri) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# assert
|
||||
assert len(saved_rows) == 1
|
||||
saved = saved_rows[0]
|
||||
# Every original column survives, byte-for-byte.
|
||||
for column, value in row.items():
|
||||
assert saved[column] == value
|
||||
# Plus the one appended column the downstream address2uprn stage groups on.
|
||||
assert saved["postcode_clean"] == "BB127BX"
|
||||
|
||||
|
||||
def test_save_batch_returns_uri_under_path_prefix(
|
||||
repo: UserAddressCsvS3Repository,
|
||||
) -> None:
|
||||
# arrange
|
||||
addresses = [
|
||||
UserAddress(
|
||||
user_address="1 High Street",
|
||||
postcode=Postcode("SW1A 1AA"),
|
||||
source_row={"Address 1": "1 High Street", "postcode": "SW1A 1AA"},
|
||||
),
|
||||
]
|
||||
|
||||
# act
|
||||
uri = repo.save_batch(addresses, "tasks/abc/batches")
|
||||
|
||||
# assert
|
||||
assert uri.startswith(f"s3://{BUCKET}/tasks/abc/batches/")
|
||||
assert uri.endswith(".csv")
|
||||
|
||||
|
||||
def test_save_then_reload_round_trip_preserves_columns(
|
||||
repo: UserAddressCsvS3Repository,
|
||||
) -> None:
|
||||
# arrange
|
||||
rows = [
|
||||
{
|
||||
"Address 1": "1 High Street",
|
||||
"postcode": "SW1A 1AA",
|
||||
"Internal Reference": "REF-001",
|
||||
},
|
||||
{
|
||||
"Address 1": "2 Low Street",
|
||||
"postcode": "XY9 8ZW",
|
||||
"Internal Reference": "",
|
||||
},
|
||||
]
|
||||
uri = _upload_csv(repo, rows, "uploads/round-trip.csv")
|
||||
addresses = repo.load_batch(uri)
|
||||
|
||||
# act
|
||||
saved_uri = repo.save_batch(addresses, "tasks/round-trip")
|
||||
saved_rows = repo._csv_client.read_rows(saved_uri) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# assert
|
||||
# Original columns come back verbatim; postcode_clean is the only addition.
|
||||
assert [
|
||||
{k: v for k, v in r.items() if k != "postcode_clean"} for r in saved_rows
|
||||
] == rows
|
||||
assert [r["postcode_clean"] for r in saved_rows] == ["SW1A1AA", "XY98ZW"]
|
||||
|
||||
|
||||
def test_save_batch_uses_unique_filename_per_call(
|
||||
repo: UserAddressCsvS3Repository,
|
||||
) -> None:
|
||||
# arrange
|
||||
addresses = [
|
||||
UserAddress(
|
||||
user_address="1 High Street",
|
||||
postcode=Postcode("SW1A 1AA"),
|
||||
source_row={"Address 1": "1 High Street", "postcode": "SW1A 1AA"},
|
||||
),
|
||||
]
|
||||
|
||||
# act
|
||||
uri_1 = repo.save_batch(addresses, "tasks/uniqueness")
|
||||
uri_2 = repo.save_batch(addresses, "tasks/uniqueness")
|
||||
|
||||
# assert
|
||||
assert uri_1 != uri_2
|
||||
0
tests/utilities/__init__.py
Normal file
0
tests/utilities/__init__.py
Normal file
0
tests/utilities/aws_lambda/__init__.py
Normal file
0
tests/utilities/aws_lambda/__init__.py
Normal file
255
tests/utilities/aws_lambda/test_subtask_handler.py
Normal file
255
tests/utilities/aws_lambda/test_subtask_handler.py
Normal file
|
|
@ -0,0 +1,255 @@
|
|||
import logging
|
||||
from collections.abc import Generator, Iterator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
from sqlmodel import Session
|
||||
|
||||
from domain.tasks.subtasks import SubTaskStatus
|
||||
from domain.tasks.tasks import TaskStatus
|
||||
from orchestration.task_orchestrator import TaskOrchestrator
|
||||
from repositories.tasks.subtask_postgres_repository import SubTaskPostgresRepository
|
||||
from repositories.tasks.task_postgres_repository import TaskPostgresRepository
|
||||
from utilities.aws_lambda.subtask_handler import subtask_handler
|
||||
|
||||
_LOGGER_NAME = "utilities.aws_lambda.subtask_handler"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Harness:
|
||||
orchestrator: TaskOrchestrator
|
||||
tasks: TaskPostgresRepository
|
||||
subtasks: SubTaskPostgresRepository
|
||||
|
||||
@contextmanager
|
||||
def factory(self) -> Generator[TaskOrchestrator, None, None]:
|
||||
yield self.orchestrator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def harness(db_engine: Engine) -> Iterator[Harness]:
|
||||
with Session(db_engine) as session:
|
||||
tasks = TaskPostgresRepository(session=session)
|
||||
subtasks = SubTaskPostgresRepository(session=session)
|
||||
yield Harness(
|
||||
orchestrator=TaskOrchestrator(task_repo=tasks, subtask_repo=subtasks),
|
||||
tasks=tasks,
|
||||
subtasks=subtasks,
|
||||
)
|
||||
|
||||
|
||||
def _direct_event(task_id: UUID, subtask_id: UUID) -> dict[str, Any]:
|
||||
return {"task_id": str(task_id), "sub_task_id": str(subtask_id)}
|
||||
|
||||
|
||||
def test_subtask_handler_injects_orchestrator_as_third_positional_argument(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
_, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
||||
received: dict[str, Any] = {}
|
||||
|
||||
@subtask_handler(orchestrator_cm=harness.factory)
|
||||
def handler(
|
||||
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
|
||||
) -> None:
|
||||
received["body"] = body
|
||||
received["context"] = context
|
||||
received["orchestrator"] = orchestrator
|
||||
|
||||
# act
|
||||
handler(_direct_event(subtask.task_id, subtask.id), context="ctx-sentinel")
|
||||
|
||||
# assert
|
||||
assert received["orchestrator"] is harness.orchestrator
|
||||
assert received["context"] == "ctx-sentinel"
|
||||
assert received["body"]["sub_task_id"] == str(subtask.id)
|
||||
|
||||
|
||||
def test_subtask_handler_completes_parent_subtask_on_success(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
||||
@subtask_handler(orchestrator_cm=harness.factory)
|
||||
def handler(
|
||||
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
# act
|
||||
handler(_direct_event(task.id, subtask.id), context=None)
|
||||
|
||||
# assert
|
||||
assert harness.subtasks.get(subtask.id).status is SubTaskStatus.COMPLETE
|
||||
assert harness.tasks.get(task.id).status is TaskStatus.COMPLETE
|
||||
|
||||
|
||||
def test_subtask_handler_marks_parent_failed_and_reraises_on_error(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
||||
@subtask_handler(orchestrator_cm=harness.factory)
|
||||
def handler(
|
||||
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
|
||||
) -> None:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
# act / assert
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
handler(_direct_event(task.id, subtask.id), context=None)
|
||||
|
||||
assert harness.subtasks.get(subtask.id).status is SubTaskStatus.FAILED
|
||||
assert harness.tasks.get(task.id).status is TaskStatus.FAILED
|
||||
|
||||
|
||||
def test_subtask_handler_injected_orchestrator_can_create_child_subtask(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
||||
child_ids: list[UUID] = []
|
||||
|
||||
@subtask_handler(orchestrator_cm=harness.factory)
|
||||
def handler(
|
||||
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
|
||||
) -> None:
|
||||
child = orchestrator.create_child_subtask(task.id, inputs={"split": 1})
|
||||
child_ids.append(child.id)
|
||||
|
||||
# act
|
||||
handler(_direct_event(task.id, subtask.id), context=None)
|
||||
|
||||
# assert
|
||||
assert len(child_ids) == 1
|
||||
persisted_child = harness.subtasks.get(child_ids[0])
|
||||
assert persisted_child.task_id == task.id
|
||||
assert persisted_child.status is SubTaskStatus.WAITING
|
||||
|
||||
|
||||
def test_subtask_handler_logs_subtask_lifecycle_on_success(
|
||||
harness: Harness, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
# arrange
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
||||
@subtask_handler(orchestrator_cm=harness.factory)
|
||||
def handler(
|
||||
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
# act
|
||||
with caplog.at_level(logging.INFO, logger=_LOGGER_NAME):
|
||||
handler(_direct_event(task.id, subtask.id), context=None)
|
||||
|
||||
# assert
|
||||
assert f"Running subtask {subtask.id}" in caplog.text
|
||||
assert f"Subtask {subtask.id} completed" in caplog.text
|
||||
|
||||
|
||||
def test_subtask_handler_logs_exception_on_failure(
|
||||
harness: Harness, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
# arrange
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
||||
@subtask_handler(orchestrator_cm=harness.factory)
|
||||
def handler(
|
||||
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
|
||||
) -> None:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
# act / assert
|
||||
with caplog.at_level(logging.INFO, logger=_LOGGER_NAME):
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
handler(_direct_event(task.id, subtask.id), context=None)
|
||||
|
||||
failures = [r for r in caplog.records if r.levelno == logging.ERROR]
|
||||
assert any(
|
||||
f"Subtask {subtask.id} failed" in r.getMessage() for r in failures
|
||||
)
|
||||
assert any(r.exc_info is not None for r in failures)
|
||||
|
||||
|
||||
def test_subtask_handler_records_cloudwatch_url_on_subtask(
|
||||
harness: Harness, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
# arrange
|
||||
monkeypatch.setenv("AWS_REGION", "eu-west-2")
|
||||
monkeypatch.setenv(
|
||||
"AWS_LAMBDA_LOG_GROUP_NAME", "/aws/lambda/postcode-splitter"
|
||||
)
|
||||
monkeypatch.setenv(
|
||||
"AWS_LAMBDA_LOG_STREAM_NAME", "2026/05/20/[$LATEST]abc123"
|
||||
)
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
||||
@subtask_handler(orchestrator_cm=harness.factory)
|
||||
def handler(
|
||||
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
# act
|
||||
handler(_direct_event(task.id, subtask.id), context=None)
|
||||
|
||||
# assert
|
||||
saved_url = harness.subtasks.get(subtask.id).cloud_logs_url
|
||||
assert saved_url is not None
|
||||
assert saved_url.startswith(
|
||||
"https://eu-west-2.console.aws.amazon.com/cloudwatch/home"
|
||||
)
|
||||
# Log group / stream are console-encoded ("/" -> "$252F").
|
||||
assert "$252Faws$252Flambda$252Fpostcode-splitter" in saved_url
|
||||
assert "$255B$2524LATEST$255D" in saved_url
|
||||
|
||||
|
||||
def test_subtask_handler_leaves_cloudwatch_url_unset_outside_lambda(
|
||||
harness: Harness, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
# arrange
|
||||
for var in (
|
||||
"AWS_REGION",
|
||||
"AWS_LAMBDA_LOG_GROUP_NAME",
|
||||
"AWS_LAMBDA_LOG_STREAM_NAME",
|
||||
):
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
||||
@subtask_handler(orchestrator_cm=harness.factory)
|
||||
def handler(
|
||||
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
# act
|
||||
handler(_direct_event(task.id, subtask.id), context=None)
|
||||
|
||||
# assert
|
||||
assert harness.subtasks.get(subtask.id).cloud_logs_url is None
|
||||
|
|
@ -5,14 +5,20 @@ TaskOrchestrator.run_subtask(...) calls.
|
|||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from contextlib import AbstractContextManager
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Optional, cast
|
||||
from urllib.parse import quote
|
||||
|
||||
from utilities.aws_lambda.default_orchestrator import default_orchestrator
|
||||
from utilities.aws_lambda.subtask_trigger_body import SubtaskTriggerBody
|
||||
from orchestration.task_orchestrator import TaskOrchestrator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
OrchestratorCM = Callable[[], AbstractContextManager[TaskOrchestrator]]
|
||||
|
||||
|
||||
|
|
@ -33,14 +39,26 @@ def subtask_handler(
|
|||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@wraps(func)
|
||||
def wrapper(event: dict[str, Any], context: Any) -> None:
|
||||
cloud_logs_url = _cloudwatch_url()
|
||||
with factory() as orchestrator:
|
||||
for record in _records(event):
|
||||
body = _parse_body(record)
|
||||
trigger = SubtaskTriggerBody.model_validate(body)
|
||||
orchestrator.run_subtask(
|
||||
trigger.sub_task_id,
|
||||
work=lambda body=body: func(body, context),
|
||||
)
|
||||
logger.info("Running subtask %s", trigger.sub_task_id)
|
||||
try:
|
||||
orchestrator.run_subtask(
|
||||
trigger.sub_task_id,
|
||||
work=lambda body=body, o=orchestrator: func(
|
||||
body, context, o
|
||||
),
|
||||
cloud_logs_url=cloud_logs_url,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Subtask %s failed", trigger.sub_task_id
|
||||
)
|
||||
raise
|
||||
logger.info("Subtask %s completed", trigger.sub_task_id)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
|
@ -65,3 +83,20 @@ def _records(event: dict[str, Any]) -> list[dict[str, Any]]:
|
|||
if isinstance(raw_records, list):
|
||||
return [r for r in cast(list[Any], raw_records) if isinstance(r, dict)]
|
||||
return [event]
|
||||
|
||||
|
||||
def _console_encode(value: str) -> str:
|
||||
return quote(value, safe="").replace("%", "$25")
|
||||
|
||||
|
||||
def _cloudwatch_url() -> Optional[str]:
|
||||
region = os.environ.get("AWS_REGION")
|
||||
log_group = os.environ.get("AWS_LAMBDA_LOG_GROUP_NAME")
|
||||
log_stream = os.environ.get("AWS_LAMBDA_LOG_STREAM_NAME")
|
||||
if not (region and log_group and log_stream):
|
||||
return None
|
||||
return (
|
||||
f"https://{region}.console.aws.amazon.com/cloudwatch/home"
|
||||
f"?region={region}#logsV2:log-groups/log-group/"
|
||||
f"{_console_encode(log_group)}/log-events/{_console_encode(log_stream)}"
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue