Model/backend/postcode_splitter/main.py

358 lines
14 KiB
Python

import os
import sys
import json
import pandas as pd
import requests
import boto3
from uuid import UUID
from urllib.parse import unquote
from utils.s3 import read_csv_from_s3 as read_csv_from_s3_dict
from utils.logger import setup_logger
from tqdm import tqdm
from backend.app.db.functions.tasks.Tasks import SubTaskInterface
logger = setup_logger()
def parse_s3_uri(s3_uri: str) -> tuple[str, str]:
"""
Parse S3 URI to extract bucket and key.
Supports two formats:
1. S3 URI format: s3://bucket/key
"""
logger.info("Parsing S3 URI")
try:
# Check if it's an S3 URI format
if s3_uri.startswith("s3://"):
parts = s3_uri[5:].split("/", 1)
if len(parts) < 2:
raise ValueError("S3 URI must include both bucket and key")
bucket = parts[0]
key = parts[1]
logger.info(f"Extracted bucket: {bucket}, key: {key}")
return bucket, key
# Otherwise, treat as AWS console URL
logger.info("Parsing as AWS console URL")
# Split base URL and query string
if "?" not in s3_uri:
raise ValueError("No query string found")
base, query = s3_uri.split("?", 1)
# Extract bucket from base URL
if "/s3/object/" not in base:
raise ValueError("No '/s3/object/' found in URL path")
path_parts = base.split("/s3/object/")
bucket = path_parts[1]
logger.info(f"Extracted bucket: {bucket}")
# Extract prefix from query parameters
params = dict(item.split("=") for item in query.split("&") if "=" in item)
key = unquote(params.get("prefix", ""))
logger.info(f"Extracted key: {key}")
return bucket, key
except Exception as e:
logger.error(f"Error parsing S3 URI: {type(e).__name__}: {e}")
raise ValueError(f"Could not parse S3 URI") from e
def send_to_address2uprn_queue(task_id: str, rows: list) -> str:
"""
Send a postcode group to the address2UPRN SQS queue.
Args:
task_id: The parent task ID
rows: List of row dictionaries for this postcode group
Returns:
Message ID from SQS
"""
sqs_client = boto3.client("sqs")
queue_url = os.getenv("ADDRESS2UPRN_QUEUE_URL")
if not queue_url:
raise ValueError("ADDRESS2UPRN_QUEUE_URL environment variable not set")
message_body = {
"task_id": task_id,
"rows": rows,
}
response = sqs_client.send_message(
QueueUrl=queue_url,
MessageBody=json.dumps(message_body),
)
logger.info(
f"Sent message to address2UPRN queue. "
f"Task: {task_id}, MessageId: {response['MessageId']}"
)
return response["MessageId"]
def handler(event, context, local=False):
print(f"Function: {context.function_name}")
print(f"Request ID: {context.aws_request_id}")
# Example SQS message for testing (copy and paste into SQS):
# {
# "task_id":"e31f2f21-175b-4a91-a3ec-a6baa325e917",
# "s3_uri":"s3://retrofit-data-dev/ara_raw_inputs/peabody/2025_11_11 - Peabody - Data Extracts for Domna_transformed.csv"
# }
# Handle both single event and batch events (SQS, etc.)
records = event.get("Records", [event])
results = []
errors = []
subtask_interface = SubTaskInterface()
for record in records:
task_id = None
subtask_id = None
try:
# For local development
if local is True:
record = {}
record["body"] = (
'{"task_id":"e31f2f21-175b-4a91-a3ec-a6baa325e917","s3_uri":"s3://retrofit-data-dev/ara_raw_inputs/peabody/2025_11_11 - Peabody - Data Extracts for Domna_transformed.csv"}'
)
# Parse body (inputs)
if isinstance(record.get("body"), str):
body = json.loads(record["body"])
else:
body = record.get("body", {})
# Validate required fields
task_id = body.get("task_id")
s3_uri = body.get("s3_uri")
if not task_id:
errors.append({"error": "Missing required field: task_id"})
continue
if not s3_uri:
errors.append({"error": "Missing required field: s3_uri"})
continue
# Convert task_id to UUID
try:
task_id = UUID(task_id) if isinstance(task_id, str) else task_id
except ValueError as e:
errors.append({"error": f"Invalid UUID format for task_id: {str(e)}"})
continue
# Create a new subtask for this postcode splitter invocation
subtask_id = subtask_interface.create_subtask(
task_id=task_id, inputs={"s3_uri": s3_uri}
)
logger.info(f"Created subtask {subtask_id} for task {task_id}")
# Read CSV from S3
logger.info(f"Processing S3 URI: {s3_uri}")
bucket, key = parse_s3_uri(s3_uri)
logger.info(f"S3 Bucket: {bucket}, Key: {key}")
csv_data = read_csv_from_s3_dict(bucket, key)
df = pd.DataFrame(csv_data)
# just do 5 well we are testing, sqs connection
if local:
df = df.head(5)
# TODO: DELETE ME, if you see this in the PR.
# TODO: DELETE ME, if you see this in the PR.
# TODO: DELETE ME, if you see this in the PR.
df = df.head(1983)
logger.info(f"CSV loaded: {len(df)} rows, {len(df.columns)} columns")
# Sanitise postcodes
df["postcode_clean"] = df["postcode"].str.upper().str.replace(" ", "")
clean_df = df.dropna(subset=["postcode_clean"])
postcode_to_addresses = {
postcode: group.to_dict(orient="records")
for postcode, group in clean_df.groupby("postcode_clean", sort=False)
}
logger.info(f"Total postcodes: {len(postcode_to_addresses)}")
# Calculate total rows to send
total_rows = sum(len(rows) for rows in postcode_to_addresses.values())
logger.info(f"Total rows to send: {total_rows}")
batch_size = 500
# If all rows fit in one batch, just send them all at once
if total_rows <= batch_size:
all_rows = []
for postcode, rows in postcode_to_addresses.items():
all_rows.extend(rows)
try:
send_to_address2uprn_queue(
task_id=str(task_id),
rows=all_rows,
)
logger.info(
f"Sent all {len(all_rows)} rows in single batch to address2UPRN queue"
)
except Exception as e:
logger.error(
f"Failed to send all rows to address2UPRN queue: {e}",
exc_info=True,
)
errors.append(
{
"error": "Failed to send to address2UPRN queue",
"details": str(e),
}
)
else:
# Multi-batch processing for large datasets
batch_rows = []
total_sent = 0
for postcode, rows in postcode_to_addresses.items():
logger.info(f"Processing postcode {postcode} with {len(rows)} rows")
# If postcode itself is larger than batch_size, send it individually
if len(rows) > batch_size:
# First, send the current batch if it has data
if batch_rows:
try:
send_to_address2uprn_queue(
task_id=str(task_id),
rows=batch_rows,
)
logger.info(
f"Sent batch of {len(batch_rows)} rows to address2UPRN queue"
)
batch_rows = []
except Exception as e:
logger.error(
f"Failed to send batch to address2UPRN queue: {e}",
exc_info=True,
)
errors.append(
{
"error": "Failed to send to address2UPRN queue",
"details": str(e),
}
)
# Send the large postcode on its own
try:
send_to_address2uprn_queue(
task_id=str(task_id),
rows=rows,
)
logger.info(
f"Sent large postcode {postcode} ({len(rows)} rows) to address2UPRN queue"
)
except Exception as e:
logger.error(
f"Failed to send large postcode to address2UPRN queue: {e}",
exc_info=True,
)
errors.append(
{
"error": "Failed to send to address2UPRN queue",
"details": str(e),
}
)
continue
# If adding this postcode's rows would exceed batch_size, send current batch
current_batch_size = len(batch_rows) + len(rows)
if batch_rows and current_batch_size > batch_size:
logger.info(
f"Batch threshold reached: current {len(batch_rows)} + next postcode {len(rows)} = {current_batch_size} > {batch_size}"
)
try:
send_to_address2uprn_queue(
task_id=str(task_id),
rows=batch_rows,
)
logger.info(
f"Sent batch of {len(batch_rows)} rows to address2UPRN queue (total sent: {total_sent})"
)
total_sent += len(batch_rows)
batch_rows = []
except Exception as e:
logger.error(
f"Failed to send batch to address2UPRN queue: {e}",
exc_info=True,
)
errors.append(
{
"error": "Failed to send to address2UPRN queue",
"details": str(e),
}
)
# Add current postcode's rows to batch
batch_rows.extend(rows)
# Send remaining batch
if batch_rows:
try:
send_to_address2uprn_queue(
task_id=str(task_id),
rows=batch_rows,
)
total_sent += len(batch_rows)
logger.info(
f"Sent final batch of {len(batch_rows)} rows to address2UPRN queue (total sent: {total_sent})"
)
batch_rows = []
except Exception as e:
logger.error(
f"Failed to send final batch to address2UPRN queue: {e}",
exc_info=True,
)
errors.append(
{
"error": "Failed to send to address2UPRN queue",
"details": str(e),
}
)
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON in request body: {e}")
errors.append({"error": "Invalid JSON in request body", "details": str(e)})
# Mark subtask as failed if we have one
if subtask_id:
try:
subtask_interface.update_subtask_status(
subtask_id, "failed", outputs={"error": str(e)}
)
except Exception as db_error:
logger.error(f"Failed to update subtask status: {db_error}")
except Exception as e:
logger.error(f"Unexpected error processing record: {e}", exc_info=True)
errors.append({"error": "Unexpected error", "details": str(e)})
# Mark subtask as failed if we have one
if subtask_id:
try:
subtask_interface.update_subtask_status(
subtask_id, "failed", outputs={"error": str(e)}
)
except Exception as db_error:
logger.error(f"Failed to update subtask status: {db_error}")
# Return error if all records failed
if errors and not results:
return {"statusCode": 500, "body": json.dumps({"errors": errors})}
return {
"statusCode": 200,
"body": json.dumps(
{"processed": results, "errors": errors if errors else None}
),
}