Model/model_data/simulation_system/core/RegistryHandler.py
2023-09-01 19:25:35 +01:00

114 lines
3.6 KiB
Python

"""
"""
from io import StringIO
import pandas as pd
from pathlib import Path
from core.Logger import logger
from core.CloudClient import BotoClient
from core.Metrics import Metrics
from core.Settings import BEST_MODEL_COLUMN_NAME
def read_csv_from_s3(client, bucket_name, file_key, index_col):
"""
Read a CSV file from S3 using boto3 and pandas.
:param bucket_name: Name of the S3 bucket.
:param file_key: Key of the file (including directory path within the bucket).
:param aws_access_key_id: AWS Access Key ID
:param aws_secret_access_key: AWS Secret Access Key
:return: DataFrame containing the CSV data.
"""
# Get the object
s3_object = client.get_object(Bucket=bucket_name, Key=file_key)
# Read the CSV body into a DataFrame
csv_body = s3_object["Body"].read().decode("utf-8")
df = pd.read_csv(StringIO(csv_body), index_col=index_col)
return df
class RegistryHandler:
"""
Handles the loading of the registry depending on the environment
"""
def load_registry(self, registry_path: Path, client: BotoClient, metrics: Metrics):
"""
Depening on the environment, we will have to load from locally or s3 (mock/real)
"""
if client.client is None:
logger.info("Using local development - no need for s3 load")
return self.load_local_registry(
registry_path=registry_path, metrics=metrics
)
logger.info(f"Check if registry exists")
check_exists = client.client.list_objects_v2(
Bucket=client.model_bucket, Prefix=str(registry_path)
)
if "Contents" in check_exists:
logger.info("Loading existing registry")
registry_df = read_csv_from_s3(
client=client.client,
bucket_name=client.model_bucket,
file_key=str(registry_path),
index_col=None,
)
else:
logger.info("No registry found - creating new one")
registry_df = self.create_new_registry(metrics=metrics)
return registry_df
def load_local_registry(self, registry_path: Path, metrics: Metrics):
"""
In local development mode, load the registry
"""
if registry_path.exists():
logger.info("Registry file found - Loading into Dataframe")
registry_df = pd.read_csv(registry_path, index_col=None)
else:
logger.info("No registry found - creating new one")
registry_df = self.create_new_registry(metrics=metrics)
return registry_df
def create_new_registry(self, metrics: Metrics):
"""
If no registry is found, create a new one
"""
# TODO: Moved columns into settings: MODEL_DETAILS and Metrics class columns
columns = [
BEST_MODEL_COLUMN_NAME,
"model_type",
"model_name",
"model_location",
] + metrics.list_metric_functions()
registry_df = pd.DataFrame(columns=columns)
return registry_df
def save_registry(self, output_filepath: Path, client: BotoClient) -> None:
"""
Providing a path, this function will save the model to be used.
"""
if client.client is None:
logger.info("In local development mode - no need for s3 client")
else:
logger.info(f"Saving registry into s3")
s3_location = client.model_bucket + "/" + str(output_filepath)
client.client.upload_file(
str(output_filepath), client.model_bucket, str(output_filepath)
)
logger.info("Save complete")