mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
added download and load for predcitions
This commit is contained in:
parent
31ce4a670e
commit
7037625e7f
6 changed files with 112 additions and 28 deletions
|
|
@ -52,20 +52,22 @@ class AutogluonModel:
|
|||
self.predictions = None
|
||||
|
||||
def load_model(
|
||||
self, filepath: str | Path, s3_client: S3FSClient | None = None
|
||||
self,
|
||||
filepath: str | Path,
|
||||
s3_client: S3FSClient,
|
||||
model_folder: str = "local_model",
|
||||
) -> None:
|
||||
"""
|
||||
Providing a path, this function will load the model to be used. Will load to internal variable
|
||||
"""
|
||||
if s3_client is None:
|
||||
filepath = str(filepath)
|
||||
if s3_client.client is None:
|
||||
logger.info("In local development mode - no need for s3 client")
|
||||
filepath = str(filepath)
|
||||
self.model = TabularPredictor.load(path=filepath)
|
||||
else:
|
||||
pass
|
||||
# logger.info(f"Loading model from s3")
|
||||
# s3_client.download_model(filepath=filepath, local_filepath=)
|
||||
# self.model =
|
||||
logger.info(f"Loading model from s3")
|
||||
s3_client.download_model(filepath=filepath, model_folder=model_folder)
|
||||
self.model = TabularPredictor.load(path=model_folder)
|
||||
|
||||
def save_model(self, output_filepath: Path, s3fs_client: S3FSClient) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -51,15 +51,45 @@ class S3FSClient:
|
|||
if runtime_environment == "local":
|
||||
logger.info("In local development - no need for s3")
|
||||
elif runtime_environment in ["local-mock", "dev"]:
|
||||
# TODO: get from enironment
|
||||
self.model_bucket = "retrofit-model-directory-dev"
|
||||
elif runtime_environment in ["staging", "prod"]:
|
||||
self.model_bucket = f"retrofit-model-directory-{runtime_environment}"
|
||||
else:
|
||||
raise NotImplementedError("No corresponding runtime environment")
|
||||
|
||||
# def download_model(self, filepath: str, local_filepath: str = "."):
|
||||
# """
|
||||
# For the file path, download the model locally so that we can load the model
|
||||
# """
|
||||
# if local_filepath is None:
|
||||
# self.local_filepath = filepath
|
||||
def download_model(self, filepath: str, model_folder: str):
|
||||
"""
|
||||
For the file path, download the model locally so that we can load the model
|
||||
"""
|
||||
|
||||
if self.client is None:
|
||||
logger.info("No need to download model as local development")
|
||||
else:
|
||||
|
||||
def list_files_recursively(folder_path, client):
|
||||
all_files = []
|
||||
for root, dirs, files in client.walk(folder_path):
|
||||
for file in files:
|
||||
s3_path = os.path.join(root, file)
|
||||
all_files.append(s3_path)
|
||||
return all_files
|
||||
|
||||
# List all files in the specified S3 folder and its subfolders
|
||||
files = list_files_recursively(
|
||||
f"{self.model_bucket}/{filepath}", client=self.client
|
||||
)
|
||||
|
||||
# Download each file
|
||||
for file in files:
|
||||
# Extract the filename from the S3 path
|
||||
filename = file.split(filepath)[-1]
|
||||
|
||||
# Define the local path where you want to save the file
|
||||
local_path = os.path.join("local_model", filename)
|
||||
|
||||
# Download the file from S3 to the local directory
|
||||
self.client.get(file, local_path)
|
||||
print(f"Downloaded {filename} to {local_path}")
|
||||
|
||||
print("Download completed.")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,54 @@
|
|||
import pandas as pd
|
||||
import os
|
||||
from typing import Protocol
|
||||
import boto3
|
||||
from io import BytesIO, StringIO
|
||||
|
||||
|
||||
def read_parquet_from_s3(bucket_name, file_key):
|
||||
"""
|
||||
Read a CSV file from S3 using boto3 and pandas.
|
||||
|
||||
:param bucket_name: Name of the S3 bucket.
|
||||
:param file_key: Key of the file (including directory path within the bucket).
|
||||
:param aws_access_key_id: AWS Access Key ID
|
||||
:param aws_secret_access_key: AWS Secret Access Key
|
||||
:return: DataFrame containing the CSV data.
|
||||
"""
|
||||
# Initialize the S3 client
|
||||
s3_client = boto3.client("s3")
|
||||
|
||||
# Get the object
|
||||
s3_object = s3_client.get_object(Bucket=bucket_name, Key=file_key)
|
||||
|
||||
# Read the CSV body into a DataFrame
|
||||
csv_body = s3_object["Body"].read()
|
||||
df = pd.read_parquet(BytesIO(csv_body))
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def read_csv_from_s3(bucket_name, file_key, index_col):
|
||||
"""
|
||||
Read a CSV file from S3 using boto3 and pandas.
|
||||
|
||||
:param bucket_name: Name of the S3 bucket.
|
||||
:param file_key: Key of the file (including directory path within the bucket).
|
||||
:param aws_access_key_id: AWS Access Key ID
|
||||
:param aws_secret_access_key: AWS Secret Access Key
|
||||
:return: DataFrame containing the CSV data.
|
||||
"""
|
||||
# Initialize the S3 client
|
||||
s3_client = boto3.client("s3")
|
||||
|
||||
# Get the object
|
||||
s3_object = s3_client.get_object(Bucket=bucket_name, Key=file_key)
|
||||
|
||||
# Read the CSV body into a DataFrame
|
||||
csv_body = s3_object["Body"].read().decode("utf-8")
|
||||
df = pd.read_csv(StringIO(csv_body), index_col=index_col)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
class DataLoader(Protocol):
|
||||
|
|
@ -80,19 +128,15 @@ class S3DataLoader:
|
|||
@staticmethod
|
||||
def load(filepath: str, index_col: str | None = None) -> pd.DataFrame:
|
||||
|
||||
storage_options = {
|
||||
"key": os.environ.get("AWS_ACCESS_KEY_ID"),
|
||||
"secret": os.environ.get("AWS_SECRET_ACCESS_KEY"),
|
||||
}
|
||||
|
||||
filepath_split = filepath.split("s3://")[-1].split("/", 1)
|
||||
bucket = filepath_split[0]
|
||||
key = filepath_split[1]
|
||||
if filepath.endswith(".parquet"):
|
||||
df = pd.read_parquet(filepath, storage_options=storage_options)
|
||||
df = read_parquet_from_s3(bucket, key)
|
||||
if index_col is not None:
|
||||
df = df.set_index(index_col)
|
||||
elif filepath.endswith(".csv"):
|
||||
df = pd.read_csv(
|
||||
filepath, index_col=index_col, storage_options=storage_options
|
||||
)
|
||||
df = read_csv_from_s3(bucket, key, index_col)
|
||||
else:
|
||||
raise ValueError(f"File format not supported for file: {filepath}")
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ def handler(event, context):
|
|||
# model_path = os.environ.get("MODEL_PATH", "http://minio:9000/data/model_directory/")
|
||||
model_path = os.environ.get(
|
||||
"MODEL_PATH",
|
||||
"s3://retrofit-model-directory-{RUNTIME_ENVIRONMENT}/RDSAP_CHANGE/autogluon/rdsap_change-medium_quality-30"
|
||||
f"s3://retrofit-model-directory-{RUNTIME_ENVIRONMENT}/RDSAP_CHANGE/autogluon/rdsap_change-medium_quality-30"
|
||||
"-2023-08-30_11-43-41/deployment/",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from datetime import datetime
|
|||
from MLModel.Models import AutogluonModel
|
||||
from core.Logger import logger
|
||||
from core.DataLoader import dataloader_factory
|
||||
from core.CloudClient import S3FSClient
|
||||
from core.Settings import (
|
||||
BASE_REGISTRY_PATH,
|
||||
REGISTRY_FILE,
|
||||
|
|
@ -23,12 +24,19 @@ from core.Settings import (
|
|||
TIMESTAMP = datetime.now().strftime(TIMESTAMP_FORMAT)
|
||||
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "dev")
|
||||
|
||||
CLIENT = S3FSClient(runtime_environment=RUNTIME_ENVIRONMENT)
|
||||
|
||||
# FOR TESTING
|
||||
# For now just loading data first and then passing into function (i.e. as if we receive json data and convert to
|
||||
# DataFrame)
|
||||
# TEST_DATA = DataLoader.load(filepath="../simulation_system/model_build_data/change_data/rdsap_full/test_data.parquet")
|
||||
# DATA = TEST_DATA.sample(1)
|
||||
|
||||
# For testing in dev s3
|
||||
# Data path can be passed as so:
|
||||
# python3 predictions.py --data-path s3://retrofit-data-dev/model_build_data/change_data/rdsap_full/test_data.parquet
|
||||
# data_path="s3://retrofit-data-dev/model_build_data/change_data/rdsap_full/test_data.parquet"
|
||||
|
||||
|
||||
def ingest_arguments() -> argparse.Namespace:
|
||||
"""
|
||||
|
|
@ -75,9 +83,7 @@ def prediction(
|
|||
logger.error("No registry path provided or registry doesn't exist")
|
||||
exit(1)
|
||||
elif RUNTIME_ENVIRONMENT == "dev":
|
||||
registry_path = (
|
||||
"s3://retrofit-model-directory-dev/RDSAP_CHANGE/model_registry.csv"
|
||||
)
|
||||
registry_path = "s3://retrofit-model-directory-dev/model_directory/RDSAP_CHANGE/model_registry.csv"
|
||||
else:
|
||||
raise NotImplemented("TO be implemented")
|
||||
|
||||
|
|
@ -130,7 +136,9 @@ def prediction(
|
|||
logger.error("No other model currently")
|
||||
exit(1)
|
||||
|
||||
model.load_model(filepath=model_location)
|
||||
model.load_model(
|
||||
filepath=model_location, s3_client=CLIENT, model_folder="local_model"
|
||||
)
|
||||
|
||||
logger.info("--- Generating Predictions ---")
|
||||
prediction = model.generate_predictions(data=data)
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from core.Settings import (
|
|||
|
||||
TIMESTAMP = datetime.now().strftime(TIMESTAMP_FORMAT)
|
||||
|
||||
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "local")
|
||||
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "dev")
|
||||
|
||||
CLIENT = S3FSClient(runtime_environment=RUNTIME_ENVIRONMENT)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue