mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
decent performing model
This commit is contained in:
parent
14417c37df
commit
b0449b9e90
1 changed files with 41 additions and 17 deletions
|
|
@ -16,8 +16,7 @@ class EnergyConsumptionModel:
|
|||
"heating_kwh": [
|
||||
"lodgement-year", "lodgement-month", "current-energy-efficiency", "energy-consumption-current",
|
||||
"heating-cost-current", "total-floor-area", "number-heated-rooms",
|
||||
# "number-habitable-rooms",
|
||||
# "mainheat-energy-eff", "mainheat-description", "main-fuel",
|
||||
"mainheat-description", "main-fuel", "mainheat-energy-eff", "number-habitable-rooms",
|
||||
],
|
||||
"hot_water_kwh": [
|
||||
"lodgement-year", "lodgement-month", "current-energy-efficiency", "energy-consumption-current",
|
||||
|
|
@ -104,25 +103,41 @@ class EnergyConsumptionModel:
|
|||
x, y, test_size=test_size, random_state=random_state
|
||||
)
|
||||
|
||||
def feature_selection(self, target):
|
||||
"""Performs feature selection using RFECV."""
|
||||
def feature_selection(self, target, cv_folds=3, sample_fraction=0.1, random_state=42):
|
||||
"""
|
||||
Performs feature selection using RFECV with XGBoost.
|
||||
|
||||
Parameters:
|
||||
- target: The target variable for feature selection.
|
||||
- cv_folds: Number of cross-validation folds.
|
||||
- sample_fraction: Fraction of the data to use for feature selection.
|
||||
- random_state: Random state for reproducibility.
|
||||
"""
|
||||
if target not in self.TARGETS:
|
||||
raise ValueError(f"Target {target} not in {self.TARGETS}")
|
||||
|
||||
logging.info(f"Starting feature selection for target {target}")
|
||||
x = self.x_train[target]
|
||||
y = self.y_train[target]
|
||||
|
||||
# Sample the data if specified
|
||||
if sample_fraction < 1.0:
|
||||
x_sample, _, y_sample, _ = train_test_split(
|
||||
self.x_train[target], self.y_train[target],
|
||||
train_size=sample_fraction, random_state=random_state
|
||||
)
|
||||
else:
|
||||
x_sample = self.x_train[target]
|
||||
y_sample = self.y_train[target]
|
||||
|
||||
# Initialize the XGBoost model and RFECV
|
||||
model = XGBRegressor(objective='reg:squarederror')
|
||||
selector = RFECV(model, step=1, cv=5, scoring='neg_mean_absolute_percentage_error')
|
||||
selector = selector.fit(x, y)
|
||||
model = XGBRegressor(objective='reg:squarederror', n_jobs=-1)
|
||||
selector = RFECV(model, step=1, cv=cv_folds, scoring='neg_mean_absolute_percentage_error')
|
||||
selector = selector.fit(x_sample, y_sample)
|
||||
|
||||
# Get the selected features
|
||||
self.selected_features[target] = x.columns[selector.support_]
|
||||
self.selected_features[target] = x_sample.columns[selector.support_]
|
||||
|
||||
# Update x_train and x_test with selected features
|
||||
self.x_train[target] = x[self.selected_features[target]]
|
||||
self.x_train[target] = self.x_train[target][self.selected_features[target]]
|
||||
self.x_test[target] = self.x_test[target][self.selected_features[target]]
|
||||
|
||||
logging.info(f"Feature selection completed for target {target}")
|
||||
|
|
@ -218,6 +233,14 @@ class EnergyConsumptionModel:
|
|||
def error_analysis(self, target, top_n=10, unique_threshold=0.8):
|
||||
"""
|
||||
Perform error analysis on the provided model and dataset.
|
||||
|
||||
Parameters:
|
||||
- target: The target variable to analyze.
|
||||
- top_n: Number of top residuals to consider for analysis.
|
||||
- unique_threshold: Threshold to exclude columns with high unique values.
|
||||
|
||||
Returns:
|
||||
- summary: Dictionary summarizing common features among poorly performing rows.
|
||||
"""
|
||||
|
||||
# Calculate predictions and residuals
|
||||
|
|
@ -234,6 +257,7 @@ class EnergyConsumptionModel:
|
|||
top_train_data = self.input_data.loc[top_train_indices]
|
||||
top_test_data = self.input_data.loc[top_test_indices]
|
||||
|
||||
# Automatically detect and exclude columns
|
||||
def exclude_columns(data, threshold):
|
||||
exclude_cols = []
|
||||
num_rows = data.shape[0]
|
||||
|
|
@ -247,16 +271,14 @@ class EnergyConsumptionModel:
|
|||
top_train_data = top_train_data.drop(columns=exclude_cols)
|
||||
top_test_data = top_test_data.drop(columns=exclude_cols)
|
||||
|
||||
# TODO: Not working
|
||||
|
||||
# One-hot encode categorical variables
|
||||
categorical_columns = top_train_data.select_dtypes(include=['object']).columns.tolist()
|
||||
top_train_data_encoded = pd.get_dummies(top_train_data, columns=categorical_columns, drop_first=True)
|
||||
top_test_data_encoded = pd.get_dummies(top_test_data, columns=categorical_columns, drop_first=True)
|
||||
|
||||
# Align the encoded data with the training data
|
||||
top_train_data_encoded = top_train_data_encoded.reindex(columns=self.x_train[target].columns, fill_value=0)
|
||||
top_test_data_encoded = top_test_data_encoded.reindex(columns=self.x_test[target].columns, fill_value=0)
|
||||
# Ensure all original columns are included in the encoded data
|
||||
top_train_data_encoded = top_train_data_encoded.reindex(columns=self.input_data.columns, fill_value=0)
|
||||
top_test_data_encoded = top_test_data_encoded.reindex(columns=self.input_data.columns, fill_value=0)
|
||||
|
||||
# Correlation analysis with residuals
|
||||
train_corr = top_train_data_encoded.corrwith(train_residuals.loc[top_train_indices])
|
||||
|
|
@ -264,6 +286,8 @@ class EnergyConsumptionModel:
|
|||
|
||||
# Return summaries
|
||||
summary = {
|
||||
"train_summary": top_train_data.describe(include='all').T,
|
||||
"test_summary": top_test_data.describe(include='all').T,
|
||||
"train_corr": train_corr,
|
||||
"test_corr": test_corr,
|
||||
"top_train_data": top_train_data,
|
||||
|
|
@ -280,7 +304,7 @@ model.feature_engineering()
|
|||
|
||||
# For heating_kwh
|
||||
model.split_dataset(target='heating_kwh')
|
||||
model.feature_selection(target='heating_kwh')
|
||||
model.feature_selection(target='heating_kwh', cv_folds=3, sample_fraction=0.1)
|
||||
model.fit_model(target='heating_kwh')
|
||||
evaluation_results = model.evaluate_model(target='heating_kwh')
|
||||
from pprint import pprint
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue