"""Tests for sample.sample() — random + filterable selection over the EPC flat-register CSV. sample() is the entry point of the training-data pipeline: it produces a thin DataFrame of certificate rows that downstream stages (fetch -> build_features -> write_parquet) operate on. Filtering supports excluding obviously-wrong cohorts before paying the per-cert fetch cost. """ from pathlib import Path import pandas as pd import pytest from ml_training_data.sample import sample def _write_csv(path: Path, rows: list[dict[str, str]]) -> None: df = pd.DataFrame(rows) df.to_csv(path, index=False) def test_sample_returns_n_rows_when_no_filter(tmp_path: Path) -> None: # Arrange csv = tmp_path / "register.csv" rows = [ {"certificate_number": f"CN-{i:04d}", "property_type": "House", "postcode": "AB1 2CD"} for i in range(100) ] _write_csv(csv, rows) # Act out = sample(csv_path=csv, n=10, seed=42) # Assert assert len(out) == 10 assert set(out.columns) == {"certificate_number", "property_type", "postcode"} def test_sample_is_deterministic_with_same_seed(tmp_path: Path) -> None: # Arrange csv = tmp_path / "register.csv" _write_csv( csv, [{"certificate_number": f"CN-{i:04d}", "property_type": "House"} for i in range(200)], ) # Act first = sample(csv_path=csv, n=20, seed=7) second = sample(csv_path=csv, n=20, seed=7) # Assert assert first["certificate_number"].tolist() == second["certificate_number"].tolist() def test_sample_filter_selects_only_matching_rows(tmp_path: Path) -> None: # Arrange csv = tmp_path / "register.csv" rows: list[dict[str, str]] = [] for i in range(50): rows.append({"certificate_number": f"H-{i:03d}", "property_type": "House"}) for i in range(50): rows.append({"certificate_number": f"F-{i:03d}", "property_type": "Flat"}) _write_csv(csv, rows) # Act out = sample(csv_path=csv, n=30, seed=1, filters={"property_type": ["House"]}) # Assert assert len(out) == 30 assert (out["property_type"] == "House").all() def test_sample_returns_fewer_than_n_when_filter_too_strict(tmp_path: Path) -> None: # Arrange csv = tmp_path / "register.csv" rows: list[dict[str, str]] = [] for i in range(3): rows.append({"certificate_number": f"BG-{i}", "property_type": "Bungalow"}) for i in range(50): rows.append({"certificate_number": f"H-{i:03d}", "property_type": "House"}) _write_csv(csv, rows) # Act out = sample(csv_path=csv, n=100, seed=1, filters={"property_type": ["Bungalow"]}) # Assert assert len(out) == 3 assert (out["property_type"] == "Bungalow").all() def test_sample_raises_when_filter_column_missing(tmp_path: Path) -> None: # Arrange csv = tmp_path / "register.csv" _write_csv(csv, [{"certificate_number": "CN-001", "property_type": "House"}]) # Act / Assert with pytest.raises(KeyError): sample(csv_path=csv, n=1, seed=1, filters={"nonexistent": ["x"]})