mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
83 lines
2.7 KiB
Python
83 lines
2.7 KiB
Python
"""
|
|
|
|
"""
|
|
|
|
import pandas as pd
|
|
from pathlib import Path
|
|
from core.Logger import logger
|
|
from core.CloudClient import S3FSClient
|
|
from core.Metrics import Metrics
|
|
from core.Settings import BEST_MODEL_COLUMN_NAME
|
|
|
|
|
|
class RegistryHandler:
|
|
"""
|
|
Handles the loading of the registry depending on the environment
|
|
"""
|
|
|
|
def load_registry(
|
|
self, registry_path: Path, s3fs_client: S3FSClient, metrics: Metrics
|
|
):
|
|
"""
|
|
Depening on the environment, we will have to load from locally or s3 (mock/real)
|
|
"""
|
|
|
|
if s3fs_client.client is None:
|
|
logger.info("Using local development - no need for s3 load")
|
|
return self.load_local_registry(
|
|
registry_path=registry_path, metrics=metrics
|
|
)
|
|
|
|
s3_location = "s3://" + s3fs_client.model_bucket + "/" + str(registry_path)
|
|
|
|
logger.info(f"Check if registry exists")
|
|
if s3fs_client.client.exists(s3_location):
|
|
registry_df = pd.read_csv(
|
|
s3fs_client.client.open(s3_location), index_col=None
|
|
)
|
|
else:
|
|
logger.info("No registry found - creating new one")
|
|
registry_df = self.create_new_registry(metrics=metrics)
|
|
|
|
return registry_df
|
|
|
|
def load_local_registry(self, registry_path: Path, metrics: Metrics):
|
|
"""
|
|
In local development mode, load the registry
|
|
"""
|
|
if registry_path.exists():
|
|
logger.info("Registry file found - Loading into Dataframe")
|
|
registry_df = pd.read_csv(registry_path, index_col=None)
|
|
else:
|
|
logger.info("No registry found - creating new one")
|
|
registry_df = self.create_new_registry(metrics=metrics)
|
|
|
|
return registry_df
|
|
|
|
def create_new_registry(self, metrics: Metrics):
|
|
"""
|
|
If no registry is found, create a new one
|
|
"""
|
|
# TODO: Moved columns into settings: MODEL_DETAILS and Metrics class columns
|
|
columns = [
|
|
BEST_MODEL_COLUMN_NAME,
|
|
"model_type",
|
|
"model_name",
|
|
"model_location",
|
|
] + metrics.list_metric_functions()
|
|
|
|
registry_df = pd.DataFrame(columns=columns)
|
|
|
|
return registry_df
|
|
|
|
def save_registry(self, output_filepath: Path, s3fs_client: S3FSClient) -> None:
|
|
"""
|
|
Providing a path, this function will save the model to be used.
|
|
"""
|
|
if s3fs_client.client is None:
|
|
logger.info("In local development mode - no need for s3 client")
|
|
else:
|
|
logger.info(f"Saving registry into s3")
|
|
s3_location = s3fs_client.model_bucket + "/" + str(output_filepath)
|
|
s3fs_client.client.put(str(output_filepath), s3_location, recursive=True)
|
|
logger.info("Save complete")
|