Model/tests/orchestration/test_postcode_splitter_orchestrator.py
2026-05-20 14:00:19 +00:00

299 lines
9.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import json
import os
from collections.abc import Iterator
from dataclasses import dataclass
from typing import Any, cast
import boto3
import pytest
from moto import mock_aws
from sqlalchemy import Engine
from sqlmodel import Session
from infrastructure.address2uprn_queue_client import Address2UprnQueueClient
from infrastructure.csv_s3_client import CsvS3Client
from orchestration.postcode_splitter_orchestrator import PostcodeSplitterOrchestrator
from orchestration.task_orchestrator import TaskOrchestrator
from repositories.tasks.subtask_postgres_repository import SubTaskPostgresRepository
from repositories.tasks.task_postgres_repository import TaskPostgresRepository
from repositories.user_address.user_address_csv_s3_repository import (
UserAddressCsvS3Repository,
)
BUCKET = "splitter-bucket"
REGION = "us-east-1"
def _make_boto_client(service_name: str) -> Any:
factory: Any = boto3.client # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
return factory(service_name, region_name=REGION)
@pytest.fixture(autouse=True)
def _aws_creds() -> Iterator[None]: # pyright: ignore[reportUnusedFunction]
keys = (
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_SESSION_TOKEN",
"AWS_DEFAULT_REGION",
)
prev: dict[str, Any] = {k: os.environ.get(k) for k in keys}
os.environ["AWS_ACCESS_KEY_ID"] = "testing"
os.environ["AWS_SECRET_ACCESS_KEY"] = "testing"
os.environ["AWS_SESSION_TOKEN"] = "testing"
os.environ["AWS_DEFAULT_REGION"] = REGION
try:
yield
finally:
for k, v in prev.items():
if v is None:
os.environ.pop(k, None)
else:
os.environ[k] = v
@dataclass
class Harness:
splitter: PostcodeSplitterOrchestrator
task_orchestrator: TaskOrchestrator
subtasks: SubTaskPostgresRepository
csv_client: CsvS3Client
boto_sqs: Any
queue_url: str
repo: UserAddressCsvS3Repository
@pytest.fixture
def harness(db_engine: Engine) -> Iterator[Harness]:
with mock_aws():
# Infra: S3 + SQS
boto_s3 = _make_boto_client("s3")
boto_s3.create_bucket(Bucket=BUCKET)
boto_sqs = _make_boto_client("sqs")
queue: dict[str, Any] = boto_sqs.create_queue(QueueName="address2uprn-queue")
queue_url = cast(str, queue["QueueUrl"])
csv_client = CsvS3Client(boto_s3, BUCKET)
repo = UserAddressCsvS3Repository(csv_client, BUCKET)
queue_client = Address2UprnQueueClient(boto_sqs, queue_url)
# DB: ephemeral PostgreSQL TaskOrchestrator
with Session(db_engine) as session:
task_repo = TaskPostgresRepository(session=session)
subtask_repo = SubTaskPostgresRepository(session=session)
task_orchestrator = TaskOrchestrator(
task_repo=task_repo, subtask_repo=subtask_repo
)
splitter = PostcodeSplitterOrchestrator(
task_orchestrator=task_orchestrator,
user_address_repo=repo,
queue_client=queue_client,
max_batch_size=3,
)
yield Harness(
splitter=splitter,
task_orchestrator=task_orchestrator,
subtasks=subtask_repo,
csv_client=csv_client,
boto_sqs=boto_sqs,
queue_url=queue_url,
repo=repo,
)
def _upload_fixture_csv(csv_client: CsvS3Client) -> str:
# Three postcode groups:
# AA1 1AA × 2 (within cap)
# BB2 2BB × 4 (oversize: > max_batch_size=3)
# CC3 3CC × 1 (final flush)
# Expected batching with cap=3 and the algorithm in
# ``iter_postcode_grouped_batches``:
# batch 1: [AA1 1AA × 2] (flushed because oversize follows)
# batch 2: [BB2 2BB × 4] (oversize own batch)
# batch 3: [CC3 3CC × 1] (final flush)
rows: list[dict[str, str]] = []
rows.extend(
{
"Address 1": f"{i} High St",
"Address 2": "",
"Address 3": "",
"postcode": "AA1 1AA",
"Internal Reference": f"AA-{i}",
}
for i in range(1, 3)
)
rows.extend(
{
"Address 1": f"{i} Long Road",
"Address 2": "",
"Address 3": "",
"postcode": "BB2 2BB",
"Internal Reference": f"BB-{i}",
}
for i in range(1, 5)
)
rows.append(
{
"Address 1": "1 Final Way",
"Address 2": "",
"Address 3": "",
"postcode": "CC3 3CC",
"Internal Reference": "CC-1",
}
)
return csv_client.save_rows(rows, "uploads/input.csv")
def _drain_queue(boto_sqs: Any, queue_url: str) -> list[dict[str, Any]]:
bodies: list[dict[str, Any]] = []
while True:
received: dict[str, Any] = boto_sqs.receive_message(
QueueUrl=queue_url, MaxNumberOfMessages=10, WaitTimeSeconds=0
)
messages = cast(list[dict[str, Any]], received.get("Messages", []))
if not messages:
break
for message in messages:
bodies.append(cast(dict[str, Any], json.loads(message["Body"])))
boto_sqs.delete_message(
QueueUrl=queue_url, ReceiptHandle=message["ReceiptHandle"]
)
return bodies
def test_split_and_dispatch_creates_three_children_for_fixture(
harness: Harness,
) -> None:
# arrange
parent_task, parent_subtask = (
harness.task_orchestrator.create_task_with_subtask(
task_source="manual:postcode-splitter-int"
)
)
input_uri = _upload_fixture_csv(harness.csv_client)
# act
child_ids = harness.splitter.split_and_dispatch(
parent_task_id=parent_task.id,
parent_subtask_id=parent_subtask.id,
input_s3_uri=input_uri,
)
# assert
assert len(child_ids) == 3
# All child ids are unique and persisted as WAITING children of the
# parent task.
assert len(set(child_ids)) == 3
for cid in child_ids:
child = harness.subtasks.get(cid)
assert child.task_id == parent_task.id
def test_split_and_dispatch_persists_child_inputs_with_task_id_and_s3_uri(
harness: Harness,
) -> None:
# arrange
parent_task, parent_subtask = (
harness.task_orchestrator.create_task_with_subtask(
task_source="manual:postcode-splitter-int"
)
)
input_uri = _upload_fixture_csv(harness.csv_client)
# act
child_ids = harness.splitter.split_and_dispatch(
parent_task_id=parent_task.id,
parent_subtask_id=parent_subtask.id,
input_s3_uri=input_uri,
)
# assert
for cid in child_ids:
child = harness.subtasks.get(cid)
assert child.inputs is not None
assert child.inputs["task_id"] == str(parent_task.id)
batch_uri = child.inputs["s3_uri"]
assert isinstance(batch_uri, str)
prefix = (
f"s3://{BUCKET}/ara_postcode_splitter_batches/"
f"{parent_task.id}/{parent_subtask.id}/"
)
assert batch_uri.startswith(prefix)
assert batch_uri.endswith(".csv")
def test_split_and_dispatch_publishes_one_message_per_child_with_matching_ids(
harness: Harness,
) -> None:
# arrange
parent_task, parent_subtask = (
harness.task_orchestrator.create_task_with_subtask(
task_source="manual:postcode-splitter-int"
)
)
input_uri = _upload_fixture_csv(harness.csv_client)
# act
child_ids = harness.splitter.split_and_dispatch(
parent_task_id=parent_task.id,
parent_subtask_id=parent_subtask.id,
input_s3_uri=input_uri,
)
# assert
bodies = _drain_queue(harness.boto_sqs, harness.queue_url)
assert len(bodies) == len(child_ids)
# Match queue messages against persisted child inputs by child_subtask_id;
# the message body's task_id/s3_uri must agree with the SubTask inputs.
bodies_by_child = {body["sub_task_id"]: body for body in bodies}
assert set(bodies_by_child.keys()) == {str(cid) for cid in child_ids}
for cid in child_ids:
child = harness.subtasks.get(cid)
body = bodies_by_child[str(cid)]
assert child.inputs is not None
assert body == {
"task_id": str(parent_task.id),
"sub_task_id": str(cid),
"s3_uri": child.inputs["s3_uri"],
}
def test_split_and_dispatch_returns_child_ids_in_dispatch_order(
harness: Harness,
) -> None:
# arrange
parent_task, parent_subtask = (
harness.task_orchestrator.create_task_with_subtask(
task_source="manual:postcode-splitter-int"
)
)
input_uri = _upload_fixture_csv(harness.csv_client)
# act
child_ids = harness.splitter.split_and_dispatch(
parent_task_id=parent_task.id,
parent_subtask_id=parent_subtask.id,
input_s3_uri=input_uri,
)
# assert
# Re-load each child's saved batch and inspect the postcode_clean column
# to confirm the dispatch order matches the postcode-batching algorithm:
# AA-batch first, BB oversize batch second, CC final-flush third.
postcodes_per_batch: list[set[str]] = []
for cid in child_ids:
child = harness.subtasks.get(cid)
assert child.inputs is not None
rows = harness.csv_client.read_rows(child.inputs["s3_uri"])
postcodes_per_batch.append({row["postcode_clean"] for row in rows})
assert postcodes_per_batch == [
{"AA11AA"},
{"BB22BB"},
{"CC33CC"},
]