added download and load for predcitions

This commit is contained in:
Michael Duong 2023-09-01 15:42:04 +01:00
parent 31ce4a670e
commit 7037625e7f
6 changed files with 112 additions and 28 deletions

View file

@ -52,20 +52,22 @@ class AutogluonModel:
self.predictions = None
def load_model(
self, filepath: str | Path, s3_client: S3FSClient | None = None
self,
filepath: str | Path,
s3_client: S3FSClient,
model_folder: str = "local_model",
) -> None:
"""
Providing a path, this function will load the model to be used. Will load to internal variable
"""
if s3_client is None:
filepath = str(filepath)
if s3_client.client is None:
logger.info("In local development mode - no need for s3 client")
filepath = str(filepath)
self.model = TabularPredictor.load(path=filepath)
else:
pass
# logger.info(f"Loading model from s3")
# s3_client.download_model(filepath=filepath, local_filepath=)
# self.model =
logger.info(f"Loading model from s3")
s3_client.download_model(filepath=filepath, model_folder=model_folder)
self.model = TabularPredictor.load(path=model_folder)
def save_model(self, output_filepath: Path, s3fs_client: S3FSClient) -> None:
"""

View file

@ -51,15 +51,45 @@ class S3FSClient:
if runtime_environment == "local":
logger.info("In local development - no need for s3")
elif runtime_environment in ["local-mock", "dev"]:
# TODO: get from enironment
self.model_bucket = "retrofit-model-directory-dev"
elif runtime_environment in ["staging", "prod"]:
self.model_bucket = f"retrofit-model-directory-{runtime_environment}"
else:
raise NotImplementedError("No corresponding runtime environment")
# def download_model(self, filepath: str, local_filepath: str = "."):
# """
# For the file path, download the model locally so that we can load the model
# """
# if local_filepath is None:
# self.local_filepath = filepath
def download_model(self, filepath: str, model_folder: str):
"""
For the file path, download the model locally so that we can load the model
"""
if self.client is None:
logger.info("No need to download model as local development")
else:
def list_files_recursively(folder_path, client):
all_files = []
for root, dirs, files in client.walk(folder_path):
for file in files:
s3_path = os.path.join(root, file)
all_files.append(s3_path)
return all_files
# List all files in the specified S3 folder and its subfolders
files = list_files_recursively(
f"{self.model_bucket}/{filepath}", client=self.client
)
# Download each file
for file in files:
# Extract the filename from the S3 path
filename = file.split(filepath)[-1]
# Define the local path where you want to save the file
local_path = os.path.join("local_model", filename)
# Download the file from S3 to the local directory
self.client.get(file, local_path)
print(f"Downloaded {filename} to {local_path}")
print("Download completed.")

View file

@ -1,6 +1,54 @@
import pandas as pd
import os
from typing import Protocol
import boto3
from io import BytesIO, StringIO
def read_parquet_from_s3(bucket_name, file_key):
"""
Read a CSV file from S3 using boto3 and pandas.
:param bucket_name: Name of the S3 bucket.
:param file_key: Key of the file (including directory path within the bucket).
:param aws_access_key_id: AWS Access Key ID
:param aws_secret_access_key: AWS Secret Access Key
:return: DataFrame containing the CSV data.
"""
# Initialize the S3 client
s3_client = boto3.client("s3")
# Get the object
s3_object = s3_client.get_object(Bucket=bucket_name, Key=file_key)
# Read the CSV body into a DataFrame
csv_body = s3_object["Body"].read()
df = pd.read_parquet(BytesIO(csv_body))
return df
def read_csv_from_s3(bucket_name, file_key, index_col):
"""
Read a CSV file from S3 using boto3 and pandas.
:param bucket_name: Name of the S3 bucket.
:param file_key: Key of the file (including directory path within the bucket).
:param aws_access_key_id: AWS Access Key ID
:param aws_secret_access_key: AWS Secret Access Key
:return: DataFrame containing the CSV data.
"""
# Initialize the S3 client
s3_client = boto3.client("s3")
# Get the object
s3_object = s3_client.get_object(Bucket=bucket_name, Key=file_key)
# Read the CSV body into a DataFrame
csv_body = s3_object["Body"].read().decode("utf-8")
df = pd.read_csv(StringIO(csv_body), index_col=index_col)
return df
class DataLoader(Protocol):
@ -80,19 +128,15 @@ class S3DataLoader:
@staticmethod
def load(filepath: str, index_col: str | None = None) -> pd.DataFrame:
storage_options = {
"key": os.environ.get("AWS_ACCESS_KEY_ID"),
"secret": os.environ.get("AWS_SECRET_ACCESS_KEY"),
}
filepath_split = filepath.split("s3://")[-1].split("/", 1)
bucket = filepath_split[0]
key = filepath_split[1]
if filepath.endswith(".parquet"):
df = pd.read_parquet(filepath, storage_options=storage_options)
df = read_parquet_from_s3(bucket, key)
if index_col is not None:
df = df.set_index(index_col)
elif filepath.endswith(".csv"):
df = pd.read_csv(
filepath, index_col=index_col, storage_options=storage_options
)
df = read_csv_from_s3(bucket, key, index_col)
else:
raise ValueError(f"File format not supported for file: {filepath}")

