mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
169 lines
6.5 KiB
Python
169 lines
6.5 KiB
Python
import logging
|
|
import os
|
|
from typing import Any
|
|
|
|
import boto3
|
|
|
|
from applications.landlord_description_overrides.landlord_description_overrides_trigger_body import (
|
|
LandlordDescriptionOverridesTriggerBody,
|
|
)
|
|
from domain.epc.built_form_type import BuiltFormType
|
|
from domain.epc.property_type import PropertyType
|
|
from domain.epc.roof_type import RoofType
|
|
from domain.epc.wall_type import WallType
|
|
from domain.epc.wall_type_construction_dates import (
|
|
wall_type_construction_date_prompt_hint,
|
|
)
|
|
from infrastructure.chatgpt.chatgpt import ChatGPT
|
|
from infrastructure.chatgpt.chatgpt_column_classifier import ChatGptColumnClassifier
|
|
from infrastructure.landlord_overrides.landlord_overrides_postgres_repository import (
|
|
LandlordOverridesRepository,
|
|
)
|
|
from infrastructure.postgres.config import PostgresConfig
|
|
from infrastructure.postgres.engine import commit_scope, make_engine, make_session
|
|
from infrastructure.postgres.landlord_built_form_type_override_table import (
|
|
LandlordBuiltFormTypeOverrideRow,
|
|
)
|
|
from infrastructure.postgres.landlord_property_type_override_table import (
|
|
LandlordPropertyTypeOverrideRow,
|
|
)
|
|
from infrastructure.postgres.landlord_roof_type_override_table import (
|
|
LandlordRoofTypeOverrideRow,
|
|
)
|
|
from infrastructure.postgres.landlord_wall_type_override_table import (
|
|
LandlordWallTypeOverrideRow,
|
|
)
|
|
from infrastructure.s3.csv_s3_client import CsvS3Client
|
|
from infrastructure.s3.s3_uri import parse_s3_uri
|
|
from orchestration.classifiable_column import ClassifiableColumn
|
|
from orchestration.landlord_description_overrides_orchestrator import (
|
|
LandlordDescriptionOverridesOrchestrator,
|
|
)
|
|
from orchestration.task_orchestrator import TaskOrchestrator
|
|
from repositories.unstandardised_address.unstandardised_address_list_csv_s3_repository import (
|
|
UnstandardisedAddressListCsvS3Repository,
|
|
)
|
|
from utilities.aws_lambda.subtask_handler import subtask_handler
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _build_columns(
|
|
column_mapping: dict[str, str], chat_gpt: ChatGPT, session: Any
|
|
) -> list[ClassifiableColumn[Any]]:
|
|
"""One ClassifiableColumn per mapped category.
|
|
|
|
``column_mapping`` is ``{category -> source CSV header}``. One header may
|
|
feed several categories -- e.g. ``"Property Type"`` -> property_type and
|
|
built_form_type -- which falls out naturally because each is a separate
|
|
entry. Unknown categories are skipped.
|
|
"""
|
|
factories = {
|
|
"property_type": lambda src: ClassifiableColumn(
|
|
name="property_type",
|
|
source_column=src,
|
|
classifier=ChatGptColumnClassifier(
|
|
chat_gpt, PropertyType, PropertyType.UNKNOWN
|
|
),
|
|
repo=LandlordOverridesRepository[PropertyType](
|
|
session, LandlordPropertyTypeOverrideRow
|
|
),
|
|
),
|
|
"built_form_type": lambda src: ClassifiableColumn(
|
|
name="built_form_type",
|
|
source_column=src,
|
|
classifier=ChatGptColumnClassifier(
|
|
chat_gpt, BuiltFormType, BuiltFormType.UNKNOWN
|
|
),
|
|
repo=LandlordOverridesRepository[BuiltFormType](
|
|
session, LandlordBuiltFormTypeOverrideRow
|
|
),
|
|
),
|
|
"wall_type": lambda src: ClassifiableColumn(
|
|
name="wall_type",
|
|
source_column=src,
|
|
classifier=ChatGptColumnClassifier(
|
|
chat_gpt,
|
|
WallType,
|
|
WallType.UNKNOWN,
|
|
extra_instructions=wall_type_construction_date_prompt_hint(),
|
|
),
|
|
repo=LandlordOverridesRepository[WallType](
|
|
session, LandlordWallTypeOverrideRow
|
|
),
|
|
),
|
|
"roof_type": lambda src: ClassifiableColumn(
|
|
name="roof_type",
|
|
source_column=src,
|
|
classifier=ChatGptColumnClassifier(
|
|
chat_gpt, RoofType, RoofType.UNKNOWN
|
|
),
|
|
repo=LandlordOverridesRepository[RoofType](
|
|
session, LandlordRoofTypeOverrideRow
|
|
),
|
|
),
|
|
}
|
|
|
|
columns: list[ClassifiableColumn[Any]] = []
|
|
for category, source_column in column_mapping.items():
|
|
factory = factories.get(category)
|
|
if factory is None:
|
|
logger.warning("Unknown classifier category %r; skipping.", category)
|
|
continue
|
|
columns.append(factory(source_column))
|
|
return columns
|
|
|
|
|
|
@subtask_handler()
|
|
def handler(
|
|
body: dict[str, Any], context: Any, task_orchestrator: TaskOrchestrator
|
|
) -> dict[str, int]:
|
|
trigger = LandlordDescriptionOverridesTriggerBody.model_validate(body)
|
|
|
|
# The classifier reads a dedicated CSV of the classifier columns (raw
|
|
# landlord headers preserved), converted from the upload by the frontend, so
|
|
# the S3 bucket comes from the trigger URI rather than a fixed env var.
|
|
bucket, _key = parse_s3_uri(trigger.s3_uri)
|
|
|
|
# boto3.client is overloaded per-service in the installed stubs; cast to Any
|
|
# so the strict-mode checker treats it as opaque.
|
|
boto3_client: Any = (
|
|
boto3.client
|
|
) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
boto_s3: Any = boto3_client("s3")
|
|
|
|
csv_client = CsvS3Client(boto_s3, bucket)
|
|
unstandardised_address_repo = UnstandardisedAddressListCsvS3Repository(
|
|
csv_client, bucket
|
|
)
|
|
|
|
# Raw rows, not load_batch: the classifier CSV carries the description
|
|
# columns but not the canonical address/postcode columns load_batch requires.
|
|
rows = csv_client.read_rows(trigger.s3_uri)
|
|
|
|
engine = make_engine(PostgresConfig.from_env(os.environ))
|
|
# The session is built up front (SQLModel sessions are lazy, so no
|
|
# connection is checked out yet) and owned by this handler. Classification
|
|
# runs first and calls ChatGPT, which is slow; we deliberately keep no
|
|
# transaction open across it. Only the persistence below -- inside
|
|
# ``commit_scope`` -- holds a connection.
|
|
session = make_session(engine)
|
|
try:
|
|
chat_gpt = ChatGPT()
|
|
columns = _build_columns(trigger.column_mapping, chat_gpt, session)
|
|
orchestrator = LandlordDescriptionOverridesOrchestrator(
|
|
unstandardised_address_repo=unstandardised_address_repo,
|
|
columns=columns,
|
|
)
|
|
|
|
classified = orchestrator.classify_from_rows(rows)
|
|
|
|
with commit_scope(session):
|
|
orchestrator.persist(classified, portfolio_id=trigger.portfolio_id)
|
|
finally:
|
|
session.close()
|
|
|
|
counts = {name: len(mapping) for name, mapping in classified.items()}
|
|
for name, n in counts.items():
|
|
logger.info("Classified %d descriptions for column %r.", n, name)
|
|
return counts
|