diff --git a/scripts/run_audit_generator_local.py b/scripts/run_audit_generator_local.py index 748e3f89..ab0215c8 100644 --- a/scripts/run_audit_generator_local.py +++ b/scripts/run_audit_generator_local.py @@ -1,18 +1,21 @@ """ -Run audit_generator locally. Writes XLSX to ./local_output/ instead of S3. +Run audit_generator locally. Usage: cd /workspaces/model - python scripts/run_audit_generator_local.py + python scripts/run_audit_generator_local.py [] + +Prompts for deal ID and S3 destination (local file or real S3) if not supplied. """ from __future__ import annotations import os import sys -from io import BytesIO from pathlib import Path -from typing import Any +from typing import Any, Union + +import boto3 # Load .env before importing infra modules from dotenv import load_dotenv @@ -21,6 +24,7 @@ load_dotenv(Path(__file__).parent.parent / "backend" / ".env") from infrastructure.postgres.config import PostgresConfig from infrastructure.postgres.engine import make_engine, make_session +from infrastructure.s3.s3_client import S3Client from orchestration.audit_generator_orchestrator import AuditGeneratorOrchestrator from orchestration.audit_generator_unit_of_work import AuditGeneratorUnitOfWork @@ -46,9 +50,19 @@ class _LocalS3Client: return str(dest) +def _make_s3_client() -> Union[S3Client, "_LocalS3Client"]: + use_real = input("Use real S3? [y/N]: ").strip().lower() == "y" + if use_real: + bucket = "retrofit-energy-assessments-dev" + boto3_client: Any = boto3.client + return S3Client(boto_s3_client=boto3_client("s3"), bucket=bucket) + output_dir = Path(__file__).parent.parent / "local_output" + return _LocalS3Client(output_dir) + + def main() -> None: deal_id = sys.argv[1] if len(sys.argv) > 1 else input("hubspot_deal_id: ").strip() - output_dir = Path(__file__).parent.parent / "local_output" + s3_client = _make_s3_client() engine = make_engine(PostgresConfig.from_env(os.environ)) @@ -60,7 +74,7 @@ def main() -> None: AuditGeneratorOrchestrator( hubspot_deal_id=deal_id, - s3_client=_LocalS3Client(output_dir), # type: ignore[arg-type] + s3_client=s3_client, # type: ignore[arg-type] uow_factory=uow_factory, ).run()