Model/infrastructure/chatgpt/chatgpt_column_classifier.py

99 lines
3.4 KiB
Python

from __future__ import annotations
import json
from enum import Enum
from typing import Any, Optional, TypeVar
from domain.data_transformation.column_classifier import (
ClassificationError,
ColumnClassifier,
)
from infrastructure.chatgpt.chatgpt import ChatGPT
from infrastructure.chatgpt.exceptions import ChatGPTClientError
E = TypeVar("E", bound=Enum)
class ChatGptColumnClassifier(ColumnClassifier[E]):
"""ColumnClassifier backed by ChatGPT, parametrised by a category enum.
The same classification path -- prompt, JSON parsing, UNKNOWN fallback --
serves any category enum; only ``category_enum`` and its ``unknown``
member differ between columns.
"""
def __init__(
self,
chat_gpt: ChatGPT,
category_enum: type[E],
unknown: E,
extra_instructions: Optional[str] = None,
) -> None:
self._chat_gpt = chat_gpt
self._category_enum = category_enum
self._unknown = unknown
# Free-form column-specific guidance appended to the system prompt
# ahead of the JSON-output instruction. Lets each column ship its
# own hints (e.g. wall-type construction-era ranges) without the
# generic classifier knowing what they are.
self._extra_instructions = extra_instructions
def classify(self, descriptions: set[str]) -> dict[str, E]:
if not descriptions:
return {}
try:
reply = self._chat_gpt.generate(
prompt=json.dumps(sorted(descriptions)),
system_prompt=self._system_prompt(),
)
except ChatGPTClientError as error:
raise ClassificationError(
f"ChatGPT classification failed for "
f"{self._category_enum.__name__}."
) from error
try:
raw: dict[str, Any] = json.loads(self._strip_code_fence(reply))
except json.JSONDecodeError as error:
raise ClassificationError(
f"ChatGPT returned a reply that is not valid JSON: {reply!r}"
) from error
return {
description: self._to_category(raw.get(description))
for description in descriptions
}
def _system_prompt(self) -> str:
categories = ", ".join(
member.value
for member in self._category_enum
if member is not self._unknown
)
parts = [
"Classify each free-text description into exactly one category. ",
f"Categories: {categories}. ",
]
if self._extra_instructions:
parts.append(self._extra_instructions + " ")
parts.append(
"Reply with only a JSON object mapping each original description "
"to its category, and nothing else."
)
return "".join(parts)
def _to_category(self, value: Any) -> E:
"""Map a reply value to a category member, defaulting to UNKNOWN."""
try:
return self._category_enum(value)
except ValueError:
return self._unknown
@staticmethod
def _strip_code_fence(reply: str) -> str:
"""Remove a surrounding markdown code fence, if ChatGPT added one."""
text = reply.strip()
if not text.startswith("```"):
return text
lines = text.splitlines()[1:]
if lines and lines[-1].strip() == "```":
lines = lines[:-1]
return "\n".join(lines)