import logging import os from typing import Any import boto3 from applications.landlord_description_overrides.landlord_description_overrides_trigger_body import ( LandlordDescriptionOverridesTriggerBody, ) from domain.epc.built_form_type import BuiltFormType from domain.epc.property_type import PropertyType from domain.epc.roof_type import RoofType from domain.epc.wall_type import WallType from domain.epc.wall_type_construction_dates import ( wall_type_construction_date_prompt_hint, ) from infrastructure.chatgpt.chatgpt import ChatGPT from infrastructure.chatgpt.chatgpt_column_classifier import ChatGptColumnClassifier from infrastructure.landlord_overrides.landlord_overrides_postgres_repository import ( LandlordOverridesRepository, ) from infrastructure.postgres.config import PostgresConfig from infrastructure.postgres.engine import commit_scope, make_engine, make_session from infrastructure.postgres.landlord_built_form_type_override_table import ( LandlordBuiltFormTypeOverrideRow, ) from infrastructure.postgres.landlord_property_type_override_table import ( LandlordPropertyTypeOverrideRow, ) from infrastructure.postgres.landlord_roof_type_override_table import ( LandlordRoofTypeOverrideRow, ) from infrastructure.postgres.landlord_wall_type_override_table import ( LandlordWallTypeOverrideRow, ) from infrastructure.s3.csv_s3_client import CsvS3Client from infrastructure.s3.s3_uri import parse_s3_uri from orchestration.classifiable_column import ClassifiableColumn from orchestration.landlord_description_overrides_orchestrator import ( LandlordDescriptionOverridesOrchestrator, ) from orchestration.task_orchestrator import TaskOrchestrator from repositories.unstandardised_address.unstandardised_address_list_csv_s3_repository import ( UnstandardisedAddressListCsvS3Repository, ) from utilities.aws_lambda.subtask_handler import subtask_handler logger = logging.getLogger(__name__) def _build_columns( column_mapping: dict[str, str], chat_gpt: ChatGPT, session: Any ) -> list[ClassifiableColumn[Any]]: """One ClassifiableColumn per mapped category. ``column_mapping`` is ``{category -> source CSV header}``. One header may feed several categories -- e.g. ``"Property Type"`` -> property_type and built_form_type -- which falls out naturally because each is a separate entry. Unknown categories are skipped. """ factories = { "property_type": lambda src: ClassifiableColumn( name="property_type", source_column=src, classifier=ChatGptColumnClassifier( chat_gpt, PropertyType, PropertyType.UNKNOWN ), repo=LandlordOverridesRepository[PropertyType]( session, LandlordPropertyTypeOverrideRow ), ), "built_form_type": lambda src: ClassifiableColumn( name="built_form_type", source_column=src, classifier=ChatGptColumnClassifier( chat_gpt, BuiltFormType, BuiltFormType.UNKNOWN ), repo=LandlordOverridesRepository[BuiltFormType]( session, LandlordBuiltFormTypeOverrideRow ), ), "wall_type": lambda src: ClassifiableColumn( name="wall_type", source_column=src, classifier=ChatGptColumnClassifier( chat_gpt, WallType, WallType.UNKNOWN, extra_instructions=wall_type_construction_date_prompt_hint(), ), repo=LandlordOverridesRepository[WallType]( session, LandlordWallTypeOverrideRow ), ), "roof_type": lambda src: ClassifiableColumn( name="roof_type", source_column=src, classifier=ChatGptColumnClassifier( chat_gpt, RoofType, RoofType.UNKNOWN ), repo=LandlordOverridesRepository[RoofType]( session, LandlordRoofTypeOverrideRow ), ), } columns: list[ClassifiableColumn[Any]] = [] for category, source_column in column_mapping.items(): factory = factories.get(category) if factory is None: logger.warning("Unknown classifier category %r; skipping.", category) continue columns.append(factory(source_column)) return columns @subtask_handler() def handler( body: dict[str, Any], context: Any, task_orchestrator: TaskOrchestrator ) -> dict[str, int]: trigger = LandlordDescriptionOverridesTriggerBody.model_validate(body) # The classifier reads a dedicated CSV of the classifier columns (raw # landlord headers preserved), converted from the upload by the frontend, so # the S3 bucket comes from the trigger URI rather than a fixed env var. bucket, _key = parse_s3_uri(trigger.s3_uri) # boto3.client is overloaded per-service in the installed stubs; cast to Any # so the strict-mode checker treats it as opaque. boto3_client: Any = ( boto3.client ) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] boto_s3: Any = boto3_client("s3") csv_client = CsvS3Client(boto_s3, bucket) unstandardised_address_repo = UnstandardisedAddressListCsvS3Repository( csv_client, bucket ) # Raw rows, not load_batch: the classifier CSV carries the description # columns but not the canonical address/postcode columns load_batch requires. rows = csv_client.read_rows(trigger.s3_uri) engine = make_engine(PostgresConfig.from_env(os.environ)) # The session is built up front (SQLModel sessions are lazy, so no # connection is checked out yet) and owned by this handler. Classification # runs first and calls ChatGPT, which is slow; we deliberately keep no # transaction open across it. Only the persistence below -- inside # ``commit_scope`` -- holds a connection. session = make_session(engine) try: chat_gpt = ChatGPT() columns = _build_columns(trigger.column_mapping, chat_gpt, session) orchestrator = LandlordDescriptionOverridesOrchestrator( unstandardised_address_repo=unstandardised_address_repo, columns=columns, ) classified = orchestrator.classify_from_rows(rows) with commit_scope(session): orchestrator.persist(classified, portfolio_id=trigger.portfolio_id) finally: session.close() counts = {name: len(mapping) for name, mapping in classified.items()} for name, n in counts.items(): logger.info("Classified %d descriptions for column %r.", n, name) return counts