Model/model_data/simulation_system/core/RegistryHandler.py
2023-09-01 11:27:20 +01:00

83 lines
2.7 KiB
Python

"""
"""
import pandas as pd
from pathlib import Path
from core.Logger import logger
from core.CloudClient import S3FSClient
from core.Metrics import Metrics
from core.Settings import BEST_MODEL_COLUMN_NAME
class RegistryHandler:
"""
Handles the loading of the registry depending on the environment
"""
def load_registry(
self, registry_path: Path, s3fs_client: S3FSClient, metrics: Metrics
):
"""
Depening on the environment, we will have to load from locally or s3 (mock/real)
"""
if s3fs_client.client is None:
logger.info("Using local development - no need for s3 load")
return self.load_local_registry(
registry_path=registry_path, metrics=metrics
)
s3_location = "s3://" + s3fs_client.model_bucket + "/" + str(registry_path)
logger.info(f"Check if registry exists")
if s3fs_client.client.exists(s3_location):
registry_df = pd.read_csv(
s3fs_client.client.open(s3_location), index_col=None
)
else:
logger.info("No registry found - creating new one")
registry_df = self.create_new_registry(metrics=metrics)
return registry_df
def load_local_registry(self, registry_path: Path, metrics: Metrics):
"""
In local development mode, load the registry
"""
if registry_path.exists():
logger.info("Registry file found - Loading into Dataframe")
registry_df = pd.read_csv(registry_path, index_col=None)
else:
logger.info("No registry found - creating new one")
registry_df = self.create_new_registry(metrics=metrics)
return registry_df
def create_new_registry(self, metrics: Metrics):
"""
If no registry is found, create a new one
"""
# TODO: Moved columns into settings: MODEL_DETAILS and Metrics class columns
columns = [
BEST_MODEL_COLUMN_NAME,
"model_type",
"model_name",
"model_location",
] + metrics.list_metric_functions()
registry_df = pd.DataFrame(columns=columns)
return registry_df
def save_registry(self, output_filepath: Path, s3fs_client: S3FSClient) -> None:
"""
Providing a path, this function will save the model to be used.
"""
if s3fs_client.client is None:
logger.info("In local development mode - no need for s3 client")
else:
logger.info(f"Saving registry into s3")
s3_location = s3fs_client.model_bucket + "/" + str(output_filepath)
s3fs_client.client.put(str(output_filepath), s3_location, recursive=True)
logger.info("Save complete")