Model/asset_list/DataMapper.py
2025-03-13 11:08:59 +00:00

178 lines
6.7 KiB
Python

# OpenAI API Key (set this in your environment variables for security)
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
class DataRemapper:
def __init__(self, standard_values, standard_map=None, max_tokens=1000):
"""
Initialize the remapper with standard values and a predefined mapping.
:param standard_values: Set of allowed standardized values.
:param standard_map: Dictionary of common remappings {raw_value: standard_value}.
"""
self.standard_values = standard_values
self.standard_map = standard_map
self.fuzzy_threshold = 90 # Adjust fuzzy matching sensitivity
self.ai_model = "gpt-4-turbo" # Use gpt-3.5-turbo for cheaper processing
# Tokenizer for counting tokens
self.tokenizer = tiktoken.encoding_for_model(self.ai_model)
# Track token usage and remap dictionary
self.total_tokens_used = 0
self.total_cost = 0
self.remap_dict = {} # {original_value: standardized_value}
self.max_tokens = max_tokens # Limit for OpenAI API
# Memoization for AI calls
self.ai_cache = {} # {tuple(unmapped_values): {original_value: standardized_value}}
# Capture the reponse for debugging
self.ai_response = None
# OpenAI pricing (as of Feb 2024)
self.pricing = {
"gpt-4-turbo": {"input": 0.01 / 1000, "output": 0.03 / 1000},
"gpt-3.5-turbo": {"input": 0.0015 / 1000, "output": 0.002 / 1000},
}
self.openai_client = OpenAI(api_key=OPENAI_API_KEY)
@staticmethod
def clean_string(text):
"""Basic text cleaning: remove extra spaces, punctuation, and normalize case."""
if not isinstance(text, str):
return None
text = text.strip().lower()
text = re.sub(r'[^\w\s]', '', text) # Remove punctuation
# Replace double strings
text = re.sub(r'\s+', ' ', text)
return text
def fuzzy_match(self, text):
"""Use fuzzy matching to find the closest standard value."""
match, score = process.extractOne(text, self.standard_values) if text else (None, 0)
return match if score >= self.fuzzy_threshold else None
def count_tokens(self, text):
"""Estimate the number of tokens in a given text."""
return len(self.tokenizer.encode(text)) if text else 0
def ai_standardize(self, unmapped_values):
"""Call OpenAI API **once** for all unmapped values to minimize cost, with memoization."""
if not unmapped_values:
return {}
unmapped_tuple = tuple(sorted(unmapped_values)) # Ensure consistency for memoization
if unmapped_tuple in self.ai_cache:
return self.ai_cache[unmapped_tuple] # Return memoized result
prompt = f"""
You are an expert in data classification. Standardize each of these values into one of the categories:
{list(self.standard_values)}.
Return only a JSON dictionary where:
- The keys are the original values.
- The values are the standardized ones.
Strictly return JSON **without markdown formatting** or extra text.
Example Output:
{{
"BLKHOUS": "block house",
"BEDSIT": "bedsit"
}}
Values to standardize:
{unmapped_values}
"""
# Count input tokens
input_tokens = self.count_tokens(prompt)
if input_tokens > self.max_tokens:
raise ValueError("Input tokens exceed the maximum limit.")
logger.info("Calling OpenAI API for standardization...")
response = self.openai_client.chat.completions.create(
model=self.ai_model,
messages=[{"role": "user", "content": prompt}],
max_tokens=self.max_tokens,
temperature=0.1,
)
output_text = response.choices[0].message.content.strip()
output_tokens = self.count_tokens(output_text) # Count output tokens
# Track total token usage
self.total_tokens_used += input_tokens + output_tokens
# Estimate cost
input_cost = input_tokens * self.pricing[self.ai_model]["input"]
output_cost = output_tokens * self.pricing[self.ai_model]["output"]
self.total_cost += input_cost + output_cost
try:
# Parse response as dictionary
mapping = eval(output_text) # OpenAI should return a valid dictionary
except:
mapping = {val: "unknown" for val in unmapped_values} # Fallback
# Memoize the AI response
self.ai_cache[unmapped_tuple] = mapping
# We store the raw AI response for debugging
logger.debug(f"AI Response: {mapping}")
self.ai_response = output_text
return mapping
def standardize_list(self, values_to_remap):
"""
Standardizes a list of values and returns a dictionary {original_value: standardized_value}.
:param values_to_remap: List of raw values to standardize.
:return: Dictionary {original_value: standardized_value}.
"""
unique_values = set(values_to_remap) # Process only unique values
unmapped_values = []
for value in unique_values:
if pd.isna(value): # Handle NaN values
self.remap_dict[value] = "unknown"
continue
cleaned_value = self.clean_string(value)
# Rule-Based Check (Predefined Mapping)
if cleaned_value in self.standard_map or value in self.standard_map:
self.remap_dict[value] = (
self.standard_map[cleaned_value] if cleaned_value in self.standard_map else self.standard_map[value]
)
continue
if value.lower() in self.standard_map:
self.remap_dict[value] = self.standard_map[value.lower()]
continue
# Exact Match in Standard Values
if cleaned_value in self.standard_values:
self.remap_dict[value] = cleaned_value
continue
# Fuzzy Matching
fuzzy_match = self.fuzzy_match(cleaned_value)
if fuzzy_match:
self.remap_dict[value] = fuzzy_match
continue
# Capture anything that wasn't mapped
unmapped_values.append(value)
# AI Model - remap anything unmapped (batch request)
ai_mapping = self.ai_standardize(unmapped_values)
self.remap_dict.update(ai_mapping)
return self.remap_dict
def report_usage(self):
"""Prints a summary of token usage and cost."""
print(f"\n🔹 Total Tokens Used: {self.total_tokens_used}")
print(f"💰 Estimated Cost: ${self.total_cost:.4f}")