mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-30 13:10:47 +00:00
feat(scripts): add full AraFirstRunPipeline local runner
scripts/run_first_run_e2e.py runs the real Ingestion -> Baseline -> Modelling pipeline against the DB by composing build_first_run_pipeline + dispatch_first_run with the live source clients (the Lambda handler can't run locally — its _source_clients_from_env still raises, #1136). Unlike run_modelling_e2e it runs real ingestion (persists EPC/spatial/solar) and has no inspect-only mode, so it's gated behind --confirm (preview otherwise); measure scoping comes only from the Scenario's exclusions (the pipeline threads no --measures), and the modelling batch is all-or-nothing, both documented. Extract the shared env/engine/S3 plumbing into scripts/e2e_common.py (public load_env/build_engine/s3_parquet_reader) so both runners share one source and neither imports the other's privates. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
parent
694cdd9c23
commit
ea72ee97bf
3 changed files with 241 additions and 54 deletions
68
scripts/e2e_common.py
Normal file
68
scripts/e2e_common.py
Normal file
|
|
@ -0,0 +1,68 @@
|
||||||
|
"""Shared configuration + client plumbing for the local e2e runner scripts
|
||||||
|
(``run_modelling_e2e`` and ``run_first_run_e2e``).
|
||||||
|
|
||||||
|
Loads ``backend/.env`` and builds the DB engine from the FastAPI-layer ``DB_*``
|
||||||
|
vars (the ``infrastructure/postgres`` layer reads ``POSTGRES_*``, which the .env
|
||||||
|
does not carry), plus an S3-backed ``ParquetReader`` for the geospatial
|
||||||
|
repository. Secrets live in the .env and the ambient ``~/.aws`` profile; this
|
||||||
|
module never hard-codes them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
import pandas as pd
|
||||||
|
from sqlalchemy import Engine, create_engine
|
||||||
|
|
||||||
|
from repositories.geospatial.geospatial_s3_repository import ParquetReader
|
||||||
|
|
||||||
|
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
ENV_PATH = _REPO_ROOT / "backend" / ".env"
|
||||||
|
|
||||||
|
|
||||||
|
def load_env(path: Path = ENV_PATH) -> None:
|
||||||
|
"""Load `KEY=value` lines from `backend/.env` into the environment (without
|
||||||
|
overriding anything already set), so the DB creds + API tokens are present."""
|
||||||
|
if not path.exists():
|
||||||
|
return
|
||||||
|
for raw in path.read_text(encoding="utf-8").splitlines():
|
||||||
|
line = raw.strip()
|
||||||
|
if not line or line.startswith("#") or "=" not in line:
|
||||||
|
continue
|
||||||
|
key, value = line.split("=", 1)
|
||||||
|
os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'"))
|
||||||
|
|
||||||
|
|
||||||
|
def db_url() -> str:
|
||||||
|
"""The connection string from the FastAPI-layer `DB_*` env vars."""
|
||||||
|
env = os.environ
|
||||||
|
return (
|
||||||
|
f"postgresql+psycopg2://{env['DB_USERNAME']}:{env['DB_PASSWORD']}"
|
||||||
|
f"@{env['DB_HOST']}:{env['DB_PORT']}/{env['DB_NAME']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_engine() -> Engine:
|
||||||
|
"""A connection-pooled engine to the target DB (DB_* creds)."""
|
||||||
|
return create_engine(
|
||||||
|
db_url(), pool_pre_ping=True, connect_args={"connect_timeout": 10}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def s3_parquet_reader(bucket: str) -> ParquetReader:
|
||||||
|
"""A `ParquetReader` (key -> DataFrame) backed by `bucket` in S3, for the
|
||||||
|
`GeospatialS3Repository`. AWS creds come from the ambient `~/.aws` profile;
|
||||||
|
pyarrow reads the parquet bytes (s3fs is not installed here)."""
|
||||||
|
# boto3 ships only partial type stubs, so the client is an untyped boundary.
|
||||||
|
client = cast(Any, boto3.client("s3")) # pyright: ignore[reportUnknownMemberType]
|
||||||
|
|
||||||
|
def read(key: str) -> pd.DataFrame:
|
||||||
|
body = cast(bytes, client.get_object(Bucket=bucket, Key=key)["Body"].read())
|
||||||
|
return pd.read_parquet(io.BytesIO(body))
|
||||||
|
|
||||||
|
return read
|
||||||
162
scripts/run_first_run_e2e.py
Normal file
162
scripts/run_first_run_e2e.py
Normal file
|
|
@ -0,0 +1,162 @@
|
||||||
|
"""Run the **full** ``AraFirstRunPipeline`` (Ingestion → Baseline → Modelling)
|
||||||
|
end-to-end against the real database, locally.
|
||||||
|
|
||||||
|
This is the production pipeline the ``ara_first_run`` Lambda runs, driven from a
|
||||||
|
shell instead of an SQS event. The Lambda ``handler`` itself cannot run locally —
|
||||||
|
``applications/ara_first_run/handler.py::_source_clients_from_env`` deliberately
|
||||||
|
raises until the deploy/Terraform wiring lands (#1136). So this script composes
|
||||||
|
the same pipeline directly via the existing ``build_first_run_pipeline`` seam,
|
||||||
|
supplying the three source clients that ``run_modelling_e2e`` already proves out
|
||||||
|
(EPC API, geospatial S3, Google Solar), then calls ``dispatch_first_run``.
|
||||||
|
|
||||||
|
How it differs from ``run_modelling_e2e``:
|
||||||
|
* It runs the **real Ingestion stage** — fetches each Property's EPC by UPRN,
|
||||||
|
resolves spatial + Google Solar, and **persists** them (``epc_property`` /
|
||||||
|
``property_details_spatial`` / ``solar``) — then Baseline, then Modelling.
|
||||||
|
``run_modelling_e2e`` does ingestion inline and only models.
|
||||||
|
* **There is no inspect-only mode**: the stages persist as they go (ADR-0012),
|
||||||
|
so any run writes to the DB. This script is gated behind ``--confirm``; without
|
||||||
|
it the script previews what it would do and exits.
|
||||||
|
* **The modelling batch is all-or-nothing**: each stage commits once per batch,
|
||||||
|
so one Property raising aborts the whole batch (no per-Property recovery like
|
||||||
|
``run_modelling_e2e``). Make sure the inputs are clean first.
|
||||||
|
|
||||||
|
Measure scoping comes **only from the Scenario's exclusions** — the pipeline
|
||||||
|
threads no ``--measures`` override (issue #1130). So if the live ``material``
|
||||||
|
catalogue cannot price/represent a measure a Property is eligible for (today:
|
||||||
|
``secondary_heating_removal``, absent from the ``material.type`` enum), that
|
||||||
|
Property's modelling raises and aborts the batch. Exclude it on the Scenario
|
||||||
|
first, e.g.::
|
||||||
|
|
||||||
|
UPDATE scenario SET exclusions = '{secondary_heating_removal}' WHERE id = 1266;
|
||||||
|
|
||||||
|
EPC Prediction (ADR-0031) is left **off** — its Landlord-Override attributes
|
||||||
|
reader is not wired here, so an EPC-less Property is not gap-filled.
|
||||||
|
|
||||||
|
Config + secrets are loaded exactly as ``run_modelling_e2e`` does: ``backend/.env``
|
||||||
|
for the DB creds (``DB_*``), the EPC Bearer token (``OPEN_EPC_API_TOKEN``), the
|
||||||
|
Google Solar key (``GOOGLE_SOLAR_API_KEY``) and the S3 bucket (``DATA_BUCKET``);
|
||||||
|
AWS creds from the ambient ``~/.aws`` profile. Run from the worktree root::
|
||||||
|
|
||||||
|
# preview only (no writes): print what would run, then exit
|
||||||
|
python -m scripts.run_first_run_e2e --scenario-ids 1266 --portfolio-id 785 \
|
||||||
|
709634 709635 709636
|
||||||
|
# actually run the full pipeline and persist (Ingestion -> Baseline -> Modelling)
|
||||||
|
python -m scripts.run_first_run_e2e --scenario-ids 1266 --portfolio-id 785 \
|
||||||
|
--confirm 709634 709635 709636
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
sys.path.insert(0, str(_REPO_ROOT)) # worktree root first — avoid the import trap
|
||||||
|
|
||||||
|
from applications.ara_first_run.ara_first_run_trigger_body import ( # noqa: E402
|
||||||
|
AraFirstRunTriggerBody,
|
||||||
|
)
|
||||||
|
from applications.ara_first_run.handler import ( # noqa: E402
|
||||||
|
build_first_run_pipeline,
|
||||||
|
dispatch_first_run,
|
||||||
|
)
|
||||||
|
from infrastructure.epc_client.epc_client_service import EpcClientService # noqa: E402
|
||||||
|
from infrastructure.solar.google_solar_api_client import ( # noqa: E402
|
||||||
|
GoogleSolarApiClient,
|
||||||
|
)
|
||||||
|
from repositories.geospatial.geospatial_s3_repository import ( # noqa: E402
|
||||||
|
GeospatialS3Repository,
|
||||||
|
)
|
||||||
|
from repositories.postgres_unit_of_work import PostgresUnitOfWork # noqa: E402
|
||||||
|
from scripts.e2e_common import ( # noqa: E402
|
||||||
|
ENV_PATH,
|
||||||
|
build_engine,
|
||||||
|
load_env,
|
||||||
|
s3_parquet_reader,
|
||||||
|
)
|
||||||
|
from sqlmodel import Session # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_ids(raw: str) -> list[int]:
|
||||||
|
"""Parse a comma-separated id list (e.g. ``--scenario-ids 1266,1270``)."""
|
||||||
|
return [int(token.strip()) for token in raw.split(",") if token.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"property_ids", type=int, nargs="+", help="Property ids to run"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--scenario-ids",
|
||||||
|
required=True,
|
||||||
|
help="comma-separated Scenario ids to model against (exclusions come "
|
||||||
|
"from each Scenario)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--portfolio-id", type=int, required=True, help="portfolio id for the run"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--confirm",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="actually run the pipeline and WRITE to the DB (default: preview only)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
scenario_ids = _parse_ids(args.scenario_ids)
|
||||||
|
|
||||||
|
load_env(ENV_PATH)
|
||||||
|
engine = build_engine()
|
||||||
|
|
||||||
|
body = AraFirstRunTriggerBody(
|
||||||
|
# task/sub_task drive the Lambda SubTask lifecycle only; running the
|
||||||
|
# pipeline directly bypasses the @subtask_handler decorator, so synthetic
|
||||||
|
# ids satisfy validation without touching the task tables.
|
||||||
|
task_id=uuid4(),
|
||||||
|
sub_task_id=uuid4(),
|
||||||
|
portfolio_id=args.portfolio_id,
|
||||||
|
property_ids=args.property_ids,
|
||||||
|
scenario_ids=scenario_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"full AraFirstRunPipeline (Ingestion -> Baseline -> Modelling) · "
|
||||||
|
f"{len(args.property_ids)} propertie(s) · scenarios {scenario_ids} · "
|
||||||
|
f"portfolio {args.portfolio_id}"
|
||||||
|
)
|
||||||
|
if not args.confirm:
|
||||||
|
print(
|
||||||
|
"\nPREVIEW ONLY — no writes. This run WOULD fetch + persist EPC/"
|
||||||
|
"spatial/solar, rebaseline, and model+persist Plans for:\n"
|
||||||
|
f" properties: {args.property_ids}\n"
|
||||||
|
"Re-run with --confirm to execute. NOTE: the modelling batch is "
|
||||||
|
"all-or-nothing; ensure each Scenario excludes any measure the live "
|
||||||
|
"catalogue cannot price (e.g. secondary_heating_removal)."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
epc_fetcher = EpcClientService(os.environ["OPEN_EPC_API_TOKEN"])
|
||||||
|
geospatial_repo = GeospatialS3Repository(
|
||||||
|
s3_parquet_reader(os.environ["DATA_BUCKET"])
|
||||||
|
)
|
||||||
|
solar_fetcher = GoogleSolarApiClient(os.environ["GOOGLE_SOLAR_API_KEY"])
|
||||||
|
|
||||||
|
pipeline = build_first_run_pipeline(
|
||||||
|
unit_of_work=lambda: PostgresUnitOfWork(lambda: Session(engine)),
|
||||||
|
epc_fetcher=epc_fetcher,
|
||||||
|
geospatial_repo=geospatial_repo,
|
||||||
|
solar_fetcher=solar_fetcher,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("running... (Ingestion -> Baseline -> Modelling, persisting per stage)\n")
|
||||||
|
dispatch_first_run(body.model_dump(), pipeline=pipeline)
|
||||||
|
print("done — EPC/spatial/solar + Baseline + Plans persisted for the batch.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -53,14 +53,10 @@ Google leg.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import io
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional, cast
|
from typing import Any, Optional
|
||||||
|
|
||||||
import boto3
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||||
sys.path.insert(0, str(_REPO_ROOT)) # worktree root first — avoid the import trap
|
sys.path.insert(0, str(_REPO_ROOT)) # worktree root first — avoid the import trap
|
||||||
|
|
@ -84,7 +80,6 @@ from infrastructure.solar.google_solar_api_client import ( # noqa: E402
|
||||||
)
|
)
|
||||||
from repositories.geospatial.geospatial_s3_repository import ( # noqa: E402
|
from repositories.geospatial.geospatial_s3_repository import ( # noqa: E402
|
||||||
GeospatialS3Repository,
|
GeospatialS3Repository,
|
||||||
ParquetReader,
|
|
||||||
)
|
)
|
||||||
from repositories.product.product_postgres_repository import ( # noqa: E402
|
from repositories.product.product_postgres_repository import ( # noqa: E402
|
||||||
ProductPostgresRepository,
|
ProductPostgresRepository,
|
||||||
|
|
@ -93,51 +88,20 @@ from repositories.postgres_unit_of_work import PostgresUnitOfWork # noqa: E402
|
||||||
from repositories.scenario.scenario_postgres_repository import ( # noqa: E402
|
from repositories.scenario.scenario_postgres_repository import ( # noqa: E402
|
||||||
ScenarioPostgresRepository,
|
ScenarioPostgresRepository,
|
||||||
)
|
)
|
||||||
from sqlalchemy import Engine, create_engine, text # noqa: E402
|
from scripts.e2e_common import ( # noqa: E402
|
||||||
|
ENV_PATH,
|
||||||
|
build_engine,
|
||||||
|
load_env,
|
||||||
|
s3_parquet_reader,
|
||||||
|
)
|
||||||
|
from sqlalchemy import Engine, text # noqa: E402
|
||||||
from sqlmodel import Session # noqa: E402
|
from sqlmodel import Session # noqa: E402
|
||||||
|
|
||||||
_ENV_PATH = _REPO_ROOT / "backend" / ".env"
|
|
||||||
_MARKDOWN_PATH = Path("modelling_e2e.md")
|
_MARKDOWN_PATH = Path("modelling_e2e.md")
|
||||||
_CSV_PATH = Path("modelling_e2e.csv")
|
_CSV_PATH = Path("modelling_e2e.csv")
|
||||||
_CANDIDATES_CSV_PATH = Path("modelling_e2e_candidates.csv")
|
_CANDIDATES_CSV_PATH = Path("modelling_e2e_candidates.csv")
|
||||||
|
|
||||||
|
|
||||||
def _load_env(path: Path) -> None:
|
|
||||||
"""Load `KEY=value` lines from `backend/.env` into the environment (without
|
|
||||||
overriding anything already set), so the DB creds + EPC token are present."""
|
|
||||||
if not path.exists():
|
|
||||||
return
|
|
||||||
for raw in path.read_text(encoding="utf-8").splitlines():
|
|
||||||
line = raw.strip()
|
|
||||||
if not line or line.startswith("#") or "=" not in line:
|
|
||||||
continue
|
|
||||||
key, value = line.split("=", 1)
|
|
||||||
os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'"))
|
|
||||||
|
|
||||||
|
|
||||||
def _db_url() -> str:
|
|
||||||
"""The connection string from the FastAPI-layer `DB_*` env vars."""
|
|
||||||
env = os.environ
|
|
||||||
return (
|
|
||||||
f"postgresql+psycopg2://{env['DB_USERNAME']}:{env['DB_PASSWORD']}"
|
|
||||||
f"@{env['DB_HOST']}:{env['DB_PORT']}/{env['DB_NAME']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _s3_parquet_reader(bucket: str) -> ParquetReader:
|
|
||||||
"""A `ParquetReader` (key -> DataFrame) backed by `bucket` in S3, for the
|
|
||||||
`GeospatialS3Repository`. AWS creds come from the ambient `~/.aws` profile;
|
|
||||||
pyarrow reads the parquet bytes (s3fs is not installed here)."""
|
|
||||||
# boto3 ships only partial type stubs, so the client is an untyped boundary.
|
|
||||||
client = cast(Any, boto3.client("s3")) # pyright: ignore[reportUnknownMemberType]
|
|
||||||
|
|
||||||
def read(key: str) -> pd.DataFrame:
|
|
||||||
body = cast(bytes, client.get_object(Bucket=bucket, Key=key)["Body"].read())
|
|
||||||
return pd.read_parquet(io.BytesIO(body))
|
|
||||||
|
|
||||||
return read
|
|
||||||
|
|
||||||
|
|
||||||
def _spatial_for(repo: GeospatialS3Repository, uprn: int) -> Optional[SpatialReference]:
|
def _spatial_for(repo: GeospatialS3Repository, uprn: int) -> Optional[SpatialReference]:
|
||||||
"""The UPRN's spatial reference (coordinates + planning protections), or
|
"""The UPRN's spatial reference (coordinates + planning protections), or
|
||||||
None when S3 doesn't cover it — a missing reference must not abort the run,
|
None when S3 doesn't cover it — a missing reference must not abort the run,
|
||||||
|
|
@ -166,13 +130,6 @@ def _solar_insights_for(
|
||||||
return None # no Google solar coverage at this point — model without it
|
return None # no Google solar coverage at this point — model without it
|
||||||
|
|
||||||
|
|
||||||
def _engine() -> Engine:
|
|
||||||
"""A connection-pooled engine to DevAssessmentModelDB (DB_* creds)."""
|
|
||||||
return create_engine(
|
|
||||||
_db_url(), pool_pre_ping=True, connect_args={"connect_timeout": 10}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _uprns_for(engine: Engine, property_ids: list[int]) -> dict[int, Optional[int]]:
|
def _uprns_for(engine: Engine, property_ids: list[int]) -> dict[int, Optional[int]]:
|
||||||
"""Read each Property's UPRN from the DB (read-only)."""
|
"""Read each Property's UPRN from the DB (read-only)."""
|
||||||
with engine.connect() as conn:
|
with engine.connect() as conn:
|
||||||
|
|
@ -389,14 +346,14 @@ def main() -> None:
|
||||||
if args.persist and (args.scenario_id is None or args.portfolio_id is None):
|
if args.persist and (args.scenario_id is None or args.portfolio_id is None):
|
||||||
parser.error("--persist requires --scenario-id and --portfolio-id")
|
parser.error("--persist requires --scenario-id and --portfolio-id")
|
||||||
|
|
||||||
_load_env(_ENV_PATH)
|
load_env(ENV_PATH)
|
||||||
# The new gov EPC API (Bearer) authenticates with OPEN_EPC_API_TOKEN — the
|
# The new gov EPC API (Bearer) authenticates with OPEN_EPC_API_TOKEN — the
|
||||||
# name is misleading; EPC_AUTH_TOKEN is dead (403). Verified against the
|
# name is misleading; EPC_AUTH_TOKEN is dead (403). Verified against the
|
||||||
# /api/domestic/search endpoint.
|
# /api/domestic/search endpoint.
|
||||||
epc_client = EpcClientService(os.environ["OPEN_EPC_API_TOKEN"])
|
epc_client = EpcClientService(os.environ["OPEN_EPC_API_TOKEN"])
|
||||||
geospatial = GeospatialS3Repository(_s3_parquet_reader(os.environ["DATA_BUCKET"]))
|
geospatial = GeospatialS3Repository(s3_parquet_reader(os.environ["DATA_BUCKET"]))
|
||||||
solar_client = GoogleSolarApiClient(os.environ["GOOGLE_SOLAR_API_KEY"])
|
solar_client = GoogleSolarApiClient(os.environ["GOOGLE_SOLAR_API_KEY"])
|
||||||
engine = _engine()
|
engine = build_engine()
|
||||||
cli_considered = _resolve_considered(
|
cli_considered = _resolve_considered(
|
||||||
_parse_measures(args.measures), _parse_measures(args.exclude_measures)
|
_parse_measures(args.measures), _parse_measures(args.exclude_measures)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue