mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
added factory pattern
This commit is contained in:
parent
ca16b1d872
commit
a137cccc05
5 changed files with 121 additions and 44 deletions
|
|
@ -13,6 +13,7 @@ import pandas as pd
|
|||
from autogluon.tabular import TabularDataset, TabularPredictor
|
||||
from sklearn.metrics import mean_absolute_percentage_error
|
||||
from core.Logger import logger
|
||||
from core.Metrics import Metrics
|
||||
from MLModel.BaseMLModel import MLModel
|
||||
|
||||
AUTOGLUON_HYPERPARAMETERS = [
|
||||
|
|
@ -25,18 +26,18 @@ AUTOGLUON_HYPERPARAMETERS = [
|
|||
METRIC_FILENAME = "metrics.csv"
|
||||
|
||||
|
||||
def select_model(model_type: str) -> MLModel:
|
||||
def model_factory(model_type: str, hyperparameters: dict = None) -> MLModel:
|
||||
"""
|
||||
Helper function to select the model to use
|
||||
Use factory pattern to register the different ML implementations
|
||||
"""
|
||||
model_types = {
|
||||
"autogluon": {
|
||||
"model": AutogluonModel,
|
||||
"naming_attributes": f"{hyperparameters['presets']}-{hyperparameters['time_limit']}",
|
||||
},
|
||||
}
|
||||
|
||||
if model_type == "autogluon":
|
||||
model = AutogluonModel()
|
||||
else:
|
||||
logger.error("No other model currently implemented")
|
||||
exit(1)
|
||||
|
||||
return model
|
||||
return model_types[model_type]
|
||||
|
||||
|
||||
class AutogluonModel:
|
||||
|
|
@ -108,6 +109,7 @@ class AutogluonModel:
|
|||
self,
|
||||
validation_data: pd.DataFrame,
|
||||
target_column: str,
|
||||
metrics: Metrics,
|
||||
metrics_location: Path = None,
|
||||
metric_filename: str = METRIC_FILENAME,
|
||||
) -> pd.DataFrame:
|
||||
|
|
@ -121,9 +123,11 @@ class AutogluonModel:
|
|||
logger.error("No model loaded/ trained - Unable to generate evaluation")
|
||||
exit(1)
|
||||
|
||||
performance = self.model.evaluate(validation_data)
|
||||
# Generate prediction, load metrics suite, generate metrics betweeen the two
|
||||
predictions = self.generate_predictions(validation_data)
|
||||
|
||||
performance = self.model.evaluate(validation_data)
|
||||
|
||||
logger.info("Prediction used for evaluations are saved in self.prediction")
|
||||
self.predictions = predictions
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,28 @@
|
|||
import pandas as pd
|
||||
import os
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class DataLoader:
|
||||
class DataLoader(Protocol):
|
||||
"""
|
||||
Interface for all DataLoader classes
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(filepath: str, index_col: str = None) -> pd.DataFrame:
|
||||
"""
|
||||
Load different datasets
|
||||
Loading data from the relevant source
|
||||
"""
|
||||
|
||||
|
||||
class LocalDataLoader:
|
||||
"""
|
||||
Implements the DataLoader Protocol for local files
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(filepath: str, index_col: str = None) -> pd.DataFrame:
|
||||
|
||||
if not os.path.exists(filepath):
|
||||
raise FileNotFoundError(f"File not found: {filepath}")
|
||||
|
||||
|
|
@ -23,42 +37,82 @@ class DataLoader:
|
|||
|
||||
return df
|
||||
|
||||
@staticmethod
|
||||
def s3_load(filepath: str, index_col: str = None) -> pd.DataFrame:
|
||||
"""
|
||||
Load different datasets from s3
|
||||
"""
|
||||
|
||||
STORAGE_OPTIONS = {
|
||||
class S3MockDataLoader:
|
||||
"""
|
||||
Implements the DataLoader Protocol for s3 files, hosting locally in a mocked service
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(filepath: str, index_col: str = None) -> pd.DataFrame:
|
||||
|
||||
# TODO: Ingest these as environment variables in the docker compose file
|
||||
storage_options = {
|
||||
"key": os.environ.get("AWS_ACCESS_KEY_ID", "admin"),
|
||||
"secret": os.environ.get("AWS_SECRET_ACCESS_KEY", "password"),
|
||||
"client_kwargs": {
|
||||
# "endpoint_url": os.environ.get("ENDPOINT_URL", "http://localhost:9000")
|
||||
"endpoint_url": os.environ.get("ENDPOINT_URL")
|
||||
"endpoint_url": os.environ.get("ENDPOINT_URL", "http://localhost:9000")
|
||||
},
|
||||
}
|
||||
|
||||
if filepath.endswith(".parquet"):
|
||||
df = pd.read_parquet(filepath, storage_options=STORAGE_OPTIONS)
|
||||
df = pd.read_parquet(filepath, storage_options=storage_options)
|
||||
if index_col is not None:
|
||||
df = df.set_index(index_col)
|
||||
elif filepath.endswith(".csv"):
|
||||
df = pd.read_csv(
|
||||
filepath, index_col=index_col, storage_options=STORAGE_OPTIONS
|
||||
filepath, index_col=index_col, storage_options=storage_options
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"File format not supported for file: {filepath}")
|
||||
|
||||
return df
|
||||
|
||||
def process(self, filepath: str, index_col: str = None) -> pd.DataFrame:
|
||||
"""
|
||||
Based off the filepath, we choose a loader style
|
||||
"""
|
||||
|
||||
if filepath.startswith("s3://"):
|
||||
df = self.s3_load(filepath=filepath, index_col=index_col)
|
||||
class S3DataLoader:
|
||||
"""
|
||||
Implements the DataLoader Protocol for s3 files
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(filepath: str, index_col: str = None) -> pd.DataFrame:
|
||||
|
||||
storage_options = {
|
||||
"key": os.environ.get("AWS_ACCESS_KEY_ID"),
|
||||
"secret": os.environ.get("AWS_SECRET_ACCESS_KEY"),
|
||||
}
|
||||
|
||||
if filepath.endswith(".parquet"):
|
||||
df = pd.read_parquet(filepath, storage_options=storage_options)
|
||||
if index_col is not None:
|
||||
df = df.set_index(index_col)
|
||||
elif filepath.endswith(".csv"):
|
||||
df = pd.read_csv(
|
||||
filepath, index_col=index_col, storage_options=storage_options
|
||||
)
|
||||
else:
|
||||
df = self.load(filepath=filepath, index_col=index_col)
|
||||
raise ValueError(f"File format not supported for file: {filepath}")
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def dataloader_factory(runtime_environment: str = None) -> DataLoader:
|
||||
"""
|
||||
Use factory pattern to determine which loading method we use
|
||||
"""
|
||||
|
||||
if runtime_environment is None:
|
||||
runtime_environment = "local"
|
||||
|
||||
dataloader_types = {
|
||||
"local": LocalDataLoader(),
|
||||
"local-mock": S3MockDataLoader(),
|
||||
"dev": S3DataLoader(),
|
||||
"staging": S3DataLoader(),
|
||||
"prod": S3DataLoader(),
|
||||
}
|
||||
|
||||
if runtime_environment not in dataloader_types:
|
||||
raise ValueError("Incorrect runtime environment specified")
|
||||
|
||||
return dataloader_types[runtime_environment]
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ MODEL_HYPERPARAMETERS = {
|
|||
}
|
||||
}
|
||||
|
||||
TIMESTAMP_FORMAT = "%Y-%m-%d_%H-%M-%S"
|
||||
TIMESTAMP_FORMAT = "%Y_%m_%d_%H_%M_%S"
|
||||
|
||||
RANDOM_SEED = 0
|
||||
SUBSAMPLE_FACTOR = 200
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from core.Settings import (
|
|||
REGISTRY_FILE,
|
||||
BEST_MODEL_COLUMN_NAME,
|
||||
)
|
||||
from MLModel.Models import AutogluonModel, select_model
|
||||
from MLModel.Models import AutogluonModel, model_factory
|
||||
|
||||
|
||||
def ingest_arguments() -> argparse.Namespace:
|
||||
|
|
@ -76,7 +76,7 @@ def regenerate_metrics(test_filepath: str, target_column: str) -> None:
|
|||
|
||||
logger.info(f"--- Loading Model ({row['model_name']}) ---")
|
||||
|
||||
model = select_model(model_type=row["model_type"])
|
||||
model = model_factory(model_type=row["model_type"])()
|
||||
|
||||
model.load_model(filepath=row["model_location"])
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import argparse
|
||||
|
||||
# import os
|
||||
import os
|
||||
|
||||
# from s3pathlib import S3Path
|
||||
|
||||
# import boto3
|
||||
|
|
@ -9,9 +10,10 @@ from datetime import datetime
|
|||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
from MLModel.Models import AutogluonModel
|
||||
from MLModel.Models import AutogluonModel, model_factory
|
||||
from core.Logger import logger
|
||||
from core.DataLoader import DataLoader
|
||||
from core.Metrics import Metrics
|
||||
from core.DataLoader import dataloader_factory
|
||||
from core.FeatureProcessor import FeatureProcessor
|
||||
from core.Settings import (
|
||||
MODEL_DIRECTORY,
|
||||
|
|
@ -38,6 +40,8 @@ from core.Settings import (
|
|||
|
||||
TIMESTAMP = datetime.now().strftime(TIMESTAMP_FORMAT)
|
||||
|
||||
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "local")
|
||||
|
||||
# STORAGE_OPTIONS = {
|
||||
# "key": os.environ.get("AWS_ACCESS_KEY_ID", 'admin'),
|
||||
# "secret": os.environ.get("AWS_SECRET_ACCESS_KEY", 'password'),
|
||||
|
|
@ -123,14 +127,15 @@ def training(
|
|||
"""
|
||||
|
||||
logger.info("--- Loading data ---")
|
||||
dataloader = DataLoader()
|
||||
train_df = dataloader.process(filepath=train_filepath)
|
||||
test_df = dataloader.process(filepath=test_filepath)
|
||||
dataloader = dataloader_factory(runtime_environment=RUNTIME_ENVIRONMENT)
|
||||
train_df = dataloader.load(filepath=train_filepath)
|
||||
test_df = dataloader.load(filepath=test_filepath)
|
||||
|
||||
logger.info("--- Feature processing ---")
|
||||
|
||||
feature_processor = FeatureProcessor()
|
||||
|
||||
# This is for convenience for now
|
||||
subsample_amount = round(len(train_df) / SUBSAMPLE_FACTOR)
|
||||
|
||||
train_df = feature_processor.process(
|
||||
|
|
@ -147,13 +152,20 @@ def training(
|
|||
hyperparameters = MODEL_HYPERPARAMETERS[model_type]
|
||||
logger.info(f"Hyperparameters are: {hyperparameters}")
|
||||
|
||||
if model_type == "autogluon":
|
||||
model_root = f"{target_column}-{hyperparameters['presets']}-{hyperparameters['time_limit']}-{TIMESTAMP}".lower()
|
||||
output_base = Path(MODEL_DIRECTORY) / target_column / model_type / model_root
|
||||
logger.info(
|
||||
"--- Loading model configuration (Model type and Naming convention) ---"
|
||||
)
|
||||
# We might want to have hyperparameters in the names to make models more recognisable
|
||||
model_toolkit = model_factory(
|
||||
model_type=model_type, hyperparameters=hyperparameters
|
||||
)
|
||||
|
||||
model = AutogluonModel(output_filepath=output_base / MODEL_FOLDER)
|
||||
else:
|
||||
raise ValueError("No alternative model implemented yet")
|
||||
model_root = (
|
||||
f"{target_column}-{model_toolkit['naming_attributes']}-{TIMESTAMP}".lower()
|
||||
)
|
||||
output_base = Path(MODEL_DIRECTORY) / target_column / model_type / model_root
|
||||
|
||||
model = model_toolkit["model"](output_filepath=output_base / MODEL_FOLDER)
|
||||
|
||||
model.train_model(
|
||||
data=train_df, target_column=target_column, hyperparameters=hyperparameters
|
||||
|
|
@ -163,6 +175,13 @@ def training(
|
|||
model.save_model(output_filepath=model.output_filepath)
|
||||
|
||||
logger.info("--- Generate evaluation metrics ---")
|
||||
# TODO: replace this with metrics class
|
||||
# metrics_df = model.model_evaluation(
|
||||
# validation_data=test_df,
|
||||
# target_column=target_column,
|
||||
# metrics_location=output_base / METRICS_FOLDER,
|
||||
# metrics = Metrics
|
||||
# )
|
||||
metrics_df = model.model_evaluation(
|
||||
validation_data=test_df,
|
||||
target_column=target_column,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue