mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
99 lines
3.4 KiB
Python
99 lines
3.4 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from enum import Enum
|
|
from typing import Any, Optional, TypeVar
|
|
|
|
from domain.data_transformation.column_classifier import (
|
|
ClassificationError,
|
|
ColumnClassifier,
|
|
)
|
|
from infrastructure.chatgpt.chatgpt import ChatGPT
|
|
from infrastructure.chatgpt.exceptions import ChatGPTClientError
|
|
|
|
E = TypeVar("E", bound=Enum)
|
|
|
|
|
|
class ChatGptColumnClassifier(ColumnClassifier[E]):
|
|
"""ColumnClassifier backed by ChatGPT, parametrised by a category enum.
|
|
|
|
The same classification path -- prompt, JSON parsing, UNKNOWN fallback --
|
|
serves any category enum; only ``category_enum`` and its ``unknown``
|
|
member differ between columns.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
chat_gpt: ChatGPT,
|
|
category_enum: type[E],
|
|
unknown: E,
|
|
extra_instructions: Optional[str] = None,
|
|
) -> None:
|
|
self._chat_gpt = chat_gpt
|
|
self._category_enum = category_enum
|
|
self._unknown = unknown
|
|
# Free-form column-specific guidance appended to the system prompt
|
|
# ahead of the JSON-output instruction. Lets each column ship its
|
|
# own hints (e.g. wall-type construction-era ranges) without the
|
|
# generic classifier knowing what they are.
|
|
self._extra_instructions = extra_instructions
|
|
|
|
def classify(self, descriptions: set[str]) -> dict[str, E]:
|
|
if not descriptions:
|
|
return {}
|
|
try:
|
|
reply = self._chat_gpt.generate(
|
|
prompt=json.dumps(sorted(descriptions)),
|
|
system_prompt=self._system_prompt(),
|
|
)
|
|
except ChatGPTClientError as error:
|
|
raise ClassificationError(
|
|
f"ChatGPT classification failed for "
|
|
f"{self._category_enum.__name__}."
|
|
) from error
|
|
try:
|
|
raw: dict[str, Any] = json.loads(self._strip_code_fence(reply))
|
|
except json.JSONDecodeError as error:
|
|
raise ClassificationError(
|
|
f"ChatGPT returned a reply that is not valid JSON: {reply!r}"
|
|
) from error
|
|
return {
|
|
description: self._to_category(raw.get(description))
|
|
for description in descriptions
|
|
}
|
|
|
|
def _system_prompt(self) -> str:
|
|
categories = ", ".join(
|
|
member.value
|
|
for member in self._category_enum
|
|
if member is not self._unknown
|
|
)
|
|
parts = [
|
|
"Classify each free-text description into exactly one category. ",
|
|
f"Categories: {categories}. ",
|
|
]
|
|
if self._extra_instructions:
|
|
parts.append(self._extra_instructions + " ")
|
|
parts.append(
|
|
"Reply with only a JSON object mapping each original description "
|
|
"to its category, and nothing else."
|
|
)
|
|
return "".join(parts)
|
|
|
|
def _to_category(self, value: Any) -> E:
|
|
"""Map a reply value to a category member, defaulting to UNKNOWN."""
|
|
try:
|
|
return self._category_enum(value)
|
|
except ValueError:
|
|
return self._unknown
|
|
|
|
@staticmethod
|
|
def _strip_code_fence(reply: str) -> str:
|
|
"""Remove a surrounding markdown code fence, if ChatGPT added one."""
|
|
text = reply.strip()
|
|
if not text.startswith("```"):
|
|
return text
|
|
lines = text.splitlines()[1:]
|
|
if lines and lines[-1].strip() == "```":
|
|
lines = lines[:-1]
|
|
return "\n".join(lines)
|