mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
183 lines
7.1 KiB
Python
183 lines
7.1 KiB
Python
"""
|
|
Set up the client to be used for downloading and uploading model files
|
|
"""
|
|
|
|
import os
|
|
import boto3
|
|
from core.Logger import logger
|
|
|
|
|
|
# class S3FSClient:
|
|
# """
|
|
# Set up the correct client to upload files to s3
|
|
# """
|
|
|
|
# def __init__(self, runtime_environment: str = "local") -> None:
|
|
# self.client: s3fs.S3FileSystem | None = None
|
|
# self.model_bucket: str
|
|
|
|
# self.client_factory(runtime_environment)
|
|
# self.determine_model_bucket(runtime_environment)
|
|
|
|
# def client_factory(self, runtime_environment: str = "local"):
|
|
# """
|
|
# Select the correct s3 client to use
|
|
# """
|
|
|
|
# if runtime_environment == "local":
|
|
# logger.info("No S3 client setup required")
|
|
# elif runtime_environment == "local-mock":
|
|
# logger.info(f"S3 settings for {runtime_environment}")
|
|
# self.client = s3fs.S3FileSystem(
|
|
# key=os.environ.get("AWS_ACCESS_KEY_ID", "admin"),
|
|
# secret=os.environ.get("AWS_SECRET_ACCESS_KEY", "password"),
|
|
# client_kwargs={
|
|
# "endpoint_url": os.environ.get(
|
|
# "ENDPOINT_URL", "http://localhost:9000"
|
|
# )
|
|
# },
|
|
# )
|
|
# elif runtime_environment in ["dev", "staging", "prod"]:
|
|
# logger.info(f"S3 settings for {runtime_environment}")
|
|
# # Key/ token should be in session/lambda for this
|
|
# self.client = s3fs.S3FileSystem()
|
|
# else:
|
|
# raise NotImplementedError("No correspnding runtime environment")
|
|
|
|
# def determine_model_bucket(self, runtime_environment: str) -> None:
|
|
# """
|
|
# For the given environment, return the correct bucket for models
|
|
# """
|
|
# if runtime_environment == "local":
|
|
# logger.info("In local development - no need for s3")
|
|
# elif runtime_environment in ["local-mock", "dev"]:
|
|
# # TODO: get from enironment
|
|
# self.model_bucket = "retrofit-model-directory-dev"
|
|
# elif runtime_environment in ["staging", "prod"]:
|
|
# self.model_bucket = f"retrofit-model-directory-{runtime_environment}"
|
|
# else:
|
|
# raise NotImplementedError("No corresponding runtime environment")
|
|
|
|
# def download_model(self, filepath: str, model_folder: str):
|
|
# """
|
|
# For the file path, download the model locally so that we can load the model
|
|
# """
|
|
|
|
# if self.client is None:
|
|
# logger.info("No need to download model as local development")
|
|
# else:
|
|
|
|
# def list_files_recursively(folder_path, client):
|
|
# all_files = []
|
|
# for root, dirs, files in client.walk(folder_path):
|
|
# for file in files:
|
|
# s3_path = os.path.join(root, file)
|
|
# all_files.append(s3_path)
|
|
# return all_files
|
|
|
|
# # List all files in the specified S3 folder and its subfolders
|
|
# files = list_files_recursively(
|
|
# f"{self.model_bucket}/{filepath}", client=self.client
|
|
# )
|
|
|
|
# # Download each file
|
|
# for file in files:
|
|
# # Extract the filename from the S3 path
|
|
# filename = file.split(filepath)[-1]
|
|
|
|
# # Define the local path where you want to save the file
|
|
# local_path = os.path.join(model_folder, filename)
|
|
|
|
# # Download the file from S3 to the local directory
|
|
# self.client.get(file, local_path)
|
|
# print(f"Downloaded {filename} to {local_path}")
|
|
|
|
# print("Download completed.")
|
|
|
|
|
|
class BotoClient:
|
|
"""
|
|
Using boto3 to access the different aws storage configurations
|
|
"""
|
|
|
|
def __init__(self, runtime_environment: str = "local") -> None:
|
|
self.client = None
|
|
self.model_bucket: str
|
|
|
|
self.client_factory(runtime_environment)
|
|
self.determine_model_bucket(runtime_environment)
|
|
|
|
def client_factory(self, runtime_environment: str = "local"):
|
|
"""
|
|
Select the correct s3 client to use
|
|
"""
|
|
|
|
if runtime_environment == "local":
|
|
logger.info("No S3 client setup required")
|
|
elif runtime_environment == "local-mock":
|
|
logger.info(f"S3 settings for {runtime_environment}")
|
|
session = boto3.Session()
|
|
self.client = session.client(
|
|
service_name="s3",
|
|
aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID", "admin"),
|
|
aws_secret_access_key=os.environ.get(
|
|
"AWS_SECRET_ACCESS_KEY", "password"
|
|
),
|
|
endpoint_url=os.environ.get("ENDPOINT_URL", "http://localhost:9000"),
|
|
)
|
|
elif runtime_environment in ["dev", "staging", "prod"]:
|
|
logger.info(f"S3 settings for {runtime_environment}")
|
|
# Key/ token should be in session/lambda for this
|
|
self.client = boto3.client("s3")
|
|
else:
|
|
raise NotImplementedError("No correspnding runtime environment")
|
|
|
|
def determine_model_bucket(self, runtime_environment: str) -> None:
|
|
"""
|
|
For the given environment, return the correct bucket for models
|
|
"""
|
|
if runtime_environment == "local":
|
|
logger.info("In local development - no need for s3")
|
|
elif runtime_environment in ["local-mock", "dev"]:
|
|
# TODO: get from enironment
|
|
self.model_bucket = "retrofit-model-directory-dev"
|
|
elif runtime_environment in ["staging", "prod"]:
|
|
self.model_bucket = f"retrofit-model-directory-{runtime_environment}"
|
|
else:
|
|
raise NotImplementedError("No corresponding runtime environment")
|
|
|
|
def download_model(self, filepath: str, model_folder: str):
|
|
"""
|
|
For the file path, download the model locally so that we can load the model
|
|
"""
|
|
# List all objects with the specified prefix in the bucket
|
|
|
|
if self.client is None:
|
|
raise ValueError("SHould not be in here!")
|
|
|
|
objects = self.client.list_objects_v2(Bucket=self.model_bucket, Prefix=filepath)
|
|
|
|
# Ensure the local directory for downloads exists
|
|
if not os.path.exists(model_folder):
|
|
os.makedirs(model_folder)
|
|
|
|
# Download each object with the specified prefix
|
|
for obj in objects.get("Contents", []):
|
|
# Get the object key (file path)
|
|
object_key = obj["Key"]
|
|
|
|
# Determine the local file path to save the object
|
|
local_file_path = os.path.join(
|
|
model_folder, object_key.split(f"{filepath}/")[-1]
|
|
)
|
|
|
|
# Create the local directory if it doesn't exist
|
|
local_directory = os.path.dirname(local_file_path)
|
|
if not os.path.exists(local_directory):
|
|
os.makedirs(local_directory)
|
|
|
|
# Download the object from S3 to the local directory
|
|
self.client.download_file(self.model_bucket, object_key, local_file_path)
|
|
print(f"Downloaded {object_key} to {local_file_path}")
|
|
|
|
print("Download completed.")
|