mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
299 lines
9.5 KiB
Python
299 lines
9.5 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
import os
|
||
from collections.abc import Iterator
|
||
from dataclasses import dataclass
|
||
from typing import Any, cast
|
||
|
||
import boto3
|
||
import pytest
|
||
from moto import mock_aws
|
||
from sqlalchemy import Engine
|
||
from sqlmodel import Session
|
||
|
||
from infrastructure.address2uprn_queue_client import Address2UprnQueueClient
|
||
from infrastructure.csv_s3_client import CsvS3Client
|
||
from orchestration.postcode_splitter_orchestrator import PostcodeSplitterOrchestrator
|
||
from orchestration.task_orchestrator import TaskOrchestrator
|
||
from repositories.tasks.subtask_postgres_repository import SubTaskPostgresRepository
|
||
from repositories.tasks.task_postgres_repository import TaskPostgresRepository
|
||
from repositories.user_address.user_address_csv_s3_repository import (
|
||
UserAddressCsvS3Repository,
|
||
)
|
||
|
||
BUCKET = "splitter-bucket"
|
||
REGION = "us-east-1"
|
||
|
||
|
||
def _make_boto_client(service_name: str) -> Any:
|
||
factory: Any = boto3.client # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
||
return factory(service_name, region_name=REGION)
|
||
|
||
|
||
@pytest.fixture(autouse=True)
|
||
def _aws_creds() -> Iterator[None]: # pyright: ignore[reportUnusedFunction]
|
||
keys = (
|
||
"AWS_ACCESS_KEY_ID",
|
||
"AWS_SECRET_ACCESS_KEY",
|
||
"AWS_SESSION_TOKEN",
|
||
"AWS_DEFAULT_REGION",
|
||
)
|
||
prev: dict[str, Any] = {k: os.environ.get(k) for k in keys}
|
||
os.environ["AWS_ACCESS_KEY_ID"] = "testing"
|
||
os.environ["AWS_SECRET_ACCESS_KEY"] = "testing"
|
||
os.environ["AWS_SESSION_TOKEN"] = "testing"
|
||
os.environ["AWS_DEFAULT_REGION"] = REGION
|
||
try:
|
||
yield
|
||
finally:
|
||
for k, v in prev.items():
|
||
if v is None:
|
||
os.environ.pop(k, None)
|
||
else:
|
||
os.environ[k] = v
|
||
|
||
|
||
@dataclass
|
||
class Harness:
|
||
splitter: PostcodeSplitterOrchestrator
|
||
task_orchestrator: TaskOrchestrator
|
||
subtasks: SubTaskPostgresRepository
|
||
csv_client: CsvS3Client
|
||
boto_sqs: Any
|
||
queue_url: str
|
||
repo: UserAddressCsvS3Repository
|
||
|
||
|
||
@pytest.fixture
|
||
def harness(db_engine: Engine) -> Iterator[Harness]:
|
||
with mock_aws():
|
||
# Infra: S3 + SQS
|
||
boto_s3 = _make_boto_client("s3")
|
||
boto_s3.create_bucket(Bucket=BUCKET)
|
||
boto_sqs = _make_boto_client("sqs")
|
||
queue: dict[str, Any] = boto_sqs.create_queue(QueueName="address2uprn-queue")
|
||
queue_url = cast(str, queue["QueueUrl"])
|
||
|
||
csv_client = CsvS3Client(boto_s3, BUCKET)
|
||
repo = UserAddressCsvS3Repository(csv_client, BUCKET)
|
||
queue_client = Address2UprnQueueClient(boto_sqs, queue_url)
|
||
|
||
# DB: ephemeral PostgreSQL TaskOrchestrator
|
||
with Session(db_engine) as session:
|
||
task_repo = TaskPostgresRepository(session=session)
|
||
subtask_repo = SubTaskPostgresRepository(session=session)
|
||
task_orchestrator = TaskOrchestrator(
|
||
task_repo=task_repo, subtask_repo=subtask_repo
|
||
)
|
||
|
||
splitter = PostcodeSplitterOrchestrator(
|
||
task_orchestrator=task_orchestrator,
|
||
user_address_repo=repo,
|
||
queue_client=queue_client,
|
||
max_batch_size=3,
|
||
)
|
||
|
||
yield Harness(
|
||
splitter=splitter,
|
||
task_orchestrator=task_orchestrator,
|
||
subtasks=subtask_repo,
|
||
csv_client=csv_client,
|
||
boto_sqs=boto_sqs,
|
||
queue_url=queue_url,
|
||
repo=repo,
|
||
)
|
||
|
||
|
||
def _upload_fixture_csv(csv_client: CsvS3Client) -> str:
|
||
# Three postcode groups:
|
||
# AA1 1AA × 2 (within cap)
|
||
# BB2 2BB × 4 (oversize: > max_batch_size=3)
|
||
# CC3 3CC × 1 (final flush)
|
||
# Expected batching with cap=3 and the algorithm in
|
||
# ``iter_postcode_grouped_batches``:
|
||
# batch 1: [AA1 1AA × 2] (flushed because oversize follows)
|
||
# batch 2: [BB2 2BB × 4] (oversize own batch)
|
||
# batch 3: [CC3 3CC × 1] (final flush)
|
||
rows: list[dict[str, str]] = []
|
||
rows.extend(
|
||
{
|
||
"Address 1": f"{i} High St",
|
||
"Address 2": "",
|
||
"Address 3": "",
|
||
"postcode": "AA1 1AA",
|
||
"Internal Reference": f"AA-{i}",
|
||
}
|
||
for i in range(1, 3)
|
||
)
|
||
rows.extend(
|
||
{
|
||
"Address 1": f"{i} Long Road",
|
||
"Address 2": "",
|
||
"Address 3": "",
|
||
"postcode": "BB2 2BB",
|
||
"Internal Reference": f"BB-{i}",
|
||
}
|
||
for i in range(1, 5)
|
||
)
|
||
rows.append(
|
||
{
|
||
"Address 1": "1 Final Way",
|
||
"Address 2": "",
|
||
"Address 3": "",
|
||
"postcode": "CC3 3CC",
|
||
"Internal Reference": "CC-1",
|
||
}
|
||
)
|
||
return csv_client.save_rows(rows, "uploads/input.csv")
|
||
|
||
|
||
def _drain_queue(boto_sqs: Any, queue_url: str) -> list[dict[str, Any]]:
|
||
bodies: list[dict[str, Any]] = []
|
||
while True:
|
||
received: dict[str, Any] = boto_sqs.receive_message(
|
||
QueueUrl=queue_url, MaxNumberOfMessages=10, WaitTimeSeconds=0
|
||
)
|
||
messages = cast(list[dict[str, Any]], received.get("Messages", []))
|
||
if not messages:
|
||
break
|
||
for message in messages:
|
||
bodies.append(cast(dict[str, Any], json.loads(message["Body"])))
|
||
boto_sqs.delete_message(
|
||
QueueUrl=queue_url, ReceiptHandle=message["ReceiptHandle"]
|
||
)
|
||
return bodies
|
||
|
||
|
||
def test_split_and_dispatch_creates_three_children_for_fixture(
|
||
harness: Harness,
|
||
) -> None:
|
||
# arrange
|
||
parent_task, parent_subtask = (
|
||
harness.task_orchestrator.create_task_with_subtask(
|
||
task_source="manual:postcode-splitter-int"
|
||
)
|
||
)
|
||
input_uri = _upload_fixture_csv(harness.csv_client)
|
||
|
||
# act
|
||
child_ids = harness.splitter.split_and_dispatch(
|
||
parent_task_id=parent_task.id,
|
||
parent_subtask_id=parent_subtask.id,
|
||
input_s3_uri=input_uri,
|
||
)
|
||
|
||
# assert
|
||
assert len(child_ids) == 3
|
||
# All child ids are unique and persisted as WAITING children of the
|
||
# parent task.
|
||
assert len(set(child_ids)) == 3
|
||
for cid in child_ids:
|
||
child = harness.subtasks.get(cid)
|
||
assert child.task_id == parent_task.id
|
||
|
||
|
||
def test_split_and_dispatch_persists_child_inputs_with_task_id_and_s3_uri(
|
||
harness: Harness,
|
||
) -> None:
|
||
# arrange
|
||
parent_task, parent_subtask = (
|
||
harness.task_orchestrator.create_task_with_subtask(
|
||
task_source="manual:postcode-splitter-int"
|
||
)
|
||
)
|
||
input_uri = _upload_fixture_csv(harness.csv_client)
|
||
|
||
# act
|
||
child_ids = harness.splitter.split_and_dispatch(
|
||
parent_task_id=parent_task.id,
|
||
parent_subtask_id=parent_subtask.id,
|
||
input_s3_uri=input_uri,
|
||
)
|
||
|
||
# assert
|
||
for cid in child_ids:
|
||
child = harness.subtasks.get(cid)
|
||
assert child.inputs is not None
|
||
assert child.inputs["task_id"] == str(parent_task.id)
|
||
batch_uri = child.inputs["s3_uri"]
|
||
assert isinstance(batch_uri, str)
|
||
prefix = (
|
||
f"s3://{BUCKET}/ara_postcode_splitter_batches/"
|
||
f"{parent_task.id}/{parent_subtask.id}/"
|
||
)
|
||
assert batch_uri.startswith(prefix)
|
||
assert batch_uri.endswith(".csv")
|
||
|
||
|
||
def test_split_and_dispatch_publishes_one_message_per_child_with_matching_ids(
|
||
harness: Harness,
|
||
) -> None:
|
||
# arrange
|
||
parent_task, parent_subtask = (
|
||
harness.task_orchestrator.create_task_with_subtask(
|
||
task_source="manual:postcode-splitter-int"
|
||
)
|
||
)
|
||
input_uri = _upload_fixture_csv(harness.csv_client)
|
||
|
||
# act
|
||
child_ids = harness.splitter.split_and_dispatch(
|
||
parent_task_id=parent_task.id,
|
||
parent_subtask_id=parent_subtask.id,
|
||
input_s3_uri=input_uri,
|
||
)
|
||
|
||
# assert
|
||
bodies = _drain_queue(harness.boto_sqs, harness.queue_url)
|
||
assert len(bodies) == len(child_ids)
|
||
|
||
# Match queue messages against persisted child inputs by child_subtask_id;
|
||
# the message body's task_id/s3_uri must agree with the SubTask inputs.
|
||
bodies_by_child = {body["sub_task_id"]: body for body in bodies}
|
||
assert set(bodies_by_child.keys()) == {str(cid) for cid in child_ids}
|
||
for cid in child_ids:
|
||
child = harness.subtasks.get(cid)
|
||
body = bodies_by_child[str(cid)]
|
||
assert child.inputs is not None
|
||
assert body == {
|
||
"task_id": str(parent_task.id),
|
||
"sub_task_id": str(cid),
|
||
"s3_uri": child.inputs["s3_uri"],
|
||
}
|
||
|
||
|
||
def test_split_and_dispatch_returns_child_ids_in_dispatch_order(
|
||
harness: Harness,
|
||
) -> None:
|
||
# arrange
|
||
parent_task, parent_subtask = (
|
||
harness.task_orchestrator.create_task_with_subtask(
|
||
task_source="manual:postcode-splitter-int"
|
||
)
|
||
)
|
||
input_uri = _upload_fixture_csv(harness.csv_client)
|
||
|
||
# act
|
||
child_ids = harness.splitter.split_and_dispatch(
|
||
parent_task_id=parent_task.id,
|
||
parent_subtask_id=parent_subtask.id,
|
||
input_s3_uri=input_uri,
|
||
)
|
||
|
||
# assert
|
||
# Re-load each child's saved batch and inspect the postcode_clean column
|
||
# to confirm the dispatch order matches the postcode-batching algorithm:
|
||
# AA-batch first, BB oversize batch second, CC final-flush third.
|
||
postcodes_per_batch: list[set[str]] = []
|
||
for cid in child_ids:
|
||
child = harness.subtasks.get(cid)
|
||
assert child.inputs is not None
|
||
rows = harness.csv_client.read_rows(child.inputs["s3_uri"])
|
||
postcodes_per_batch.append({row["postcode_clean"] for row in rows})
|
||
|
||
assert postcodes_per_batch == [
|
||
{"AA11AA"},
|
||
{"BB22BB"},
|
||
{"CC33CC"},
|
||
]
|