Model/applications/landlord_description_overrides/handler.py
2026-06-02 10:46:29 +00:00

169 lines
6.5 KiB
Python

import logging
import os
from typing import Any
import boto3
from applications.landlord_description_overrides.landlord_description_overrides_trigger_body import (
LandlordDescriptionOverridesTriggerBody,
)
from domain.epc.built_form_type import BuiltFormType
from domain.epc.property_type import PropertyType
from domain.epc.roof_type import RoofType
from domain.epc.wall_type import WallType
from domain.epc.wall_type_construction_dates import (
wall_type_construction_date_prompt_hint,
)
from infrastructure.chatgpt.chatgpt import ChatGPT
from infrastructure.chatgpt.chatgpt_column_classifier import ChatGptColumnClassifier
from infrastructure.landlord_overrides.landlord_overrides_postgres_repository import (
LandlordOverridesRepository,
)
from infrastructure.postgres.config import PostgresConfig
from infrastructure.postgres.engine import commit_scope, make_engine, make_session
from infrastructure.postgres.landlord_built_form_type_override_table import (
LandlordBuiltFormTypeOverrideRow,
)
from infrastructure.postgres.landlord_property_type_override_table import (
LandlordPropertyTypeOverrideRow,
)
from infrastructure.postgres.landlord_roof_type_override_table import (
LandlordRoofTypeOverrideRow,
)
from infrastructure.postgres.landlord_wall_type_override_table import (
LandlordWallTypeOverrideRow,
)
from infrastructure.s3.csv_s3_client import CsvS3Client
from infrastructure.s3.s3_uri import parse_s3_uri
from orchestration.classifiable_column import ClassifiableColumn
from orchestration.landlord_description_overrides_orchestrator import (
LandlordDescriptionOverridesOrchestrator,
)
from orchestration.task_orchestrator import TaskOrchestrator
from repositories.unstandardised_address.unstandardised_address_list_csv_s3_repository import (
UnstandardisedAddressListCsvS3Repository,
)
from utilities.aws_lambda.subtask_handler import subtask_handler
logger = logging.getLogger(__name__)
def _build_columns(
column_mapping: dict[str, str], chat_gpt: ChatGPT, session: Any
) -> list[ClassifiableColumn[Any]]:
"""One ClassifiableColumn per mapped category.
``column_mapping`` is ``{category -> source CSV header}``. One header may
feed several categories -- e.g. ``"Property Type"`` -> property_type and
built_form_type -- which falls out naturally because each is a separate
entry. Unknown categories are skipped.
"""
factories = {
"property_type": lambda src: ClassifiableColumn(
name="property_type",
source_column=src,
classifier=ChatGptColumnClassifier(
chat_gpt, PropertyType, PropertyType.UNKNOWN
),
repo=LandlordOverridesRepository[PropertyType](
session, LandlordPropertyTypeOverrideRow
),
),
"built_form_type": lambda src: ClassifiableColumn(
name="built_form_type",
source_column=src,
classifier=ChatGptColumnClassifier(
chat_gpt, BuiltFormType, BuiltFormType.UNKNOWN
),
repo=LandlordOverridesRepository[BuiltFormType](
session, LandlordBuiltFormTypeOverrideRow
),
),
"wall_type": lambda src: ClassifiableColumn(
name="wall_type",
source_column=src,
classifier=ChatGptColumnClassifier(
chat_gpt,
WallType,
WallType.UNKNOWN,
extra_instructions=wall_type_construction_date_prompt_hint(),
),
repo=LandlordOverridesRepository[WallType](
session, LandlordWallTypeOverrideRow
),
),
"roof_type": lambda src: ClassifiableColumn(
name="roof_type",
source_column=src,
classifier=ChatGptColumnClassifier(
chat_gpt, RoofType, RoofType.UNKNOWN
),
repo=LandlordOverridesRepository[RoofType](
session, LandlordRoofTypeOverrideRow
),
),
}
columns: list[ClassifiableColumn[Any]] = []
for category, source_column in column_mapping.items():
factory = factories.get(category)
if factory is None:
logger.warning("Unknown classifier category %r; skipping.", category)
continue
columns.append(factory(source_column))
return columns
@subtask_handler()
def handler(
body: dict[str, Any], context: Any, task_orchestrator: TaskOrchestrator
) -> dict[str, int]:
trigger = LandlordDescriptionOverridesTriggerBody.model_validate(body)
# The classifier reads a dedicated CSV of the classifier columns (raw
# landlord headers preserved), converted from the upload by the frontend, so
# the S3 bucket comes from the trigger URI rather than a fixed env var.
bucket, _key = parse_s3_uri(trigger.s3_uri)
# boto3.client is overloaded per-service in the installed stubs; cast to Any
# so the strict-mode checker treats it as opaque.
boto3_client: Any = (
boto3.client
) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
boto_s3: Any = boto3_client("s3")
csv_client = CsvS3Client(boto_s3, bucket)
unstandardised_address_repo = UnstandardisedAddressListCsvS3Repository(
csv_client, bucket
)
# Raw rows, not load_batch: the classifier CSV carries the description
# columns but not the canonical address/postcode columns load_batch requires.
rows = csv_client.read_rows(trigger.s3_uri)
engine = make_engine(PostgresConfig.from_env(os.environ))
# The session is built up front (SQLModel sessions are lazy, so no
# connection is checked out yet) and owned by this handler. Classification
# runs first and calls ChatGPT, which is slow; we deliberately keep no
# transaction open across it. Only the persistence below -- inside
# ``commit_scope`` -- holds a connection.
session = make_session(engine)
try:
chat_gpt = ChatGPT()
columns = _build_columns(trigger.column_mapping, chat_gpt, session)
orchestrator = LandlordDescriptionOverridesOrchestrator(
unstandardised_address_repo=unstandardised_address_repo,
columns=columns,
)
classified = orchestrator.classify_from_rows(rows)
with commit_scope(session):
orchestrator.persist(classified, portfolio_id=trigger.portfolio_id)
finally:
session.close()
counts = {name: len(mapping) for name, mapping in classified.items()}
for name, n in counts.items():
logger.info("Classified %d descriptions for column %r.", n, name)
return counts