mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
added missing files
This commit is contained in:
parent
fa0c77af78
commit
d338be867b
10 changed files with 592 additions and 0 deletions
3
backend/epc_client/__init__.py
Normal file
3
backend/epc_client/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from backend.epc_client.client import EpcClientService, EpcSearchResult
|
||||
|
||||
__all__ = ["EpcClientService", "EpcSearchResult"]
|
||||
23
backend/epc_client/_retry.py
Normal file
23
backend/epc_client/_retry.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
import time
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
from backend.epc_client.exceptions import EpcRateLimitError
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def call_with_retry(
|
||||
fn: Callable[[], T],
|
||||
max_retries: int = 5,
|
||||
backoff_base: float = 1.0,
|
||||
backoff_multiplier: float = 2.0,
|
||||
) -> T:
|
||||
last_exc: EpcRateLimitError | None = None
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return fn()
|
||||
except EpcRateLimitError as exc:
|
||||
last_exc = exc
|
||||
if attempt < max_retries:
|
||||
time.sleep(backoff_base * (backoff_multiplier ** attempt))
|
||||
raise last_exc # type: ignore[misc]
|
||||
175
backend/epc_client/client.py
Normal file
175
backend/epc_client/client.py
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
# Spec: https://raw.githubusercontent.com/communitiesuk/epb-data-warehouse/main/api/api.yml
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
|
||||
import httpx
|
||||
import pandas as pd
|
||||
|
||||
from backend.epc_client.exceptions import EpcApiError, EpcNotFoundError, EpcRateLimitError
|
||||
from backend.epc_client._retry import call_with_retry
|
||||
from datatypes.epc.domain.epc_property_data import EpcPropertyData
|
||||
from datatypes.epc.domain.mapper import EpcPropertyDataMapper
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpcSearchResult:
|
||||
certificate_number: str
|
||||
address_line_1: str
|
||||
address_line_2: Optional[str]
|
||||
address_line_3: Optional[str]
|
||||
address_line_4: Optional[str]
|
||||
postcode: str
|
||||
post_town: str
|
||||
uprn: Optional[int]
|
||||
current_energy_efficiency_band: str
|
||||
registration_date: str
|
||||
|
||||
def full_address(self) -> str:
|
||||
parts = [
|
||||
self.address_line_1,
|
||||
self.address_line_2,
|
||||
self.address_line_3,
|
||||
self.address_line_4,
|
||||
]
|
||||
return ", ".join(p for p in parts if p)
|
||||
|
||||
|
||||
class EpcClientService:
|
||||
BASE_URL = "https://api.get-energy-performance-data.communities.gov.uk"
|
||||
_MIN_MATCH_SCORE = 0.6
|
||||
|
||||
def __init__(self, auth_token: str) -> None:
|
||||
self._headers = {
|
||||
"Authorization": f"Bearer {auth_token}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
def get_by_certificate_number(self, cert_num: str) -> EpcPropertyData:
|
||||
raw = call_with_retry(lambda: self._fetch_certificate(cert_num))
|
||||
return EpcPropertyDataMapper.from_api_response(raw)
|
||||
|
||||
def get_by_uprn(self, uprn: int) -> Optional[EpcPropertyData]:
|
||||
results = call_with_retry(lambda: self._search(uprn=uprn))
|
||||
if not results:
|
||||
return None
|
||||
latest = max(results, key=lambda r: r.registration_date)
|
||||
return self.get_by_certificate_number(latest.certificate_number)
|
||||
|
||||
def search_by_postcode(self, postcode: str) -> list[EpcSearchResult]:
|
||||
return call_with_retry(lambda: self._search(postcode=postcode))
|
||||
|
||||
def find_best_match(self, postcode: str, address: str) -> Optional[EpcPropertyData]:
|
||||
from backend.utils.addressMatch import get_uprn_candidates
|
||||
|
||||
candidates = self.search_by_postcode(postcode)
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
# Round 1: score on addressLine1 only
|
||||
cert_num = self._pick_best_cert(candidates, address, use_full_address=False, fn=get_uprn_candidates)
|
||||
if cert_num:
|
||||
return self._safe_get(cert_num)
|
||||
|
||||
# Round 2: score on all address lines joined
|
||||
cert_num = self._pick_best_cert(candidates, address, use_full_address=True, fn=get_uprn_candidates)
|
||||
if cert_num:
|
||||
return self._safe_get(cert_num)
|
||||
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _fetch_certificate(self, cert_num: str) -> dict:
|
||||
resp = httpx.get(
|
||||
f"{self.BASE_URL}/api/certificate",
|
||||
params={"certificate_number": cert_num},
|
||||
headers=self._headers,
|
||||
)
|
||||
if resp.status_code == 404:
|
||||
raise EpcNotFoundError(cert_num)
|
||||
if resp.status_code == 429:
|
||||
raise EpcRateLimitError("Rate limited by EPC API")
|
||||
if not resp.is_success:
|
||||
raise EpcApiError(f"EPC API error {resp.status_code}: {resp.text}")
|
||||
return resp.json()["data"]
|
||||
|
||||
def _search(
|
||||
self,
|
||||
postcode: Optional[str] = None,
|
||||
uprn: Optional[int] = None,
|
||||
) -> list[EpcSearchResult]:
|
||||
params: dict[str, str | int] = {}
|
||||
if postcode:
|
||||
params["postcode"] = postcode
|
||||
if uprn is not None:
|
||||
params["uprn"] = uprn
|
||||
|
||||
resp = httpx.get(
|
||||
f"{self.BASE_URL}/api/domestic/search",
|
||||
params=params,
|
||||
headers=self._headers,
|
||||
)
|
||||
if resp.status_code == 404:
|
||||
return []
|
||||
if resp.status_code == 429:
|
||||
raise EpcRateLimitError("Rate limited by EPC API")
|
||||
if not resp.is_success:
|
||||
raise EpcApiError(f"EPC API error {resp.status_code}: {resp.text}")
|
||||
|
||||
rows = resp.json().get("data", [])
|
||||
return [self._parse_search_result(r) for r in rows]
|
||||
|
||||
@staticmethod
|
||||
def _parse_search_result(row: dict) -> EpcSearchResult:
|
||||
return EpcSearchResult(
|
||||
certificate_number=row["certificateNumber"],
|
||||
address_line_1=row["addressLine1"],
|
||||
address_line_2=row.get("addressLine2"),
|
||||
address_line_3=row.get("addressLine3"),
|
||||
address_line_4=row.get("addressLine4"),
|
||||
postcode=row["postcode"],
|
||||
post_town=row["postTown"],
|
||||
uprn=row.get("uprn"),
|
||||
current_energy_efficiency_band=row["currentEnergyEfficiencyBand"],
|
||||
registration_date=row["registrationDate"],
|
||||
)
|
||||
|
||||
def _pick_best_cert(
|
||||
self,
|
||||
candidates: list[EpcSearchResult],
|
||||
user_address: str,
|
||||
use_full_address: bool,
|
||||
fn: Callable[..., pd.DataFrame],
|
||||
) -> Optional[str]:
|
||||
df = pd.DataFrame([
|
||||
{
|
||||
"address": r.full_address() if use_full_address else r.address_line_1,
|
||||
"uprn": str(r.uprn) if r.uprn is not None else "",
|
||||
"certificate_number": r.certificate_number,
|
||||
}
|
||||
for r in candidates
|
||||
])
|
||||
|
||||
scored = fn(df, user_address=user_address)
|
||||
if scored.empty:
|
||||
return None
|
||||
|
||||
best_score = scored.iloc[0]["lexiscore"]
|
||||
if best_score < self._MIN_MATCH_SCORE:
|
||||
return None
|
||||
|
||||
top = scored[scored["lexirank"] == 1]
|
||||
if len(top) != 1:
|
||||
return None
|
||||
|
||||
return str(top.iloc[0]["certificate_number"])
|
||||
|
||||
def _safe_get(self, cert_num: str) -> Optional[EpcPropertyData]:
|
||||
try:
|
||||
return self.get_by_certificate_number(cert_num)
|
||||
except EpcNotFoundError:
|
||||
return None
|
||||
10
backend/epc_client/exceptions.py
Normal file
10
backend/epc_client/exceptions.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
class EpcApiError(Exception):
|
||||
"""Base for all EPC client errors."""
|
||||
|
||||
|
||||
class EpcNotFoundError(EpcApiError):
|
||||
"""Raised when the API returns 404."""
|
||||
|
||||
|
||||
class EpcRateLimitError(EpcApiError):
|
||||
"""Raised when the API returns 429 and all retries are exhausted."""
|
||||
1
backend/epc_client/requirements.txt
Normal file
1
backend/epc_client/requirements.txt
Normal file
|
|
@ -0,0 +1 @@
|
|||
httpx>=0.27.0
|
||||
0
backend/epc_client/tests/__init__.py
Normal file
0
backend/epc_client/tests/__init__.py
Normal file
48
backend/epc_client/tests/conftest.py
Normal file
48
backend/epc_client/tests/conftest.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
import json
|
||||
import pathlib
|
||||
import pytest
|
||||
|
||||
from backend.epc_client.client import EpcClientService
|
||||
|
||||
SAMPLES_DIR = pathlib.Path("backend/epc_api/json_samples")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rdsap_21_0_0_cert():
|
||||
return json.loads((SAMPLES_DIR / "RdSAP-Schema-21.0.0/epc.json").read_text())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rdsap_21_0_1_cert():
|
||||
return json.loads((SAMPLES_DIR / "RdSAP-Schema-21.0.1/epc.json").read_text())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def epc_service():
|
||||
return EpcClientService(auth_token="test-token")
|
||||
|
||||
|
||||
def make_search_row(
|
||||
cert_num="CERT-001",
|
||||
address_line_1="1 Test Street",
|
||||
postcode="SW1A 1AA",
|
||||
post_town="London",
|
||||
uprn=100023336956,
|
||||
band="D",
|
||||
registration_date="2024-01-01",
|
||||
address_line_2=None,
|
||||
address_line_3=None,
|
||||
address_line_4=None,
|
||||
):
|
||||
return {
|
||||
"certificateNumber": cert_num,
|
||||
"addressLine1": address_line_1,
|
||||
"addressLine2": address_line_2,
|
||||
"addressLine3": address_line_3,
|
||||
"addressLine4": address_line_4,
|
||||
"postcode": postcode,
|
||||
"postTown": post_town,
|
||||
"uprn": uprn,
|
||||
"currentEnergyEfficiencyBand": band,
|
||||
"registrationDate": registration_date,
|
||||
}
|
||||
224
backend/epc_client/tests/test_client.py
Normal file
224
backend/epc_client/tests/test_client.py
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
from unittest.mock import MagicMock, patch, call
|
||||
import pytest
|
||||
|
||||
from backend.epc_client.client import EpcClientService, EpcSearchResult
|
||||
from backend.epc_client.exceptions import EpcNotFoundError, EpcRateLimitError
|
||||
from datatypes.epc.domain.epc_property_data import EpcPropertyData
|
||||
from backend.epc_client.tests.conftest import make_search_row
|
||||
|
||||
|
||||
def _mock_response(status_code=200, json_data=None):
|
||||
resp = MagicMock()
|
||||
resp.status_code = status_code
|
||||
resp.is_success = 200 <= status_code < 300
|
||||
resp.json.return_value = json_data or {}
|
||||
resp.text = str(json_data)
|
||||
return resp
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 1: get_by_certificate_number happy path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_get_by_certificate_number_returns_epc_property_data(epc_service, rdsap_21_0_1_cert):
|
||||
cert_response = {"data": rdsap_21_0_1_cert}
|
||||
with patch("httpx.get", return_value=_mock_response(200, cert_response)):
|
||||
result = epc_service.get_by_certificate_number("CERT-001")
|
||||
|
||||
assert isinstance(result, EpcPropertyData)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 2: get_by_certificate_number 404 → EpcNotFoundError
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_get_by_certificate_number_404_raises_not_found(epc_service):
|
||||
with patch("httpx.get", return_value=_mock_response(404)):
|
||||
with pytest.raises(EpcNotFoundError):
|
||||
epc_service.get_by_certificate_number("BAD-CERT")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 3: 429 retried, succeeds on 3rd attempt
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_get_by_certificate_number_retries_on_429_and_succeeds(epc_service, rdsap_21_0_1_cert):
|
||||
cert_response = {"data": rdsap_21_0_1_cert}
|
||||
responses = [
|
||||
_mock_response(429),
|
||||
_mock_response(429),
|
||||
_mock_response(200, cert_response),
|
||||
]
|
||||
with patch("httpx.get", side_effect=responses), patch("time.sleep"):
|
||||
result = epc_service.get_by_certificate_number("CERT-001")
|
||||
|
||||
assert isinstance(result, EpcPropertyData)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 4: get_by_uprn empty search → None
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_get_by_uprn_returns_none_when_no_results(epc_service):
|
||||
with patch("httpx.get", return_value=_mock_response(200, {"data": []})):
|
||||
result = epc_service.get_by_uprn(100023336956)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 5: get_by_uprn multiple results → fetches latest by registration_date
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_get_by_uprn_picks_most_recent_certificate(epc_service, rdsap_21_0_1_cert):
|
||||
search_rows = [
|
||||
make_search_row(cert_num="CERT-OLD", registration_date="2022-01-01"),
|
||||
make_search_row(cert_num="CERT-NEW", registration_date="2024-06-01"),
|
||||
make_search_row(cert_num="CERT-MID", registration_date="2023-03-15"),
|
||||
]
|
||||
cert_response = {"data": rdsap_21_0_1_cert}
|
||||
|
||||
def fake_get(url, params=None, **kwargs):
|
||||
if "search" in url:
|
||||
return _mock_response(200, {"data": search_rows})
|
||||
return _mock_response(200, cert_response)
|
||||
|
||||
with patch("httpx.get", side_effect=fake_get) as mock_get:
|
||||
result = epc_service.get_by_uprn(100023336956)
|
||||
|
||||
assert isinstance(result, EpcPropertyData)
|
||||
# Second call must be for the most recent cert
|
||||
cert_call = mock_get.call_args_list[1]
|
||||
assert cert_call.kwargs["params"]["certificate_number"] == "CERT-NEW"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 6: search_by_postcode returns list[EpcSearchResult]
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_search_by_postcode_returns_results(epc_service):
|
||||
rows = [
|
||||
make_search_row(cert_num="CERT-A", address_line_1="1 High Street"),
|
||||
make_search_row(cert_num="CERT-B", address_line_1="2 High Street"),
|
||||
]
|
||||
with patch("httpx.get", return_value=_mock_response(200, {"data": rows})):
|
||||
results = epc_service.search_by_postcode("SW1A 1AA")
|
||||
|
||||
assert len(results) == 2
|
||||
assert all(isinstance(r, EpcSearchResult) for r in results)
|
||||
assert results[0].certificate_number == "CERT-A"
|
||||
assert results[1].address_line_1 == "2 High Street"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 7: search_by_postcode 404 → empty list
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_search_by_postcode_404_returns_empty_list(epc_service):
|
||||
with patch("httpx.get", return_value=_mock_response(404)):
|
||||
results = epc_service.search_by_postcode("ZZ9 9ZZ")
|
||||
|
||||
assert results == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests 8-10: find_best_match
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_scored_df(rows, scores, ranks):
|
||||
import pandas as pd
|
||||
df = pd.DataFrame(rows)
|
||||
df["lexiscore"] = scores
|
||||
df["lexirank"] = ranks
|
||||
return df.sort_values("lexirank")
|
||||
|
||||
|
||||
def test_find_best_match_round1_clear_winner(epc_service, rdsap_21_0_1_cert):
|
||||
search_rows = [
|
||||
make_search_row(cert_num="CERT-WIN", address_line_1="1 High Street"),
|
||||
make_search_row(cert_num="CERT-LOSE", address_line_1="99 Nowhere Lane"),
|
||||
]
|
||||
cert_response = {"data": rdsap_21_0_1_cert}
|
||||
|
||||
df_rows = [
|
||||
{"address": "1 High Street", "uprn": "100023336956", "certificate_number": "CERT-WIN"},
|
||||
{"address": "99 Nowhere Lane", "uprn": "100023336956", "certificate_number": "CERT-LOSE"},
|
||||
]
|
||||
scored = _make_scored_df(df_rows, [0.9, 0.1], [1, 2])
|
||||
|
||||
def fake_get(url, params=None, **kwargs):
|
||||
if "search" in url:
|
||||
return _mock_response(200, {"data": search_rows})
|
||||
return _mock_response(200, cert_response)
|
||||
|
||||
with patch("httpx.get", side_effect=fake_get), \
|
||||
patch("backend.utils.addressMatch.get_uprn_candidates", return_value=scored):
|
||||
result = epc_service.find_best_match("SW1A 1AA", "1 High Street")
|
||||
|
||||
assert isinstance(result, EpcPropertyData)
|
||||
|
||||
|
||||
def test_find_best_match_round1_ambiguous_round2_resolves(epc_service, rdsap_21_0_1_cert):
|
||||
search_rows = [
|
||||
make_search_row(
|
||||
cert_num="CERT-A", address_line_1="1 High Street",
|
||||
address_line_2="Ground Floor",
|
||||
),
|
||||
make_search_row(
|
||||
cert_num="CERT-B", address_line_1="1 High Street",
|
||||
address_line_2="First Floor",
|
||||
),
|
||||
]
|
||||
cert_response = {"data": rdsap_21_0_1_cert}
|
||||
|
||||
# Round 1: both score equally — ambiguous (two rank-1s)
|
||||
ambiguous = _make_scored_df(
|
||||
[
|
||||
{"address": "1 High Street", "uprn": "111", "certificate_number": "CERT-A"},
|
||||
{"address": "1 High Street", "uprn": "222", "certificate_number": "CERT-B"},
|
||||
],
|
||||
[0.9, 0.9],
|
||||
[1, 1],
|
||||
)
|
||||
# Round 2: CERT-A wins on full address
|
||||
resolved = _make_scored_df(
|
||||
[
|
||||
{"address": "1 High Street, Ground Floor", "uprn": "111", "certificate_number": "CERT-A"},
|
||||
{"address": "1 High Street, First Floor", "uprn": "222", "certificate_number": "CERT-B"},
|
||||
],
|
||||
[0.85, 0.4],
|
||||
[1, 2],
|
||||
)
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def fake_candidates(df, user_address, **kwargs):
|
||||
call_count["n"] += 1
|
||||
return ambiguous if call_count["n"] == 1 else resolved
|
||||
|
||||
def fake_get(url, params=None, **kwargs):
|
||||
if "search" in url:
|
||||
return _mock_response(200, {"data": search_rows})
|
||||
return _mock_response(200, cert_response)
|
||||
|
||||
with patch("httpx.get", side_effect=fake_get), \
|
||||
patch("backend.utils.addressMatch.get_uprn_candidates", side_effect=fake_candidates):
|
||||
result = epc_service.find_best_match("SW1A 1AA", "1 High Street Ground Floor")
|
||||
|
||||
assert isinstance(result, EpcPropertyData)
|
||||
|
||||
|
||||
def test_find_best_match_returns_none_when_no_good_match(epc_service):
|
||||
search_rows = [make_search_row(cert_num="CERT-X", address_line_1="99 Nowhere Lane")]
|
||||
|
||||
low_score = _make_scored_df(
|
||||
[{"address": "99 Nowhere Lane", "uprn": "111", "certificate_number": "CERT-X"}],
|
||||
[0.1],
|
||||
[1],
|
||||
)
|
||||
|
||||
with patch("httpx.get", return_value=_mock_response(200, {"data": search_rows})), \
|
||||
patch("backend.utils.addressMatch.get_uprn_candidates", return_value=low_score):
|
||||
result = epc_service.find_best_match("SW1A 1AA", "1 Completely Different Road")
|
||||
|
||||
assert result is None
|
||||
31
backend/epc_client/tests/test_mapper_dispatcher.py
Normal file
31
backend/epc_client/tests/test_mapper_dispatcher.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
import pytest
|
||||
|
||||
from datatypes.epc.domain.mapper import EpcPropertyDataMapper
|
||||
from datatypes.epc.domain.epc_property_data import EpcPropertyData
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 1: from_api_response with RdSAP-Schema-21.0.0 fixture → EpcPropertyData
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_from_api_response_rdsap_21_0_0(rdsap_21_0_0_cert):
|
||||
result = EpcPropertyDataMapper.from_api_response(rdsap_21_0_0_cert)
|
||||
assert isinstance(result, EpcPropertyData)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 2: from_api_response with RdSAP-Schema-21.0.1 fixture → EpcPropertyData
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_from_api_response_rdsap_21_0_1(rdsap_21_0_1_cert):
|
||||
result = EpcPropertyDataMapper.from_api_response(rdsap_21_0_1_cert)
|
||||
assert isinstance(result, EpcPropertyData)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 3: unknown schema_type → ValueError
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_from_api_response_unknown_schema_raises():
|
||||
with pytest.raises(ValueError, match="Unsupported EPC schema"):
|
||||
EpcPropertyDataMapper.from_api_response({"schema_type": "RdSAP-Schema-99.0.0"})
|
||||
77
datatypes/epc/schema/helpers.py
Normal file
77
datatypes/epc/schema/helpers.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
import dataclasses
|
||||
import typing
|
||||
from datetime import date
|
||||
from typing import Any, Dict, Type, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
|
||||
"""
|
||||
Recursively convert a plain dict (e.g. from json.loads) into the given
|
||||
dataclass type, using the field type hints to convert nested structures.
|
||||
|
||||
Handles:
|
||||
- Nested dataclasses
|
||||
- List[SomeDataclass]
|
||||
- Optional[X] / Union[X, None]
|
||||
- Union[DataclassType, primitive] (e.g. Union[Measurement, int])
|
||||
- Primitive pass-through for Union[str, int] etc.
|
||||
"""
|
||||
return _from_dict_impl(cls, data) # type: ignore[return-value]
|
||||
|
||||
|
||||
def _from_dict_impl(cls: Any, data: Any) -> Any:
|
||||
hints = typing.get_type_hints(cls)
|
||||
kwargs: Dict[str, Any] = {}
|
||||
|
||||
for field in dataclasses.fields(cls): # type: ignore[arg-type]
|
||||
has_default = (
|
||||
field.default is not dataclasses.MISSING
|
||||
or field.default_factory is not dataclasses.MISSING # type: ignore[misc]
|
||||
)
|
||||
if field.name not in data:
|
||||
if has_default:
|
||||
continue
|
||||
raise ValueError(f"{cls.__name__}: missing required field '{field.name}'")
|
||||
|
||||
kwargs[field.name] = _coerce(data[field.name], hints[field.name])
|
||||
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
def _coerce(value: Any, hint: Any) -> Any:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
origin = typing.get_origin(hint)
|
||||
args = typing.get_args(hint)
|
||||
|
||||
# Union (includes Optional[X] which is Union[X, None])
|
||||
if origin is typing.Union:
|
||||
if value is None:
|
||||
return None
|
||||
non_none_args = [a for a in args if a is not type(None)]
|
||||
if len(non_none_args) == 1:
|
||||
# Optional[X] — recurse so List[X] and nested dataclasses are handled
|
||||
return _coerce(value, non_none_args[0])
|
||||
# Multi-type Union (e.g. Union[Measurement, int]): try dataclasses first
|
||||
for arg in non_none_args:
|
||||
if dataclasses.is_dataclass(arg) and isinstance(value, dict):
|
||||
return _from_dict_impl(arg, value)
|
||||
# All remaining args are primitives — return value as-is
|
||||
return value
|
||||
|
||||
# List[X]
|
||||
if origin is list:
|
||||
item_hint = args[0]
|
||||
return [_coerce(item, item_hint) for item in value]
|
||||
|
||||
# Plain dataclass
|
||||
if dataclasses.is_dataclass(hint) and isinstance(value, dict):
|
||||
return _from_dict_impl(hint, value)
|
||||
|
||||
if hint is date and isinstance(value, str):
|
||||
return date.fromisoformat(value)
|
||||
|
||||
return value
|
||||
Loading…
Add table
Reference in a new issue