Model/services/ml_training_data/tests/unit/test_sample.py
Khalim Conn-Kowlessar eb42cb88a1 slice 14a: ml_training_data pkg + sample.py (CSV filter + random sample)
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-16 17:39:43 +00:00

98 lines
3 KiB
Python

"""Tests for sample.sample() — random + filterable selection over the EPC flat-register CSV.
sample() is the entry point of the training-data pipeline: it produces a thin DataFrame
of certificate rows that downstream stages (fetch -> build_features -> write_parquet)
operate on. Filtering supports excluding obviously-wrong cohorts before paying the
per-cert fetch cost.
"""
from pathlib import Path
import pandas as pd
import pytest
from ml_training_data.sample import sample
def _write_csv(path: Path, rows: list[dict[str, str]]) -> None:
df = pd.DataFrame(rows)
df.to_csv(path, index=False)
def test_sample_returns_n_rows_when_no_filter(tmp_path: Path) -> None:
# Arrange
csv = tmp_path / "register.csv"
rows = [
{"certificate_number": f"CN-{i:04d}", "property_type": "House", "postcode": "AB1 2CD"}
for i in range(100)
]
_write_csv(csv, rows)
# Act
out = sample(csv_path=csv, n=10, seed=42)
# Assert
assert len(out) == 10
assert set(out.columns) == {"certificate_number", "property_type", "postcode"}
def test_sample_is_deterministic_with_same_seed(tmp_path: Path) -> None:
# Arrange
csv = tmp_path / "register.csv"
_write_csv(
csv,
[{"certificate_number": f"CN-{i:04d}", "property_type": "House"} for i in range(200)],
)
# Act
first = sample(csv_path=csv, n=20, seed=7)
second = sample(csv_path=csv, n=20, seed=7)
# Assert
assert first["certificate_number"].tolist() == second["certificate_number"].tolist()
def test_sample_filter_selects_only_matching_rows(tmp_path: Path) -> None:
# Arrange
csv = tmp_path / "register.csv"
rows: list[dict[str, str]] = []
for i in range(50):
rows.append({"certificate_number": f"H-{i:03d}", "property_type": "House"})
for i in range(50):
rows.append({"certificate_number": f"F-{i:03d}", "property_type": "Flat"})
_write_csv(csv, rows)
# Act
out = sample(csv_path=csv, n=30, seed=1, filters={"property_type": ["House"]})
# Assert
assert len(out) == 30
assert (out["property_type"] == "House").all()
def test_sample_returns_fewer_than_n_when_filter_too_strict(tmp_path: Path) -> None:
# Arrange
csv = tmp_path / "register.csv"
rows: list[dict[str, str]] = []
for i in range(3):
rows.append({"certificate_number": f"BG-{i}", "property_type": "Bungalow"})
for i in range(50):
rows.append({"certificate_number": f"H-{i:03d}", "property_type": "House"})
_write_csv(csv, rows)
# Act
out = sample(csv_path=csv, n=100, seed=1, filters={"property_type": ["Bungalow"]})
# Assert
assert len(out) == 3
assert (out["property_type"] == "Bungalow").all()
def test_sample_raises_when_filter_column_missing(tmp_path: Path) -> None:
# Arrange
csv = tmp_path / "register.csv"
_write_csv(csv, [{"certificate_number": "CN-001", "property_type": "House"}])
# Act / Assert
with pytest.raises(KeyError):
sample(csv_path=csv, n=1, seed=1, filters={"nonexistent": ["x"]})