Model/model_data/simulation_system/core/CloudClient.py
2023-09-01 15:42:04 +01:00

95 lines
3.5 KiB
Python

"""
Set up the client to be used for downloading and uploading model files
"""
import os
import s3fs
from core.Logger import logger
class S3FSClient:
"""
Set up the correct client to upload files to s3
"""
def __init__(self, runtime_environment: str = "local") -> None:
self.client: s3fs.S3FileSystem | None = None
self.model_bucket: str
self.client_factory(runtime_environment)
self.determine_model_bucket(runtime_environment)
def client_factory(self, runtime_environment: str = "local"):
"""
Select the correct s3 client to use
"""
if runtime_environment == "local":
logger.info("No S3 client setup required")
elif runtime_environment == "local-mock":
logger.info(f"S3 settings for {runtime_environment}")
self.client = s3fs.S3FileSystem(
key=os.environ.get("AWS_ACCESS_KEY_ID", "admin"),
secret=os.environ.get("AWS_SECRET_ACCESS_KEY", "password"),
client_kwargs={
"endpoint_url": os.environ.get(
"ENDPOINT_URL", "http://localhost:9000"
)
},
)
elif runtime_environment in ["dev", "staging", "prod"]:
logger.info(f"S3 settings for {runtime_environment}")
# Key/ token should be in session/lambda for this
self.client = s3fs.S3FileSystem()
else:
raise NotImplementedError("No correspnding runtime environment")
def determine_model_bucket(self, runtime_environment: str) -> None:
"""
For the given environment, return the correct bucket for models
"""
if runtime_environment == "local":
logger.info("In local development - no need for s3")
elif runtime_environment in ["local-mock", "dev"]:
# TODO: get from enironment
self.model_bucket = "retrofit-model-directory-dev"
elif runtime_environment in ["staging", "prod"]:
self.model_bucket = f"retrofit-model-directory-{runtime_environment}"
else:
raise NotImplementedError("No corresponding runtime environment")
def download_model(self, filepath: str, model_folder: str):
"""
For the file path, download the model locally so that we can load the model
"""
if self.client is None:
logger.info("No need to download model as local development")
else:
def list_files_recursively(folder_path, client):
all_files = []
for root, dirs, files in client.walk(folder_path):
for file in files:
s3_path = os.path.join(root, file)
all_files.append(s3_path)
return all_files
# List all files in the specified S3 folder and its subfolders
files = list_files_recursively(
f"{self.model_bucket}/{filepath}", client=self.client
)
# Download each file
for file in files:
# Extract the filename from the S3 path
filename = file.split(filepath)[-1]
# Define the local path where you want to save the file
local_path = os.path.join("local_model", filename)
# Download the file from S3 to the local directory
self.client.get(file, local_path)
print(f"Downloaded {filename} to {local_path}")
print("Download completed.")