mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
66 lines
1.7 KiB
Python
66 lines
1.7 KiB
Python
"""
|
|
BaseMLModel class
|
|
This is the base protocol:
|
|
- Any implementation will be its own seperate file
|
|
Key tasks:
|
|
- Template Model class for different model types
|
|
- Save model
|
|
- Load Model
|
|
- Generate Inference
|
|
"""
|
|
|
|
from numpy import ndarray
|
|
from pathlib import Path
|
|
from typing import Protocol, NamedTuple, Any
|
|
import pandas as pd
|
|
from training import S3FSClient
|
|
|
|
|
|
class MLModel(Protocol):
|
|
"""
|
|
Base ML Model protocol
|
|
"""
|
|
|
|
def load_model(self, filepath: Path) -> None:
|
|
"""
|
|
Providing a path, this function will load the model to be used. Will load to internal variable
|
|
"""
|
|
|
|
def save_model(
|
|
self, output_filepath: Path, s3_client: S3FSClient | None = None
|
|
) -> None:
|
|
"""
|
|
Providing a path, this function will save the model to be used.
|
|
"""
|
|
|
|
def train_model(
|
|
self, data: pd.DataFrame, target_column: str, hyperparameters: dict
|
|
) -> None:
|
|
"""
|
|
For the given data and hyperparameters (specified to the model), a model is trained
|
|
"""
|
|
|
|
def generate_predictions(self, data: pd.DataFrame) -> ndarray[Any, Any] | None:
|
|
"""
|
|
For the given dataframe, model is loaded and predictions are generated
|
|
"""
|
|
|
|
def model_evaluation(
|
|
self,
|
|
validation_data: pd.DataFrame,
|
|
target_column: str,
|
|
metrics_location: Path | None = None,
|
|
) -> pd.DataFrame | None:
|
|
"""
|
|
For any validation data, a set of predictions and metrics are return
|
|
"""
|
|
|
|
def optimise_model_for_deployment(self):
|
|
"""
|
|
Perfomance post processing on Model to ensure ready for deployment
|
|
"""
|
|
|
|
def model_metadata(self) -> dict | None:
|
|
"""
|
|
Extract out model metadata as dictionary
|
|
"""
|