Model/backend/bulk_address2uprn_combiner/main.py
2026-04-22 14:21:39 +00:00

79 lines
2.5 KiB
Python

import os
import boto3
import pandas as pd
from io import BytesIO
from typing import Any
from uuid import UUID
from datetime import datetime, timezone
from utils.logger import setup_logger
from backend.utils.subtasks import subtask_handler
from backend.app.db.functions.bulk_address_uploads_functions import (
set_combined_output_s3_uri,
set_combining_status,
)
logger = setup_logger()
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
def list_csv_files(s3_client, bucket: str, task_id: str) -> list[str]:
paginator = s3_client.get_paginator("list_objects_v2")
prefix = f"ara_raw_outputs/{task_id}/"
keys = []
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
for obj in page.get("Contents", []):
if obj["Key"].endswith(".csv"):
keys.append(obj["Key"])
return keys
def download_csv(s3_client, bucket: str, key: str) -> pd.DataFrame:
obj = s3_client.get_object(Bucket=bucket, Key=key)
return pd.read_csv(BytesIO(obj["Body"].read()))
@subtask_handler()
def handler(body: dict[str, Any], context: Any) -> str:
task_id_str: str = body.get("task_id", "")
if not task_id_str:
raise RuntimeError("Missing task_id in message body")
set_combining_status(UUID(task_id_str))
bucket = S3_BUCKET_NAME
if not bucket:
raise RuntimeError("S3_BUCKET_NAME env var not set")
s3 = boto3.client("s3")
logger.info(f"Combining ara_raw_outputs for task {task_id_str}")
csv_keys = list_csv_files(s3, bucket, task_id_str)
if not csv_keys:
raise RuntimeError(f"No CSV files found under ara_raw_outputs/{task_id_str}/")
logger.info(f"Found {len(csv_keys)} CSV files")
dfs = [download_csv(s3, bucket, key) for key in csv_keys]
combined = pd.concat(dfs, ignore_index=True)
logger.info(f"Combined {len(combined)} rows from {len(dfs)} files")
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S")
output_key = f"bulk_final_outputs/{task_id_str}/combined_{timestamp}.csv"
csv_buffer = BytesIO()
combined.to_csv(csv_buffer, index=False)
csv_buffer.seek(0)
s3.put_object(Bucket=bucket, Key=output_key, Body=csv_buffer.getvalue())
s3_uri = f"s3://{bucket}/{output_key}"
logger.info(f"Saved combined CSV to {s3_uri}")
print(f"OUTPUT_S3_URI: {s3_uri}")
set_combined_output_s3_uri(UUID(task_id_str), s3_uri)
logger.info(f"Persisted combined_output_s3_uri + awaiting_review status for task {task_id_str}")
return s3_uri