View file

@ -30,7 +30,7 @@ def handler(event, context):
# model_path = os.environ.get("MODEL_PATH", "http://minio:9000/data/model_directory/")
model_path = os.environ.get(
"MODEL_PATH",
"s3://retrofit-model-directory-{RUNTIME_ENVIRONMENT}/RDSAP_CHANGE/autogluon/rdsap_change-medium_quality-30"
f"s3://retrofit-model-directory-{RUNTIME_ENVIRONMENT}/RDSAP_CHANGE/autogluon/rdsap_change-medium_quality-30"
"-2023-08-30_11-43-41/deployment/",
)

View file

@ -11,6 +11,7 @@ from datetime import datetime
from MLModel.Models import AutogluonModel
from core.Logger import logger
from core.DataLoader import dataloader_factory
from core.CloudClient import S3FSClient
from core.Settings import (
BASE_REGISTRY_PATH,
REGISTRY_FILE,
@ -23,12 +24,19 @@ from core.Settings import (
TIMESTAMP = datetime.now().strftime(TIMESTAMP_FORMAT)
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "dev")
CLIENT = S3FSClient(runtime_environment=RUNTIME_ENVIRONMENT)
# FOR TESTING
# For now just loading data first and then passing into function (i.e. as if we receive json data and convert to
# DataFrame)
# TEST_DATA = DataLoader.load(filepath="../simulation_system/model_build_data/change_data/rdsap_full/test_data.parquet")
# DATA = TEST_DATA.sample(1)
# For testing in dev s3
# Data path can be passed as so:
# python3 predictions.py --data-path s3://retrofit-data-dev/model_build_data/change_data/rdsap_full/test_data.parquet
# data_path="s3://retrofit-data-dev/model_build_data/change_data/rdsap_full/test_data.parquet"
def ingest_arguments() -> argparse.Namespace:
"""
@ -75,9 +83,7 @@ def prediction(
logger.error("No registry path provided or registry doesn't exist")
exit(1)
elif RUNTIME_ENVIRONMENT == "dev":
registry_path = (
"s3://retrofit-model-directory-dev/RDSAP_CHANGE/model_registry.csv"
)
registry_path = "s3://retrofit-model-directory-dev/model_directory/RDSAP_CHANGE/model_registry.csv"
else:
raise NotImplemented("TO be implemented")
@ -130,7 +136,9 @@ def prediction(
logger.error("No other model currently")
exit(1)
model.load_model(filepath=model_location)
model.load_model(
filepath=model_location, s3_client=CLIENT, model_folder="local_model"
)
logger.info("--- Generating Predictions ---")
prediction = model.generate_predictions(data=data)

View file

@ -28,7 +28,7 @@ from core.Settings import (
TIMESTAMP = datetime.now().strftime(TIMESTAMP_FORMAT)
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "local")
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "dev")
CLIENT = S3FSClient(runtime_environment=RUNTIME_ENVIRONMENT)