mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
178 lines
6.7 KiB
Python
178 lines
6.7 KiB
Python
# OpenAI API Key (set this in your environment variables for security)
|
|
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
|
|
|
|
|
|
class DataRemapper:
|
|
def __init__(self, standard_values, standard_map=None, max_tokens=1000):
|
|
"""
|
|
Initialize the remapper with standard values and a predefined mapping.
|
|
|
|
:param standard_values: Set of allowed standardized values.
|
|
:param standard_map: Dictionary of common remappings {raw_value: standard_value}.
|
|
"""
|
|
self.standard_values = standard_values
|
|
self.standard_map = standard_map
|
|
self.fuzzy_threshold = 90 # Adjust fuzzy matching sensitivity
|
|
self.ai_model = "gpt-4-turbo" # Use gpt-3.5-turbo for cheaper processing
|
|
|
|
# Tokenizer for counting tokens
|
|
self.tokenizer = tiktoken.encoding_for_model(self.ai_model)
|
|
|
|
# Track token usage and remap dictionary
|
|
self.total_tokens_used = 0
|
|
self.total_cost = 0
|
|
self.remap_dict = {} # {original_value: standardized_value}
|
|
self.max_tokens = max_tokens # Limit for OpenAI API
|
|
|
|
# Memoization for AI calls
|
|
self.ai_cache = {} # {tuple(unmapped_values): {original_value: standardized_value}}
|
|
# Capture the reponse for debugging
|
|
self.ai_response = None
|
|
|
|
# OpenAI pricing (as of Feb 2024)
|
|
self.pricing = {
|
|
"gpt-4-turbo": {"input": 0.01 / 1000, "output": 0.03 / 1000},
|
|
"gpt-3.5-turbo": {"input": 0.0015 / 1000, "output": 0.002 / 1000},
|
|
}
|
|
|
|
self.openai_client = OpenAI(api_key=OPENAI_API_KEY)
|
|
|
|
@staticmethod
|
|
def clean_string(text):
|
|
"""Basic text cleaning: remove extra spaces, punctuation, and normalize case."""
|
|
if not isinstance(text, str):
|
|
return None
|
|
text = text.strip().lower()
|
|
text = re.sub(r'[^\w\s]', '', text) # Remove punctuation
|
|
# Replace double strings
|
|
text = re.sub(r'\s+', ' ', text)
|
|
return text
|
|
|
|
def fuzzy_match(self, text):
|
|
"""Use fuzzy matching to find the closest standard value."""
|
|
match, score = process.extractOne(text, self.standard_values) if text else (None, 0)
|
|
return match if score >= self.fuzzy_threshold else None
|
|
|
|
def count_tokens(self, text):
|
|
"""Estimate the number of tokens in a given text."""
|
|
return len(self.tokenizer.encode(text)) if text else 0
|
|
|
|
def ai_standardize(self, unmapped_values):
|
|
"""Call OpenAI API **once** for all unmapped values to minimize cost, with memoization."""
|
|
if not unmapped_values:
|
|
return {}
|
|
|
|
unmapped_tuple = tuple(sorted(unmapped_values)) # Ensure consistency for memoization
|
|
if unmapped_tuple in self.ai_cache:
|
|
return self.ai_cache[unmapped_tuple] # Return memoized result
|
|
|
|
prompt = f"""
|
|
You are an expert in data classification. Standardize each of these values into one of the categories:
|
|
{list(self.standard_values)}.
|
|
|
|
Return only a JSON dictionary where:
|
|
- The keys are the original values.
|
|
- The values are the standardized ones.
|
|
|
|
Strictly return JSON **without markdown formatting** or extra text.
|
|
|
|
Example Output:
|
|
{{
|
|
"BLKHOUS": "block house",
|
|
"BEDSIT": "bedsit"
|
|
}}
|
|
|
|
Values to standardize:
|
|
{unmapped_values}
|
|
"""
|
|
|
|
# Count input tokens
|
|
input_tokens = self.count_tokens(prompt)
|
|
if input_tokens > self.max_tokens:
|
|
raise ValueError("Input tokens exceed the maximum limit.")
|
|
|
|
logger.info("Calling OpenAI API for standardization...")
|
|
response = self.openai_client.chat.completions.create(
|
|
model=self.ai_model,
|
|
messages=[{"role": "user", "content": prompt}],
|
|
max_tokens=self.max_tokens,
|
|
temperature=0.1,
|
|
)
|
|
|
|
output_text = response.choices[0].message.content.strip()
|
|
output_tokens = self.count_tokens(output_text) # Count output tokens
|
|
|
|
# Track total token usage
|
|
self.total_tokens_used += input_tokens + output_tokens
|
|
|
|
# Estimate cost
|
|
input_cost = input_tokens * self.pricing[self.ai_model]["input"]
|
|
output_cost = output_tokens * self.pricing[self.ai_model]["output"]
|
|
self.total_cost += input_cost + output_cost
|
|
|
|
try:
|
|
# Parse response as dictionary
|
|
mapping = eval(output_text) # OpenAI should return a valid dictionary
|
|
except:
|
|
mapping = {val: "unknown" for val in unmapped_values} # Fallback
|
|
|
|
# Memoize the AI response
|
|
self.ai_cache[unmapped_tuple] = mapping
|
|
# We store the raw AI response for debugging
|
|
logger.debug(f"AI Response: {mapping}")
|
|
self.ai_response = output_text
|
|
|
|
return mapping
|
|
|
|
def standardize_list(self, values_to_remap):
|
|
"""
|
|
Standardizes a list of values and returns a dictionary {original_value: standardized_value}.
|
|
|
|
:param values_to_remap: List of raw values to standardize.
|
|
:return: Dictionary {original_value: standardized_value}.
|
|
"""
|
|
unique_values = set(values_to_remap) # Process only unique values
|
|
|
|
unmapped_values = []
|
|
for value in unique_values:
|
|
if pd.isna(value): # Handle NaN values
|
|
self.remap_dict[value] = "unknown"
|
|
continue
|
|
|
|
cleaned_value = self.clean_string(value)
|
|
|
|
# Rule-Based Check (Predefined Mapping)
|
|
if cleaned_value in self.standard_map or value in self.standard_map:
|
|
self.remap_dict[value] = (
|
|
self.standard_map[cleaned_value] if cleaned_value in self.standard_map else self.standard_map[value]
|
|
)
|
|
continue
|
|
|
|
if value.lower() in self.standard_map:
|
|
self.remap_dict[value] = self.standard_map[value.lower()]
|
|
continue
|
|
|
|
# Exact Match in Standard Values
|
|
if cleaned_value in self.standard_values:
|
|
self.remap_dict[value] = cleaned_value
|
|
continue
|
|
|
|
# Fuzzy Matching
|
|
fuzzy_match = self.fuzzy_match(cleaned_value)
|
|
if fuzzy_match:
|
|
self.remap_dict[value] = fuzzy_match
|
|
continue
|
|
|
|
# Capture anything that wasn't mapped
|
|
unmapped_values.append(value)
|
|
|
|
# AI Model - remap anything unmapped (batch request)
|
|
ai_mapping = self.ai_standardize(unmapped_values)
|
|
self.remap_dict.update(ai_mapping)
|
|
|
|
return self.remap_dict
|
|
|
|
def report_usage(self):
|
|
"""Prints a summary of token usage and cost."""
|
|
print(f"\n🔹 Total Tokens Used: {self.total_tokens_used}")
|
|
print(f"💰 Estimated Cost: ${self.total_cost:.4f}")
|