Model/backend/postcode_splitter/main.py

222 lines
7.7 KiB
Python

import os
import sys
import json
import pandas as pd
import requests
from uuid import UUID
from urllib.parse import unquote
from utils.s3 import read_csv_from_s3 as read_csv_from_s3_dict
from utils.logger import setup_logger
from tqdm import tqdm
from backend.app.db.functions.tasks.Tasks import SubTaskInterface
from backend.address2UPRN.main import (
resolve_uprns_for_postcode_group,
get_epc_data_with_postcode,
)
logger = setup_logger()
def parse_s3_uri(s3_uri: str) -> tuple[str, str]:
"""
Parse S3 URI to extract bucket and key.
Supports two formats:
1. S3 URI format: s3://bucket/key
"""
logger.info("Parsing S3 URI")
try:
# Check if it's an S3 URI format
if s3_uri.startswith("s3://"):
parts = s3_uri[5:].split("/", 1)
if len(parts) < 2:
raise ValueError("S3 URI must include both bucket and key")
bucket = parts[0]
key = parts[1]
logger.info(f"Extracted bucket: {bucket}, key: {key}")
return bucket, key
# Otherwise, treat as AWS console URL
logger.info("Parsing as AWS console URL")
# Split base URL and query string
if "?" not in s3_uri:
raise ValueError("No query string found")
base, query = s3_uri.split("?", 1)
# Extract bucket from base URL
if "/s3/object/" not in base:
raise ValueError("No '/s3/object/' found in URL path")
path_parts = base.split("/s3/object/")
bucket = path_parts[1]
logger.info(f"Extracted bucket: {bucket}")
# Extract prefix from query parameters
params = dict(item.split("=") for item in query.split("&") if "=" in item)
key = unquote(params.get("prefix", ""))
logger.info(f"Extracted key: {key}")
return bucket, key
except Exception as e:
logger.error(f"Error parsing S3 URI: {type(e).__name__}: {e}")
raise ValueError(f"Could not parse S3 URI") from e
def sanitise_postcode(postcode: str) -> str | None:
"""
Normalise postcode for grouping.
- Uppercase
- Remove all whitespace
"""
if pd.isna(postcode):
return None
return postcode.upper().replace(" ", "")
def handler(event, context, local=False):
print(f"Function: {context.function_name}")
print(f"Request ID: {context.aws_request_id}")
# Example SQS message for testing (copy and paste into SQS):
# {
# "task_id":"e31f2f21-175b-4a91-a3ec-a6baa325e917",
# "s3_uri":"s3://retrofit-data-dev/ara_raw_inputs/peabody/2025_11_11 - Peabody - Data Extracts for Domna_transformed.csv"
# }
# Handle both single event and batch events (SQS, etc.)
records = event.get("Records", [event])
results = []
errors = []
subtask_interface = SubTaskInterface()
for record in records:
task_id = None
subtask_id = None
try:
# For local development
if local is True:
record = {}
record["body"] = (
'{"task_id":"e31f2f21-175b-4a91-a3ec-a6baa325e917","s3_uri":"s3://retrofit-data-dev/ara_raw_inputs/peabody/2025_11_11 - Peabody - Data Extracts for Domna_transformed.csv"}'
)
# Parse body (inputs)
if isinstance(record.get("body"), str):
body = json.loads(record["body"])
else:
body = record.get("body", {})
# Validate required fields
task_id = body.get("task_id")
s3_uri = body.get("s3_uri")
if not task_id:
errors.append({"error": "Missing required field: task_id"})
continue
if not s3_uri:
errors.append({"error": "Missing required field: s3_uri"})
continue
# Convert task_id to UUID
try:
task_id = UUID(task_id) if isinstance(task_id, str) else task_id
except ValueError as e:
errors.append({"error": f"Invalid UUID format for task_id: {str(e)}"})
continue
# Create a new subtask for this postcode splitter invocation
subtask_id = subtask_interface.create_subtask(
task_id=task_id, inputs={"s3_uri": s3_uri}
)
logger.info(f"Created subtask {subtask_id} for task {task_id}")
# Read CSV from S3
logger.info(f"Processing S3 URI: {s3_uri}")
bucket, key = parse_s3_uri(s3_uri)
logger.info(f"S3 Bucket: {bucket}, Key: {key}")
csv_data = read_csv_from_s3_dict(bucket, key)
df = pd.DataFrame(csv_data)
logger.info(f"CSV loaded: {len(df)} rows, {len(df.columns)} columns")
# Sanitise postcodes
df["postcode_clean"] = df["Postcode"].apply(sanitise_postcode)
# Group by sanitised postcode (excluding null values)
grouped_data = []
for postcode, group_df in df.dropna(subset=["postcode_clean"]).groupby(
"postcode_clean"
):
group_info = {
"postcode": postcode,
"row_count": len(group_df),
"rows": group_df.to_dict(orient="records"),
}
grouped_data.append(group_info)
logger.info(f"Postcode: {postcode}, Rows: {len(group_df)}")
logger.info(f"Total postcodes: {len(grouped_data)}")
results.append(
{
"message": "Postcode splitter processing completed",
"task_id": str(task_id),
"s3_uri": s3_uri,
"subtask_id": str(subtask_id),
"total_rows": len(df),
"total_postcodes": len(grouped_data),
"grouped_data": grouped_data,
}
)
# Mark subtask as complete after successful processing
subtask_interface.update_subtask_status(
subtask_id,
"complete",
outputs={
"status": "processing_complete",
"s3_uri": s3_uri,
"rows_processed": len(df),
"total_postcodes": len(grouped_data),
},
)
logger.info(f"Subtask {subtask_id} marked as complete")
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON in request body: {e}")
errors.append({"error": "Invalid JSON in request body", "details": str(e)})
# Mark subtask as failed if we have one
if subtask_id:
try:
subtask_interface.update_subtask_status(
subtask_id, "failed", outputs={"error": str(e)}
)
except Exception as db_error:
logger.error(f"Failed to update subtask status: {db_error}")
except Exception as e:
logger.error(f"Unexpected error processing record: {e}", exc_info=True)
errors.append({"error": "Unexpected error", "details": str(e)})
# Mark subtask as failed if we have one
if subtask_id:
try:
subtask_interface.update_subtask_status(
subtask_id, "failed", outputs={"error": str(e)}
)
except Exception as db_error:
logger.error(f"Failed to update subtask status: {db_error}")
# Return error if all records failed
if errors and not results:
return {"statusCode": 500, "body": json.dumps({"errors": errors})}
return {
"statusCode": 200,
"body": json.dumps(
{"processed": results, "errors": errors if errors else None}
),
}