mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
98 lines
3 KiB
Python
98 lines
3 KiB
Python
"""Tests for sample.sample() — random + filterable selection over the EPC flat-register CSV.
|
|
|
|
sample() is the entry point of the training-data pipeline: it produces a thin DataFrame
|
|
of certificate rows that downstream stages (fetch -> build_features -> write_parquet)
|
|
operate on. Filtering supports excluding obviously-wrong cohorts before paying the
|
|
per-cert fetch cost.
|
|
"""
|
|
|
|
from pathlib import Path
|
|
|
|
import pandas as pd
|
|
import pytest
|
|
|
|
from ml_training_data.sample import sample
|
|
|
|
|
|
def _write_csv(path: Path, rows: list[dict[str, str]]) -> None:
|
|
df = pd.DataFrame(rows)
|
|
df.to_csv(path, index=False)
|
|
|
|
|
|
def test_sample_returns_n_rows_when_no_filter(tmp_path: Path) -> None:
|
|
# Arrange
|
|
csv = tmp_path / "register.csv"
|
|
rows = [
|
|
{"certificate_number": f"CN-{i:04d}", "property_type": "House", "postcode": "AB1 2CD"}
|
|
for i in range(100)
|
|
]
|
|
_write_csv(csv, rows)
|
|
|
|
# Act
|
|
out = sample(csv_path=csv, n=10, seed=42)
|
|
|
|
# Assert
|
|
assert len(out) == 10
|
|
assert set(out.columns) == {"certificate_number", "property_type", "postcode"}
|
|
|
|
|
|
def test_sample_is_deterministic_with_same_seed(tmp_path: Path) -> None:
|
|
# Arrange
|
|
csv = tmp_path / "register.csv"
|
|
_write_csv(
|
|
csv,
|
|
[{"certificate_number": f"CN-{i:04d}", "property_type": "House"} for i in range(200)],
|
|
)
|
|
|
|
# Act
|
|
first = sample(csv_path=csv, n=20, seed=7)
|
|
second = sample(csv_path=csv, n=20, seed=7)
|
|
|
|
# Assert
|
|
assert first["certificate_number"].tolist() == second["certificate_number"].tolist()
|
|
|
|
|
|
def test_sample_filter_selects_only_matching_rows(tmp_path: Path) -> None:
|
|
# Arrange
|
|
csv = tmp_path / "register.csv"
|
|
rows: list[dict[str, str]] = []
|
|
for i in range(50):
|
|
rows.append({"certificate_number": f"H-{i:03d}", "property_type": "House"})
|
|
for i in range(50):
|
|
rows.append({"certificate_number": f"F-{i:03d}", "property_type": "Flat"})
|
|
_write_csv(csv, rows)
|
|
|
|
# Act
|
|
out = sample(csv_path=csv, n=30, seed=1, filters={"property_type": ["House"]})
|
|
|
|
# Assert
|
|
assert len(out) == 30
|
|
assert (out["property_type"] == "House").all()
|
|
|
|
|
|
def test_sample_returns_fewer_than_n_when_filter_too_strict(tmp_path: Path) -> None:
|
|
# Arrange
|
|
csv = tmp_path / "register.csv"
|
|
rows: list[dict[str, str]] = []
|
|
for i in range(3):
|
|
rows.append({"certificate_number": f"BG-{i}", "property_type": "Bungalow"})
|
|
for i in range(50):
|
|
rows.append({"certificate_number": f"H-{i:03d}", "property_type": "House"})
|
|
_write_csv(csv, rows)
|
|
|
|
# Act
|
|
out = sample(csv_path=csv, n=100, seed=1, filters={"property_type": ["Bungalow"]})
|
|
|
|
# Assert
|
|
assert len(out) == 3
|
|
assert (out["property_type"] == "Bungalow").all()
|
|
|
|
|
|
def test_sample_raises_when_filter_column_missing(tmp_path: Path) -> None:
|
|
# Arrange
|
|
csv = tmp_path / "register.csv"
|
|
_write_csv(csv, [{"certificate_number": "CN-001", "property_type": "House"}])
|
|
|
|
# Act / Assert
|
|
with pytest.raises(KeyError):
|
|
sample(csv_path=csv, n=1, seed=1, filters={"nonexistent": ["x"]})
|