mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
tests framework completed
This commit is contained in:
parent
d0cf3d14ad
commit
dc159e0b45
19 changed files with 336 additions and 44 deletions
|
|
@ -53,6 +53,9 @@ class UserAddressCsvS3Repository(UserAddressRepository):
|
|||
{**addr.source_row, _POSTCODE_CLEAN_COLUMN: str(addr.postcode)}
|
||||
for addr in addresses
|
||||
]
|
||||
|
||||
# TODO: [New Starter Task] file_name generation can be standardised
|
||||
# and also easier to read, test for future implementation. Buiild that!
|
||||
filename = (
|
||||
f"{datetime.now(timezone.utc).isoformat()}_{uuid.uuid4().hex[:8]}.csv"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,9 +7,7 @@ from domain.addresses.user_address import UserAddress
|
|||
|
||||
class UserAddressRepository(ABC):
|
||||
@abstractmethod
|
||||
def load_batch(self, s3_uri: str) -> list[UserAddress]:
|
||||
...
|
||||
def load_batch(self, s3_uri: str) -> list[UserAddress]: ...
|
||||
|
||||
@abstractmethod
|
||||
def save_batch(self, addresses: list[UserAddress], path_prefix: str) -> str:
|
||||
...
|
||||
def save_batch(self, addresses: list[UserAddress], path_prefix: str) -> str: ...
|
||||
|
|
|
|||
48
tests/conftest.py
Normal file
48
tests/conftest.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
"""Shared pytest fixtures for the ``tests/`` tree.
|
||||
|
||||
Provides an ephemeral PostgreSQL engine for tests that exercise SQLModel
|
||||
repositories. PostgreSQL has no true in-memory mode; ``pytest-postgresql``
|
||||
starts a real, throwaway server in a temp directory (the process is started
|
||||
once per session and a fresh database is created/dropped per test). That is
|
||||
the closest equivalent to "in-memory" and matches production behaviour far
|
||||
better than SQLite (enums, JSONB, constraint semantics, etc.).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from psycopg import Connection
|
||||
from pytest_postgresql import factories
|
||||
from sqlalchemy import Engine
|
||||
from sqlmodel import SQLModel, create_engine
|
||||
|
||||
# Importing the SQLModel row modules registers their tables on
|
||||
# SQLModel.metadata so ``create_all`` builds the full schema. Imports look
|
||||
# unused; they aren't.
|
||||
|
||||
|
||||
# pg_ctl ships under a versioned path and is not on PATH in the dev container.
|
||||
_PG_CTL = next(iter(sorted(glob.glob("/usr/lib/postgresql/*/bin/pg_ctl"))), "pg_ctl")
|
||||
|
||||
postgresql_proc = factories.postgresql_proc(
|
||||
executable=_PG_CTL
|
||||
) # pyright: ignore[reportUnknownMemberType]
|
||||
postgresql = factories.postgresql("postgresql_proc")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_engine(postgresql: Connection[Any]) -> Iterator[Engine]:
|
||||
"""A SQLModel engine bound to a fresh, ephemeral PostgreSQL database."""
|
||||
info = postgresql.info
|
||||
url = f"postgresql+psycopg://{info.user}:@{info.host}:{info.port}/{info.dbname}"
|
||||
engine = create_engine(url)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
try:
|
||||
yield engine
|
||||
finally:
|
||||
SQLModel.metadata.drop_all(engine)
|
||||
engine.dispose()
|
||||
|
|
@ -15,12 +15,16 @@ def _addrs(postcode: str, n: int) -> list[UserAddress]:
|
|||
|
||||
|
||||
def test_empty_input_yields_no_batches() -> None:
|
||||
# act / assert
|
||||
assert list(iter_postcode_grouped_batches([])) == []
|
||||
|
||||
|
||||
def test_single_batch_under_cap() -> None:
|
||||
# arrange
|
||||
addrs = _addrs("AA1 1AA", 3) + _addrs("BB2 2BB", 2)
|
||||
# act
|
||||
batches = list(iter_postcode_grouped_batches(addrs, max_batch_size=500))
|
||||
# assert
|
||||
assert len(batches) == 1
|
||||
assert batches[0] == addrs
|
||||
|
||||
|
|
@ -28,8 +32,11 @@ def test_single_batch_under_cap() -> None:
|
|||
def test_multiple_postcodes_packed_into_one_batch_up_to_cap() -> None:
|
||||
# Two groups whose total exactly equals the cap pack into a single
|
||||
# batch -- no premature flush.
|
||||
# arrange
|
||||
addrs = _addrs("AA1 1AA", 3) + _addrs("BB2 2BB", 2)
|
||||
# act
|
||||
batches = list(iter_postcode_grouped_batches(addrs, max_batch_size=5))
|
||||
# assert
|
||||
assert len(batches) == 1
|
||||
assert len(batches[0]) == 5
|
||||
|
||||
|
|
@ -37,8 +44,11 @@ def test_multiple_postcodes_packed_into_one_batch_up_to_cap() -> None:
|
|||
def test_flush_on_overflow_before_adding_next_postcode() -> None:
|
||||
# Cap is 5. First group fills 3 slots; second group of 3 would overflow,
|
||||
# so the buffer is flushed first and the next group starts a fresh batch.
|
||||
# arrange
|
||||
addrs = _addrs("AA1 1AA", 3) + _addrs("BB2 2BB", 3)
|
||||
# act
|
||||
batches = list(iter_postcode_grouped_batches(addrs, max_batch_size=5))
|
||||
# assert
|
||||
assert len(batches) == 2
|
||||
assert [str(a.postcode) for a in batches[0]] == ["AA11AA"] * 3
|
||||
assert [str(a.postcode) for a in batches[1]] == ["BB22BB"] * 3
|
||||
|
|
@ -47,8 +57,11 @@ def test_flush_on_overflow_before_adding_next_postcode() -> None:
|
|||
def test_single_postcode_group_exceeding_cap_is_dispatched_whole() -> None:
|
||||
# An oversize single-postcode group goes out as one batch larger than
|
||||
# the cap -- the cap never splits a postcode.
|
||||
# arrange
|
||||
addrs = _addrs("AA1 1AA", 7)
|
||||
# act
|
||||
batches = list(iter_postcode_grouped_batches(addrs, max_batch_size=5))
|
||||
# assert
|
||||
assert len(batches) == 1
|
||||
assert len(batches[0]) == 7
|
||||
|
||||
|
|
@ -56,12 +69,15 @@ def test_single_postcode_group_exceeding_cap_is_dispatched_whole() -> None:
|
|||
def test_oversize_group_flushes_existing_buffer_first() -> None:
|
||||
# Mirrors the legacy ``if buffer: flush`` branch when an oversize group
|
||||
# is encountered: buffered work must not be lost or interleaved.
|
||||
# arrange
|
||||
small = _addrs("AA1 1AA", 2)
|
||||
big = _addrs("BB2 2BB", 7)
|
||||
tail = _addrs("CC3 3CC", 1)
|
||||
# act
|
||||
batches = list(
|
||||
iter_postcode_grouped_batches(small + big + tail, max_batch_size=5)
|
||||
)
|
||||
# assert
|
||||
assert len(batches) == 3
|
||||
assert [str(a.postcode) for a in batches[0]] == ["AA11AA", "AA11AA"]
|
||||
assert [str(a.postcode) for a in batches[1]] == ["BB22BB"] * 7
|
||||
|
|
@ -70,17 +86,23 @@ def test_oversize_group_flushes_existing_buffer_first() -> None:
|
|||
|
||||
def test_final_flush_yields_remaining_buffer() -> None:
|
||||
# No overflow ever happens, but the trailing buffer must still come out.
|
||||
# arrange
|
||||
addrs = _addrs("AA1 1AA", 2) + _addrs("BB2 2BB", 2)
|
||||
# act
|
||||
batches = list(iter_postcode_grouped_batches(addrs, max_batch_size=500))
|
||||
# assert
|
||||
assert batches == [addrs]
|
||||
|
||||
|
||||
def test_postcode_grouping_preserves_first_seen_order() -> None:
|
||||
# Interleaved input must still group by postcode and emit in first-seen
|
||||
# order -- never alphabetical.
|
||||
# arrange
|
||||
a1, a2 = _addrs("ZZ9 9ZZ", 2)
|
||||
b1, b2 = _addrs("AA1 1AA", 2)
|
||||
# act
|
||||
batches = list(iter_postcode_grouped_batches([a1, b1, a2, b2]))
|
||||
# assert
|
||||
assert len(batches) == 1
|
||||
assert [str(a.postcode) for a in batches[0]] == [
|
||||
"ZZ99ZZ",
|
||||
|
|
@ -91,5 +113,6 @@ def test_postcode_grouping_preserves_first_seen_order() -> None:
|
|||
|
||||
|
||||
def test_invalid_max_batch_size_raises() -> None:
|
||||
# act / assert
|
||||
with pytest.raises(ValueError, match="max_batch_size"):
|
||||
list(iter_postcode_grouped_batches([], max_batch_size=0))
|
||||
|
|
|
|||
|
|
@ -7,35 +7,45 @@ from domain.postcode import Postcode
|
|||
|
||||
|
||||
def test_user_address_holds_postcode_value_object() -> None:
|
||||
# act
|
||||
addr = UserAddress(user_address="1 The Street", postcode=Postcode("sw1a 1aa"))
|
||||
# assert
|
||||
assert addr.postcode == Postcode("SW1A1AA")
|
||||
|
||||
|
||||
def test_user_address_preserves_user_address_verbatim() -> None:
|
||||
# The free-text user_address string is intentionally NOT normalised --
|
||||
# only the postcode is canonicalised, and that happens inside Postcode.
|
||||
# act
|
||||
addr = UserAddress(
|
||||
user_address=" 1 The Street ", postcode=Postcode("SW1A1AA")
|
||||
)
|
||||
# assert
|
||||
assert addr.user_address == " 1 The Street "
|
||||
|
||||
|
||||
def test_user_address_internal_reference_defaults_to_none() -> None:
|
||||
# act
|
||||
addr = UserAddress(user_address="1 The Street", postcode=Postcode("SW1A1AA"))
|
||||
# assert
|
||||
assert addr.internal_reference is None
|
||||
|
||||
|
||||
def test_user_address_internal_reference_accepted() -> None:
|
||||
# act
|
||||
addr = UserAddress(
|
||||
user_address="1 The Street",
|
||||
postcode=Postcode("SW1A1AA"),
|
||||
internal_reference="cust-42",
|
||||
)
|
||||
# assert
|
||||
assert addr.internal_reference == "cust-42"
|
||||
|
||||
|
||||
def test_user_address_is_frozen() -> None:
|
||||
# arrange
|
||||
addr = UserAddress(user_address="1 The Street", postcode=Postcode("SW1A1AA"))
|
||||
# act / assert
|
||||
with pytest.raises(dataclasses.FrozenInstanceError):
|
||||
addr.postcode = Postcode("OTHER") # type: ignore[misc]
|
||||
|
||||
|
|
@ -43,29 +53,37 @@ def test_user_address_is_frozen() -> None:
|
|||
def test_user_address_equality_uses_canonical_postcode() -> None:
|
||||
# Postcode sanitises eagerly, so addresses built from different surface
|
||||
# forms of the same postcode compare equal.
|
||||
# arrange
|
||||
a = UserAddress(user_address="1 The Street", postcode=Postcode("sw1a 1aa"))
|
||||
b = UserAddress(user_address="1 The Street", postcode=Postcode("SW1A1AA"))
|
||||
# act / assert
|
||||
assert a == b
|
||||
|
||||
|
||||
def test_user_address_source_row_defaults_to_empty_dict() -> None:
|
||||
# act
|
||||
addr = UserAddress(user_address="1 The Street", postcode=Postcode("SW1A1AA"))
|
||||
# assert
|
||||
assert addr.source_row == {}
|
||||
|
||||
|
||||
def test_user_address_carries_source_row() -> None:
|
||||
# arrange
|
||||
row = {"Address 1": "1 The Street", "postcode": "SW1A 1AA", "SAP Score": "72"}
|
||||
# act
|
||||
addr = UserAddress(
|
||||
user_address="1 The Street",
|
||||
postcode=Postcode("SW1A 1AA"),
|
||||
source_row=row,
|
||||
)
|
||||
# assert
|
||||
assert addr.source_row == row
|
||||
|
||||
|
||||
def test_user_address_equality_ignores_source_row() -> None:
|
||||
# source_row is excluded from equality (and hashing): identity stays
|
||||
# defined by the parsed fields.
|
||||
# arrange
|
||||
a = UserAddress(
|
||||
user_address="1 The Street",
|
||||
postcode=Postcode("SW1A1AA"),
|
||||
|
|
@ -76,4 +94,5 @@ def test_user_address_equality_ignores_source_row() -> None:
|
|||
postcode=Postcode("SW1A1AA"),
|
||||
source_row={"y": "2"},
|
||||
)
|
||||
# act / assert
|
||||
assert a == b
|
||||
|
|
|
|||
|
|
@ -6,10 +6,13 @@ from domain.tasks.subtasks import SubTask, SubTaskStatus
|
|||
|
||||
|
||||
def test_create_subtask_starts_waiting() -> None:
|
||||
# arrange
|
||||
task_id = uuid4()
|
||||
|
||||
# act
|
||||
st = SubTask.create(task_id=task_id, inputs={"foo": "bar"})
|
||||
|
||||
# assert
|
||||
assert st.task_id == task_id
|
||||
assert st.status is SubTaskStatus.WAITING
|
||||
assert st.inputs == {"foo": "bar"}
|
||||
|
|
@ -19,57 +22,74 @@ def test_create_subtask_starts_waiting() -> None:
|
|||
|
||||
|
||||
def test_start_transitions_to_in_progress_and_sets_cloud_logs_url() -> None:
|
||||
# arrange
|
||||
st = SubTask.create(task_id=uuid4())
|
||||
|
||||
# act
|
||||
st.start(cloud_logs_url="https://example/log")
|
||||
|
||||
# assert
|
||||
assert st.status is SubTaskStatus.IN_PROGRESS
|
||||
assert st.cloud_logs_url == "https://example/log"
|
||||
assert st.job_started is not None
|
||||
|
||||
|
||||
def test_start_is_idempotent_from_in_progress() -> None:
|
||||
# arrange
|
||||
st = SubTask.create(task_id=uuid4())
|
||||
st.start()
|
||||
first_start = st.job_started
|
||||
|
||||
# act
|
||||
st.start(cloud_logs_url="https://other")
|
||||
|
||||
# assert
|
||||
assert st.status is SubTaskStatus.IN_PROGRESS
|
||||
assert st.job_started == first_start # not overwritten
|
||||
assert st.cloud_logs_url == "https://other"
|
||||
|
||||
|
||||
def test_start_rejects_from_terminal_status() -> None:
|
||||
# arrange
|
||||
st = SubTask.create(task_id=uuid4())
|
||||
st.complete()
|
||||
# act / assert
|
||||
with pytest.raises(ValueError):
|
||||
st.start()
|
||||
|
||||
|
||||
def test_complete_marks_outputs_and_job_completed() -> None:
|
||||
# arrange
|
||||
st = SubTask.create(task_id=uuid4())
|
||||
st.start()
|
||||
|
||||
# act
|
||||
st.complete({"uprn": "123"})
|
||||
|
||||
# assert
|
||||
assert st.status is SubTaskStatus.COMPLETE
|
||||
assert st.outputs == {"result": {"uprn": "123"}}
|
||||
assert st.job_completed is not None
|
||||
|
||||
|
||||
def test_complete_without_result_leaves_outputs_unset() -> None:
|
||||
# arrange
|
||||
st = SubTask.create(task_id=uuid4())
|
||||
# act
|
||||
st.complete()
|
||||
# assert
|
||||
assert st.outputs is None
|
||||
|
||||
|
||||
def test_fail_records_error_in_outputs() -> None:
|
||||
# arrange
|
||||
st = SubTask.create(task_id=uuid4())
|
||||
err = RuntimeError("boom")
|
||||
|
||||
# act
|
||||
st.fail(err)
|
||||
|
||||
# assert
|
||||
assert st.status is SubTaskStatus.FAILED
|
||||
assert st.outputs == {"error": "boom"}
|
||||
assert st.job_completed is not None
|
||||
|
|
|
|||
|
|
@ -5,12 +5,12 @@ from domain.tasks.tasks import Source, Task, TaskStatus
|
|||
|
||||
|
||||
def test_create_task_starts_waiting() -> None:
|
||||
# Arrange / Act
|
||||
# arrange / act
|
||||
t = Task.create(
|
||||
task_source="manual:test", source=Source.PORTFOLIO, source_id="abc-123"
|
||||
)
|
||||
|
||||
# Assert
|
||||
# assert
|
||||
assert t.status is TaskStatus.WAITING
|
||||
assert t.source is Source.PORTFOLIO
|
||||
assert t.source_id == "abc-123"
|
||||
|
|
@ -19,86 +19,113 @@ def test_create_task_starts_waiting() -> None:
|
|||
|
||||
|
||||
def test_create_task_rejects_blank_task_source() -> None:
|
||||
# act / assert
|
||||
with pytest.raises(ValueError, match="task_source"):
|
||||
Task.create(task_source=" ")
|
||||
|
||||
|
||||
def test_start_transitions_to_in_progress() -> None:
|
||||
# arrange
|
||||
t = Task.create(task_source="manual:test")
|
||||
# act
|
||||
t.start()
|
||||
# assert
|
||||
assert t.status is TaskStatus.IN_PROGRESS
|
||||
|
||||
|
||||
def test_complete_marks_job_completed() -> None:
|
||||
# arrange
|
||||
t = Task.create(task_source="manual:test")
|
||||
t.start()
|
||||
# act
|
||||
t.complete()
|
||||
# assert
|
||||
assert t.status is TaskStatus.COMPLETE
|
||||
assert t.job_completed is not None
|
||||
|
||||
|
||||
def test_fail_marks_job_completed() -> None:
|
||||
# arrange
|
||||
t = Task.create(task_source="manual:test")
|
||||
# act
|
||||
t.fail()
|
||||
# assert
|
||||
assert t.status is TaskStatus.FAILED
|
||||
assert t.job_completed is not None
|
||||
|
||||
|
||||
def test_start_rejects_from_terminal_status() -> None:
|
||||
# arrange
|
||||
t = Task.create(task_source="manual:test")
|
||||
t.complete()
|
||||
# act / assert
|
||||
with pytest.raises(ValueError):
|
||||
t.start()
|
||||
|
||||
|
||||
def test_recalculate_with_empty_statuses_is_noop() -> None:
|
||||
# arrange
|
||||
t = Task.create(task_source="manual:test")
|
||||
original_status = t.status
|
||||
original_completed = t.job_completed
|
||||
|
||||
# act
|
||||
t.recalculate_from_subtasks([])
|
||||
|
||||
# assert
|
||||
assert t.status is original_status
|
||||
assert t.job_completed is original_completed
|
||||
|
||||
|
||||
def test_recalculate_all_waiting_keeps_waiting() -> None:
|
||||
# arrange
|
||||
t = Task.create(task_source="manual:test")
|
||||
t.start() # task moved to IN_PROGRESS earlier
|
||||
t.complete() # then COMPLETE, with job_completed set
|
||||
|
||||
# act
|
||||
t.recalculate_from_subtasks([SubTaskStatus.WAITING, SubTaskStatus.WAITING])
|
||||
|
||||
# assert
|
||||
assert t.status is TaskStatus.WAITING
|
||||
assert t.job_completed is None
|
||||
|
||||
|
||||
def test_recalculate_any_in_progress_marks_in_progress() -> None:
|
||||
# arrange
|
||||
t = Task.create(task_source="manual:test")
|
||||
|
||||
# act
|
||||
t.recalculate_from_subtasks(
|
||||
[SubTaskStatus.WAITING, SubTaskStatus.IN_PROGRESS, SubTaskStatus.COMPLETE]
|
||||
)
|
||||
|
||||
# assert
|
||||
assert t.status is TaskStatus.IN_PROGRESS
|
||||
assert t.job_completed is None
|
||||
|
||||
|
||||
def test_recalculate_all_complete_marks_complete() -> None:
|
||||
# arrange
|
||||
t = Task.create(task_source="manual:test")
|
||||
|
||||
# act
|
||||
t.recalculate_from_subtasks([SubTaskStatus.COMPLETE, SubTaskStatus.COMPLETE])
|
||||
|
||||
# assert
|
||||
assert t.status is TaskStatus.COMPLETE
|
||||
assert t.job_completed is not None
|
||||
|
||||
|
||||
def test_recalculate_any_failed_marks_failed_even_with_others() -> None:
|
||||
# arrange
|
||||
t = Task.create(task_source="manual:test")
|
||||
|
||||
# act
|
||||
t.recalculate_from_subtasks(
|
||||
[SubTaskStatus.IN_PROGRESS, SubTaskStatus.COMPLETE, SubTaskStatus.FAILED]
|
||||
)
|
||||
|
||||
# assert
|
||||
assert t.status is TaskStatus.FAILED
|
||||
assert t.job_completed is not None
|
||||
|
|
|
|||
|
|
@ -6,43 +6,54 @@ from domain.postcode import Postcode
|
|||
|
||||
|
||||
def test_postcode_uppercases() -> None:
|
||||
# act / assert
|
||||
assert Postcode("sw1a1aa").value == "SW1A1AA"
|
||||
|
||||
|
||||
def test_postcode_strips_internal_spaces() -> None:
|
||||
# act / assert
|
||||
assert Postcode("sw1a 1aa").value == "SW1A1AA"
|
||||
|
||||
|
||||
def test_postcode_strips_leading_and_trailing_whitespace() -> None:
|
||||
# act / assert
|
||||
assert Postcode(" sw1a 1aa ").value == "SW1A1AA"
|
||||
|
||||
|
||||
def test_postcode_strips_tabs_and_newlines() -> None:
|
||||
# CSV ingestion occasionally introduces stray whitespace characters; the
|
||||
# canonical form must absorb them just like literal spaces.
|
||||
# act / assert
|
||||
assert Postcode("sw1a\t1aa\n").value == "SW1A1AA"
|
||||
|
||||
|
||||
def test_postcode_construction_is_idempotent() -> None:
|
||||
# arrange
|
||||
once = Postcode("sw1a 1aa")
|
||||
# act / assert
|
||||
assert Postcode(once.value).value == "SW1A1AA"
|
||||
|
||||
|
||||
def test_postcode_empty_string() -> None:
|
||||
# act / assert
|
||||
assert Postcode("").value == ""
|
||||
|
||||
|
||||
def test_postcode_str_returns_canonical_value() -> None:
|
||||
# act / assert
|
||||
assert str(Postcode("sw1a 1aa")) == "SW1A1AA"
|
||||
|
||||
|
||||
def test_postcode_equality_ignores_surface_form() -> None:
|
||||
# Differing case / whitespace sanitise to the same canonical value, so
|
||||
# the value objects compare equal.
|
||||
# act / assert
|
||||
assert Postcode("sw1a 1aa") == Postcode("SW1A1AA")
|
||||
|
||||
|
||||
def test_postcode_is_frozen() -> None:
|
||||
# arrange
|
||||
postcode = Postcode("SW1A1AA")
|
||||
# act / assert
|
||||
with pytest.raises(dataclasses.FrozenInstanceError):
|
||||
postcode.value = "OTHER" # type: ignore[misc]
|
||||
|
|
|
|||
|
|
@ -28,12 +28,15 @@ def queue_setup() -> Iterator[tuple[Address2UprnQueueClient, Any, str]]:
|
|||
def test_publish_returns_message_id(
|
||||
queue_setup: tuple[Address2UprnQueueClient, Any, str],
|
||||
) -> None:
|
||||
# arrange
|
||||
client, _boto, _url = queue_setup
|
||||
# act
|
||||
message_id = client.publish(
|
||||
parent_task_id=uuid4(),
|
||||
child_subtask_id=uuid4(),
|
||||
s3_uri="s3://my-bucket/path/to/chunk.csv",
|
||||
)
|
||||
# assert
|
||||
assert isinstance(message_id, str)
|
||||
assert message_id
|
||||
|
||||
|
|
@ -41,17 +44,20 @@ def test_publish_returns_message_id(
|
|||
def test_publish_body_uses_typed_shape(
|
||||
queue_setup: tuple[Address2UprnQueueClient, Any, str],
|
||||
) -> None:
|
||||
# arrange
|
||||
client, boto_client, queue_url = queue_setup
|
||||
parent_id = uuid4()
|
||||
child_id = uuid4()
|
||||
s3_uri = "s3://my-bucket/path/to/chunk.csv"
|
||||
|
||||
# act
|
||||
client.publish(
|
||||
parent_task_id=parent_id,
|
||||
child_subtask_id=child_id,
|
||||
s3_uri=s3_uri,
|
||||
)
|
||||
|
||||
# assert
|
||||
received: dict[str, Any] = boto_client.receive_message(
|
||||
QueueUrl=queue_url, MaxNumberOfMessages=1
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,26 +18,34 @@ def csv_client() -> Iterator[CsvS3Client]:
|
|||
|
||||
|
||||
def test_save_rows_returns_s3_uri(csv_client: CsvS3Client) -> None:
|
||||
# arrange
|
||||
rows = [{"address": "1 High St", "postcode": "AB1 2CD"}]
|
||||
# act
|
||||
uri = csv_client.save_rows(rows, "uploads/addresses.csv")
|
||||
# assert
|
||||
assert uri == f"s3://{BUCKET}/uploads/addresses.csv"
|
||||
|
||||
|
||||
def test_round_trip_preserves_rows(csv_client: CsvS3Client) -> None:
|
||||
# arrange
|
||||
rows = [
|
||||
{"address": "1 High St", "postcode": "AB1 2CD"},
|
||||
{"address": "2 Low St", "postcode": "XY9 8ZW"},
|
||||
]
|
||||
# act
|
||||
uri = csv_client.save_rows(rows, "uploads/addresses.csv")
|
||||
fetched = csv_client.read_rows(uri)
|
||||
# assert
|
||||
assert fetched == rows
|
||||
|
||||
|
||||
def test_save_rows_rejects_empty_list(csv_client: CsvS3Client) -> None:
|
||||
# act / assert
|
||||
with pytest.raises(ValueError, match="empty"):
|
||||
csv_client.save_rows([], "uploads/empty.csv")
|
||||
|
||||
|
||||
def test_read_rows_rejects_wrong_bucket(csv_client: CsvS3Client) -> None:
|
||||
# act / assert
|
||||
with pytest.raises(ValueError, match="does not match client bucket"):
|
||||
csv_client.read_rows("s3://other-bucket/uploads/addresses.csv")
|
||||
|
|
|
|||
|
|
@ -18,14 +18,19 @@ def s3_client() -> Iterator[S3Client]:
|
|||
|
||||
|
||||
def test_put_object_returns_s3_uri(s3_client: S3Client) -> None:
|
||||
# act
|
||||
uri = s3_client.put_object("folder/data.bin", b"payload")
|
||||
# assert
|
||||
assert uri == f"s3://{BUCKET}/folder/data.bin"
|
||||
|
||||
|
||||
def test_get_object_returns_bytes_written_by_put_object(s3_client: S3Client) -> None:
|
||||
# arrange
|
||||
s3_client.put_object("round/trip.bin", b"hello world")
|
||||
# act / assert
|
||||
assert s3_client.get_object("round/trip.bin") == b"hello world"
|
||||
|
||||
|
||||
def test_bucket_property_exposes_configured_bucket(s3_client: S3Client) -> None:
|
||||
# act / assert
|
||||
assert s3_client.bucket == BUCKET
|
||||
|
|
|
|||
|
|
@ -4,29 +4,37 @@ from infrastructure.s3_uri import parse_s3_uri
|
|||
|
||||
|
||||
def test_parses_simple_s3_uri() -> None:
|
||||
# act / assert
|
||||
assert parse_s3_uri("s3://my-bucket/file.csv") == ("my-bucket", "file.csv")
|
||||
|
||||
|
||||
def test_parses_s3_uri_with_nested_key() -> None:
|
||||
# act
|
||||
bucket, key = parse_s3_uri("s3://my-bucket/nested/path/to/file.csv")
|
||||
# assert
|
||||
assert (bucket, key) == ("my-bucket", "nested/path/to/file.csv")
|
||||
|
||||
|
||||
def test_rejects_s3_uri_without_key() -> None:
|
||||
# act / assert
|
||||
with pytest.raises(ValueError, match="bucket and a key"):
|
||||
parse_s3_uri("s3://my-bucket")
|
||||
|
||||
|
||||
def test_rejects_s3_uri_with_empty_key() -> None:
|
||||
# act / assert
|
||||
with pytest.raises(ValueError, match="bucket and a key"):
|
||||
parse_s3_uri("s3://my-bucket/")
|
||||
|
||||
|
||||
def test_parses_console_url_prefix() -> None:
|
||||
# arrange
|
||||
url = "https://eu-west-2.console.aws.amazon.com/s3/object/my-bucket?prefix=nested%2Ffile.csv"
|
||||
# act / assert
|
||||
assert parse_s3_uri(url) == ("my-bucket", "nested/file.csv")
|
||||
|
||||
|
||||
def test_rejects_unparseable_string() -> None:
|
||||
# act / assert
|
||||
with pytest.raises(ValueError):
|
||||
parse_s3_uri("not-a-uri-at-all")
|
||||
|
|
|
|||
|
|
@ -19,17 +19,23 @@ def sqs_setup() -> Iterator[tuple[SqsClient, Any, str]]:
|
|||
|
||||
|
||||
def test_send_returns_message_id(sqs_setup: tuple[SqsClient, Any, str]) -> None:
|
||||
# arrange
|
||||
client, _boto, _url = sqs_setup
|
||||
# act
|
||||
message_id = client.send({"hello": "world"})
|
||||
# assert
|
||||
assert isinstance(message_id, str)
|
||||
assert message_id
|
||||
|
||||
|
||||
def test_send_json_serialises_body(sqs_setup: tuple[SqsClient, Any, str]) -> None:
|
||||
# arrange
|
||||
client, boto_client, queue_url = sqs_setup
|
||||
body = {"hello": "world", "count": 3}
|
||||
# act
|
||||
client.send(body)
|
||||
|
||||
# assert
|
||||
received: dict[str, Any] = boto_client.receive_message(
|
||||
QueueUrl=queue_url, MaxNumberOfMessages=1
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,8 @@ from typing import Any, cast
|
|||
import boto3
|
||||
import pytest
|
||||
from moto import mock_aws
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
from sqlalchemy import Engine
|
||||
from sqlmodel import Session
|
||||
|
||||
from infrastructure.address2uprn_queue_client import Address2UprnQueueClient
|
||||
from infrastructure.csv_s3_client import CsvS3Client
|
||||
|
|
@ -65,7 +66,7 @@ class Harness:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def harness() -> Iterator[Harness]:
|
||||
def harness(db_engine: Engine) -> Iterator[Harness]:
|
||||
with mock_aws():
|
||||
# Infra: S3 + SQS
|
||||
boto_s3 = _make_boto_client("s3")
|
||||
|
|
@ -78,10 +79,8 @@ def harness() -> Iterator[Harness]:
|
|||
repo = UserAddressCsvS3Repository(csv_client, BUCKET)
|
||||
queue_client = Address2UprnQueueClient(boto_sqs, queue_url)
|
||||
|
||||
# DB: in-memory SQLite TaskOrchestrator
|
||||
engine = create_engine("sqlite://")
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as session:
|
||||
# DB: ephemeral PostgreSQL TaskOrchestrator
|
||||
with Session(db_engine) as session:
|
||||
task_repo = TaskPostgresRepository(session=session)
|
||||
subtask_repo = SubTaskPostgresRepository(session=session)
|
||||
task_orchestrator = TaskOrchestrator(
|
||||
|
|
@ -169,6 +168,7 @@ def _drain_queue(boto_sqs: Any, queue_url: str) -> list[dict[str, Any]]:
|
|||
def test_split_and_dispatch_creates_three_children_for_fixture(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
parent_task, parent_subtask = (
|
||||
harness.task_orchestrator.create_task_with_subtask(
|
||||
task_source="manual:postcode-splitter-int"
|
||||
|
|
@ -176,12 +176,14 @@ def test_split_and_dispatch_creates_three_children_for_fixture(
|
|||
)
|
||||
input_uri = _upload_fixture_csv(harness.csv_client)
|
||||
|
||||
# act
|
||||
child_ids = harness.splitter.split_and_dispatch(
|
||||
parent_task_id=parent_task.id,
|
||||
parent_subtask_id=parent_subtask.id,
|
||||
input_s3_uri=input_uri,
|
||||
)
|
||||
|
||||
# assert
|
||||
assert len(child_ids) == 3
|
||||
# All child ids are unique and persisted as WAITING children of the
|
||||
# parent task.
|
||||
|
|
@ -194,6 +196,7 @@ def test_split_and_dispatch_creates_three_children_for_fixture(
|
|||
def test_split_and_dispatch_persists_child_inputs_with_task_id_and_s3_uri(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
parent_task, parent_subtask = (
|
||||
harness.task_orchestrator.create_task_with_subtask(
|
||||
task_source="manual:postcode-splitter-int"
|
||||
|
|
@ -201,12 +204,14 @@ def test_split_and_dispatch_persists_child_inputs_with_task_id_and_s3_uri(
|
|||
)
|
||||
input_uri = _upload_fixture_csv(harness.csv_client)
|
||||
|
||||
# act
|
||||
child_ids = harness.splitter.split_and_dispatch(
|
||||
parent_task_id=parent_task.id,
|
||||
parent_subtask_id=parent_subtask.id,
|
||||
input_s3_uri=input_uri,
|
||||
)
|
||||
|
||||
# assert
|
||||
for cid in child_ids:
|
||||
child = harness.subtasks.get(cid)
|
||||
assert child.inputs is not None
|
||||
|
|
@ -224,6 +229,7 @@ def test_split_and_dispatch_persists_child_inputs_with_task_id_and_s3_uri(
|
|||
def test_split_and_dispatch_publishes_one_message_per_child_with_matching_ids(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
parent_task, parent_subtask = (
|
||||
harness.task_orchestrator.create_task_with_subtask(
|
||||
task_source="manual:postcode-splitter-int"
|
||||
|
|
@ -231,12 +237,14 @@ def test_split_and_dispatch_publishes_one_message_per_child_with_matching_ids(
|
|||
)
|
||||
input_uri = _upload_fixture_csv(harness.csv_client)
|
||||
|
||||
# act
|
||||
child_ids = harness.splitter.split_and_dispatch(
|
||||
parent_task_id=parent_task.id,
|
||||
parent_subtask_id=parent_subtask.id,
|
||||
input_s3_uri=input_uri,
|
||||
)
|
||||
|
||||
# assert
|
||||
bodies = _drain_queue(harness.boto_sqs, harness.queue_url)
|
||||
assert len(bodies) == len(child_ids)
|
||||
|
||||
|
|
@ -258,6 +266,7 @@ def test_split_and_dispatch_publishes_one_message_per_child_with_matching_ids(
|
|||
def test_split_and_dispatch_returns_child_ids_in_dispatch_order(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
parent_task, parent_subtask = (
|
||||
harness.task_orchestrator.create_task_with_subtask(
|
||||
task_source="manual:postcode-splitter-int"
|
||||
|
|
@ -265,12 +274,14 @@ def test_split_and_dispatch_returns_child_ids_in_dispatch_order(
|
|||
)
|
||||
input_uri = _upload_fixture_csv(harness.csv_client)
|
||||
|
||||
# act
|
||||
child_ids = harness.splitter.split_and_dispatch(
|
||||
parent_task_id=parent_task.id,
|
||||
parent_subtask_id=parent_subtask.id,
|
||||
input_s3_uri=input_uri,
|
||||
)
|
||||
|
||||
# assert
|
||||
# Re-load each child's saved batch and inspect the postcode_clean column
|
||||
# to confirm the dispatch order matches the postcode-batching algorithm:
|
||||
# AA-batch first, BB oversize batch second, CC final-flush third.
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ from collections.abc import Iterator
|
|||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
from sqlalchemy import Engine
|
||||
from sqlmodel import Session
|
||||
|
||||
from domain.tasks.subtasks import SubTask, SubTaskStatus
|
||||
from domain.tasks.tasks import Source, TaskStatus
|
||||
|
|
@ -19,10 +20,8 @@ class Harness:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def harness() -> Iterator[Harness]:
|
||||
engine = create_engine("sqlite://")
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as session:
|
||||
def harness(db_engine: Engine) -> Iterator[Harness]:
|
||||
with Session(db_engine) as session:
|
||||
tasks = TaskPostgresRepository(session=session)
|
||||
subtasks = SubTaskPostgresRepository(session=session)
|
||||
yield Harness(
|
||||
|
|
@ -35,6 +34,7 @@ def harness() -> Iterator[Harness]:
|
|||
def test_create_task_with_subtask_creates_both_in_waiting(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# act
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test",
|
||||
inputs={"foo": "bar"},
|
||||
|
|
@ -42,6 +42,7 @@ def test_create_task_with_subtask_creates_both_in_waiting(
|
|||
source_id="abc",
|
||||
)
|
||||
|
||||
# assert
|
||||
assert task.status is TaskStatus.WAITING
|
||||
assert subtask.status is SubTaskStatus.WAITING
|
||||
assert subtask.task_id == task.id
|
||||
|
|
@ -49,27 +50,33 @@ def test_create_task_with_subtask_creates_both_in_waiting(
|
|||
|
||||
|
||||
def test_start_subtask_cascades_to_in_progress(harness: Harness) -> None:
|
||||
# arrange
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
||||
# act
|
||||
started = harness.orchestrator.start_subtask(
|
||||
subtask.id, cloud_logs_url="https://example/log"
|
||||
)
|
||||
|
||||
# assert
|
||||
assert started.status is SubTaskStatus.IN_PROGRESS
|
||||
assert started.cloud_logs_url == "https://example/log"
|
||||
assert harness.tasks.get(task.id).status is TaskStatus.IN_PROGRESS
|
||||
|
||||
|
||||
def test_complete_subtask_cascades_to_complete(harness: Harness) -> None:
|
||||
# arrange
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
harness.orchestrator.start_subtask(subtask.id)
|
||||
|
||||
# act
|
||||
harness.orchestrator.complete_subtask(subtask.id, {"value": 42})
|
||||
|
||||
# assert
|
||||
done_subtask = harness.subtasks.get(subtask.id)
|
||||
done_task = harness.tasks.get(task.id)
|
||||
assert done_subtask.outputs == {"result": {"value": 42}}
|
||||
|
|
@ -78,12 +85,15 @@ def test_complete_subtask_cascades_to_complete(harness: Harness) -> None:
|
|||
|
||||
|
||||
def test_fail_subtask_cascades_to_failed(harness: Harness) -> None:
|
||||
# arrange
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
||||
# act
|
||||
harness.orchestrator.fail_subtask(subtask.id, RuntimeError("boom"))
|
||||
|
||||
# assert
|
||||
failed_subtask = harness.subtasks.get(subtask.id)
|
||||
failed_task = harness.tasks.get(task.id)
|
||||
assert failed_subtask.outputs == {"error": "boom"}
|
||||
|
|
@ -93,42 +103,51 @@ def test_fail_subtask_cascades_to_failed(harness: Harness) -> None:
|
|||
def test_failed_subtask_locks_task_failed_even_with_others_complete(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
task, first = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
second = SubTask.create(task_id=task.id)
|
||||
harness.subtasks.create(second)
|
||||
|
||||
# act
|
||||
harness.orchestrator.complete_subtask(first.id)
|
||||
harness.orchestrator.fail_subtask(second.id, RuntimeError("nope"))
|
||||
|
||||
# assert
|
||||
assert harness.tasks.get(task.id).status is TaskStatus.FAILED
|
||||
|
||||
|
||||
def test_mixed_complete_and_in_progress_keeps_task_in_progress(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
task, first = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
second = SubTask.create(task_id=task.id)
|
||||
harness.subtasks.create(second)
|
||||
|
||||
# act
|
||||
harness.orchestrator.complete_subtask(first.id)
|
||||
harness.orchestrator.start_subtask(second.id)
|
||||
|
||||
# assert
|
||||
assert harness.tasks.get(task.id).status is TaskStatus.IN_PROGRESS
|
||||
|
||||
|
||||
def test_run_subtask_happy_path_returns_result_and_cascades_complete(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
||||
# act
|
||||
result = harness.orchestrator.run_subtask(subtask.id, work=lambda: {"answer": 42})
|
||||
|
||||
# assert
|
||||
assert result == {"answer": 42}
|
||||
assert harness.subtasks.get(subtask.id).status is SubTaskStatus.COMPLETE
|
||||
assert harness.tasks.get(task.id).status is TaskStatus.COMPLETE
|
||||
|
|
@ -137,16 +156,19 @@ def test_run_subtask_happy_path_returns_result_and_cascades_complete(
|
|||
def test_create_child_subtask_adds_waiting_child_without_changing_parent_status(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
task, first = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
harness.orchestrator.start_subtask(first.id)
|
||||
assert harness.tasks.get(task.id).status is TaskStatus.IN_PROGRESS
|
||||
|
||||
# act
|
||||
child = harness.orchestrator.create_child_subtask(
|
||||
task.id, inputs={"split": "a"}
|
||||
)
|
||||
|
||||
# assert
|
||||
persisted_child = harness.subtasks.get(child.id)
|
||||
assert persisted_child.task_id == task.id
|
||||
assert persisted_child.status is SubTaskStatus.WAITING
|
||||
|
|
@ -159,6 +181,7 @@ def test_create_child_subtask_adds_waiting_child_without_changing_parent_status(
|
|||
def test_run_subtask_failing_work_marks_failed_and_reraises(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
|
@ -166,6 +189,7 @@ def test_run_subtask_failing_work_marks_failed_and_reraises(
|
|||
def boom() -> None:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
# act / assert
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
harness.orchestrator.run_subtask(subtask.id, work=boom)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,33 +1,40 @@
|
|||
from collections.abc import Iterator
|
||||
from uuid import uuid4
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
from sqlalchemy import Engine
|
||||
from sqlmodel import Session
|
||||
|
||||
# Importing the SQLModel row modules registers their tables in
|
||||
# SQLModel.metadata so create_all builds both. Imports look unused; they aren't.
|
||||
import infrastructure.postgres.subtask_table # noqa: F401 # pyright: ignore[reportUnusedImport]
|
||||
import infrastructure.postgres.task_table # noqa: F401 # pyright: ignore[reportUnusedImport]
|
||||
from domain.tasks.subtasks import SubTask, SubTaskStatus
|
||||
from domain.tasks.tasks import Task
|
||||
from repositories.tasks.subtask_postgres_repository import SubTaskPostgresRepository
|
||||
from repositories.tasks.task_postgres_repository import TaskPostgresRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session() -> Iterator[Session]:
|
||||
engine = create_engine("sqlite://")
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as s:
|
||||
def session(db_engine: Engine) -> Iterator[Session]:
|
||||
with Session(db_engine) as s:
|
||||
yield s
|
||||
|
||||
|
||||
def _persisted_task_id(session: Session) -> UUID:
|
||||
"""Create a parent Task row so SubTask FK constraints are satisfied."""
|
||||
task = Task.create(task_source="manual:test")
|
||||
TaskPostgresRepository(session=session).create(task)
|
||||
return task.id
|
||||
|
||||
|
||||
def test_create_and_get_round_trip_preserves_inputs(session: Session) -> None:
|
||||
# arrange
|
||||
repo = SubTaskPostgresRepository(session=session)
|
||||
task_id = uuid4()
|
||||
task_id = _persisted_task_id(session)
|
||||
st = SubTask.create(task_id=task_id, inputs={"address": "68 Glendon Way"})
|
||||
|
||||
# act
|
||||
repo.create(st)
|
||||
fetched = repo.get(st.id)
|
||||
|
||||
# assert
|
||||
assert fetched.id == st.id
|
||||
assert fetched.task_id == task_id
|
||||
assert fetched.status is SubTaskStatus.WAITING
|
||||
|
|
@ -36,16 +43,21 @@ def test_create_and_get_round_trip_preserves_inputs(session: Session) -> None:
|
|||
|
||||
|
||||
def test_save_persists_status_and_outputs(session: Session) -> None:
|
||||
# arrange
|
||||
repo = SubTaskPostgresRepository(session=session)
|
||||
st = SubTask.create(task_id=uuid4())
|
||||
st = SubTask.create(task_id=_persisted_task_id(session))
|
||||
repo.create(st)
|
||||
|
||||
# act
|
||||
st.start(cloud_logs_url="https://example/log")
|
||||
repo.save(st)
|
||||
# assert
|
||||
assert repo.get(st.id).status is SubTaskStatus.IN_PROGRESS
|
||||
|
||||
# act
|
||||
st.complete({"uprn": "123"})
|
||||
repo.save(st)
|
||||
# assert
|
||||
done = repo.get(st.id)
|
||||
assert done.status is SubTaskStatus.COMPLETE
|
||||
assert done.outputs == {"result": {"uprn": "123"}}
|
||||
|
|
@ -54,16 +66,19 @@ def test_save_persists_status_and_outputs(session: Session) -> None:
|
|||
|
||||
|
||||
def test_list_by_task_filters_by_task_id(session: Session) -> None:
|
||||
# arrange
|
||||
repo = SubTaskPostgresRepository(session=session)
|
||||
task_a = uuid4()
|
||||
task_b = uuid4()
|
||||
task_a = _persisted_task_id(session)
|
||||
task_b = _persisted_task_id(session)
|
||||
repo.create(SubTask.create(task_id=task_a))
|
||||
repo.create(SubTask.create(task_id=task_a))
|
||||
repo.create(SubTask.create(task_id=task_b))
|
||||
|
||||
# act
|
||||
a_results = repo.list_by_task(task_a)
|
||||
b_results = repo.list_by_task(task_b)
|
||||
|
||||
# assert
|
||||
assert len(a_results) == 2
|
||||
assert len(b_results) == 1
|
||||
assert all(s.task_id == task_a for s in a_results)
|
||||
|
|
@ -71,11 +86,15 @@ def test_list_by_task_filters_by_task_id(session: Session) -> None:
|
|||
|
||||
|
||||
def test_list_by_task_returns_empty_for_unknown_task(session: Session) -> None:
|
||||
# arrange
|
||||
repo = SubTaskPostgresRepository(session=session)
|
||||
# act / assert
|
||||
assert repo.list_by_task(uuid4()) == []
|
||||
|
||||
|
||||
def test_get_missing_raises(session: Session) -> None:
|
||||
# arrange
|
||||
repo = SubTaskPostgresRepository(session=session)
|
||||
# act / assert
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
repo.get(uuid4())
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ from collections.abc import Iterator
|
|||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
from sqlalchemy import Engine
|
||||
from sqlmodel import Session
|
||||
|
||||
from domain.tasks.tasks import Source, Task, TaskStatus
|
||||
from infrastructure.postgres.task_table import TaskRow
|
||||
|
|
@ -10,25 +11,23 @@ from repositories.tasks.task_postgres_repository import TaskPostgresRepository
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def session() -> Iterator[Session]:
|
||||
engine = create_engine("sqlite://")
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as s:
|
||||
def session(db_engine: Engine) -> Iterator[Session]:
|
||||
with Session(db_engine) as s:
|
||||
yield s
|
||||
|
||||
|
||||
def test_create_and_get_round_trip(session: Session) -> None:
|
||||
# Arrange
|
||||
# arrange
|
||||
repo = TaskPostgresRepository(session=session)
|
||||
t = Task.create(
|
||||
task_source="manual:test", source=Source.PORTFOLIO, source_id="abc-123"
|
||||
)
|
||||
|
||||
# Act
|
||||
# act
|
||||
repo.create(t)
|
||||
fetched = repo.get(t.id)
|
||||
|
||||
# Assert
|
||||
# assert
|
||||
assert fetched.id == t.id
|
||||
assert fetched.status is TaskStatus.WAITING
|
||||
assert fetched.source is Source.PORTFOLIO
|
||||
|
|
@ -36,33 +35,43 @@ def test_create_and_get_round_trip(session: Session) -> None:
|
|||
|
||||
|
||||
def test_save_persists_status_transition(session: Session) -> None:
|
||||
# arrange
|
||||
repo = TaskPostgresRepository(session=session)
|
||||
t = Task.create(task_source="manual:test")
|
||||
repo.create(t)
|
||||
|
||||
# act
|
||||
t.start()
|
||||
repo.save(t)
|
||||
# assert
|
||||
assert repo.get(t.id).status is TaskStatus.IN_PROGRESS
|
||||
|
||||
# act
|
||||
t.complete()
|
||||
repo.save(t)
|
||||
# assert
|
||||
done = repo.get(t.id)
|
||||
assert done.status is TaskStatus.COMPLETE
|
||||
assert done.job_completed is not None
|
||||
|
||||
|
||||
def test_get_missing_raises(session: Session) -> None:
|
||||
# arrange
|
||||
repo = TaskPostgresRepository(session=session)
|
||||
# act / assert
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
repo.get(uuid4())
|
||||
|
||||
|
||||
def test_get_normalises_legacy_capitalised_status(session: Session) -> None:
|
||||
# Existing rows written by backend code use "In Progress" (capitalised).
|
||||
# arrange
|
||||
repo = TaskPostgresRepository(session=session)
|
||||
row = TaskRow(task_source="manual:test", status="In Progress")
|
||||
session.add(row)
|
||||
session.commit()
|
||||
|
||||
# act
|
||||
fetched = repo.get(row.id)
|
||||
# assert
|
||||
assert fetched.status is TaskStatus.IN_PROGRESS
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ def _upload_csv(
|
|||
def test_load_batch_parses_address_postcode_and_reference(
|
||||
repo: UserAddressCsvS3Repository,
|
||||
) -> None:
|
||||
# arrange
|
||||
rows = [
|
||||
{
|
||||
"Address 1": "1 High Street",
|
||||
|
|
@ -43,8 +44,10 @@ def test_load_batch_parses_address_postcode_and_reference(
|
|||
]
|
||||
uri = _upload_csv(repo, rows, "uploads/full.csv")
|
||||
|
||||
# act
|
||||
addresses = repo.load_batch(uri)
|
||||
|
||||
# assert
|
||||
assert len(addresses) == 1
|
||||
address = addresses[0]
|
||||
assert address.user_address == "1 High Street, Flat 2, Townville"
|
||||
|
|
@ -55,6 +58,7 @@ def test_load_batch_parses_address_postcode_and_reference(
|
|||
def test_load_batch_uses_only_address_1_when_others_missing(
|
||||
repo: UserAddressCsvS3Repository,
|
||||
) -> None:
|
||||
# arrange
|
||||
rows = [
|
||||
{
|
||||
"Address 1": "10 Cardiff Road",
|
||||
|
|
@ -66,8 +70,10 @@ def test_load_batch_uses_only_address_1_when_others_missing(
|
|||
]
|
||||
uri = _upload_csv(repo, rows, "uploads/address1-only.csv")
|
||||
|
||||
# act
|
||||
addresses = repo.load_batch(uri)
|
||||
|
||||
# assert
|
||||
assert len(addresses) == 1
|
||||
assert addresses[0].user_address == "10 Cardiff Road"
|
||||
assert addresses[0].postcode == Postcode("CF101AA")
|
||||
|
|
@ -77,6 +83,7 @@ def test_load_batch_uses_only_address_1_when_others_missing(
|
|||
def test_load_batch_handles_missing_internal_reference(
|
||||
repo: UserAddressCsvS3Repository,
|
||||
) -> None:
|
||||
# arrange
|
||||
rows = [
|
||||
{
|
||||
"Address 1": "5 Park Lane",
|
||||
|
|
@ -88,8 +95,10 @@ def test_load_batch_handles_missing_internal_reference(
|
|||
]
|
||||
uri = _upload_csv(repo, rows, "uploads/no-ref.csv")
|
||||
|
||||
# act
|
||||
addresses = repo.load_batch(uri)
|
||||
|
||||
# assert
|
||||
assert len(addresses) == 1
|
||||
assert addresses[0].user_address == "5 Park Lane"
|
||||
assert addresses[0].postcode == Postcode("M11AA")
|
||||
|
|
@ -101,6 +110,7 @@ def test_load_batch_captures_full_source_row(
|
|||
) -> None:
|
||||
# A raw EPC-export-shaped row: the splitter must preserve every column,
|
||||
# not just the ones it parses into UserAddress fields.
|
||||
# arrange
|
||||
row = {
|
||||
"Asset Reference": "511",
|
||||
"Address 1": "9 Abingdon Road Padiham Lancashire BB12 7BX",
|
||||
|
|
@ -110,17 +120,21 @@ def test_load_batch_captures_full_source_row(
|
|||
}
|
||||
uri = _upload_csv(repo, [row], "uploads/epc.csv")
|
||||
|
||||
# act
|
||||
addresses = repo.load_batch(uri)
|
||||
|
||||
# assert
|
||||
assert addresses[0].source_row == row
|
||||
|
||||
|
||||
def test_load_batch_raises_when_postcode_column_absent(
|
||||
repo: UserAddressCsvS3Repository,
|
||||
) -> None:
|
||||
# arrange
|
||||
rows = [{"Address 1": "1 High Street", "Property Type": "Flat"}]
|
||||
uri = _upload_csv(repo, rows, "uploads/no-postcode.csv")
|
||||
|
||||
# act / assert
|
||||
with pytest.raises(ValueError, match="no 'postcode' column"):
|
||||
repo.load_batch(uri)
|
||||
|
||||
|
|
@ -128,6 +142,7 @@ def test_load_batch_raises_when_postcode_column_absent(
|
|||
def test_save_batch_passes_through_all_columns_and_appends_postcode_clean(
|
||||
repo: UserAddressCsvS3Repository,
|
||||
) -> None:
|
||||
# arrange
|
||||
row = {
|
||||
"Asset Reference": "511",
|
||||
"Address 1": "9 Abingdon Road Padiham Lancashire BB12 7BX",
|
||||
|
|
@ -137,9 +152,11 @@ def test_save_batch_passes_through_all_columns_and_appends_postcode_clean(
|
|||
uri = _upload_csv(repo, [row], "uploads/epc.csv")
|
||||
addresses = repo.load_batch(uri)
|
||||
|
||||
# act
|
||||
saved_uri = repo.save_batch(addresses, "tasks/passthrough")
|
||||
saved_rows = repo._csv_client.read_rows(saved_uri) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# assert
|
||||
assert len(saved_rows) == 1
|
||||
saved = saved_rows[0]
|
||||
# Every original column survives, byte-for-byte.
|
||||
|
|
@ -152,6 +169,7 @@ def test_save_batch_passes_through_all_columns_and_appends_postcode_clean(
|
|||
def test_save_batch_returns_uri_under_path_prefix(
|
||||
repo: UserAddressCsvS3Repository,
|
||||
) -> None:
|
||||
# arrange
|
||||
addresses = [
|
||||
UserAddress(
|
||||
user_address="1 High Street",
|
||||
|
|
@ -160,8 +178,10 @@ def test_save_batch_returns_uri_under_path_prefix(
|
|||
),
|
||||
]
|
||||
|
||||
# act
|
||||
uri = repo.save_batch(addresses, "tasks/abc/batches")
|
||||
|
||||
# assert
|
||||
assert uri.startswith(f"s3://{BUCKET}/tasks/abc/batches/")
|
||||
assert uri.endswith(".csv")
|
||||
|
||||
|
|
@ -169,6 +189,7 @@ def test_save_batch_returns_uri_under_path_prefix(
|
|||
def test_save_then_reload_round_trip_preserves_columns(
|
||||
repo: UserAddressCsvS3Repository,
|
||||
) -> None:
|
||||
# arrange
|
||||
rows = [
|
||||
{
|
||||
"Address 1": "1 High Street",
|
||||
|
|
@ -184,9 +205,11 @@ def test_save_then_reload_round_trip_preserves_columns(
|
|||
uri = _upload_csv(repo, rows, "uploads/round-trip.csv")
|
||||
addresses = repo.load_batch(uri)
|
||||
|
||||
# act
|
||||
saved_uri = repo.save_batch(addresses, "tasks/round-trip")
|
||||
saved_rows = repo._csv_client.read_rows(saved_uri) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# assert
|
||||
# Original columns come back verbatim; postcode_clean is the only addition.
|
||||
assert [
|
||||
{k: v for k, v in r.items() if k != "postcode_clean"} for r in saved_rows
|
||||
|
|
@ -197,6 +220,7 @@ def test_save_then_reload_round_trip_preserves_columns(
|
|||
def test_save_batch_uses_unique_filename_per_call(
|
||||
repo: UserAddressCsvS3Repository,
|
||||
) -> None:
|
||||
# arrange
|
||||
addresses = [
|
||||
UserAddress(
|
||||
user_address="1 High Street",
|
||||
|
|
@ -205,7 +229,9 @@ def test_save_batch_uses_unique_filename_per_call(
|
|||
),
|
||||
]
|
||||
|
||||
# act
|
||||
uri_1 = repo.save_batch(addresses, "tasks/uniqueness")
|
||||
uri_2 = repo.save_batch(addresses, "tasks/uniqueness")
|
||||
|
||||
# assert
|
||||
assert uri_1 != uri_2
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ from typing import Any
|
|||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
from sqlalchemy import Engine
|
||||
from sqlmodel import Session
|
||||
|
||||
from domain.tasks.subtasks import SubTaskStatus
|
||||
from domain.tasks.tasks import TaskStatus
|
||||
|
|
@ -30,10 +31,8 @@ class Harness:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def harness() -> Iterator[Harness]:
|
||||
engine = create_engine("sqlite://")
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as session:
|
||||
def harness(db_engine: Engine) -> Iterator[Harness]:
|
||||
with Session(db_engine) as session:
|
||||
tasks = TaskPostgresRepository(session=session)
|
||||
subtasks = SubTaskPostgresRepository(session=session)
|
||||
yield Harness(
|
||||
|
|
@ -50,6 +49,7 @@ def _direct_event(task_id: UUID, subtask_id: UUID) -> dict[str, Any]:
|
|||
def test_subtask_handler_injects_orchestrator_as_third_positional_argument(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
_, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
|
@ -64,8 +64,10 @@ def test_subtask_handler_injects_orchestrator_as_third_positional_argument(
|
|||
received["context"] = context
|
||||
received["orchestrator"] = orchestrator
|
||||
|
||||
# act
|
||||
handler(_direct_event(subtask.task_id, subtask.id), context="ctx-sentinel")
|
||||
|
||||
# assert
|
||||
assert received["orchestrator"] is harness.orchestrator
|
||||
assert received["context"] == "ctx-sentinel"
|
||||
assert received["body"]["sub_task_id"] == str(subtask.id)
|
||||
|
|
@ -74,6 +76,7 @@ def test_subtask_handler_injects_orchestrator_as_third_positional_argument(
|
|||
def test_subtask_handler_completes_parent_subtask_on_success(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
|
@ -84,8 +87,10 @@ def test_subtask_handler_completes_parent_subtask_on_success(
|
|||
) -> None:
|
||||
return None
|
||||
|
||||
# act
|
||||
handler(_direct_event(task.id, subtask.id), context=None)
|
||||
|
||||
# assert
|
||||
assert harness.subtasks.get(subtask.id).status is SubTaskStatus.COMPLETE
|
||||
assert harness.tasks.get(task.id).status is TaskStatus.COMPLETE
|
||||
|
||||
|
|
@ -93,6 +98,7 @@ def test_subtask_handler_completes_parent_subtask_on_success(
|
|||
def test_subtask_handler_marks_parent_failed_and_reraises_on_error(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
|
@ -103,6 +109,7 @@ def test_subtask_handler_marks_parent_failed_and_reraises_on_error(
|
|||
) -> None:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
# act / assert
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
handler(_direct_event(task.id, subtask.id), context=None)
|
||||
|
||||
|
|
@ -113,6 +120,7 @@ def test_subtask_handler_marks_parent_failed_and_reraises_on_error(
|
|||
def test_subtask_handler_injected_orchestrator_can_create_child_subtask(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
# arrange
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
|
@ -126,8 +134,10 @@ def test_subtask_handler_injected_orchestrator_can_create_child_subtask(
|
|||
child = orchestrator.create_child_subtask(task.id, inputs={"split": 1})
|
||||
child_ids.append(child.id)
|
||||
|
||||
# act
|
||||
handler(_direct_event(task.id, subtask.id), context=None)
|
||||
|
||||
# assert
|
||||
assert len(child_ids) == 1
|
||||
persisted_child = harness.subtasks.get(child_ids[0])
|
||||
assert persisted_child.task_id == task.id
|
||||
|
|
@ -137,6 +147,7 @@ def test_subtask_handler_injected_orchestrator_can_create_child_subtask(
|
|||
def test_subtask_handler_logs_subtask_lifecycle_on_success(
|
||||
harness: Harness, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
# arrange
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
|
@ -147,9 +158,11 @@ def test_subtask_handler_logs_subtask_lifecycle_on_success(
|
|||
) -> None:
|
||||
return None
|
||||
|
||||
# act
|
||||
with caplog.at_level(logging.INFO, logger=_LOGGER_NAME):
|
||||
handler(_direct_event(task.id, subtask.id), context=None)
|
||||
|
||||
# assert
|
||||
assert f"Running subtask {subtask.id}" in caplog.text
|
||||
assert f"Subtask {subtask.id} completed" in caplog.text
|
||||
|
||||
|
|
@ -157,6 +170,7 @@ def test_subtask_handler_logs_subtask_lifecycle_on_success(
|
|||
def test_subtask_handler_logs_exception_on_failure(
|
||||
harness: Harness, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
# arrange
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
|
@ -167,6 +181,7 @@ def test_subtask_handler_logs_exception_on_failure(
|
|||
) -> None:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
# act / assert
|
||||
with caplog.at_level(logging.INFO, logger=_LOGGER_NAME):
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
handler(_direct_event(task.id, subtask.id), context=None)
|
||||
|
|
@ -181,6 +196,7 @@ def test_subtask_handler_logs_exception_on_failure(
|
|||
def test_subtask_handler_records_cloudwatch_url_on_subtask(
|
||||
harness: Harness, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
# arrange
|
||||
monkeypatch.setenv("AWS_REGION", "eu-west-2")
|
||||
monkeypatch.setenv(
|
||||
"AWS_LAMBDA_LOG_GROUP_NAME", "/aws/lambda/postcode-splitter"
|
||||
|
|
@ -198,8 +214,10 @@ def test_subtask_handler_records_cloudwatch_url_on_subtask(
|
|||
) -> None:
|
||||
return None
|
||||
|
||||
# act
|
||||
handler(_direct_event(task.id, subtask.id), context=None)
|
||||
|
||||
# assert
|
||||
saved_url = harness.subtasks.get(subtask.id).cloud_logs_url
|
||||
assert saved_url is not None
|
||||
assert saved_url.startswith(
|
||||
|
|
@ -213,6 +231,7 @@ def test_subtask_handler_records_cloudwatch_url_on_subtask(
|
|||
def test_subtask_handler_leaves_cloudwatch_url_unset_outside_lambda(
|
||||
harness: Harness, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
# arrange
|
||||
for var in (
|
||||
"AWS_REGION",
|
||||
"AWS_LAMBDA_LOG_GROUP_NAME",
|
||||
|
|
@ -229,6 +248,8 @@ def test_subtask_handler_leaves_cloudwatch_url_unset_outside_lambda(
|
|||
) -> None:
|
||||
return None
|
||||
|
||||
# act
|
||||
handler(_direct_event(task.id, subtask.id), context=None)
|
||||
|
||||
# assert
|
||||
assert harness.subtasks.get(subtask.id).cloud_logs_url is None
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue