Model/model_data/simulation_system/MLModel.py
2023-08-15 18:18:55 +01:00

113 lines
3.3 KiB
Python

"""
MLModel class
Key tasks:
- Template Model class for different model types
- Save model
- Load Model
- Generate Inference
"""
from pathlib import Path
from typing import Protocol, NamedTuple
import pandas as pd
from autogluon import TabularPredictor
AUTOGLUON_HYPERPARAMETERS = ['problem_type', 'eval_metric', 'time_limit', 'presets', 'excluded_model_types']
class MLModel(Protocol):
'''
Base ML Model protocol
'''
def load_model(self, filepath: Path) -> None:
"""
Providing a path, this function will load the model to be used. Will load to internal variable
"""
def save_model(self, output_filepath: Path) -> None:
"""
Providing a path, this function will save the model to be used.
"""
def train_model(
self,
data: pd.DataFrame,
target: str,
hyperparameter: dict) -> None:
"""
For the given data and hyperparameters (specified to the model), a model is trained
"""
def generate_predictions(self, data: pd.DataFrame) -> pd.DataFrame:
"""
For the given dataframe, model is loaded and predictions are generated
"""
def model_evaluation(self, validation_data: pd.DataFrame) -> NamedTuple:
"""
For any validation data, a set of predictions and metrics are return
"""
class AutogluonModel(MLModel):
"""
Autogluon model that implements the MLModel Protocol
"""
def __init__(self) -> None:
self.model = None
def load_model(self, filepath: Path) -> None:
"""
Providing a path, this function will load the model to be used. Will load to internal variable
"""
self.model = TabularPredictor.load(path=filepath)
def save_model(self, output_filepath: Path) -> None:
"""
Providing a path, this function will save the model to be used.
"""
def train_model(
self,
data: pd.DataFrame,
target_column: str,
hyperparameters: dict = None) -> None:
"""
For the given data and hyperparameters, a model is trained
"""
if set(AUTOGLUON_HYPERPARAMETERS) != set(hyperparameters.keys()):
print("Hyperparameters (dict) is incorrectly defined - please check what hyperparameters are required")
exit(1)
self.model = TabularPredictor(
label=target_column,
path=hyperparameters['output_path'],
problem_type=hyperparameters['problem_type'],
eval_metric=hyperparameters['eval_metric']
).fit(
data,
time_limit=hyperparameters['time_limit'],
presets=hyperparameters['presets'],
excluded_model_types=hyperparameters['excluded_model_types']
)
def generate_predictions(self, data: pd.DataFrame) -> pd.DataFrame:
"""
For the given dataframe, model is loaded and predictions are generated
"""
if self.model is None:
print("No model loaded/ trained")
exit(1)
predictions = self.model.predict(data)
return predictions
def model_evaluation(self, validation_data: pd.DataFrame) -> NamedTuple:
"""
For any validation data, a set of predictions and metrics are return
"""