mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
114 lines
3.6 KiB
Python
114 lines
3.6 KiB
Python
"""
|
|
|
|
"""
|
|
|
|
from io import StringIO
|
|
import pandas as pd
|
|
from pathlib import Path
|
|
from core.Logger import logger
|
|
from core.CloudClient import BotoClient
|
|
from core.Metrics import Metrics
|
|
from core.Settings import BEST_MODEL_COLUMN_NAME
|
|
|
|
|
|
def read_csv_from_s3(client, bucket_name, file_key, index_col):
|
|
"""
|
|
Read a CSV file from S3 using boto3 and pandas.
|
|
|
|
:param bucket_name: Name of the S3 bucket.
|
|
:param file_key: Key of the file (including directory path within the bucket).
|
|
:param aws_access_key_id: AWS Access Key ID
|
|
:param aws_secret_access_key: AWS Secret Access Key
|
|
:return: DataFrame containing the CSV data.
|
|
"""
|
|
|
|
# Get the object
|
|
s3_object = client.get_object(Bucket=bucket_name, Key=file_key)
|
|
|
|
# Read the CSV body into a DataFrame
|
|
csv_body = s3_object["Body"].read().decode("utf-8")
|
|
df = pd.read_csv(StringIO(csv_body), index_col=index_col)
|
|
|
|
return df
|
|
|
|
|
|
class RegistryHandler:
|
|
"""
|
|
Handles the loading of the registry depending on the environment
|
|
"""
|
|
|
|
def load_registry(self, registry_path: Path, client: BotoClient, metrics: Metrics):
|
|
"""
|
|
Depening on the environment, we will have to load from locally or s3 (mock/real)
|
|
"""
|
|
|
|
if client.client is None:
|
|
logger.info("Using local development - no need for s3 load")
|
|
return self.load_local_registry(
|
|
registry_path=registry_path, metrics=metrics
|
|
)
|
|
|
|
logger.info(f"Check if registry exists")
|
|
|
|
check_exists = client.client.list_objects_v2(
|
|
Bucket=client.model_bucket, Prefix=str(registry_path)
|
|
)
|
|
|
|
if "Contents" in check_exists:
|
|
logger.info("Loading existing registry")
|
|
registry_df = read_csv_from_s3(
|
|
client=client.client,
|
|
bucket_name=client.model_bucket,
|
|
file_key=str(registry_path),
|
|
index_col=None,
|
|
)
|
|
else:
|
|
logger.info("No registry found - creating new one")
|
|
registry_df = self.create_new_registry(metrics=metrics)
|
|
|
|
return registry_df
|
|
|
|
def load_local_registry(self, registry_path: Path, metrics: Metrics):
|
|
"""
|
|
In local development mode, load the registry
|
|
"""
|
|
if registry_path.exists():
|
|
logger.info("Registry file found - Loading into Dataframe")
|
|
registry_df = pd.read_csv(registry_path, index_col=None)
|
|
else:
|
|
logger.info("No registry found - creating new one")
|
|
registry_df = self.create_new_registry(metrics=metrics)
|
|
|
|
return registry_df
|
|
|
|
def create_new_registry(self, metrics: Metrics):
|
|
"""
|
|
If no registry is found, create a new one
|
|
"""
|
|
# TODO: Moved columns into settings: MODEL_DETAILS and Metrics class columns
|
|
columns = [
|
|
BEST_MODEL_COLUMN_NAME,
|
|
"model_type",
|
|
"model_name",
|
|
"model_location",
|
|
] + metrics.list_metric_functions()
|
|
|
|
registry_df = pd.DataFrame(columns=columns)
|
|
|
|
return registry_df
|
|
|
|
def save_registry(self, output_filepath: Path, client: BotoClient) -> None:
|
|
"""
|
|
Providing a path, this function will save the model to be used.
|
|
"""
|
|
if client.client is None:
|
|
logger.info("In local development mode - no need for s3 client")
|
|
else:
|
|
logger.info(f"Saving registry into s3")
|
|
s3_location = client.model_bucket + "/" + str(output_filepath)
|
|
|
|
client.client.upload_file(
|
|
str(output_filepath), client.model_bucket, str(output_filepath)
|
|
)
|
|
|
|
logger.info("Save complete")
|