mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
278 lines
8.6 KiB
Python
278 lines
8.6 KiB
Python
import os
|
|
import sys
|
|
import json
|
|
import pandas as pd
|
|
import requests
|
|
import boto3
|
|
from uuid import UUID, uuid4
|
|
from utils.s3 import (
|
|
read_csv_from_s3 as read_csv_from_s3_dict,
|
|
save_csv_to_s3,
|
|
parse_s3_uri,
|
|
)
|
|
from utils.logger import setup_logger
|
|
from tqdm import tqdm
|
|
from backend.app.db.functions.tasks.Tasks import SubTaskInterface
|
|
from datetime import datetime
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
def upload_batch_to_s3(
|
|
batch_df: pd.DataFrame, task_id: str, sub_task_id: str, bucket_name: str = None
|
|
) -> str:
|
|
"""
|
|
Upload batch DataFrame to S3 as CSV.
|
|
"""
|
|
if bucket_name is None:
|
|
bucket_name = os.getenv("S3_BUCKET_NAME")
|
|
|
|
if not bucket_name:
|
|
logger.error(
|
|
"S3 bucket name not provided and S3_BUCKET_NAME environment variable not set"
|
|
)
|
|
raise ValueError("S3_BUCKET_NAME not configured")
|
|
|
|
try:
|
|
file_name = f"{datetime.now().isoformat()}_{str(uuid4())[:8]}"
|
|
file_key = (
|
|
f"ara_postcode_splitter_batches/{task_id}/{sub_task_id}/{file_name}.csv"
|
|
)
|
|
|
|
success = save_csv_to_s3(batch_df, bucket_name, file_key)
|
|
|
|
if success:
|
|
s3_uri = f"s3://{bucket_name}/{file_key}"
|
|
logger.info(f"Successfully uploaded batch to {s3_uri}")
|
|
return s3_uri
|
|
else:
|
|
logger.error(f"Failed to upload batch to S3")
|
|
raise ValueError("Failed to save CSV to S3")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error uploading batch to S3: {str(e)}")
|
|
raise
|
|
|
|
|
|
def send_to_address2uprn_queue(task_id: str, sub_task_id: str, s3_uri: str) -> str:
|
|
"""
|
|
Send a batch to the address2UPRN SQS queue with S3 reference.
|
|
|
|
Args:
|
|
task_id: The parent task ID
|
|
sub_task_id: The new subtask ID for this batch
|
|
s3_uri: S3 URI pointing to the batch CSV file
|
|
|
|
Returns:
|
|
Message ID from SQS
|
|
"""
|
|
sqs_client = boto3.client("sqs")
|
|
queue_url = os.getenv("ADDRESS2UPRN_QUEUE_URL")
|
|
|
|
if not queue_url:
|
|
raise ValueError("ADDRESS2UPRN_QUEUE_URL environment variable not set")
|
|
|
|
message_body = {
|
|
"task_id": task_id,
|
|
"sub_task_id": sub_task_id,
|
|
"s3_uri": s3_uri,
|
|
}
|
|
|
|
response = sqs_client.send_message(
|
|
QueueUrl=queue_url,
|
|
MessageBody=json.dumps(message_body),
|
|
)
|
|
|
|
logger.info(
|
|
f"Sent message to address2UPRN queue. "
|
|
f"Task: {task_id}, SubTask: {sub_task_id}, MessageId: {response['MessageId']}"
|
|
)
|
|
|
|
return response["MessageId"]
|
|
|
|
|
|
def create_batch_and_send_to_address2uprn(
|
|
batch_df: pd.DataFrame,
|
|
task_id: str,
|
|
sub_task_id: str,
|
|
subtask_interface: SubTaskInterface,
|
|
bucket_name: str,
|
|
) -> str:
|
|
"""
|
|
Create a batch DataFrame, upload to S3, create subtask, and send to address2UPRN queue.
|
|
|
|
"""
|
|
# Upload batch to S3
|
|
|
|
s3_uri = upload_batch_to_s3(batch_df, str(task_id), str(sub_task_id), bucket_name)
|
|
|
|
# Create a new subtask for this batch with all inputs
|
|
created_batch_sub_task_id = subtask_interface.create_subtask(
|
|
task_id=task_id,
|
|
inputs={
|
|
"task_id": str(task_id),
|
|
"s3_uri": s3_uri,
|
|
},
|
|
)
|
|
|
|
logger.info(f"Created batch subtask {created_batch_sub_task_id}")
|
|
|
|
# Send message with S3 reference
|
|
send_to_address2uprn_queue(
|
|
task_id=str(task_id),
|
|
sub_task_id=str(created_batch_sub_task_id),
|
|
s3_uri=s3_uri,
|
|
)
|
|
|
|
return created_batch_sub_task_id
|
|
|
|
|
|
def handler(event, context, local=False):
|
|
print(f"Function: {context.function_name}")
|
|
print(f"Request ID: {context.aws_request_id}")
|
|
|
|
# Example SQS message for testing (copy and paste into SQS):
|
|
if local is True:
|
|
event = {
|
|
"Records": [
|
|
{
|
|
"body": json.dumps(
|
|
{
|
|
"task_id": "e31f2f21-175b-4a91-a3ec-a6baa325e917",
|
|
"sub_task_id": "8673913b-1a88-42d7-8578-0449123d94b0",
|
|
"s3_uri": "s3://retrofit-data-dev/ara_raw_inputs/peabody/2025_11_11 - Peabody - Data Extracts for Domna_transformed.csv",
|
|
}
|
|
)
|
|
}
|
|
]
|
|
}
|
|
# Handle both single event and batch events (SQS, etc.)
|
|
records = event.get("Records", [event])
|
|
results = []
|
|
errors = []
|
|
subtask_interface = SubTaskInterface()
|
|
bucket_name = os.getenv("S3_BUCKET_NAME")
|
|
if local:
|
|
bucket_name = "retrofit-data-dev"
|
|
|
|
for record in records:
|
|
if local:
|
|
record = records[0]
|
|
task_id = None
|
|
subtask_id = None
|
|
# Parse body (inputs)
|
|
|
|
if isinstance(record.get("body"), str):
|
|
body = json.loads(record["body"])
|
|
else:
|
|
body = record.get("body", {})
|
|
|
|
# Validate required fields
|
|
task_id = body.get("task_id")
|
|
subtask_id = body.get("sub_task_id")
|
|
s3_uri = body.get("s3_uri")
|
|
|
|
# Convert task_id to UUID
|
|
task_id = UUID(task_id) if isinstance(task_id, str) else task_id
|
|
subtask_id = UUID(subtask_id) if isinstance(subtask_id, str) else subtask_id
|
|
|
|
# Mark subtask as in progress
|
|
subtask_interface.update_subtask_status(subtask_id, "in progress")
|
|
logger.info(f"Marked subtask {subtask_id} as in progress")
|
|
|
|
# Read CSV from S3
|
|
bucket, key = parse_s3_uri(s3_uri)
|
|
logger.info(f"S3 Bucket: {bucket}, Key: {key}")
|
|
|
|
csv_data = read_csv_from_s3_dict(bucket, key)
|
|
df = pd.DataFrame(csv_data)
|
|
|
|
logger.info(f"CSV loaded: {len(df)} rows, {len(df.columns)} columns")
|
|
|
|
# Sanitise postcodes
|
|
df["postcode_clean"] = df["postcode"].str.upper().str.replace(" ", "")
|
|
|
|
df = df.dropna(subset=["postcode_clean"])
|
|
|
|
batch_size = 500
|
|
if df.shape[0] < batch_size:
|
|
create_batch_and_send_to_address2uprn(
|
|
batch_df=df,
|
|
task_id=task_id,
|
|
sub_task_id=subtask_id,
|
|
subtask_interface=subtask_interface,
|
|
bucket_name=bucket_name,
|
|
)
|
|
else:
|
|
postcode_to_addresses = {
|
|
postcode: group
|
|
for postcode, group in df.groupby("postcode_clean", sort=False)
|
|
}
|
|
|
|
count = 0
|
|
buffer = []
|
|
|
|
for postcode, group_df in postcode_to_addresses.items():
|
|
group_len = len(group_df)
|
|
|
|
# If single postcode is bigger than batch_size → send directly
|
|
if group_len >= batch_size:
|
|
if buffer:
|
|
create_batch_and_send_to_address2uprn(
|
|
batch_df=pd.concat(buffer, ignore_index=True),
|
|
task_id=task_id,
|
|
sub_task_id=subtask_id,
|
|
subtask_interface=subtask_interface,
|
|
bucket_name=bucket_name,
|
|
)
|
|
buffer = []
|
|
count = 0
|
|
|
|
create_batch_and_send_to_address2uprn(
|
|
batch_df=group_df,
|
|
task_id=task_id,
|
|
sub_task_id=subtask_id,
|
|
subtask_interface=subtask_interface,
|
|
bucket_name=bucket_name,
|
|
)
|
|
continue
|
|
|
|
# If adding would exceed batch → flush first
|
|
if count + group_len > batch_size:
|
|
create_batch_and_send_to_address2uprn(
|
|
batch_df=pd.concat(buffer, ignore_index=True),
|
|
task_id=task_id,
|
|
sub_task_id=subtask_id,
|
|
subtask_interface=subtask_interface,
|
|
bucket_name=bucket_name,
|
|
)
|
|
buffer = []
|
|
count = 0
|
|
|
|
# Add group
|
|
buffer.append(group_df)
|
|
count += group_len
|
|
|
|
# Final flush
|
|
if buffer:
|
|
create_batch_and_send_to_address2uprn(
|
|
batch_df=pd.concat(buffer, ignore_index=True),
|
|
task_id=task_id,
|
|
sub_task_id=subtask_id,
|
|
subtask_interface=subtask_interface,
|
|
bucket_name=bucket_name,
|
|
)
|
|
|
|
# Mark subtask as completed
|
|
subtask_interface.update_subtask_status(
|
|
subtask_id,
|
|
"completed",
|
|
outputs={"rows_processed": "completed"},
|
|
)
|
|
|
|
return {
|
|
"statusCode": 200,
|
|
"body": json.dumps(
|
|
{"processed": results, "errors": errors if errors else None}
|
|
),
|
|
}
|