mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
begin boto3 chagne
This commit is contained in:
parent
baec5e7cc0
commit
58e6ce54d8
8 changed files with 108 additions and 26 deletions
|
|
@ -34,4 +34,4 @@ USER ${USER}
|
|||
WORKDIR /home/simulation_system
|
||||
|
||||
# Run the python command
|
||||
CMD ["python3", "predictions.py", "--data-path", "./model_build_data/change_data/rdsap_full/test_data.parquet"]
|
||||
CMD ["python3", "predictions.py", "--data-path", "s3://retrofit-data-dev/model_build_data/change_data/rdsap_full/test_data_with_id.parquet"]
|
||||
|
|
|
|||
|
|
@ -34,4 +34,4 @@ USER ${USER}
|
|||
WORKDIR /home/simulation_system
|
||||
|
||||
# Run the python command
|
||||
CMD ["python3", "training.py", "--train-filepath", "./model_build_data/change_data/rdsap_full/train_validation_data.parquet", "--test-filepath", "./model_build_data/change_data/rdsap_full/test_data.parquet"]
|
||||
CMD ["python3", "training.py", "--train-filepath", "s3://retrofit-data-dev/model_build_data/change_data/rdsap_full/train_validation_data.parquet", "--test-filepath", "s3://retrofit-data-dev/model_build_data/change_data/rdsap_full/test_data.parquet"]
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ Set up the client to be used for downloading and uploading model files
|
|||
|
||||
import os
|
||||
import s3fs
|
||||
import boto3
|
||||
from core.Logger import logger
|
||||
|
||||
|
||||
|
|
@ -86,10 +87,68 @@ class S3FSClient:
|
|||
filename = file.split(filepath)[-1]
|
||||
|
||||
# Define the local path where you want to save the file
|
||||
local_path = os.path.join("local_model", filename)
|
||||
local_path = os.path.join(model_folder, filename)
|
||||
|
||||
# Download the file from S3 to the local directory
|
||||
self.client.get(file, local_path)
|
||||
print(f"Downloaded {filename} to {local_path}")
|
||||
|
||||
print("Download completed.")
|
||||
|
||||
|
||||
class BotoClient:
|
||||
"""
|
||||
Using boto3 to access the different aws storage configurations
|
||||
"""
|
||||
|
||||
def __init__(self, runtime_environment: str = "local") -> None:
|
||||
self.client = None
|
||||
self.model_bucket: str
|
||||
|
||||
self.client_factory(runtime_environment)
|
||||
self.determine_model_bucket(runtime_environment)
|
||||
|
||||
def client_factory(self, runtime_environment: str = "local"):
|
||||
"""
|
||||
Select the correct s3 client to use
|
||||
"""
|
||||
|
||||
if runtime_environment == "local":
|
||||
logger.info("No S3 client setup required")
|
||||
elif runtime_environment == "local-mock":
|
||||
logger.info(f"S3 settings for {runtime_environment}")
|
||||
session = boto3.Session()
|
||||
self.client = session.client(
|
||||
service_name="s3",
|
||||
aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID", "admin"),
|
||||
aws_secret_access_key=os.environ.get(
|
||||
"AWS_SECRET_ACCESS_KEY", "password"
|
||||
),
|
||||
endpoint_url=os.environ.get("ENDPOINT_URL", "http://localhost:9000"),
|
||||
)
|
||||
elif runtime_environment in ["dev", "staging", "prod"]:
|
||||
logger.info(f"S3 settings for {runtime_environment}")
|
||||
# Key/ token should be in session/lambda for this
|
||||
self.client = boto3.client("s3")
|
||||
else:
|
||||
raise NotImplementedError("No correspnding runtime environment")
|
||||
|
||||
def determine_model_bucket(self, runtime_environment: str) -> None:
|
||||
"""
|
||||
For the given environment, return the correct bucket for models
|
||||
"""
|
||||
if runtime_environment == "local":
|
||||
logger.info("In local development - no need for s3")
|
||||
elif runtime_environment in ["local-mock", "dev"]:
|
||||
# TODO: get from enironment
|
||||
self.model_bucket = "retrofit-model-directory-dev"
|
||||
elif runtime_environment in ["staging", "prod"]:
|
||||
self.model_bucket = f"retrofit-model-directory-{runtime_environment}"
|
||||
else:
|
||||
raise NotImplementedError("No corresponding runtime environment")
|
||||
|
||||
def download_model(self, filepath: str, model_folder: str):
|
||||
"""
|
||||
For the file path, download the model locally so that we can load the model
|
||||
"""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import os
|
|||
from typing import Protocol
|
||||
import boto3
|
||||
from io import BytesIO, StringIO
|
||||
from core.CloudClient import BotoClient
|
||||
|
||||
|
||||
def read_parquet_from_s3(bucket_name, file_key):
|
||||
|
|
@ -57,7 +58,9 @@ class DataLoader(Protocol):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(filepath: str, index_col: str | None = None) -> pd.DataFrame | None:
|
||||
def load(
|
||||
client: BotoClient, filepath: str, index_col: str | None = None
|
||||
) -> pd.DataFrame | None:
|
||||
"""
|
||||
Loading data from the relevant source
|
||||
"""
|
||||
|
|
@ -92,7 +95,9 @@ class S3MockDataLoader:
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(filepath: str, index_col: str | None = None) -> pd.DataFrame:
|
||||
def load(
|
||||
client: BotoClient, filepath: str, index_col: str | None = None
|
||||
) -> pd.DataFrame:
|
||||
|
||||
# TODO: Ingest these as environment variables in the docker compose file
|
||||
storage_options = {
|
||||
|
|
@ -126,7 +131,9 @@ class S3DataLoader:
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(filepath: str, index_col: str | None = None) -> pd.DataFrame:
|
||||
def load(
|
||||
client: BotoClient, filepath: str, index_col: str | None = None
|
||||
) -> pd.DataFrame:
|
||||
|
||||
filepath_split = filepath.split("s3://")[-1].split("/", 1)
|
||||
bucket = filepath_split[0]
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ services:
|
|||
# dockerfile: ./Dockerfiles/Dockerfile.training
|
||||
# image: simulation_system_training
|
||||
# environment:
|
||||
# RUNTIME_ENVIRONMENT: local-mock
|
||||
# ENDPOINT_URL: http://minio:9000/
|
||||
# AWS_ACCESS_KEY_ID: *MINIO_USER
|
||||
# AWS_SECRET_ACCESS_KEY: *MINIO_PASS
|
||||
|
|
@ -34,19 +35,19 @@ services:
|
|||
# command:
|
||||
# ["bash"]
|
||||
|
||||
# simulation_system_prediction:
|
||||
# build:
|
||||
# context: ./
|
||||
# dockerfile: ./Dockerfiles/Dockerfile.prediction
|
||||
# image: simulation_system_prediction
|
||||
# environment:
|
||||
# ENDPOINT_URL: http://minio:9000/
|
||||
# AWS_ACCESS_KEY_ID: *MINIO_USER
|
||||
# AWS_SECRET_ACCESS_KEY: *MINIO_PASS
|
||||
# tty: true
|
||||
# depends_on:
|
||||
# simulation_system_training:
|
||||
# condition: service_completed_successfully
|
||||
simulation_system_prediction:
|
||||
build:
|
||||
context: ./
|
||||
dockerfile: ./Dockerfiles/Dockerfile.prediction
|
||||
image: simulation_system_prediction
|
||||
environment:
|
||||
ENDPOINT_URL: http://minio:9000/
|
||||
AWS_ACCESS_KEY_ID: *MINIO_USER
|
||||
AWS_SECRET_ACCESS_KEY: *MINIO_PASS
|
||||
tty: true
|
||||
# depends_on:
|
||||
# simulation_system_training:
|
||||
# condition: service_completed_successfully
|
||||
# command:
|
||||
# ["bash"]
|
||||
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -5,6 +5,7 @@ Script to load MLModel class and generate predictions
|
|||
import os
|
||||
import json
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
|
@ -12,6 +13,8 @@ from MLModel.Models import AutogluonModel
|
|||
from core.Logger import logger
|
||||
from core.DataLoader import dataloader_factory
|
||||
from core.CloudClient import S3FSClient
|
||||
from core.Metrics import Metrics
|
||||
from core.RegistryHandler import RegistryHandler
|
||||
from core.Settings import (
|
||||
BASE_REGISTRY_PATH,
|
||||
REGISTRY_FILE,
|
||||
|
|
@ -19,10 +22,11 @@ from core.Settings import (
|
|||
PREDICTION_FILE,
|
||||
METADATA_FILE,
|
||||
TIMESTAMP_FORMAT,
|
||||
MODEL_DIRECTORY,
|
||||
)
|
||||
|
||||
TIMESTAMP = datetime.now().strftime(TIMESTAMP_FORMAT)
|
||||
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "dev")
|
||||
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "local-mock")
|
||||
|
||||
CLIENT = S3FSClient(runtime_environment=RUNTIME_ENVIRONMENT)
|
||||
|
||||
|
|
@ -82,7 +86,7 @@ def prediction(
|
|||
if registry_path is None or not registry_path.exists():
|
||||
logger.error("No registry path provided or registry doesn't exist")
|
||||
exit(1)
|
||||
elif RUNTIME_ENVIRONMENT == "dev":
|
||||
elif RUNTIME_ENVIRONMENT in ["local-mock", "dev"]:
|
||||
registry_path = "s3://retrofit-model-directory-dev/model_directory/RDSAP_CHANGE/model_registry.csv"
|
||||
else:
|
||||
raise NotImplemented("TO be implemented")
|
||||
|
|
@ -95,7 +99,17 @@ def prediction(
|
|||
else:
|
||||
# TODO: Think about where registry will sit/ type
|
||||
logger.info("Loading best model from registry")
|
||||
registry_df = pd.read_csv(registry_path)
|
||||
|
||||
metrics = Metrics()
|
||||
registry_handler = RegistryHandler()
|
||||
|
||||
registry_path = Path(MODEL_DIRECTORY) / target_column / REGISTRY_FILE
|
||||
|
||||
registry_df = registry_handler.load_registry(
|
||||
registry_path=registry_path, s3fs_client=CLIENT, metrics=metrics
|
||||
)
|
||||
|
||||
# registry_df = pd.read_csv(registry_path)
|
||||
best_model_df = registry_df[registry_df["best_model"]]
|
||||
|
||||
model_location = best_model_df["model_location"].values[0]
|
||||
|
|
@ -120,7 +134,7 @@ def prediction(
|
|||
raise ValueError("No data loaded")
|
||||
|
||||
# # TODO: DOWNSAMPLING DOWN TO JUST USE ONE FOR PREDICTION
|
||||
# data = data.sample(1)
|
||||
data = data.sample(1)
|
||||
else:
|
||||
logger.info("Using data provided")
|
||||
data = json.loads(str(data))
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from core.Logger import logger
|
|||
from core.Metrics import Metrics, sort_by_metric
|
||||
from core.DataLoader import dataloader_factory
|
||||
from core.FeatureProcessor import FeatureProcessor
|
||||
from core.CloudClient import S3FSClient
|
||||
from core.CloudClient import S3FSClient, BotoClient
|
||||
from core.RegistryHandler import RegistryHandler
|
||||
from core.Settings import (
|
||||
MODEL_DIRECTORY,
|
||||
|
|
@ -28,9 +28,10 @@ from core.Settings import (
|
|||
|
||||
TIMESTAMP = datetime.now().strftime(TIMESTAMP_FORMAT)
|
||||
|
||||
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "local")
|
||||
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "local-mock")
|
||||
|
||||
CLIENT = S3FSClient(runtime_environment=RUNTIME_ENVIRONMENT)
|
||||
CLIENT = BotoClient(runtime_environment=RUNTIME_ENVIRONMENT)
|
||||
# CLIENT = S3FSClient(runtime_environment=RUNTIME_ENVIRONMENT)
|
||||
|
||||
|
||||
# FOR TESTING
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue