mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
133 lines
4.1 KiB
Python
133 lines
4.1 KiB
Python
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import boto3
|
|
import pandas as pd
|
|
from botocore.config import Config
|
|
from tqdm import tqdm
|
|
|
|
from utils.logger import setup_logger
|
|
|
|
logger = setup_logger()
|
|
|
|
SRC_ROOT = Path("/workspaces/home/epc_data")
|
|
TMP_ROOT = Path("/tmp/epc_postcodes")
|
|
S3_BUCKET = "retrofit-data-dev"
|
|
S3_PREFIX = "historical_epc"
|
|
|
|
# This scripts assume you downloading the zip, unzip it, and running it locally
|
|
|
|
|
|
def sanitise(pc: pd.Series) -> pd.Series:
|
|
return pc.astype("string").str.upper().str.replace(" ", "", regex=False)
|
|
|
|
|
|
def shard_la(la_dir: Path) -> None:
|
|
certs = pd.read_csv(la_dir / "certificates.csv", low_memory=False)
|
|
|
|
certs["POSTCODE_CLEAN"] = sanitise(certs["POSTCODE"])
|
|
before = len(certs)
|
|
certs = certs.dropna(subset=["POSTCODE_CLEAN"])
|
|
certs = certs[certs["POSTCODE_CLEAN"] != ""]
|
|
dropped = before - len(certs)
|
|
if dropped:
|
|
logger.warning(f"{la_dir.name}: dropped {dropped} rows with empty postcode")
|
|
|
|
for pc, group in certs.groupby("POSTCODE_CLEAN", sort=False):
|
|
out = TMP_ROOT / f"{pc}.csv"
|
|
group.drop(columns=["POSTCODE_CLEAN"]).to_csv(
|
|
out, mode="a", header=not out.exists(), index=False
|
|
)
|
|
|
|
|
|
def list_existing_keys(s3: Any) -> set[str]:
|
|
existing: set[str] = set()
|
|
paginator = s3.get_paginator("list_objects_v2")
|
|
pages = paginator.paginate(Bucket=S3_BUCKET, Prefix=f"{S3_PREFIX}/")
|
|
for page in tqdm(pages, desc="list s3"):
|
|
for obj in page.get("Contents", []):
|
|
existing.add(obj["Key"])
|
|
logger.info(f"Found {len(existing)} existing objects under {S3_PREFIX}/")
|
|
return existing
|
|
|
|
|
|
def upload_postcode(path: Path, s3: Any) -> None:
|
|
df = pd.read_csv(path, low_memory=False).drop_duplicates()
|
|
|
|
dupes = df["LMK_KEY"].value_counts()
|
|
bad = dupes[dupes > 1]
|
|
if not bad.empty:
|
|
raise ValueError(
|
|
f"Postcode {path.stem}: LMK_KEY appears with conflicting cert data: "
|
|
f"{bad.index.tolist()[:5]}"
|
|
)
|
|
|
|
buf = BytesIO()
|
|
df.to_csv(buf, index=False, compression="gzip")
|
|
s3.put_object(
|
|
Bucket=S3_BUCKET,
|
|
Key=f"{S3_PREFIX}/{path.stem}/data.csv.gz",
|
|
Body=buf.getvalue(),
|
|
ContentType="text/csv",
|
|
ContentEncoding="gzip",
|
|
)
|
|
|
|
|
|
def main():
|
|
TMP_ROOT.mkdir(parents=True, exist_ok=True)
|
|
la_dirs = sorted(
|
|
p for p in SRC_ROOT.iterdir() if p.is_dir() and p.name.startswith("domestic-")
|
|
)
|
|
logger.info(f"Sharding {len(la_dirs)} LA folders -> {TMP_ROOT}")
|
|
|
|
for la in tqdm(la_dirs, desc="shard"):
|
|
shard_la(la)
|
|
|
|
s3 = boto3.client(
|
|
"s3",
|
|
config=Config(
|
|
max_pool_connections=512, retries={"max_attempts": 5, "mode": "standard"}
|
|
),
|
|
)
|
|
pc_files = sorted(TMP_ROOT.glob("*.csv"))
|
|
logger.info(f"Found {len(pc_files)} local shards")
|
|
|
|
existing = list_existing_keys(s3)
|
|
todo = [p for p in pc_files if f"{S3_PREFIX}/{p.stem}/data.csv.gz" not in existing]
|
|
skipped = len(pc_files) - len(todo)
|
|
logger.info(
|
|
f"Uploading {len(todo)} shards (skipping {skipped} already in S3) -> "
|
|
f"s3://{S3_BUCKET}/{S3_PREFIX}/"
|
|
)
|
|
|
|
workers = 256
|
|
todo_iter = iter(todo)
|
|
inflight: dict[Any, Path] = {}
|
|
pbar = tqdm(total=len(todo), desc="upload")
|
|
with ThreadPoolExecutor(max_workers=workers) as pool:
|
|
for _ in range(workers * 2):
|
|
pc = next(todo_iter, None)
|
|
if pc is None:
|
|
break
|
|
inflight[pool.submit(upload_postcode, pc, s3)] = pc
|
|
|
|
while inflight:
|
|
done, _ = wait(inflight.keys(), return_when=FIRST_COMPLETED)
|
|
for fut in done:
|
|
pc = inflight.pop(fut)
|
|
try:
|
|
fut.result()
|
|
except Exception as e:
|
|
logger.error(f"{pc.name}: {e}")
|
|
raise
|
|
pbar.update(1)
|
|
nxt = next(todo_iter, None)
|
|
if nxt is not None:
|
|
inflight[pool.submit(upload_postcode, nxt, s3)] = nxt
|
|
pbar.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|