diff --git a/.gitignore b/.gitignore index 888d527a..9e5df0c7 100644 --- a/.gitignore +++ b/.gitignore @@ -121,6 +121,7 @@ celerybeat.pid # Environments .env +.env.local .venv env/ venv/ diff --git a/UBIQUITOUS_LANGUAGE.md b/UBIQUITOUS_LANGUAGE.md index 1765cbc8..c3074c02 100644 --- a/UBIQUITOUS_LANGUAGE.md +++ b/UBIQUITOUS_LANGUAGE.md @@ -23,7 +23,7 @@ Invoke `/ubiquitous-language` in any session to extract new terms from the conve |------|------------|------------------| | **UPRN** | Unique Property Reference Number — the government-issued permanent identifier for a physical address in the UK. | "property ID", "address ID", "code" | | **Postcode** | A UK postal code used to group nearby addresses; the primary search key for finding EPC records. | "zip code", "postal code" | -| **User Address** | A free-text address string provided by a user or imported from a customer dataset, before any normalisation or matching. | "user input", "raw address", "user_inputed_address" | +| **User Address** | A structured dataclass (`domain.addresses.user_address.UserAddress`) capturing a customer-supplied address: a free-text `user_address` line, a canonical `postcode` (sanitised on construction), and an optional `internal_reference`. The bare string sense -- the raw free-text address line as it arrives from upstream ingestion, before being wrapped -- remains valid when discussing CSV columns, API payloads, or other upstream contexts; in domain code, prefer the dataclass. | "user input", "raw address", "user_inputed_address" | | **Dwelling** | A single residential unit that can hold an EPC — a house, flat, or maisonette. | "property", "unit", "home" | ## Address Matching @@ -72,7 +72,7 @@ Invoke `/ubiquitous-language` in any session to extract new terms from the conve ## Flagged ambiguities -- **"address"** appears as both the raw **User Address** (free-text from customer data) and a structured field on an **EPC Search Result** (normalised address lines). Always qualify: "user address" vs "EPC address" or "address line 1". +- **"address"** appears as both the raw **User Address** (free-text from customer data, or the structured `UserAddress` dataclass that wraps it) and a structured field on an **EPC Search Result** (normalised address lines). Always qualify: "user address" vs "EPC address" or "address line 1". Within `domain/`, **User Address** specifically means the `UserAddress` dataclass; in upstream ingestion contexts (CSV columns, SQS payloads) it can still mean the raw string sense. - **"score"** is used for the `AddressMatch.score()` function output, the `lexiscore` DataFrame column, and informally in conversation. Prefer **Lexiscore** in domain discussions; reserve "score" for method-level code comments. - **"user_inputed_address"** in `backend/address2UPRN/main.py` is a misspelling and a synonym for **User Address** — the canonical term. New code should use `user_address`. - **"EPC"** is overloaded as both the document (an Energy Performance Certificate) and the rating band letter. Use **EPC** for the document and **EPC Band** for the letter. diff --git a/applications/__init__.py b/applications/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/applications/postcode_splitter/Dockerfile b/applications/postcode_splitter/Dockerfile new file mode 100644 index 00000000..aea1f914 --- /dev/null +++ b/applications/postcode_splitter/Dockerfile @@ -0,0 +1,34 @@ +FROM public.ecr.aws/lambda/python:3.11 + +# Postgres host/port/database are baked into the image at build time from +# the deploy workflow's --build-arg values (GitHub Actions DEV_DB_* secrets), +# mirroring backend/postcode_splitter/handler/Dockerfile. They map onto the +# POSTGRES_* names PostgresConfig.from_env reads. Username/password are NOT +# baked in -- Terraform injects those as Lambda env vars from Secrets Manager. +ARG DEV_DB_HOST +ARG DEV_DB_PORT +ARG DEV_DB_NAME + +ENV POSTGRES_HOST=${DEV_DB_HOST} +ENV POSTGRES_PORT=${DEV_DB_PORT} +ENV POSTGRES_DATABASE=${DEV_DB_NAME} + +WORKDIR /var/task + +COPY applications/postcode_splitter/requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Copy the layered source the handler imports from. The new splitter pulls +# only DDD-shaped packages — no pandas, no legacy backend/. +COPY domain/ domain/ +COPY infrastructure/ infrastructure/ +COPY orchestration/ orchestration/ +COPY repositories/ repositories/ +COPY utilities/ utilities/ +COPY applications/ applications/ + +# Place the handler at the Lambda task root so the runtime can resolve +# ``main.handler`` without an extra package prefix. +COPY applications/postcode_splitter/handler.py /var/task/main.py + +CMD ["main.handler"] diff --git a/applications/postcode_splitter/__init__.py b/applications/postcode_splitter/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/applications/postcode_splitter/handler.py b/applications/postcode_splitter/handler.py new file mode 100644 index 00000000..9fb3ca6a --- /dev/null +++ b/applications/postcode_splitter/handler.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import os +from typing import Any + +import boto3 + +from applications.postcode_splitter.postcode_splitter_trigger_body import ( + PostcodeSplitterTriggerBody, +) +from infrastructure.address2uprn_queue_client import Address2UprnQueueClient +from infrastructure.csv_s3_client import CsvS3Client +from orchestration.postcode_splitter_orchestrator import PostcodeSplitterOrchestrator +from orchestration.task_orchestrator import TaskOrchestrator +from repositories.user_address.user_address_csv_s3_repository import ( + UserAddressCsvS3Repository, +) +from utilities.aws_lambda.subtask_handler import subtask_handler + + +@subtask_handler() +def handler( + body: dict[str, Any], context: Any, task_orchestrator: TaskOrchestrator +) -> dict[str, list[str]]: + trigger = PostcodeSplitterTriggerBody.model_validate(body) + + bucket = os.environ["S3_BUCKET_NAME"] + queue_url = os.environ["ADDRESS2UPRN_QUEUE_URL"] + + # boto3.client is overloaded per-service in the installed stubs; cast + # to Any so the strict-mode checker treats it as opaque. + boto3_client: Any = boto3.client # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + boto_s3: Any = boto3_client("s3") + boto_sqs: Any = boto3_client("sqs") + + csv_client = CsvS3Client(boto_s3, bucket) + user_address_repo = UserAddressCsvS3Repository(csv_client, bucket) + queue_client = Address2UprnQueueClient(boto_sqs, queue_url) + + splitter = PostcodeSplitterOrchestrator( + task_orchestrator=task_orchestrator, + user_address_repo=user_address_repo, + queue_client=queue_client, + ) + + child_ids = splitter.split_and_dispatch( + parent_task_id=trigger.task_id, + parent_subtask_id=trigger.sub_task_id, + input_s3_uri=trigger.s3_uri, + ) + + return {"child_subtask_ids": [str(cid) for cid in child_ids]} diff --git a/applications/postcode_splitter/local_handler/.env.local.example b/applications/postcode_splitter/local_handler/.env.local.example new file mode 100644 index 00000000..28fa8390 --- /dev/null +++ b/applications/postcode_splitter/local_handler/.env.local.example @@ -0,0 +1,34 @@ +# Local-test environment for the postcode_splitter Lambda. +# +# cp .env.local.example .env.local then fill in the values below. +# +# .env.local is gitignored. The container hits REAL AWS and a REAL Postgres, +# so every value here points at infrastructure that actually exists. +# +# NOTE: the new DDD code uses different env var names than the repo root +# .env. The mapping (root .env name -> var here) is given per section. +# Keep comments on their own lines — docker-compose's env_file parser folds a +# trailing "# ..." into the value. + +# --- Postgres (orchestration/default_orchestrator -> PostgresConfig.from_env) --- +# POSTGRES_HOST <- DB_HOST, PORT <- DB_PORT, USERNAME <- DB_USERNAME, +# PASSWORD <- DB_PASSWORD, DATABASE <- DB_NAME. +POSTGRES_HOST= +POSTGRES_PORT=5432 +POSTGRES_USERNAME= +POSTGRES_PASSWORD= +POSTGRES_DATABASE= +# POSTGRES_DRIVER=psycopg2 (optional; defaults to psycopg2) + +# --- Handler config (applications/postcode_splitter/handler.py) --- +# S3_BUCKET_NAME: bucket holding the input address CSV (root .env: DATA_BUCKET). +# ADDRESS2UPRN_QUEUE_URL: SQS queue the splitter fans batches out to; not in +# the root .env (Terraform sets it in prod). +S3_BUCKET_NAME= +ADDRESS2UPRN_QUEUE_URL= + +# --- AWS credentials for boto3 (S3 + SQS clients) --- +AWS_ACCESS_KEY_ID= +AWS_SECRET_ACCESS_KEY= +AWS_DEFAULT_REGION=eu-west-2 +# AWS_SESSION_TOKEN= (only if using temporary/SSO credentials) diff --git a/applications/postcode_splitter/local_handler/docker-compose.yml b/applications/postcode_splitter/local_handler/docker-compose.yml new file mode 100644 index 00000000..68af1c40 --- /dev/null +++ b/applications/postcode_splitter/local_handler/docker-compose.yml @@ -0,0 +1,9 @@ +services: + postcode-splitter: + build: + context: ../../../ + dockerfile: applications/postcode_splitter/Dockerfile + ports: + - "9001:8080" + env_file: + - .env.local diff --git a/applications/postcode_splitter/local_handler/invoke_local_lambda.py b/applications/postcode_splitter/local_handler/invoke_local_lambda.py new file mode 100755 index 00000000..21fa9b9e --- /dev/null +++ b/applications/postcode_splitter/local_handler/invoke_local_lambda.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 +import json +import requests + +HOST = "localhost" +PORT = "9001" + +LAMBDA_URL = f"http://{HOST}:{PORT}/2015-03-31/functions/function/invocations" + +payload = { + "Records": [ + { + "body": json.dumps( + { + "task_id": "f4b3332f-c0cc-481f-96a5-d39860a647cf", + "sub_task_id": "14c042de-40c4-473b-8cd8-72c983a94a8d", + "s3_uri": "s3://retrofit-data-dev/ara_raw_inputs/calico/Calico Homes Full list EPC Properties(Sheet2) (1) (1).csv", + } + ) + } + ] +} + +response = requests.post(LAMBDA_URL, json=payload) + +print("Status code:", response.status_code) +print("Response:") +print(response.text) diff --git a/applications/postcode_splitter/local_handler/run_local.sh b/applications/postcode_splitter/local_handler/run_local.sh new file mode 100755 index 00000000..345b60ee --- /dev/null +++ b/applications/postcode_splitter/local_handler/run_local.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash +set -euo pipefail +cd "$(dirname "$0")" + +if [ ! -f .env.local ]; then + cp .env.local.example .env.local + echo "Created .env.local from the template — fill it in, then re-run." >&2 + exit 1 +fi + +docker compose build --no-cache +docker compose up --force-recreate diff --git a/applications/postcode_splitter/postcode_splitter_trigger_body.py b/applications/postcode_splitter/postcode_splitter_trigger_body.py new file mode 100644 index 00000000..4c33f4a4 --- /dev/null +++ b/applications/postcode_splitter/postcode_splitter_trigger_body.py @@ -0,0 +1,11 @@ +from uuid import UUID + +from pydantic import BaseModel, ConfigDict + + +class PostcodeSplitterTriggerBody(BaseModel): + model_config = ConfigDict(extra="allow") + + task_id: UUID + sub_task_id: UUID + s3_uri: str diff --git a/applications/postcode_splitter/requirements.txt b/applications/postcode_splitter/requirements.txt new file mode 100644 index 00000000..6a85a255 --- /dev/null +++ b/applications/postcode_splitter/requirements.txt @@ -0,0 +1,4 @@ +boto3 +pydantic +sqlmodel +psycopg2-binary diff --git a/backend/address2UPRN/handler/requirements.txt b/backend/address2UPRN/handler/requirements.txt index 6ef41b2d..02aaefba 100644 --- a/backend/address2UPRN/handler/requirements.txt +++ b/backend/address2UPRN/handler/requirements.txt @@ -8,4 +8,5 @@ boto3==1.35.44 sqlmodel sqlalchemy==2.0.36 psycopg2-binary==2.9.10 -pydantic-settings==2.6.0 \ No newline at end of file +pydantic-settings==2.6.0 +httpx \ No newline at end of file diff --git a/deployment/terraform/lambda/postcodeSplitter/main.tf b/deployment/terraform/lambda/postcodeSplitter/main.tf index 94c5cd4e..325f7dc7 100644 --- a/deployment/terraform/lambda/postcodeSplitter/main.tf +++ b/deployment/terraform/lambda/postcodeSplitter/main.tf @@ -40,20 +40,6 @@ module "lambda" { LOG_LEVEL = "info" DB_USERNAME = local.db_credentials.db_assessment_model_username DB_PASSWORD = local.db_credentials.db_assessment_model_password - GOOGLE_SOLAR_API_KEY = "test" - SAP_PREDICTIONS_BUCKET = "test" - CARBON_PREDICTIONS_BUCKET = "test" - HEAT_PREDICTIONS_BUCKET = "test" - HEATING_KWH_PREDICTIONS_BUCKET = "test" - HOTWATER_KWH_PREDICTIONS_BUCKET = "test" - API_KEY = "test" - ENVIRONMENT = "test" - SECRET_KEY = "test" - PLAN_TRIGGER_BUCKET = "test" - DATA_BUCKET = "test" - EPC_AUTH_TOKEN = "test" - ENGINE_SQS_URL = "test" - ENERGY_ASSESSMENTS_BUCKET = "test" ADDRESS2UPRN_QUEUE_URL = data.terraform_remote_state.address2uprn.outputs.address2uprn_queue_url S3_BUCKET_NAME = data.terraform_remote_state.shared.outputs.retrofit_sap_data_bucket_name }, diff --git a/domain/addresses/__init__.py b/domain/addresses/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/domain/addresses/postcode_batching.py b/domain/addresses/postcode_batching.py new file mode 100644 index 00000000..44e4d967 --- /dev/null +++ b/domain/addresses/postcode_batching.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from collections.abc import Iterable, Iterator + +from domain.addresses.user_address import UserAddress +from domain.postcode import Postcode + + +def iter_postcode_grouped_batches( + addresses: Iterable[UserAddress], + *, + max_batch_size: int = 500, +) -> Iterator[list[UserAddress]]: + if max_batch_size < 1: + raise ValueError("max_batch_size must be >= 1") + + groups = _group_by_postcode_in_order(addresses) + + buffer: list[UserAddress] = [] + for group in groups.values(): + group_len = len(group) + + # Oversize single-Postcode group: flush buffer first, then dispatch + # the group as its own batch. Mirrors the legacy + # ``if group_len >= batch_size`` branch. + if group_len >= max_batch_size: + if buffer: + yield buffer + buffer = [] + yield group + continue + + # Adding this group would overflow: flush buffer before appending. + if len(buffer) + group_len > max_batch_size: + yield buffer + buffer = [] + + buffer.extend(group) + + # Final flush. + if buffer: + yield buffer + + +def _group_by_postcode_in_order( + addresses: Iterable[UserAddress], +) -> dict[Postcode, list[UserAddress]]: + groups: dict[Postcode, list[UserAddress]] = {} + for address in addresses: + groups.setdefault(address.postcode, []).append(address) + return groups diff --git a/domain/addresses/user_address.py b/domain/addresses/user_address.py new file mode 100644 index 00000000..9a28751b --- /dev/null +++ b/domain/addresses/user_address.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional + +from domain.postcode import Postcode + + +def _empty_source_row() -> dict[str, str]: + return {} + + +@dataclass(frozen=True) +class UserAddress: + user_address: str + postcode: Postcode + internal_reference: Optional[str] = None + source_row: dict[str, str] = field(default_factory=_empty_source_row, compare=False) diff --git a/domain/postcode.py b/domain/postcode.py new file mode 100644 index 00000000..8e4e7c79 --- /dev/null +++ b/domain/postcode.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class Postcode: + value: str + + def __post_init__(self) -> None: + # Frozen dataclass: bypass the descriptor with object.__setattr__. + object.__setattr__(self, "value", "".join(self.value.split()).upper()) + + def __str__(self) -> str: + return self.value diff --git a/infrastructure/address2uprn_queue_client.py b/infrastructure/address2uprn_queue_client.py new file mode 100644 index 00000000..314e981f --- /dev/null +++ b/infrastructure/address2uprn_queue_client.py @@ -0,0 +1,20 @@ +from uuid import UUID + +from infrastructure.sqs_client import SqsClient + + +class Address2UprnQueueClient(SqsClient): + def publish( + self, + *, + parent_task_id: UUID, + child_subtask_id: UUID, + s3_uri: str, + ) -> str: + return self.send( + { + "task_id": str(parent_task_id), + "sub_task_id": str(child_subtask_id), + "s3_uri": s3_uri, + } + ) diff --git a/infrastructure/csv_s3_client.py b/infrastructure/csv_s3_client.py new file mode 100644 index 00000000..055d1ce3 --- /dev/null +++ b/infrastructure/csv_s3_client.py @@ -0,0 +1,28 @@ +import csv +from io import StringIO + +from infrastructure.s3_client import S3Client +from infrastructure.s3_uri import parse_s3_uri + + +class CsvS3Client(S3Client): + def read_rows(self, s3_uri: str) -> list[dict[str, str]]: + bucket, key = parse_s3_uri(s3_uri) + if bucket != self.bucket: + raise ValueError( + f"s3_uri bucket {bucket!r} does not match client bucket {self.bucket!r}" + ) + raw = self.get_object(key) + text = raw.decode("utf-8-sig") + reader = csv.DictReader(StringIO(text)) + return [dict(row) for row in reader] + + def save_rows(self, rows: list[dict[str, str]], key: str) -> str: + if not rows: + raise ValueError("Cannot save an empty rows list: header is unknown") + buffer = StringIO() + fieldnames = list(rows[0].keys()) + writer = csv.DictWriter(buffer, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(rows) + return self.put_object(key, buffer.getvalue().encode("utf-8")) diff --git a/infrastructure/s3_client.py b/infrastructure/s3_client.py new file mode 100644 index 00000000..a789fcc2 --- /dev/null +++ b/infrastructure/s3_client.py @@ -0,0 +1,22 @@ +from typing import Any + + +class S3Client: + def __init__(self, boto_s3_client: Any, bucket: str) -> None: + self._client = boto_s3_client + self._bucket = bucket + + @property + def bucket(self) -> str: + return self._bucket + + def get_object(self, key: str) -> bytes: + response: dict[str, Any] = self._client.get_object( + Bucket=self._bucket, Key=key + ) + body: bytes = response["Body"].read() + return body + + def put_object(self, key: str, body: bytes) -> str: + self._client.put_object(Bucket=self._bucket, Key=key, Body=body) + return f"s3://{self._bucket}/{key}" diff --git a/infrastructure/s3_uri.py b/infrastructure/s3_uri.py new file mode 100644 index 00000000..1dd5d967 --- /dev/null +++ b/infrastructure/s3_uri.py @@ -0,0 +1,25 @@ +from urllib.parse import unquote + + +def parse_s3_uri(s3_uri: str) -> tuple[str, str]: + if s3_uri.startswith("s3://"): + parts = s3_uri[len("s3://") :].split("/", 1) + if len(parts) < 2 or not parts[0] or not parts[1]: + raise ValueError("S3 URI must include both a bucket and a key") + return parts[0], parts[1] + + if "?" not in s3_uri: + raise ValueError(f"Not an s3:// URI and has no query string: {s3_uri!r}") + base, query = s3_uri.split("?", 1) + + if "/s3/object/" not in base: + raise ValueError(f"Console URL has no '/s3/object/' segment: {s3_uri!r}") + bucket = base.split("/s3/object/", 1)[1] + + params: dict[str, str] = {} + for item in query.split("&"): + if "=" in item: + name, value = item.split("=", 1) + params[name] = value + key = unquote(params.get("prefix", "")) + return bucket, key diff --git a/infrastructure/sqs_client.py b/infrastructure/sqs_client.py new file mode 100644 index 00000000..6fe8dd2e --- /dev/null +++ b/infrastructure/sqs_client.py @@ -0,0 +1,20 @@ +import json +from typing import Any + + +class SqsClient: + def __init__(self, boto_sqs_client: Any, queue_url: str) -> None: + self._client = boto_sqs_client + self._queue_url = queue_url + + @property + def queue_url(self) -> str: + return self._queue_url + + def send(self, body: dict[str, Any]) -> str: + response: dict[str, Any] = self._client.send_message( + QueueUrl=self._queue_url, + MessageBody=json.dumps(body), + ) + message_id: str = response["MessageId"] + return message_id diff --git a/orchestration/postcode_splitter_orchestrator.py b/orchestration/postcode_splitter_orchestrator.py new file mode 100644 index 00000000..36f4b515 --- /dev/null +++ b/orchestration/postcode_splitter_orchestrator.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from uuid import UUID + +from infrastructure.address2uprn_queue_client import Address2UprnQueueClient +from orchestration.task_orchestrator import TaskOrchestrator +from domain.addresses.postcode_batching import iter_postcode_grouped_batches +from repositories.user_address.user_address_repository import UserAddressRepository + + +class PostcodeSplitterOrchestrator: + def __init__( + self, + task_orchestrator: TaskOrchestrator, + user_address_repo: UserAddressRepository, + queue_client: Address2UprnQueueClient, + max_batch_size: int = 500, + ) -> None: + self._task_orchestrator = task_orchestrator + self._user_address_repo = user_address_repo + self._queue_client = queue_client + self._max_batch_size = max_batch_size + + def split_and_dispatch( + self, + *, + parent_task_id: UUID, + parent_subtask_id: UUID, + input_s3_uri: str, + ) -> list[UUID]: + addresses = self._user_address_repo.load_batch(input_s3_uri) + path_prefix = ( + f"ara_postcode_splitter_batches/{parent_task_id}/{parent_subtask_id}" + ) + + child_ids: list[UUID] = [] + for batch in iter_postcode_grouped_batches( + addresses, max_batch_size=self._max_batch_size + ): + batch_uri = self._user_address_repo.save_batch(batch, path_prefix) + child = self._task_orchestrator.create_child_subtask( + parent_task_id, + inputs={ + "task_id": str(parent_task_id), + "s3_uri": batch_uri, + }, + ) + self._queue_client.publish( + parent_task_id=parent_task_id, + child_subtask_id=child.id, + s3_uri=batch_uri, + ) + child_ids.append(child.id) + + return child_ids diff --git a/orchestration/task_orchestrator.py b/orchestration/task_orchestrator.py index 6c67d1ce..ebb71a32 100644 --- a/orchestration/task_orchestrator.py +++ b/orchestration/task_orchestrator.py @@ -48,6 +48,16 @@ class TaskOrchestrator: self._subtasks.create(subtask) return task, subtask + def create_child_subtask( + self, + parent_task_id: UUID, + *, + inputs: Optional[dict[str, Any]] = None, + ) -> SubTask: + subtask = SubTask.create(task_id=parent_task_id, inputs=inputs) + self._subtasks.create(subtask) + return subtask + def start_subtask( self, subtask_id: UUID, cloud_logs_url: Optional[str] = None ) -> SubTask: diff --git a/repositories/user_address/__init__.py b/repositories/user_address/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/repositories/user_address/user_address_csv_s3_repository.py b/repositories/user_address/user_address_csv_s3_repository.py new file mode 100644 index 00000000..058fd5a5 --- /dev/null +++ b/repositories/user_address/user_address_csv_s3_repository.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import uuid +from datetime import datetime, timezone +from typing import Optional + +from domain.addresses.user_address import UserAddress +from domain.postcode import Postcode +from infrastructure.csv_s3_client import CsvS3Client +from repositories.user_address.user_address_repository import UserAddressRepository + +_ADDRESS_COLUMNS: tuple[str, str, str] = ("Address 1", "Address 2", "Address 3") +_POSTCODE_COLUMN: str = "postcode" +_INTERNAL_REFERENCE_COLUMN: str = "Internal Reference" +_POSTCODE_CLEAN_COLUMN: str = "postcode_clean" + + +class UserAddressCsvS3Repository(UserAddressRepository): + def __init__(self, csv_client: CsvS3Client, bucket: str) -> None: + self._csv_client = csv_client + self._bucket = bucket + + def load_batch(self, s3_uri: str) -> list[UserAddress]: + rows = self._csv_client.read_rows(s3_uri) + if rows and _POSTCODE_COLUMN not in rows[0]: + raise ValueError( + f"Input CSV {s3_uri} has no {_POSTCODE_COLUMN!r} column; " + f"columns present: {sorted(rows[0])}" + ) + addresses: list[UserAddress] = [] + for row in rows: + parts = [ + row[col].strip() + for col in _ADDRESS_COLUMNS + if col in row and row[col].strip() + ] + user_address = ", ".join(parts) + postcode = row.get(_POSTCODE_COLUMN, "") + raw_ref = row.get(_INTERNAL_REFERENCE_COLUMN, "").strip() + internal_reference: Optional[str] = raw_ref or None + addresses.append( + UserAddress( + user_address=user_address, + postcode=Postcode(postcode), + internal_reference=internal_reference, + source_row=row, + ) + ) + return addresses + + def save_batch(self, addresses: list[UserAddress], path_prefix: str) -> str: + rows: list[dict[str, str]] = [ + {**addr.source_row, _POSTCODE_CLEAN_COLUMN: str(addr.postcode)} + for addr in addresses + ] + + # TODO: [New Starter Task] file_name generation can be standardised + # and also easier to read, test for future implementation. Buiild that! + filename = ( + f"{datetime.now(timezone.utc).isoformat()}_{uuid.uuid4().hex[:8]}.csv" + ) + key = f"{path_prefix.rstrip('/')}/{filename}" + return self._csv_client.save_rows(rows, key) diff --git a/repositories/user_address/user_address_repository.py b/repositories/user_address/user_address_repository.py new file mode 100644 index 00000000..b2c0f866 --- /dev/null +++ b/repositories/user_address/user_address_repository.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod + +from domain.addresses.user_address import UserAddress + + +class UserAddressRepository(ABC): + @abstractmethod + def load_batch(self, s3_uri: str) -> list[UserAddress]: ... + + @abstractmethod + def save_batch(self, addresses: list[UserAddress], path_prefix: str) -> str: ... diff --git a/test.requirements.txt b/test.requirements.txt index 7fdd7dc4..26125034 100644 --- a/test.requirements.txt +++ b/test.requirements.txt @@ -9,4 +9,5 @@ hubspot-api-client fuzzywuzzy pymupdf playwright==1.58.0 -msal \ No newline at end of file +msal +moto[s3,sqs] \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..0a246372 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,48 @@ +"""Shared pytest fixtures for the ``tests/`` tree. + +Provides an ephemeral PostgreSQL engine for tests that exercise SQLModel +repositories. PostgreSQL has no true in-memory mode; ``pytest-postgresql`` +starts a real, throwaway server in a temp directory (the process is started +once per session and a fresh database is created/dropped per test). That is +the closest equivalent to "in-memory" and matches production behaviour far +better than SQLite (enums, JSONB, constraint semantics, etc.). +""" + +from __future__ import annotations + +import glob +from collections.abc import Iterator +from typing import Any + +import pytest +from psycopg import Connection +from pytest_postgresql import factories +from sqlalchemy import Engine +from sqlmodel import SQLModel, create_engine + +# Importing the SQLModel row modules registers their tables on +# SQLModel.metadata so ``create_all`` builds the full schema. Imports look +# unused; they aren't. + + +# pg_ctl ships under a versioned path and is not on PATH in the dev container. +_PG_CTL = next(iter(sorted(glob.glob("/usr/lib/postgresql/*/bin/pg_ctl"))), "pg_ctl") + +postgresql_proc = factories.postgresql_proc( + executable=_PG_CTL +) # pyright: ignore[reportUnknownMemberType] +postgresql = factories.postgresql("postgresql_proc") + + +@pytest.fixture +def db_engine(postgresql: Connection[Any]) -> Iterator[Engine]: + """A SQLModel engine bound to a fresh, ephemeral PostgreSQL database.""" + info = postgresql.info + url = f"postgresql+psycopg://{info.user}:@{info.host}:{info.port}/{info.dbname}" + engine = create_engine(url) + SQLModel.metadata.create_all(engine) + try: + yield engine + finally: + SQLModel.metadata.drop_all(engine) + engine.dispose() diff --git a/tests/domain/addresses/__init__.py b/tests/domain/addresses/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/domain/addresses/test_postcode_batching.py b/tests/domain/addresses/test_postcode_batching.py new file mode 100644 index 00000000..8ffcf1b5 --- /dev/null +++ b/tests/domain/addresses/test_postcode_batching.py @@ -0,0 +1,118 @@ +import pytest + +from domain.addresses.postcode_batching import iter_postcode_grouped_batches +from domain.addresses.user_address import UserAddress +from domain.postcode import Postcode + + +def _addrs(postcode: str, n: int) -> list[UserAddress]: + return [ + UserAddress( + user_address=f"{i} {postcode} Street", postcode=Postcode(postcode) + ) + for i in range(n) + ] + + +def test_empty_input_yields_no_batches() -> None: + # act / assert + assert list(iter_postcode_grouped_batches([])) == [] + + +def test_single_batch_under_cap() -> None: + # arrange + addrs = _addrs("AA1 1AA", 3) + _addrs("BB2 2BB", 2) + # act + batches = list(iter_postcode_grouped_batches(addrs, max_batch_size=500)) + # assert + assert len(batches) == 1 + assert batches[0] == addrs + + +def test_multiple_postcodes_packed_into_one_batch_up_to_cap() -> None: + # Two groups whose total exactly equals the cap pack into a single + # batch -- no premature flush. + # arrange + addrs = _addrs("AA1 1AA", 3) + _addrs("BB2 2BB", 2) + # act + batches = list(iter_postcode_grouped_batches(addrs, max_batch_size=5)) + # assert + assert len(batches) == 1 + assert len(batches[0]) == 5 + + +def test_flush_on_overflow_before_adding_next_postcode() -> None: + # Cap is 5. First group fills 3 slots; second group of 3 would overflow, + # so the buffer is flushed first and the next group starts a fresh batch. + # arrange + addrs = _addrs("AA1 1AA", 3) + _addrs("BB2 2BB", 3) + # act + batches = list(iter_postcode_grouped_batches(addrs, max_batch_size=5)) + # assert + assert len(batches) == 2 + assert [str(a.postcode) for a in batches[0]] == ["AA11AA"] * 3 + assert [str(a.postcode) for a in batches[1]] == ["BB22BB"] * 3 + + +def test_single_postcode_group_exceeding_cap_is_dispatched_whole() -> None: + # An oversize single-postcode group goes out as one batch larger than + # the cap -- the cap never splits a postcode. + # arrange + addrs = _addrs("AA1 1AA", 7) + # act + batches = list(iter_postcode_grouped_batches(addrs, max_batch_size=5)) + # assert + assert len(batches) == 1 + assert len(batches[0]) == 7 + + +def test_oversize_group_flushes_existing_buffer_first() -> None: + # Mirrors the legacy ``if buffer: flush`` branch when an oversize group + # is encountered: buffered work must not be lost or interleaved. + # arrange + small = _addrs("AA1 1AA", 2) + big = _addrs("BB2 2BB", 7) + tail = _addrs("CC3 3CC", 1) + # act + batches = list( + iter_postcode_grouped_batches(small + big + tail, max_batch_size=5) + ) + # assert + assert len(batches) == 3 + assert [str(a.postcode) for a in batches[0]] == ["AA11AA", "AA11AA"] + assert [str(a.postcode) for a in batches[1]] == ["BB22BB"] * 7 + assert [str(a.postcode) for a in batches[2]] == ["CC33CC"] + + +def test_final_flush_yields_remaining_buffer() -> None: + # No overflow ever happens, but the trailing buffer must still come out. + # arrange + addrs = _addrs("AA1 1AA", 2) + _addrs("BB2 2BB", 2) + # act + batches = list(iter_postcode_grouped_batches(addrs, max_batch_size=500)) + # assert + assert batches == [addrs] + + +def test_postcode_grouping_preserves_first_seen_order() -> None: + # Interleaved input must still group by postcode and emit in first-seen + # order -- never alphabetical. + # arrange + a1, a2 = _addrs("ZZ9 9ZZ", 2) + b1, b2 = _addrs("AA1 1AA", 2) + # act + batches = list(iter_postcode_grouped_batches([a1, b1, a2, b2])) + # assert + assert len(batches) == 1 + assert [str(a.postcode) for a in batches[0]] == [ + "ZZ99ZZ", + "ZZ99ZZ", + "AA11AA", + "AA11AA", + ] + + +def test_invalid_max_batch_size_raises() -> None: + # act / assert + with pytest.raises(ValueError, match="max_batch_size"): + list(iter_postcode_grouped_batches([], max_batch_size=0)) diff --git a/tests/domain/addresses/test_user_address.py b/tests/domain/addresses/test_user_address.py new file mode 100644 index 00000000..8d092df3 --- /dev/null +++ b/tests/domain/addresses/test_user_address.py @@ -0,0 +1,98 @@ +import dataclasses + +import pytest + +from domain.addresses.user_address import UserAddress +from domain.postcode import Postcode + + +def test_user_address_holds_postcode_value_object() -> None: + # act + addr = UserAddress(user_address="1 The Street", postcode=Postcode("sw1a 1aa")) + # assert + assert addr.postcode == Postcode("SW1A1AA") + + +def test_user_address_preserves_user_address_verbatim() -> None: + # The free-text user_address string is intentionally NOT normalised -- + # only the postcode is canonicalised, and that happens inside Postcode. + # act + addr = UserAddress( + user_address=" 1 The Street ", postcode=Postcode("SW1A1AA") + ) + # assert + assert addr.user_address == " 1 The Street " + + +def test_user_address_internal_reference_defaults_to_none() -> None: + # act + addr = UserAddress(user_address="1 The Street", postcode=Postcode("SW1A1AA")) + # assert + assert addr.internal_reference is None + + +def test_user_address_internal_reference_accepted() -> None: + # act + addr = UserAddress( + user_address="1 The Street", + postcode=Postcode("SW1A1AA"), + internal_reference="cust-42", + ) + # assert + assert addr.internal_reference == "cust-42" + + +def test_user_address_is_frozen() -> None: + # arrange + addr = UserAddress(user_address="1 The Street", postcode=Postcode("SW1A1AA")) + # act / assert + with pytest.raises(dataclasses.FrozenInstanceError): + addr.postcode = Postcode("OTHER") # type: ignore[misc] + + +def test_user_address_equality_uses_canonical_postcode() -> None: + # Postcode sanitises eagerly, so addresses built from different surface + # forms of the same postcode compare equal. + # arrange + a = UserAddress(user_address="1 The Street", postcode=Postcode("sw1a 1aa")) + b = UserAddress(user_address="1 The Street", postcode=Postcode("SW1A1AA")) + # act / assert + assert a == b + + +def test_user_address_source_row_defaults_to_empty_dict() -> None: + # act + addr = UserAddress(user_address="1 The Street", postcode=Postcode("SW1A1AA")) + # assert + assert addr.source_row == {} + + +def test_user_address_carries_source_row() -> None: + # arrange + row = {"Address 1": "1 The Street", "postcode": "SW1A 1AA", "SAP Score": "72"} + # act + addr = UserAddress( + user_address="1 The Street", + postcode=Postcode("SW1A 1AA"), + source_row=row, + ) + # assert + assert addr.source_row == row + + +def test_user_address_equality_ignores_source_row() -> None: + # source_row is excluded from equality (and hashing): identity stays + # defined by the parsed fields. + # arrange + a = UserAddress( + user_address="1 The Street", + postcode=Postcode("SW1A1AA"), + source_row={"x": "1"}, + ) + b = UserAddress( + user_address="1 The Street", + postcode=Postcode("SW1A1AA"), + source_row={"y": "2"}, + ) + # act / assert + assert a == b diff --git a/tests/domain/tasks/test_subtasks.py b/tests/domain/tasks/test_subtasks.py index 2721d38f..8cee4496 100644 --- a/tests/domain/tasks/test_subtasks.py +++ b/tests/domain/tasks/test_subtasks.py @@ -6,10 +6,13 @@ from domain.tasks.subtasks import SubTask, SubTaskStatus def test_create_subtask_starts_waiting() -> None: + # arrange task_id = uuid4() + # act st = SubTask.create(task_id=task_id, inputs={"foo": "bar"}) + # assert assert st.task_id == task_id assert st.status is SubTaskStatus.WAITING assert st.inputs == {"foo": "bar"} @@ -19,57 +22,74 @@ def test_create_subtask_starts_waiting() -> None: def test_start_transitions_to_in_progress_and_sets_cloud_logs_url() -> None: + # arrange st = SubTask.create(task_id=uuid4()) + # act st.start(cloud_logs_url="https://example/log") + # assert assert st.status is SubTaskStatus.IN_PROGRESS assert st.cloud_logs_url == "https://example/log" assert st.job_started is not None def test_start_is_idempotent_from_in_progress() -> None: + # arrange st = SubTask.create(task_id=uuid4()) st.start() first_start = st.job_started + # act st.start(cloud_logs_url="https://other") + # assert assert st.status is SubTaskStatus.IN_PROGRESS assert st.job_started == first_start # not overwritten assert st.cloud_logs_url == "https://other" def test_start_rejects_from_terminal_status() -> None: + # arrange st = SubTask.create(task_id=uuid4()) st.complete() + # act / assert with pytest.raises(ValueError): st.start() def test_complete_marks_outputs_and_job_completed() -> None: + # arrange st = SubTask.create(task_id=uuid4()) st.start() + # act st.complete({"uprn": "123"}) + # assert assert st.status is SubTaskStatus.COMPLETE assert st.outputs == {"result": {"uprn": "123"}} assert st.job_completed is not None def test_complete_without_result_leaves_outputs_unset() -> None: + # arrange st = SubTask.create(task_id=uuid4()) + # act st.complete() + # assert assert st.outputs is None def test_fail_records_error_in_outputs() -> None: + # arrange st = SubTask.create(task_id=uuid4()) err = RuntimeError("boom") + # act st.fail(err) + # assert assert st.status is SubTaskStatus.FAILED assert st.outputs == {"error": "boom"} assert st.job_completed is not None diff --git a/tests/domain/tasks/test_tasks.py b/tests/domain/tasks/test_tasks.py index f30c0aa1..ba82412b 100644 --- a/tests/domain/tasks/test_tasks.py +++ b/tests/domain/tasks/test_tasks.py @@ -5,12 +5,12 @@ from domain.tasks.tasks import Source, Task, TaskStatus def test_create_task_starts_waiting() -> None: - # Arrange / Act + # arrange / act t = Task.create( task_source="manual:test", source=Source.PORTFOLIO, source_id="abc-123" ) - # Assert + # assert assert t.status is TaskStatus.WAITING assert t.source is Source.PORTFOLIO assert t.source_id == "abc-123" @@ -19,86 +19,113 @@ def test_create_task_starts_waiting() -> None: def test_create_task_rejects_blank_task_source() -> None: + # act / assert with pytest.raises(ValueError, match="task_source"): Task.create(task_source=" ") def test_start_transitions_to_in_progress() -> None: + # arrange t = Task.create(task_source="manual:test") + # act t.start() + # assert assert t.status is TaskStatus.IN_PROGRESS def test_complete_marks_job_completed() -> None: + # arrange t = Task.create(task_source="manual:test") t.start() + # act t.complete() + # assert assert t.status is TaskStatus.COMPLETE assert t.job_completed is not None def test_fail_marks_job_completed() -> None: + # arrange t = Task.create(task_source="manual:test") + # act t.fail() + # assert assert t.status is TaskStatus.FAILED assert t.job_completed is not None def test_start_rejects_from_terminal_status() -> None: + # arrange t = Task.create(task_source="manual:test") t.complete() + # act / assert with pytest.raises(ValueError): t.start() def test_recalculate_with_empty_statuses_is_noop() -> None: + # arrange t = Task.create(task_source="manual:test") original_status = t.status original_completed = t.job_completed + # act t.recalculate_from_subtasks([]) + # assert assert t.status is original_status assert t.job_completed is original_completed def test_recalculate_all_waiting_keeps_waiting() -> None: + # arrange t = Task.create(task_source="manual:test") t.start() # task moved to IN_PROGRESS earlier t.complete() # then COMPLETE, with job_completed set + # act t.recalculate_from_subtasks([SubTaskStatus.WAITING, SubTaskStatus.WAITING]) + # assert assert t.status is TaskStatus.WAITING assert t.job_completed is None def test_recalculate_any_in_progress_marks_in_progress() -> None: + # arrange t = Task.create(task_source="manual:test") + # act t.recalculate_from_subtasks( [SubTaskStatus.WAITING, SubTaskStatus.IN_PROGRESS, SubTaskStatus.COMPLETE] ) + # assert assert t.status is TaskStatus.IN_PROGRESS assert t.job_completed is None def test_recalculate_all_complete_marks_complete() -> None: + # arrange t = Task.create(task_source="manual:test") + # act t.recalculate_from_subtasks([SubTaskStatus.COMPLETE, SubTaskStatus.COMPLETE]) + # assert assert t.status is TaskStatus.COMPLETE assert t.job_completed is not None def test_recalculate_any_failed_marks_failed_even_with_others() -> None: + # arrange t = Task.create(task_source="manual:test") + # act t.recalculate_from_subtasks( [SubTaskStatus.IN_PROGRESS, SubTaskStatus.COMPLETE, SubTaskStatus.FAILED] ) + # assert assert t.status is TaskStatus.FAILED assert t.job_completed is not None diff --git a/tests/domain/test_postcode.py b/tests/domain/test_postcode.py new file mode 100644 index 00000000..f7ce9015 --- /dev/null +++ b/tests/domain/test_postcode.py @@ -0,0 +1,59 @@ +import dataclasses + +import pytest + +from domain.postcode import Postcode + + +def test_postcode_uppercases() -> None: + # act / assert + assert Postcode("sw1a1aa").value == "SW1A1AA" + + +def test_postcode_strips_internal_spaces() -> None: + # act / assert + assert Postcode("sw1a 1aa").value == "SW1A1AA" + + +def test_postcode_strips_leading_and_trailing_whitespace() -> None: + # act / assert + assert Postcode(" sw1a 1aa ").value == "SW1A1AA" + + +def test_postcode_strips_tabs_and_newlines() -> None: + # CSV ingestion occasionally introduces stray whitespace characters; the + # canonical form must absorb them just like literal spaces. + # act / assert + assert Postcode("sw1a\t1aa\n").value == "SW1A1AA" + + +def test_postcode_construction_is_idempotent() -> None: + # arrange + once = Postcode("sw1a 1aa") + # act / assert + assert Postcode(once.value).value == "SW1A1AA" + + +def test_postcode_empty_string() -> None: + # act / assert + assert Postcode("").value == "" + + +def test_postcode_str_returns_canonical_value() -> None: + # act / assert + assert str(Postcode("sw1a 1aa")) == "SW1A1AA" + + +def test_postcode_equality_ignores_surface_form() -> None: + # Differing case / whitespace sanitise to the same canonical value, so + # the value objects compare equal. + # act / assert + assert Postcode("sw1a 1aa") == Postcode("SW1A1AA") + + +def test_postcode_is_frozen() -> None: + # arrange + postcode = Postcode("SW1A1AA") + # act / assert + with pytest.raises(dataclasses.FrozenInstanceError): + postcode.value = "OTHER" # type: ignore[misc] diff --git a/tests/infrastructure/__init__.py b/tests/infrastructure/__init__.py new file mode 100644 index 00000000..f5ad62d0 --- /dev/null +++ b/tests/infrastructure/__init__.py @@ -0,0 +1,10 @@ +from typing import Any + +import boto3 + +REGION = "us-east-1" + + +def make_boto_client(service_name: str) -> Any: + factory: Any = boto3.client # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + return factory(service_name, region_name=REGION) diff --git a/tests/infrastructure/conftest.py b/tests/infrastructure/conftest.py new file mode 100644 index 00000000..25c1ac3b --- /dev/null +++ b/tests/infrastructure/conftest.py @@ -0,0 +1,28 @@ +import os +from collections.abc import Iterator +from typing import Optional + +import pytest + + +@pytest.fixture(autouse=True) +def _aws_creds() -> Iterator[None]: # pyright: ignore[reportUnusedFunction] + keys = ( + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", + "AWS_DEFAULT_REGION", + ) + prev: dict[str, Optional[str]] = {k: os.environ.get(k) for k in keys} + os.environ["AWS_ACCESS_KEY_ID"] = "testing" + os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" + os.environ["AWS_SESSION_TOKEN"] = "testing" + os.environ["AWS_DEFAULT_REGION"] = "us-east-1" + try: + yield + finally: + for k, v in prev.items(): + if v is None: + os.environ.pop(k, None) + else: + os.environ[k] = v diff --git a/tests/infrastructure/test_address2uprn_queue_client.py b/tests/infrastructure/test_address2uprn_queue_client.py new file mode 100644 index 00000000..c8e89ece --- /dev/null +++ b/tests/infrastructure/test_address2uprn_queue_client.py @@ -0,0 +1,71 @@ +import json +from collections.abc import Iterator +from typing import Any, cast +from uuid import uuid4 + +import pytest +from moto import mock_aws + +from infrastructure.address2uprn_queue_client import Address2UprnQueueClient +from tests.infrastructure import make_boto_client + + +@pytest.fixture +def queue_setup() -> Iterator[tuple[Address2UprnQueueClient, Any, str]]: + with mock_aws(): + boto_client = make_boto_client("sqs") + queue: dict[str, Any] = boto_client.create_queue( + QueueName="address2uprn-queue" + ) + queue_url = cast(str, queue["QueueUrl"]) + yield ( + Address2UprnQueueClient(boto_client, queue_url), + boto_client, + queue_url, + ) + + +def test_publish_returns_message_id( + queue_setup: tuple[Address2UprnQueueClient, Any, str], +) -> None: + # arrange + client, _boto, _url = queue_setup + # act + message_id = client.publish( + parent_task_id=uuid4(), + child_subtask_id=uuid4(), + s3_uri="s3://my-bucket/path/to/chunk.csv", + ) + # assert + assert isinstance(message_id, str) + assert message_id + + +def test_publish_body_uses_typed_shape( + queue_setup: tuple[Address2UprnQueueClient, Any, str], +) -> None: + # arrange + client, boto_client, queue_url = queue_setup + parent_id = uuid4() + child_id = uuid4() + s3_uri = "s3://my-bucket/path/to/chunk.csv" + + # act + client.publish( + parent_task_id=parent_id, + child_subtask_id=child_id, + s3_uri=s3_uri, + ) + + # assert + received: dict[str, Any] = boto_client.receive_message( + QueueUrl=queue_url, MaxNumberOfMessages=1 + ) + messages: list[dict[str, Any]] = received["Messages"] + assert len(messages) == 1 + body = json.loads(messages[0]["Body"]) + assert body == { + "task_id": str(parent_id), + "sub_task_id": str(child_id), + "s3_uri": s3_uri, + } diff --git a/tests/infrastructure/test_csv_s3_client.py b/tests/infrastructure/test_csv_s3_client.py new file mode 100644 index 00000000..30e27164 --- /dev/null +++ b/tests/infrastructure/test_csv_s3_client.py @@ -0,0 +1,51 @@ +from collections.abc import Iterator + +import pytest +from moto import mock_aws + +from infrastructure.csv_s3_client import CsvS3Client +from tests.infrastructure import make_boto_client + +BUCKET = "csv-bucket" + + +@pytest.fixture +def csv_client() -> Iterator[CsvS3Client]: + with mock_aws(): + boto_client = make_boto_client("s3") + boto_client.create_bucket(Bucket=BUCKET) + yield CsvS3Client(boto_client, BUCKET) + + +def test_save_rows_returns_s3_uri(csv_client: CsvS3Client) -> None: + # arrange + rows = [{"address": "1 High St", "postcode": "AB1 2CD"}] + # act + uri = csv_client.save_rows(rows, "uploads/addresses.csv") + # assert + assert uri == f"s3://{BUCKET}/uploads/addresses.csv" + + +def test_round_trip_preserves_rows(csv_client: CsvS3Client) -> None: + # arrange + rows = [ + {"address": "1 High St", "postcode": "AB1 2CD"}, + {"address": "2 Low St", "postcode": "XY9 8ZW"}, + ] + # act + uri = csv_client.save_rows(rows, "uploads/addresses.csv") + fetched = csv_client.read_rows(uri) + # assert + assert fetched == rows + + +def test_save_rows_rejects_empty_list(csv_client: CsvS3Client) -> None: + # act / assert + with pytest.raises(ValueError, match="empty"): + csv_client.save_rows([], "uploads/empty.csv") + + +def test_read_rows_rejects_wrong_bucket(csv_client: CsvS3Client) -> None: + # act / assert + with pytest.raises(ValueError, match="does not match client bucket"): + csv_client.read_rows("s3://other-bucket/uploads/addresses.csv") diff --git a/tests/infrastructure/test_s3_client.py b/tests/infrastructure/test_s3_client.py new file mode 100644 index 00000000..67db4f58 --- /dev/null +++ b/tests/infrastructure/test_s3_client.py @@ -0,0 +1,36 @@ +from collections.abc import Iterator + +import pytest +from moto import mock_aws + +from infrastructure.s3_client import S3Client +from tests.infrastructure import make_boto_client + +BUCKET = "test-bucket" + + +@pytest.fixture +def s3_client() -> Iterator[S3Client]: + with mock_aws(): + boto_client = make_boto_client("s3") + boto_client.create_bucket(Bucket=BUCKET) + yield S3Client(boto_client, BUCKET) + + +def test_put_object_returns_s3_uri(s3_client: S3Client) -> None: + # act + uri = s3_client.put_object("folder/data.bin", b"payload") + # assert + assert uri == f"s3://{BUCKET}/folder/data.bin" + + +def test_get_object_returns_bytes_written_by_put_object(s3_client: S3Client) -> None: + # arrange + s3_client.put_object("round/trip.bin", b"hello world") + # act / assert + assert s3_client.get_object("round/trip.bin") == b"hello world" + + +def test_bucket_property_exposes_configured_bucket(s3_client: S3Client) -> None: + # act / assert + assert s3_client.bucket == BUCKET diff --git a/tests/infrastructure/test_s3_uri.py b/tests/infrastructure/test_s3_uri.py new file mode 100644 index 00000000..32fd710f --- /dev/null +++ b/tests/infrastructure/test_s3_uri.py @@ -0,0 +1,40 @@ +import pytest + +from infrastructure.s3_uri import parse_s3_uri + + +def test_parses_simple_s3_uri() -> None: + # act / assert + assert parse_s3_uri("s3://my-bucket/file.csv") == ("my-bucket", "file.csv") + + +def test_parses_s3_uri_with_nested_key() -> None: + # act + bucket, key = parse_s3_uri("s3://my-bucket/nested/path/to/file.csv") + # assert + assert (bucket, key) == ("my-bucket", "nested/path/to/file.csv") + + +def test_rejects_s3_uri_without_key() -> None: + # act / assert + with pytest.raises(ValueError, match="bucket and a key"): + parse_s3_uri("s3://my-bucket") + + +def test_rejects_s3_uri_with_empty_key() -> None: + # act / assert + with pytest.raises(ValueError, match="bucket and a key"): + parse_s3_uri("s3://my-bucket/") + + +def test_parses_console_url_prefix() -> None: + # arrange + url = "https://eu-west-2.console.aws.amazon.com/s3/object/my-bucket?prefix=nested%2Ffile.csv" + # act / assert + assert parse_s3_uri(url) == ("my-bucket", "nested/file.csv") + + +def test_rejects_unparseable_string() -> None: + # act / assert + with pytest.raises(ValueError): + parse_s3_uri("not-a-uri-at-all") diff --git a/tests/infrastructure/test_sqs_client.py b/tests/infrastructure/test_sqs_client.py new file mode 100644 index 00000000..44186bbb --- /dev/null +++ b/tests/infrastructure/test_sqs_client.py @@ -0,0 +1,44 @@ +import json +from collections.abc import Iterator +from typing import Any, cast + +import pytest +from moto import mock_aws + +from infrastructure.sqs_client import SqsClient +from tests.infrastructure import make_boto_client + + +@pytest.fixture +def sqs_setup() -> Iterator[tuple[SqsClient, Any, str]]: + with mock_aws(): + boto_client = make_boto_client("sqs") + queue: dict[str, Any] = boto_client.create_queue(QueueName="test-queue") + queue_url = cast(str, queue["QueueUrl"]) + yield SqsClient(boto_client, queue_url), boto_client, queue_url + + +def test_send_returns_message_id(sqs_setup: tuple[SqsClient, Any, str]) -> None: + # arrange + client, _boto, _url = sqs_setup + # act + message_id = client.send({"hello": "world"}) + # assert + assert isinstance(message_id, str) + assert message_id + + +def test_send_json_serialises_body(sqs_setup: tuple[SqsClient, Any, str]) -> None: + # arrange + client, boto_client, queue_url = sqs_setup + body = {"hello": "world", "count": 3} + # act + client.send(body) + + # assert + received: dict[str, Any] = boto_client.receive_message( + QueueUrl=queue_url, MaxNumberOfMessages=1 + ) + messages: list[dict[str, Any]] = received["Messages"] + assert len(messages) == 1 + assert json.loads(messages[0]["Body"]) == body diff --git a/tests/orchestration/test_postcode_splitter_orchestrator.py b/tests/orchestration/test_postcode_splitter_orchestrator.py new file mode 100644 index 00000000..a718ffbc --- /dev/null +++ b/tests/orchestration/test_postcode_splitter_orchestrator.py @@ -0,0 +1,299 @@ +from __future__ import annotations + +import json +import os +from collections.abc import Iterator +from dataclasses import dataclass +from typing import Any, cast + +import boto3 +import pytest +from moto import mock_aws +from sqlalchemy import Engine +from sqlmodel import Session + +from infrastructure.address2uprn_queue_client import Address2UprnQueueClient +from infrastructure.csv_s3_client import CsvS3Client +from orchestration.postcode_splitter_orchestrator import PostcodeSplitterOrchestrator +from orchestration.task_orchestrator import TaskOrchestrator +from repositories.tasks.subtask_postgres_repository import SubTaskPostgresRepository +from repositories.tasks.task_postgres_repository import TaskPostgresRepository +from repositories.user_address.user_address_csv_s3_repository import ( + UserAddressCsvS3Repository, +) + +BUCKET = "splitter-bucket" +REGION = "us-east-1" + + +def _make_boto_client(service_name: str) -> Any: + factory: Any = boto3.client # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + return factory(service_name, region_name=REGION) + + +@pytest.fixture(autouse=True) +def _aws_creds() -> Iterator[None]: # pyright: ignore[reportUnusedFunction] + keys = ( + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", + "AWS_DEFAULT_REGION", + ) + prev: dict[str, Any] = {k: os.environ.get(k) for k in keys} + os.environ["AWS_ACCESS_KEY_ID"] = "testing" + os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" + os.environ["AWS_SESSION_TOKEN"] = "testing" + os.environ["AWS_DEFAULT_REGION"] = REGION + try: + yield + finally: + for k, v in prev.items(): + if v is None: + os.environ.pop(k, None) + else: + os.environ[k] = v + + +@dataclass +class Harness: + splitter: PostcodeSplitterOrchestrator + task_orchestrator: TaskOrchestrator + subtasks: SubTaskPostgresRepository + csv_client: CsvS3Client + boto_sqs: Any + queue_url: str + repo: UserAddressCsvS3Repository + + +@pytest.fixture +def harness(db_engine: Engine) -> Iterator[Harness]: + with mock_aws(): + # Infra: S3 + SQS + boto_s3 = _make_boto_client("s3") + boto_s3.create_bucket(Bucket=BUCKET) + boto_sqs = _make_boto_client("sqs") + queue: dict[str, Any] = boto_sqs.create_queue(QueueName="address2uprn-queue") + queue_url = cast(str, queue["QueueUrl"]) + + csv_client = CsvS3Client(boto_s3, BUCKET) + repo = UserAddressCsvS3Repository(csv_client, BUCKET) + queue_client = Address2UprnQueueClient(boto_sqs, queue_url) + + # DB: ephemeral PostgreSQL TaskOrchestrator + with Session(db_engine) as session: + task_repo = TaskPostgresRepository(session=session) + subtask_repo = SubTaskPostgresRepository(session=session) + task_orchestrator = TaskOrchestrator( + task_repo=task_repo, subtask_repo=subtask_repo + ) + + splitter = PostcodeSplitterOrchestrator( + task_orchestrator=task_orchestrator, + user_address_repo=repo, + queue_client=queue_client, + max_batch_size=3, + ) + + yield Harness( + splitter=splitter, + task_orchestrator=task_orchestrator, + subtasks=subtask_repo, + csv_client=csv_client, + boto_sqs=boto_sqs, + queue_url=queue_url, + repo=repo, + ) + + +def _upload_fixture_csv(csv_client: CsvS3Client) -> str: + # Three postcode groups: + # AA1 1AA × 2 (within cap) + # BB2 2BB × 4 (oversize: > max_batch_size=3) + # CC3 3CC × 1 (final flush) + # Expected batching with cap=3 and the algorithm in + # ``iter_postcode_grouped_batches``: + # batch 1: [AA1 1AA × 2] (flushed because oversize follows) + # batch 2: [BB2 2BB × 4] (oversize own batch) + # batch 3: [CC3 3CC × 1] (final flush) + rows: list[dict[str, str]] = [] + rows.extend( + { + "Address 1": f"{i} High St", + "Address 2": "", + "Address 3": "", + "postcode": "AA1 1AA", + "Internal Reference": f"AA-{i}", + } + for i in range(1, 3) + ) + rows.extend( + { + "Address 1": f"{i} Long Road", + "Address 2": "", + "Address 3": "", + "postcode": "BB2 2BB", + "Internal Reference": f"BB-{i}", + } + for i in range(1, 5) + ) + rows.append( + { + "Address 1": "1 Final Way", + "Address 2": "", + "Address 3": "", + "postcode": "CC3 3CC", + "Internal Reference": "CC-1", + } + ) + return csv_client.save_rows(rows, "uploads/input.csv") + + +def _drain_queue(boto_sqs: Any, queue_url: str) -> list[dict[str, Any]]: + bodies: list[dict[str, Any]] = [] + while True: + received: dict[str, Any] = boto_sqs.receive_message( + QueueUrl=queue_url, MaxNumberOfMessages=10, WaitTimeSeconds=0 + ) + messages = cast(list[dict[str, Any]], received.get("Messages", [])) + if not messages: + break + for message in messages: + bodies.append(cast(dict[str, Any], json.loads(message["Body"]))) + boto_sqs.delete_message( + QueueUrl=queue_url, ReceiptHandle=message["ReceiptHandle"] + ) + return bodies + + +def test_split_and_dispatch_creates_three_children_for_fixture( + harness: Harness, +) -> None: + # arrange + parent_task, parent_subtask = ( + harness.task_orchestrator.create_task_with_subtask( + task_source="manual:postcode-splitter-int" + ) + ) + input_uri = _upload_fixture_csv(harness.csv_client) + + # act + child_ids = harness.splitter.split_and_dispatch( + parent_task_id=parent_task.id, + parent_subtask_id=parent_subtask.id, + input_s3_uri=input_uri, + ) + + # assert + assert len(child_ids) == 3 + # All child ids are unique and persisted as WAITING children of the + # parent task. + assert len(set(child_ids)) == 3 + for cid in child_ids: + child = harness.subtasks.get(cid) + assert child.task_id == parent_task.id + + +def test_split_and_dispatch_persists_child_inputs_with_task_id_and_s3_uri( + harness: Harness, +) -> None: + # arrange + parent_task, parent_subtask = ( + harness.task_orchestrator.create_task_with_subtask( + task_source="manual:postcode-splitter-int" + ) + ) + input_uri = _upload_fixture_csv(harness.csv_client) + + # act + child_ids = harness.splitter.split_and_dispatch( + parent_task_id=parent_task.id, + parent_subtask_id=parent_subtask.id, + input_s3_uri=input_uri, + ) + + # assert + for cid in child_ids: + child = harness.subtasks.get(cid) + assert child.inputs is not None + assert child.inputs["task_id"] == str(parent_task.id) + batch_uri = child.inputs["s3_uri"] + assert isinstance(batch_uri, str) + prefix = ( + f"s3://{BUCKET}/ara_postcode_splitter_batches/" + f"{parent_task.id}/{parent_subtask.id}/" + ) + assert batch_uri.startswith(prefix) + assert batch_uri.endswith(".csv") + + +def test_split_and_dispatch_publishes_one_message_per_child_with_matching_ids( + harness: Harness, +) -> None: + # arrange + parent_task, parent_subtask = ( + harness.task_orchestrator.create_task_with_subtask( + task_source="manual:postcode-splitter-int" + ) + ) + input_uri = _upload_fixture_csv(harness.csv_client) + + # act + child_ids = harness.splitter.split_and_dispatch( + parent_task_id=parent_task.id, + parent_subtask_id=parent_subtask.id, + input_s3_uri=input_uri, + ) + + # assert + bodies = _drain_queue(harness.boto_sqs, harness.queue_url) + assert len(bodies) == len(child_ids) + + # Match queue messages against persisted child inputs by child_subtask_id; + # the message body's task_id/s3_uri must agree with the SubTask inputs. + bodies_by_child = {body["sub_task_id"]: body for body in bodies} + assert set(bodies_by_child.keys()) == {str(cid) for cid in child_ids} + for cid in child_ids: + child = harness.subtasks.get(cid) + body = bodies_by_child[str(cid)] + assert child.inputs is not None + assert body == { + "task_id": str(parent_task.id), + "sub_task_id": str(cid), + "s3_uri": child.inputs["s3_uri"], + } + + +def test_split_and_dispatch_returns_child_ids_in_dispatch_order( + harness: Harness, +) -> None: + # arrange + parent_task, parent_subtask = ( + harness.task_orchestrator.create_task_with_subtask( + task_source="manual:postcode-splitter-int" + ) + ) + input_uri = _upload_fixture_csv(harness.csv_client) + + # act + child_ids = harness.splitter.split_and_dispatch( + parent_task_id=parent_task.id, + parent_subtask_id=parent_subtask.id, + input_s3_uri=input_uri, + ) + + # assert + # Re-load each child's saved batch and inspect the postcode_clean column + # to confirm the dispatch order matches the postcode-batching algorithm: + # AA-batch first, BB oversize batch second, CC final-flush third. + postcodes_per_batch: list[set[str]] = [] + for cid in child_ids: + child = harness.subtasks.get(cid) + assert child.inputs is not None + rows = harness.csv_client.read_rows(child.inputs["s3_uri"]) + postcodes_per_batch.append({row["postcode_clean"] for row in rows}) + + assert postcodes_per_batch == [ + {"AA11AA"}, + {"BB22BB"}, + {"CC33CC"}, + ] diff --git a/tests/orchestration/test_task_orchestrator.py b/tests/orchestration/test_task_orchestrator.py index 1a48127f..ae89991d 100644 --- a/tests/orchestration/test_task_orchestrator.py +++ b/tests/orchestration/test_task_orchestrator.py @@ -2,7 +2,8 @@ from collections.abc import Iterator from dataclasses import dataclass import pytest -from sqlmodel import Session, SQLModel, create_engine +from sqlalchemy import Engine +from sqlmodel import Session from domain.tasks.subtasks import SubTask, SubTaskStatus from domain.tasks.tasks import Source, TaskStatus @@ -19,10 +20,8 @@ class Harness: @pytest.fixture -def harness() -> Iterator[Harness]: - engine = create_engine("sqlite://") - SQLModel.metadata.create_all(engine) - with Session(engine) as session: +def harness(db_engine: Engine) -> Iterator[Harness]: + with Session(db_engine) as session: tasks = TaskPostgresRepository(session=session) subtasks = SubTaskPostgresRepository(session=session) yield Harness( @@ -35,6 +34,7 @@ def harness() -> Iterator[Harness]: def test_create_task_with_subtask_creates_both_in_waiting( harness: Harness, ) -> None: + # act task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test", inputs={"foo": "bar"}, @@ -42,6 +42,7 @@ def test_create_task_with_subtask_creates_both_in_waiting( source_id="abc", ) + # assert assert task.status is TaskStatus.WAITING assert subtask.status is SubTaskStatus.WAITING assert subtask.task_id == task.id @@ -49,27 +50,33 @@ def test_create_task_with_subtask_creates_both_in_waiting( def test_start_subtask_cascades_to_in_progress(harness: Harness) -> None: + # arrange task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) + # act started = harness.orchestrator.start_subtask( subtask.id, cloud_logs_url="https://example/log" ) + # assert assert started.status is SubTaskStatus.IN_PROGRESS assert started.cloud_logs_url == "https://example/log" assert harness.tasks.get(task.id).status is TaskStatus.IN_PROGRESS def test_complete_subtask_cascades_to_complete(harness: Harness) -> None: + # arrange task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) harness.orchestrator.start_subtask(subtask.id) + # act harness.orchestrator.complete_subtask(subtask.id, {"value": 42}) + # assert done_subtask = harness.subtasks.get(subtask.id) done_task = harness.tasks.get(task.id) assert done_subtask.outputs == {"result": {"value": 42}} @@ -78,12 +85,15 @@ def test_complete_subtask_cascades_to_complete(harness: Harness) -> None: def test_fail_subtask_cascades_to_failed(harness: Harness) -> None: + # arrange task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) + # act harness.orchestrator.fail_subtask(subtask.id, RuntimeError("boom")) + # assert failed_subtask = harness.subtasks.get(subtask.id) failed_task = harness.tasks.get(task.id) assert failed_subtask.outputs == {"error": "boom"} @@ -93,50 +103,85 @@ def test_fail_subtask_cascades_to_failed(harness: Harness) -> None: def test_failed_subtask_locks_task_failed_even_with_others_complete( harness: Harness, ) -> None: + # arrange task, first = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) second = SubTask.create(task_id=task.id) harness.subtasks.create(second) + # act harness.orchestrator.complete_subtask(first.id) harness.orchestrator.fail_subtask(second.id, RuntimeError("nope")) + # assert assert harness.tasks.get(task.id).status is TaskStatus.FAILED def test_mixed_complete_and_in_progress_keeps_task_in_progress( harness: Harness, ) -> None: + # arrange task, first = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) second = SubTask.create(task_id=task.id) harness.subtasks.create(second) + # act harness.orchestrator.complete_subtask(first.id) harness.orchestrator.start_subtask(second.id) + # assert assert harness.tasks.get(task.id).status is TaskStatus.IN_PROGRESS def test_run_subtask_happy_path_returns_result_and_cascades_complete( harness: Harness, ) -> None: + # arrange task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) + # act result = harness.orchestrator.run_subtask(subtask.id, work=lambda: {"answer": 42}) + # assert assert result == {"answer": 42} assert harness.subtasks.get(subtask.id).status is SubTaskStatus.COMPLETE assert harness.tasks.get(task.id).status is TaskStatus.COMPLETE +def test_create_child_subtask_adds_waiting_child_without_changing_parent_status( + harness: Harness, +) -> None: + # arrange + task, first = harness.orchestrator.create_task_with_subtask( + task_source="manual:test" + ) + harness.orchestrator.start_subtask(first.id) + assert harness.tasks.get(task.id).status is TaskStatus.IN_PROGRESS + + # act + child = harness.orchestrator.create_child_subtask( + task.id, inputs={"split": "a"} + ) + + # assert + persisted_child = harness.subtasks.get(child.id) + assert persisted_child.task_id == task.id + assert persisted_child.status is SubTaskStatus.WAITING + assert persisted_child.inputs == {"split": "a"} + assert persisted_child.id != first.id + # Cascade is a no-op: parent stays IN_PROGRESS. + assert harness.tasks.get(task.id).status is TaskStatus.IN_PROGRESS + + def test_run_subtask_failing_work_marks_failed_and_reraises( harness: Harness, ) -> None: + # arrange task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) @@ -144,6 +189,7 @@ def test_run_subtask_failing_work_marks_failed_and_reraises( def boom() -> None: raise RuntimeError("boom") + # act / assert with pytest.raises(RuntimeError, match="boom"): harness.orchestrator.run_subtask(subtask.id, work=boom) diff --git a/tests/repositories/tasks/postgres/test_subtask_postgres_repository.py b/tests/repositories/tasks/postgres/test_subtask_postgres_repository.py index ac39e089..9cec52ea 100644 --- a/tests/repositories/tasks/postgres/test_subtask_postgres_repository.py +++ b/tests/repositories/tasks/postgres/test_subtask_postgres_repository.py @@ -1,33 +1,40 @@ from collections.abc import Iterator -from uuid import uuid4 +from uuid import UUID, uuid4 import pytest -from sqlmodel import Session, SQLModel, create_engine +from sqlalchemy import Engine +from sqlmodel import Session -# Importing the SQLModel row modules registers their tables in -# SQLModel.metadata so create_all builds both. Imports look unused; they aren't. -import infrastructure.postgres.subtask_table # noqa: F401 # pyright: ignore[reportUnusedImport] -import infrastructure.postgres.task_table # noqa: F401 # pyright: ignore[reportUnusedImport] from domain.tasks.subtasks import SubTask, SubTaskStatus +from domain.tasks.tasks import Task from repositories.tasks.subtask_postgres_repository import SubTaskPostgresRepository +from repositories.tasks.task_postgres_repository import TaskPostgresRepository @pytest.fixture -def session() -> Iterator[Session]: - engine = create_engine("sqlite://") - SQLModel.metadata.create_all(engine) - with Session(engine) as s: +def session(db_engine: Engine) -> Iterator[Session]: + with Session(db_engine) as s: yield s +def _persisted_task_id(session: Session) -> UUID: + """Create a parent Task row so SubTask FK constraints are satisfied.""" + task = Task.create(task_source="manual:test") + TaskPostgresRepository(session=session).create(task) + return task.id + + def test_create_and_get_round_trip_preserves_inputs(session: Session) -> None: + # arrange repo = SubTaskPostgresRepository(session=session) - task_id = uuid4() + task_id = _persisted_task_id(session) st = SubTask.create(task_id=task_id, inputs={"address": "68 Glendon Way"}) + # act repo.create(st) fetched = repo.get(st.id) + # assert assert fetched.id == st.id assert fetched.task_id == task_id assert fetched.status is SubTaskStatus.WAITING @@ -36,16 +43,21 @@ def test_create_and_get_round_trip_preserves_inputs(session: Session) -> None: def test_save_persists_status_and_outputs(session: Session) -> None: + # arrange repo = SubTaskPostgresRepository(session=session) - st = SubTask.create(task_id=uuid4()) + st = SubTask.create(task_id=_persisted_task_id(session)) repo.create(st) + # act st.start(cloud_logs_url="https://example/log") repo.save(st) + # assert assert repo.get(st.id).status is SubTaskStatus.IN_PROGRESS + # act st.complete({"uprn": "123"}) repo.save(st) + # assert done = repo.get(st.id) assert done.status is SubTaskStatus.COMPLETE assert done.outputs == {"result": {"uprn": "123"}} @@ -54,16 +66,19 @@ def test_save_persists_status_and_outputs(session: Session) -> None: def test_list_by_task_filters_by_task_id(session: Session) -> None: + # arrange repo = SubTaskPostgresRepository(session=session) - task_a = uuid4() - task_b = uuid4() + task_a = _persisted_task_id(session) + task_b = _persisted_task_id(session) repo.create(SubTask.create(task_id=task_a)) repo.create(SubTask.create(task_id=task_a)) repo.create(SubTask.create(task_id=task_b)) + # act a_results = repo.list_by_task(task_a) b_results = repo.list_by_task(task_b) + # assert assert len(a_results) == 2 assert len(b_results) == 1 assert all(s.task_id == task_a for s in a_results) @@ -71,11 +86,15 @@ def test_list_by_task_filters_by_task_id(session: Session) -> None: def test_list_by_task_returns_empty_for_unknown_task(session: Session) -> None: + # arrange repo = SubTaskPostgresRepository(session=session) + # act / assert assert repo.list_by_task(uuid4()) == [] def test_get_missing_raises(session: Session) -> None: + # arrange repo = SubTaskPostgresRepository(session=session) + # act / assert with pytest.raises(ValueError, match="not found"): repo.get(uuid4()) diff --git a/tests/repositories/tasks/postgres/test_task_postgres_repository.py b/tests/repositories/tasks/postgres/test_task_postgres_repository.py index 3e1aa226..8a49a861 100644 --- a/tests/repositories/tasks/postgres/test_task_postgres_repository.py +++ b/tests/repositories/tasks/postgres/test_task_postgres_repository.py @@ -2,7 +2,8 @@ from collections.abc import Iterator from uuid import uuid4 import pytest -from sqlmodel import Session, SQLModel, create_engine +from sqlalchemy import Engine +from sqlmodel import Session from domain.tasks.tasks import Source, Task, TaskStatus from infrastructure.postgres.task_table import TaskRow @@ -10,25 +11,23 @@ from repositories.tasks.task_postgres_repository import TaskPostgresRepository @pytest.fixture -def session() -> Iterator[Session]: - engine = create_engine("sqlite://") - SQLModel.metadata.create_all(engine) - with Session(engine) as s: +def session(db_engine: Engine) -> Iterator[Session]: + with Session(db_engine) as s: yield s def test_create_and_get_round_trip(session: Session) -> None: - # Arrange + # arrange repo = TaskPostgresRepository(session=session) t = Task.create( task_source="manual:test", source=Source.PORTFOLIO, source_id="abc-123" ) - # Act + # act repo.create(t) fetched = repo.get(t.id) - # Assert + # assert assert fetched.id == t.id assert fetched.status is TaskStatus.WAITING assert fetched.source is Source.PORTFOLIO @@ -36,33 +35,43 @@ def test_create_and_get_round_trip(session: Session) -> None: def test_save_persists_status_transition(session: Session) -> None: + # arrange repo = TaskPostgresRepository(session=session) t = Task.create(task_source="manual:test") repo.create(t) + # act t.start() repo.save(t) + # assert assert repo.get(t.id).status is TaskStatus.IN_PROGRESS + # act t.complete() repo.save(t) + # assert done = repo.get(t.id) assert done.status is TaskStatus.COMPLETE assert done.job_completed is not None def test_get_missing_raises(session: Session) -> None: + # arrange repo = TaskPostgresRepository(session=session) + # act / assert with pytest.raises(ValueError, match="not found"): repo.get(uuid4()) def test_get_normalises_legacy_capitalised_status(session: Session) -> None: # Existing rows written by backend code use "In Progress" (capitalised). + # arrange repo = TaskPostgresRepository(session=session) row = TaskRow(task_source="manual:test", status="In Progress") session.add(row) session.commit() + # act fetched = repo.get(row.id) + # assert assert fetched.status is TaskStatus.IN_PROGRESS diff --git a/tests/repositories/user_address/__init__.py b/tests/repositories/user_address/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/repositories/user_address/conftest.py b/tests/repositories/user_address/conftest.py new file mode 100644 index 00000000..25c1ac3b --- /dev/null +++ b/tests/repositories/user_address/conftest.py @@ -0,0 +1,28 @@ +import os +from collections.abc import Iterator +from typing import Optional + +import pytest + + +@pytest.fixture(autouse=True) +def _aws_creds() -> Iterator[None]: # pyright: ignore[reportUnusedFunction] + keys = ( + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", + "AWS_DEFAULT_REGION", + ) + prev: dict[str, Optional[str]] = {k: os.environ.get(k) for k in keys} + os.environ["AWS_ACCESS_KEY_ID"] = "testing" + os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" + os.environ["AWS_SESSION_TOKEN"] = "testing" + os.environ["AWS_DEFAULT_REGION"] = "us-east-1" + try: + yield + finally: + for k, v in prev.items(): + if v is None: + os.environ.pop(k, None) + else: + os.environ[k] = v diff --git a/tests/repositories/user_address/test_user_address_csv_s3_repository.py b/tests/repositories/user_address/test_user_address_csv_s3_repository.py new file mode 100644 index 00000000..9ffb250a --- /dev/null +++ b/tests/repositories/user_address/test_user_address_csv_s3_repository.py @@ -0,0 +1,237 @@ +from collections.abc import Iterator + +import pytest +from moto import mock_aws + +from domain.addresses.user_address import UserAddress +from domain.postcode import Postcode +from infrastructure.csv_s3_client import CsvS3Client +from repositories.user_address.user_address_csv_s3_repository import ( + UserAddressCsvS3Repository, +) +from tests.infrastructure import make_boto_client + +BUCKET = "user-address-bucket" + + +@pytest.fixture +def repo() -> Iterator[UserAddressCsvS3Repository]: + with mock_aws(): + boto_client = make_boto_client("s3") + boto_client.create_bucket(Bucket=BUCKET) + csv_client = CsvS3Client(boto_client, BUCKET) + yield UserAddressCsvS3Repository(csv_client, BUCKET) + + +def _upload_csv( + repo: UserAddressCsvS3Repository, rows: list[dict[str, str]], key: str +) -> str: + return repo._csv_client.save_rows(rows, key) # pyright: ignore[reportPrivateUsage] + + +def test_load_batch_parses_address_postcode_and_reference( + repo: UserAddressCsvS3Repository, +) -> None: + # arrange + rows = [ + { + "Address 1": "1 High Street", + "Address 2": "Flat 2", + "Address 3": "Townville", + "postcode": "sw1a 1aa", + "Internal Reference": "REF-001", + } + ] + uri = _upload_csv(repo, rows, "uploads/full.csv") + + # act + addresses = repo.load_batch(uri) + + # assert + assert len(addresses) == 1 + address = addresses[0] + assert address.user_address == "1 High Street, Flat 2, Townville" + assert address.postcode == Postcode("SW1A1AA") + assert address.internal_reference == "REF-001" + + +def test_load_batch_uses_only_address_1_when_others_missing( + repo: UserAddressCsvS3Repository, +) -> None: + # arrange + rows = [ + { + "Address 1": "10 Cardiff Road", + "Address 2": "", + "Address 3": "", + "postcode": "CF10 1AA", + "Internal Reference": "REF-002", + } + ] + uri = _upload_csv(repo, rows, "uploads/address1-only.csv") + + # act + addresses = repo.load_batch(uri) + + # assert + assert len(addresses) == 1 + assert addresses[0].user_address == "10 Cardiff Road" + assert addresses[0].postcode == Postcode("CF101AA") + assert addresses[0].internal_reference == "REF-002" + + +def test_load_batch_handles_missing_internal_reference( + repo: UserAddressCsvS3Repository, +) -> None: + # arrange + rows = [ + { + "Address 1": "5 Park Lane", + "Address 2": "", + "Address 3": "", + "postcode": "M1 1AA", + "Internal Reference": "", + } + ] + uri = _upload_csv(repo, rows, "uploads/no-ref.csv") + + # act + addresses = repo.load_batch(uri) + + # assert + assert len(addresses) == 1 + assert addresses[0].user_address == "5 Park Lane" + assert addresses[0].postcode == Postcode("M11AA") + assert addresses[0].internal_reference is None + + +def test_load_batch_captures_full_source_row( + repo: UserAddressCsvS3Repository, +) -> None: + # A raw EPC-export-shaped row: the splitter must preserve every column, + # not just the ones it parses into UserAddress fields. + # arrange + row = { + "Asset Reference": "511", + "Address 1": "9 Abingdon Road Padiham Lancashire BB12 7BX", + "postcode": "BB12 7BX", + "Property Type": "House: End Terrace", + "SAP Score": "69", + } + uri = _upload_csv(repo, [row], "uploads/epc.csv") + + # act + addresses = repo.load_batch(uri) + + # assert + assert addresses[0].source_row == row + + +def test_load_batch_raises_when_postcode_column_absent( + repo: UserAddressCsvS3Repository, +) -> None: + # arrange + rows = [{"Address 1": "1 High Street", "Property Type": "Flat"}] + uri = _upload_csv(repo, rows, "uploads/no-postcode.csv") + + # act / assert + with pytest.raises(ValueError, match="no 'postcode' column"): + repo.load_batch(uri) + + +def test_save_batch_passes_through_all_columns_and_appends_postcode_clean( + repo: UserAddressCsvS3Repository, +) -> None: + # arrange + row = { + "Asset Reference": "511", + "Address 1": "9 Abingdon Road Padiham Lancashire BB12 7BX", + "postcode": " BB12 7BX", + "Property Type": "House: End Terrace", + } + uri = _upload_csv(repo, [row], "uploads/epc.csv") + addresses = repo.load_batch(uri) + + # act + saved_uri = repo.save_batch(addresses, "tasks/passthrough") + saved_rows = repo._csv_client.read_rows(saved_uri) # pyright: ignore[reportPrivateUsage] + + # assert + assert len(saved_rows) == 1 + saved = saved_rows[0] + # Every original column survives, byte-for-byte. + for column, value in row.items(): + assert saved[column] == value + # Plus the one appended column the downstream address2uprn stage groups on. + assert saved["postcode_clean"] == "BB127BX" + + +def test_save_batch_returns_uri_under_path_prefix( + repo: UserAddressCsvS3Repository, +) -> None: + # arrange + addresses = [ + UserAddress( + user_address="1 High Street", + postcode=Postcode("SW1A 1AA"), + source_row={"Address 1": "1 High Street", "postcode": "SW1A 1AA"}, + ), + ] + + # act + uri = repo.save_batch(addresses, "tasks/abc/batches") + + # assert + assert uri.startswith(f"s3://{BUCKET}/tasks/abc/batches/") + assert uri.endswith(".csv") + + +def test_save_then_reload_round_trip_preserves_columns( + repo: UserAddressCsvS3Repository, +) -> None: + # arrange + rows = [ + { + "Address 1": "1 High Street", + "postcode": "SW1A 1AA", + "Internal Reference": "REF-001", + }, + { + "Address 1": "2 Low Street", + "postcode": "XY9 8ZW", + "Internal Reference": "", + }, + ] + uri = _upload_csv(repo, rows, "uploads/round-trip.csv") + addresses = repo.load_batch(uri) + + # act + saved_uri = repo.save_batch(addresses, "tasks/round-trip") + saved_rows = repo._csv_client.read_rows(saved_uri) # pyright: ignore[reportPrivateUsage] + + # assert + # Original columns come back verbatim; postcode_clean is the only addition. + assert [ + {k: v for k, v in r.items() if k != "postcode_clean"} for r in saved_rows + ] == rows + assert [r["postcode_clean"] for r in saved_rows] == ["SW1A1AA", "XY98ZW"] + + +def test_save_batch_uses_unique_filename_per_call( + repo: UserAddressCsvS3Repository, +) -> None: + # arrange + addresses = [ + UserAddress( + user_address="1 High Street", + postcode=Postcode("SW1A 1AA"), + source_row={"Address 1": "1 High Street", "postcode": "SW1A 1AA"}, + ), + ] + + # act + uri_1 = repo.save_batch(addresses, "tasks/uniqueness") + uri_2 = repo.save_batch(addresses, "tasks/uniqueness") + + # assert + assert uri_1 != uri_2 diff --git a/tests/utilities/__init__.py b/tests/utilities/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/utilities/aws_lambda/__init__.py b/tests/utilities/aws_lambda/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/utilities/aws_lambda/test_subtask_handler.py b/tests/utilities/aws_lambda/test_subtask_handler.py new file mode 100644 index 00000000..d671adc4 --- /dev/null +++ b/tests/utilities/aws_lambda/test_subtask_handler.py @@ -0,0 +1,255 @@ +import logging +from collections.abc import Generator, Iterator +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any +from uuid import UUID + +import pytest +from sqlalchemy import Engine +from sqlmodel import Session + +from domain.tasks.subtasks import SubTaskStatus +from domain.tasks.tasks import TaskStatus +from orchestration.task_orchestrator import TaskOrchestrator +from repositories.tasks.subtask_postgres_repository import SubTaskPostgresRepository +from repositories.tasks.task_postgres_repository import TaskPostgresRepository +from utilities.aws_lambda.subtask_handler import subtask_handler + +_LOGGER_NAME = "utilities.aws_lambda.subtask_handler" + + +@dataclass +class Harness: + orchestrator: TaskOrchestrator + tasks: TaskPostgresRepository + subtasks: SubTaskPostgresRepository + + @contextmanager + def factory(self) -> Generator[TaskOrchestrator, None, None]: + yield self.orchestrator + + +@pytest.fixture +def harness(db_engine: Engine) -> Iterator[Harness]: + with Session(db_engine) as session: + tasks = TaskPostgresRepository(session=session) + subtasks = SubTaskPostgresRepository(session=session) + yield Harness( + orchestrator=TaskOrchestrator(task_repo=tasks, subtask_repo=subtasks), + tasks=tasks, + subtasks=subtasks, + ) + + +def _direct_event(task_id: UUID, subtask_id: UUID) -> dict[str, Any]: + return {"task_id": str(task_id), "sub_task_id": str(subtask_id)} + + +def test_subtask_handler_injects_orchestrator_as_third_positional_argument( + harness: Harness, +) -> None: + # arrange + _, subtask = harness.orchestrator.create_task_with_subtask( + task_source="manual:test" + ) + + received: dict[str, Any] = {} + + @subtask_handler(orchestrator_cm=harness.factory) + def handler( + body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator + ) -> None: + received["body"] = body + received["context"] = context + received["orchestrator"] = orchestrator + + # act + handler(_direct_event(subtask.task_id, subtask.id), context="ctx-sentinel") + + # assert + assert received["orchestrator"] is harness.orchestrator + assert received["context"] == "ctx-sentinel" + assert received["body"]["sub_task_id"] == str(subtask.id) + + +def test_subtask_handler_completes_parent_subtask_on_success( + harness: Harness, +) -> None: + # arrange + task, subtask = harness.orchestrator.create_task_with_subtask( + task_source="manual:test" + ) + + @subtask_handler(orchestrator_cm=harness.factory) + def handler( + body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator + ) -> None: + return None + + # act + handler(_direct_event(task.id, subtask.id), context=None) + + # assert + assert harness.subtasks.get(subtask.id).status is SubTaskStatus.COMPLETE + assert harness.tasks.get(task.id).status is TaskStatus.COMPLETE + + +def test_subtask_handler_marks_parent_failed_and_reraises_on_error( + harness: Harness, +) -> None: + # arrange + task, subtask = harness.orchestrator.create_task_with_subtask( + task_source="manual:test" + ) + + @subtask_handler(orchestrator_cm=harness.factory) + def handler( + body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator + ) -> None: + raise RuntimeError("boom") + + # act / assert + with pytest.raises(RuntimeError, match="boom"): + handler(_direct_event(task.id, subtask.id), context=None) + + assert harness.subtasks.get(subtask.id).status is SubTaskStatus.FAILED + assert harness.tasks.get(task.id).status is TaskStatus.FAILED + + +def test_subtask_handler_injected_orchestrator_can_create_child_subtask( + harness: Harness, +) -> None: + # arrange + task, subtask = harness.orchestrator.create_task_with_subtask( + task_source="manual:test" + ) + + child_ids: list[UUID] = [] + + @subtask_handler(orchestrator_cm=harness.factory) + def handler( + body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator + ) -> None: + child = orchestrator.create_child_subtask(task.id, inputs={"split": 1}) + child_ids.append(child.id) + + # act + handler(_direct_event(task.id, subtask.id), context=None) + + # assert + assert len(child_ids) == 1 + persisted_child = harness.subtasks.get(child_ids[0]) + assert persisted_child.task_id == task.id + assert persisted_child.status is SubTaskStatus.WAITING + + +def test_subtask_handler_logs_subtask_lifecycle_on_success( + harness: Harness, caplog: pytest.LogCaptureFixture +) -> None: + # arrange + task, subtask = harness.orchestrator.create_task_with_subtask( + task_source="manual:test" + ) + + @subtask_handler(orchestrator_cm=harness.factory) + def handler( + body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator + ) -> None: + return None + + # act + with caplog.at_level(logging.INFO, logger=_LOGGER_NAME): + handler(_direct_event(task.id, subtask.id), context=None) + + # assert + assert f"Running subtask {subtask.id}" in caplog.text + assert f"Subtask {subtask.id} completed" in caplog.text + + +def test_subtask_handler_logs_exception_on_failure( + harness: Harness, caplog: pytest.LogCaptureFixture +) -> None: + # arrange + task, subtask = harness.orchestrator.create_task_with_subtask( + task_source="manual:test" + ) + + @subtask_handler(orchestrator_cm=harness.factory) + def handler( + body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator + ) -> None: + raise RuntimeError("boom") + + # act / assert + with caplog.at_level(logging.INFO, logger=_LOGGER_NAME): + with pytest.raises(RuntimeError, match="boom"): + handler(_direct_event(task.id, subtask.id), context=None) + + failures = [r for r in caplog.records if r.levelno == logging.ERROR] + assert any( + f"Subtask {subtask.id} failed" in r.getMessage() for r in failures + ) + assert any(r.exc_info is not None for r in failures) + + +def test_subtask_handler_records_cloudwatch_url_on_subtask( + harness: Harness, monkeypatch: pytest.MonkeyPatch +) -> None: + # arrange + monkeypatch.setenv("AWS_REGION", "eu-west-2") + monkeypatch.setenv( + "AWS_LAMBDA_LOG_GROUP_NAME", "/aws/lambda/postcode-splitter" + ) + monkeypatch.setenv( + "AWS_LAMBDA_LOG_STREAM_NAME", "2026/05/20/[$LATEST]abc123" + ) + task, subtask = harness.orchestrator.create_task_with_subtask( + task_source="manual:test" + ) + + @subtask_handler(orchestrator_cm=harness.factory) + def handler( + body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator + ) -> None: + return None + + # act + handler(_direct_event(task.id, subtask.id), context=None) + + # assert + saved_url = harness.subtasks.get(subtask.id).cloud_logs_url + assert saved_url is not None + assert saved_url.startswith( + "https://eu-west-2.console.aws.amazon.com/cloudwatch/home" + ) + # Log group / stream are console-encoded ("/" -> "$252F"). + assert "$252Faws$252Flambda$252Fpostcode-splitter" in saved_url + assert "$255B$2524LATEST$255D" in saved_url + + +def test_subtask_handler_leaves_cloudwatch_url_unset_outside_lambda( + harness: Harness, monkeypatch: pytest.MonkeyPatch +) -> None: + # arrange + for var in ( + "AWS_REGION", + "AWS_LAMBDA_LOG_GROUP_NAME", + "AWS_LAMBDA_LOG_STREAM_NAME", + ): + monkeypatch.delenv(var, raising=False) + task, subtask = harness.orchestrator.create_task_with_subtask( + task_source="manual:test" + ) + + @subtask_handler(orchestrator_cm=harness.factory) + def handler( + body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator + ) -> None: + return None + + # act + handler(_direct_event(task.id, subtask.id), context=None) + + # assert + assert harness.subtasks.get(subtask.id).cloud_logs_url is None diff --git a/utilities/aws_lambda/subtask_handler.py b/utilities/aws_lambda/subtask_handler.py index 64c1daa6..592ffebf 100644 --- a/utilities/aws_lambda/subtask_handler.py +++ b/utilities/aws_lambda/subtask_handler.py @@ -5,14 +5,20 @@ TaskOrchestrator.run_subtask(...) calls. """ import json +import logging +import os from contextlib import AbstractContextManager from functools import wraps from typing import Any, Callable, Optional, cast +from urllib.parse import quote from utilities.aws_lambda.default_orchestrator import default_orchestrator from utilities.aws_lambda.subtask_trigger_body import SubtaskTriggerBody from orchestration.task_orchestrator import TaskOrchestrator +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + OrchestratorCM = Callable[[], AbstractContextManager[TaskOrchestrator]] @@ -33,14 +39,26 @@ def subtask_handler( def decorator(func: Callable[..., Any]) -> Callable[..., Any]: @wraps(func) def wrapper(event: dict[str, Any], context: Any) -> None: + cloud_logs_url = _cloudwatch_url() with factory() as orchestrator: for record in _records(event): body = _parse_body(record) trigger = SubtaskTriggerBody.model_validate(body) - orchestrator.run_subtask( - trigger.sub_task_id, - work=lambda body=body: func(body, context), - ) + logger.info("Running subtask %s", trigger.sub_task_id) + try: + orchestrator.run_subtask( + trigger.sub_task_id, + work=lambda body=body, o=orchestrator: func( + body, context, o + ), + cloud_logs_url=cloud_logs_url, + ) + except Exception: + logger.exception( + "Subtask %s failed", trigger.sub_task_id + ) + raise + logger.info("Subtask %s completed", trigger.sub_task_id) return wrapper @@ -65,3 +83,20 @@ def _records(event: dict[str, Any]) -> list[dict[str, Any]]: if isinstance(raw_records, list): return [r for r in cast(list[Any], raw_records) if isinstance(r, dict)] return [event] + + +def _console_encode(value: str) -> str: + return quote(value, safe="").replace("%", "$25") + + +def _cloudwatch_url() -> Optional[str]: + region = os.environ.get("AWS_REGION") + log_group = os.environ.get("AWS_LAMBDA_LOG_GROUP_NAME") + log_stream = os.environ.get("AWS_LAMBDA_LOG_STREAM_NAME") + if not (region and log_group and log_stream): + return None + return ( + f"https://{region}.console.aws.amazon.com/cloudwatch/home" + f"?region={region}#logsV2:log-groups/log-group/" + f"{_console_encode(log_group)}/log-events/{_console_encode(log_stream)}" + )