from __future__ import annotations import json from enum import Enum from typing import Any, Optional, TypeVar from domain.data_transformation.column_classifier import ( ClassificationError, ColumnClassifier, ) from infrastructure.chatgpt.chatgpt import ChatGPT from infrastructure.chatgpt.exceptions import ChatGPTClientError E = TypeVar("E", bound=Enum) class ChatGptColumnClassifier(ColumnClassifier[E]): """ColumnClassifier backed by ChatGPT, parametrised by a category enum. The same classification path -- prompt, JSON parsing, UNKNOWN fallback -- serves any category enum; only ``category_enum`` and its ``unknown`` member differ between columns. """ def __init__( self, chat_gpt: ChatGPT, category_enum: type[E], unknown: E, extra_instructions: Optional[str] = None, ) -> None: self._chat_gpt = chat_gpt self._category_enum = category_enum self._unknown = unknown # Free-form column-specific guidance appended to the system prompt # ahead of the JSON-output instruction. Lets each column ship its # own hints (e.g. wall-type construction-era ranges) without the # generic classifier knowing what they are. self._extra_instructions = extra_instructions def classify(self, descriptions: set[str]) -> dict[str, E]: if not descriptions: return {} try: reply = self._chat_gpt.generate( prompt=json.dumps(sorted(descriptions)), system_prompt=self._system_prompt(), ) except ChatGPTClientError as error: raise ClassificationError( f"ChatGPT classification failed for " f"{self._category_enum.__name__}." ) from error try: raw: dict[str, Any] = json.loads(self._strip_code_fence(reply)) except json.JSONDecodeError as error: raise ClassificationError( f"ChatGPT returned a reply that is not valid JSON: {reply!r}" ) from error return { description: self._to_category(raw.get(description)) for description in descriptions } def _system_prompt(self) -> str: categories = ", ".join( member.value for member in self._category_enum if member is not self._unknown ) parts = [ "Classify each free-text description into exactly one category. ", f"Categories: {categories}. ", ] if self._extra_instructions: parts.append(self._extra_instructions + " ") parts.append( "Reply with only a JSON object mapping each original description " "to its category, and nothing else." ) return "".join(parts) def _to_category(self, value: Any) -> E: """Map a reply value to a category member, defaulting to UNKNOWN.""" try: return self._category_enum(value) except ValueError: return self._unknown @staticmethod def _strip_code_fence(reply: str) -> str: """Remove a surrounding markdown code fence, if ChatGPT added one.""" text = reply.strip() if not text.startswith("```"): return text lines = text.splitlines()[1:] if lines and lines[-1].strip() == "```": lines = lines[:-1] return "\n".join(lines)