from enum import Enum from typing import Any from domain.addresses.unstandardised_address import AddressList from orchestration.classifiable_column import ClassifiableColumn from repositories.unstandardised_address.unstandardised_address_list_repository import ( UnstandardisedAddressListRepository, ) class LandlordDescriptionOverridesOrchestrator: def __init__( self, unstandardised_address_repo: UnstandardisedAddressListRepository, columns: list[ClassifiableColumn[Any]], ) -> None: self._unstandardised_address_repo = unstandardised_address_repo # Each entry is one (source CSV column, target enum) classification. # Two entries may share ``source_column`` -- e.g. ``"Property Type"`` # feeds both PropertyType and BuiltFormType classifiers -- so the # registry is a list rather than a dict keyed by header. self._columns = columns def get_unstandardised_addresses( self, input_s3_uri: str, ) -> AddressList: return self._unstandardised_address_repo.load_batch(input_s3_uri) def get_col_to_description_mappings( self, list_of_unstandardised_address: AddressList ) -> dict[str, set[str]]: mappings: dict[str, set[str]] = {} for unstandardised_address in list_of_unstandardised_address: for key, value in unstandardised_address.additional_info.items(): bucket = mappings.setdefault(key, set()) # A comma-separated value is several descriptions in one cell; # split it so each is its own entry. Lower-case so case-only # typos collapse to one variant. for variant in value.split(","): variant = variant.strip().lower() if variant: bucket.add(variant) return mappings def classify_columns( self, addresses: AddressList ) -> dict[str, dict[str, Enum]]: """Classify every registered column's descriptions. Returns a mapping of ``ClassifiableColumn.name`` to ``{description: category}``. A registered column whose ``source_column`` is absent from the addresses contributes an empty inner mapping. """ col_to_desc = self.get_col_to_description_mappings(addresses) return { column.name: column.classifier.classify( col_to_desc.get(column.source_column, set()) ) for column in self._columns } def persist( self, classified: dict[str, dict[str, Enum]], portfolio_id: int ) -> None: """Persist already-classified results via each column's repository. ``classified`` is keyed by ``ClassifiableColumn.name`` -- the shape ``classify_columns`` and ``classify_from_rows`` return. Each non-empty mapping is written through the column's own repo under ``source = 'classifier'``; an empty mapping (a registered column absent from this batch) skips the DB round-trip. The orchestrator does not commit -- the caller owns the transaction boundary, and is expected to open it only around this call so the slow classification never holds a connection. """ for column in self._columns: mapping = classified.get(column.name) if not mapping: continue column.repo.upsert_all(portfolio_id, mapping) def classify_and_persist( self, addresses: AddressList, portfolio_id: int ) -> dict[str, dict[str, Enum]]: """Classify every registered column and persist the results. Returns the same shape as ``classify_columns`` so callers can log per-column counts. """ classified = self.classify_columns(addresses) self.persist(classified, portfolio_id) return classified def classify_from_rows( self, rows: list[dict[str, str]] ) -> dict[str, dict[str, Enum]]: """Classify raw CSV rows without touching the database. The classification half of ``classify_and_persist_from_rows``, split out so a caller can run the slow ChatGPT work *before* opening a transaction and then write the finished results with ``persist`` inside one short-lived connection. Unlike the ``AddressList`` path this builds no ``AddressList``, so it has no canonical address/postcode requirement -- the classifier only needs the raw description cells. Used when reading the original landlord upload (raw headers) rather than the address-matching CSV. """ col_to_desc = self._descriptions_from_rows(rows) return { column.name: column.classifier.classify( col_to_desc.get(column.source_column, set()) ) for column in self._columns } def classify_and_persist_from_rows( self, rows: list[dict[str, str]], portfolio_id: int ) -> dict[str, dict[str, Enum]]: """Classify + persist straight from raw CSV rows in one call. A convenience composition of ``classify_from_rows`` + ``persist``. Prefer calling the two separately when classification is slow, so the transaction opens only around ``persist`` (see the Lambda handler). """ classified = self.classify_from_rows(rows) self.persist(classified, portfolio_id) return classified @staticmethod def _descriptions_from_rows(rows: list[dict[str, str]]) -> dict[str, set[str]]: mappings: dict[str, set[str]] = {} for row in rows: for key, value in row.items(): bucket = mappings.setdefault(key, set()) for variant in (value or "").split(","): variant = variant.strip().lower() if variant: bucket.add(variant) return mappings