""" """ import pandas as pd from pathlib import Path from core.Logger import logger from core.CloudClient import S3FSClient from core.Metrics import Metrics from core.Settings import BEST_MODEL_COLUMN_NAME class RegistryHandler: """ Handles the loading of the registry depending on the environment """ def load_registry( self, registry_path: Path, s3fs_client: S3FSClient, metrics: Metrics ): """ Depening on the environment, we will have to load from locally or s3 (mock/real) """ if s3fs_client.client is None: logger.info("Using local development - no need for s3 load") return self.load_local_registry( registry_path=registry_path, metrics=metrics ) s3_location = "s3://" + s3fs_client.model_bucket + "/" + str(registry_path) logger.info(f"Check if registry exists") if s3fs_client.client.exists(s3_location): registry_df = pd.read_csv( s3fs_client.client.open(s3_location), index_col=None ) else: logger.info("No registry found - creating new one") registry_df = self.create_new_registry(metrics=metrics) return registry_df def load_local_registry(self, registry_path: Path, metrics: Metrics): """ In local development mode, load the registry """ if registry_path.exists(): logger.info("Registry file found - Loading into Dataframe") registry_df = pd.read_csv(registry_path, index_col=None) else: logger.info("No registry found - creating new one") registry_df = self.create_new_registry(metrics=metrics) return registry_df def create_new_registry(self, metrics: Metrics): """ If no registry is found, create a new one """ # TODO: Moved columns into settings: MODEL_DETAILS and Metrics class columns columns = [ BEST_MODEL_COLUMN_NAME, "model_type", "model_name", "model_location", ] + metrics.list_metric_functions() registry_df = pd.DataFrame(columns=columns) return registry_df def save_registry(self, output_filepath: Path, s3fs_client: S3FSClient) -> None: """ Providing a path, this function will save the model to be used. """ if s3fs_client.client is None: logger.info("In local development mode - no need for s3 client") else: logger.info(f"Saving registry into s3") s3_location = s3fs_client.model_bucket + "/" + str(output_filepath) s3fs_client.client.put(str(output_filepath), s3_location, recursive=True) logger.info("Save complete")