mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
79 lines
2.5 KiB
Python
79 lines
2.5 KiB
Python
import os
|
|
import boto3
|
|
import pandas as pd
|
|
from io import BytesIO
|
|
from typing import Any
|
|
from uuid import UUID
|
|
from datetime import datetime, timezone
|
|
|
|
from utils.logger import setup_logger
|
|
from backend.utils.subtasks import subtask_handler
|
|
from backend.app.db.functions.bulk_address_uploads_functions import (
|
|
set_combined_output_s3_uri,
|
|
set_combining_status,
|
|
)
|
|
|
|
logger = setup_logger()
|
|
|
|
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
|
|
|
|
|
|
def list_csv_files(s3_client, bucket: str, task_id: str) -> list[str]:
|
|
paginator = s3_client.get_paginator("list_objects_v2")
|
|
prefix = f"ara_raw_outputs/{task_id}/"
|
|
keys = []
|
|
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
|
|
for obj in page.get("Contents", []):
|
|
if obj["Key"].endswith(".csv"):
|
|
keys.append(obj["Key"])
|
|
return keys
|
|
|
|
|
|
def download_csv(s3_client, bucket: str, key: str) -> pd.DataFrame:
|
|
obj = s3_client.get_object(Bucket=bucket, Key=key)
|
|
return pd.read_csv(BytesIO(obj["Body"].read()))
|
|
|
|
|
|
@subtask_handler()
|
|
def handler(body: dict[str, Any], context: Any) -> str:
|
|
task_id_str: str = body.get("task_id", "")
|
|
|
|
if not task_id_str:
|
|
raise RuntimeError("Missing task_id in message body")
|
|
|
|
set_combining_status(UUID(task_id_str))
|
|
|
|
bucket = S3_BUCKET_NAME
|
|
if not bucket:
|
|
raise RuntimeError("S3_BUCKET_NAME env var not set")
|
|
|
|
s3 = boto3.client("s3")
|
|
|
|
logger.info(f"Combining ara_raw_outputs for task {task_id_str}")
|
|
|
|
csv_keys = list_csv_files(s3, bucket, task_id_str)
|
|
if not csv_keys:
|
|
raise RuntimeError(f"No CSV files found under ara_raw_outputs/{task_id_str}/")
|
|
|
|
logger.info(f"Found {len(csv_keys)} CSV files")
|
|
|
|
dfs = [download_csv(s3, bucket, key) for key in csv_keys]
|
|
combined = pd.concat(dfs, ignore_index=True)
|
|
logger.info(f"Combined {len(combined)} rows from {len(dfs)} files")
|
|
|
|
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S")
|
|
output_key = f"bulk_final_outputs/{task_id_str}/combined_{timestamp}.csv"
|
|
|
|
csv_buffer = BytesIO()
|
|
combined.to_csv(csv_buffer, index=False)
|
|
csv_buffer.seek(0)
|
|
s3.put_object(Bucket=bucket, Key=output_key, Body=csv_buffer.getvalue())
|
|
|
|
s3_uri = f"s3://{bucket}/{output_key}"
|
|
logger.info(f"Saved combined CSV to {s3_uri}")
|
|
print(f"OUTPUT_S3_URI: {s3_uri}")
|
|
|
|
set_combined_output_s3_uri(UUID(task_id_str), s3_uri)
|
|
logger.info(f"Persisted combined_output_s3_uri + awaiting_review status for task {task_id_str}")
|
|
|
|
return s3_uri
|