mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
69 lines
2.2 KiB
Python
69 lines
2.2 KiB
Python
import logging
|
|
from typing import Any
|
|
|
|
import boto3
|
|
from orchestration.sal_orchestrator import (
|
|
SALOrchestrator,
|
|
)
|
|
from infrastructure.s3.csv_s3_client import CsvS3Client
|
|
from repositories.unstandardised_address.unstandardised_address_list_csv_s3_repository import (
|
|
UnstandardisedAddressListCsvS3Repository,
|
|
)
|
|
from domain.addresses.unstandardised_address import AddressList
|
|
from domain.sal.column_classifier import ColumnClassifier
|
|
from domain.sal.property_type import PropertyType
|
|
from domain.sal.wall_type import WallType
|
|
from infrastructure.chatgpt.chatgpt import ChatGPT
|
|
from infrastructure.chatgpt.chatgpt_column_classifier import (
|
|
ChatGptColumnClassifier,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def handler(
|
|
body: dict[str, Any],
|
|
context: Any,
|
|
) -> dict[str, list[str]]:
|
|
|
|
s3_uri = "s3://retrofit-data-dev/bulk_onboarding_inputs/hyde2 (1).csv"
|
|
bucket = "retrofit-data-dev"
|
|
|
|
# boto3.client is overloaded per-service in the installed stubs; cast
|
|
# to Any so the strict-mode checker treats it as opaque.
|
|
boto3_client: Any = boto3.client # noqa
|
|
boto_s3: Any = boto3_client("s3")
|
|
|
|
csv_client = CsvS3Client(boto_s3, bucket)
|
|
unstandardised_address_repo = UnstandardisedAddressListCsvS3Repository(
|
|
csv_client, bucket
|
|
)
|
|
|
|
# One ChatGPT-backed classifier per landlord-CSV column, keyed by column name.
|
|
chat_gpt = ChatGPT()
|
|
classifiers: dict[str, ColumnClassifier[Any]] = {
|
|
"Property Type": ChatGptColumnClassifier(
|
|
chat_gpt, PropertyType, PropertyType.UNKNOWN
|
|
),
|
|
"Walls": ChatGptColumnClassifier(chat_gpt, WallType, WallType.UNKNOWN),
|
|
}
|
|
|
|
sal = SALOrchestrator(
|
|
unstandardised_address_repo=unstandardised_address_repo,
|
|
classifiers=classifiers,
|
|
)
|
|
|
|
addressList: AddressList = sal.get_unstandardised_addresses(input_s3_uri=s3_uri)
|
|
|
|
# Cap the batch to the first 20 while the ChatGPT path is under test.
|
|
addressList = AddressList(addressList[:20])
|
|
|
|
classified = sal.classify_columns(addressList)
|
|
for column, mapping in classified.items():
|
|
logger.info(
|
|
"Classified %d descriptions for column %r.", len(mapping), column
|
|
)
|
|
|
|
# TODO: persist `classified` to landlord overrides.
|
|
|
|
return {"hello": ["200"]}
|