From d07fc351a59292a57c3b47eb8b0436d9434f6346 Mon Sep 17 00:00:00 2001 From: Jun-te Kim Date: Thu, 12 Feb 2026 18:04:27 +0000 Subject: [PATCH] added permission to add --- backend/postcode_splitter/main.py | 152 +++++++++++++++--- .../terraform/lambda/postcodeSplitter/main.tf | 2 +- infrastructure/terraform/shared/main.tf | 2 +- 3 files changed, 132 insertions(+), 24 deletions(-) diff --git a/backend/postcode_splitter/main.py b/backend/postcode_splitter/main.py index e834c44e..2714f330 100644 --- a/backend/postcode_splitter/main.py +++ b/backend/postcode_splitter/main.py @@ -4,12 +4,13 @@ import json import pandas as pd import requests import boto3 -from uuid import UUID +from uuid import UUID, uuid4 from urllib.parse import unquote -from utils.s3 import read_csv_from_s3 as read_csv_from_s3_dict +from utils.s3 import read_csv_from_s3 as read_csv_from_s3_dict, save_csv_to_s3 from utils.logger import setup_logger from tqdm import tqdm from backend.app.db.functions.tasks.Tasks import SubTaskInterface +from datetime import datetime logger = setup_logger() @@ -62,13 +63,55 @@ def parse_s3_uri(s3_uri: str) -> tuple[str, str]: raise ValueError(f"Could not parse S3 URI") from e -def send_to_address2uprn_queue(task_id: str, rows: list) -> str: +def upload_batch_to_s3(batch_df: pd.DataFrame, task_id: str, sub_task_id: str, bucket_name: str = None) -> str: """ - Send a postcode group to the address2UPRN SQS queue. + Upload batch DataFrame to S3 as CSV. + + Args: + batch_df: The DataFrame containing batch data + task_id: The parent task ID (used for file path) + sub_task_id: The subtask ID (used for file path) + bucket_name: The S3 bucket name (defaults to env variable) + + Returns: + S3 URI (s3://bucket/key) of the uploaded file + """ + if bucket_name is None: + bucket_name = os.getenv("S3_BUCKET_NAME") + + if not bucket_name: + logger.error( + "S3 bucket name not provided and S3_BUCKET_NAME environment variable not set" + ) + raise ValueError("S3_BUCKET_NAME not configured") + + try: + file_name = f"{datetime.now().isoformat()}_{str(uuid4())[:8]}" + file_key = f"ara_postcode_splitter_batches/{task_id}/{sub_task_id}/{file_name}.csv" + + success = save_csv_to_s3(batch_df, bucket_name, file_key) + + if success: + s3_uri = f"s3://{bucket_name}/{file_key}" + logger.info(f"Successfully uploaded batch to {s3_uri}") + return s3_uri + else: + logger.error(f"Failed to upload batch to S3") + raise ValueError("Failed to save CSV to S3") + + except Exception as e: + logger.error(f"Error uploading batch to S3: {str(e)}") + raise + + +def send_to_address2uprn_queue(task_id: str, sub_task_id: str, s3_uri: str) -> str: + """ + Send a batch to the address2UPRN SQS queue with S3 reference. Args: task_id: The parent task ID - rows: List of row dictionaries for this postcode group + sub_task_id: The new subtask ID for this batch + s3_uri: S3 URI pointing to the batch CSV file Returns: Message ID from SQS @@ -81,7 +124,8 @@ def send_to_address2uprn_queue(task_id: str, rows: list) -> str: message_body = { "task_id": task_id, - "rows": rows, + "sub_task_id": sub_task_id, + "s3_uri": s3_uri, } response = sqs_client.send_message( @@ -91,12 +135,59 @@ def send_to_address2uprn_queue(task_id: str, rows: list) -> str: logger.info( f"Sent message to address2UPRN queue. " - f"Task: {task_id}, MessageId: {response['MessageId']}" + f"Task: {task_id}, SubTask: {sub_task_id}, MessageId: {response['MessageId']}" ) return response["MessageId"] +def create_batch_and_send_to_address2uprn( + batch_rows: list, + task_id: str, + subtask_interface: SubTaskInterface, + bucket_name: str, +) -> str: + """ + Create a batch DataFrame, upload to S3, create subtask, and send to address2UPRN queue. + + Args: + batch_rows: List of row dictionaries for this batch + task_id: The parent task ID + subtask_interface: SubTaskInterface instance + bucket_name: S3 bucket name + + Returns: + The created batch subtask ID + """ + # Generate unique batch subtask ID + batch_sub_task_id = str(uuid4()) + + # Upload batch to S3 + batch_df = pd.DataFrame(batch_rows) + s3_uri = upload_batch_to_s3(batch_df, str(task_id), batch_sub_task_id, bucket_name) + + # Create a new subtask for this batch with all inputs + created_batch_sub_task_id = subtask_interface.create_subtask( + task_id=task_id, + inputs={ + "task_id": str(task_id), + "sub_task_id": batch_sub_task_id, + "batch_size": len(batch_rows), + "s3_uri": s3_uri, + } + ) + logger.info(f"Created batch subtask {created_batch_sub_task_id}") + + # Send message with S3 reference + send_to_address2uprn_queue( + task_id=str(task_id), + sub_task_id=batch_sub_task_id, + s3_uri=s3_uri, + ) + + return created_batch_sub_task_id + + def handler(event, context): print(f"Function: {context.function_name}") print(f"Request ID: {context.aws_request_id}") @@ -112,6 +203,7 @@ def handler(event, context): results = [] errors = [] subtask_interface = SubTaskInterface() + bucket_name = os.getenv("S3_BUCKET_NAME") for record in records: task_id = None @@ -148,6 +240,12 @@ def handler(event, context): ) logger.info(f"Created subtask {subtask_id} for task {task_id}") + # Mark subtask as in progress + subtask_interface.update_subtask_status( + subtask_id, "in progress" + ) + logger.info(f"Marked subtask {subtask_id} as in progress") + # Read CSV from S3 logger.info(f"Processing S3 URI: {s3_uri}") bucket, key = parse_s3_uri(s3_uri) @@ -184,9 +282,11 @@ def handler(event, context): for postcode, rows in postcode_to_addresses.items(): all_rows.extend(rows) try: - send_to_address2uprn_queue( - task_id=str(task_id), - rows=all_rows, + create_batch_and_send_to_address2uprn( + batch_rows=all_rows, + task_id=task_id, + subtask_interface=subtask_interface, + bucket_name=bucket_name, ) logger.info( f"Sent all {len(all_rows)} rows in single batch to address2UPRN queue" @@ -214,9 +314,11 @@ def handler(event, context): # First, send the current batch if it has data if batch_rows: try: - send_to_address2uprn_queue( - task_id=str(task_id), - rows=batch_rows, + create_batch_and_send_to_address2uprn( + batch_rows=batch_rows, + task_id=task_id, + subtask_interface=subtask_interface, + bucket_name=bucket_name, ) logger.info( f"Sent batch of {len(batch_rows)} rows to address2UPRN queue" @@ -236,9 +338,11 @@ def handler(event, context): # Send the large postcode on its own try: - send_to_address2uprn_queue( - task_id=str(task_id), - rows=rows, + create_batch_and_send_to_address2uprn( + batch_rows=rows, + task_id=task_id, + subtask_interface=subtask_interface, + bucket_name=bucket_name, ) logger.info( f"Sent large postcode {postcode} ({len(rows)} rows) to address2UPRN queue" @@ -263,9 +367,11 @@ def handler(event, context): f"Batch threshold reached: current {len(batch_rows)} + next postcode {len(rows)} = {current_batch_size} > {batch_size}" ) try: - send_to_address2uprn_queue( - task_id=str(task_id), - rows=batch_rows, + create_batch_and_send_to_address2uprn( + batch_rows=batch_rows, + task_id=task_id, + subtask_interface=subtask_interface, + bucket_name=bucket_name, ) logger.info( f"Sent batch of {len(batch_rows)} rows to address2UPRN queue (total sent: {total_sent})" @@ -290,9 +396,11 @@ def handler(event, context): # Send remaining batch if batch_rows: try: - send_to_address2uprn_queue( - task_id=str(task_id), - rows=batch_rows, + create_batch_and_send_to_address2uprn( + batch_rows=batch_rows, + task_id=task_id, + subtask_interface=subtask_interface, + bucket_name=bucket_name, ) total_sent += len(batch_rows) logger.info( diff --git a/infrastructure/terraform/lambda/postcodeSplitter/main.tf b/infrastructure/terraform/lambda/postcodeSplitter/main.tf index 78d927d3..e17d272d 100644 --- a/infrastructure/terraform/lambda/postcodeSplitter/main.tf +++ b/infrastructure/terraform/lambda/postcodeSplitter/main.tf @@ -55,7 +55,7 @@ module "lambda" { ENGINE_SQS_URL = "test" ENERGY_ASSESSMENTS_BUCKET = "test" ADDRESS2UPRN_QUEUE_URL = data.terraform_remote_state.address2uprn.outputs.address2uprn_queue_url - S3_BUCKET_NAME = "retrofit-data-dev" # Hardcoded as deployed via serverless i believe + S3_BUCKET_NAME = data.terraform_remote_state.shared.outputs.retrofit_sap_data_bucket_name }, ) } diff --git a/infrastructure/terraform/shared/main.tf b/infrastructure/terraform/shared/main.tf index eb2a679d..acf8c281 100644 --- a/infrastructure/terraform/shared/main.tf +++ b/infrastructure/terraform/shared/main.tf @@ -386,7 +386,7 @@ module "postcode_splitter_s3_read" { policy_name = "PostcodeSplitterReadS3" policy_description = "Allow postcode splitter Lambda to read from retrofit-data bucket" bucket_arns = ["arn:aws:s3:::retrofit-data-${var.stage}"] - actions = ["s3:GetObject"] + actions = ["s3:GetObject", "s3:ListBucket", "s3:PutObject"] resource_paths = ["/*"] }