""" Set up the client to be used for downloading and uploading model files """ import os import s3fs import boto3 from core.Logger import logger class S3FSClient: """ Set up the correct client to upload files to s3 """ def __init__(self, runtime_environment: str = "local") -> None: self.client: s3fs.S3FileSystem | None = None self.model_bucket: str self.client_factory(runtime_environment) self.determine_model_bucket(runtime_environment) def client_factory(self, runtime_environment: str = "local"): """ Select the correct s3 client to use """ if runtime_environment == "local": logger.info("No S3 client setup required") elif runtime_environment == "local-mock": logger.info(f"S3 settings for {runtime_environment}") self.client = s3fs.S3FileSystem( key=os.environ.get("AWS_ACCESS_KEY_ID", "admin"), secret=os.environ.get("AWS_SECRET_ACCESS_KEY", "password"), client_kwargs={ "endpoint_url": os.environ.get( "ENDPOINT_URL", "http://localhost:9000" ) }, ) elif runtime_environment in ["dev", "staging", "prod"]: logger.info(f"S3 settings for {runtime_environment}") # Key/ token should be in session/lambda for this self.client = s3fs.S3FileSystem() else: raise NotImplementedError("No correspnding runtime environment") def determine_model_bucket(self, runtime_environment: str) -> None: """ For the given environment, return the correct bucket for models """ if runtime_environment == "local": logger.info("In local development - no need for s3") elif runtime_environment in ["local-mock", "dev"]: # TODO: get from enironment self.model_bucket = "retrofit-model-directory-dev" elif runtime_environment in ["staging", "prod"]: self.model_bucket = f"retrofit-model-directory-{runtime_environment}" else: raise NotImplementedError("No corresponding runtime environment") def download_model(self, filepath: str, model_folder: str): """ For the file path, download the model locally so that we can load the model """ if self.client is None: logger.info("No need to download model as local development") else: def list_files_recursively(folder_path, client): all_files = [] for root, dirs, files in client.walk(folder_path): for file in files: s3_path = os.path.join(root, file) all_files.append(s3_path) return all_files # List all files in the specified S3 folder and its subfolders files = list_files_recursively( f"{self.model_bucket}/{filepath}", client=self.client ) # Download each file for file in files: # Extract the filename from the S3 path filename = file.split(filepath)[-1] # Define the local path where you want to save the file local_path = os.path.join(model_folder, filename) # Download the file from S3 to the local directory self.client.get(file, local_path) print(f"Downloaded {filename} to {local_path}") print("Download completed.") class BotoClient: """ Using boto3 to access the different aws storage configurations """ def __init__(self, runtime_environment: str = "local") -> None: self.client = None self.model_bucket: str self.client_factory(runtime_environment) self.determine_model_bucket(runtime_environment) def client_factory(self, runtime_environment: str = "local"): """ Select the correct s3 client to use """ if runtime_environment == "local": logger.info("No S3 client setup required") elif runtime_environment == "local-mock": logger.info(f"S3 settings for {runtime_environment}") session = boto3.Session() self.client = session.client( service_name="s3", aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID", "admin"), aws_secret_access_key=os.environ.get( "AWS_SECRET_ACCESS_KEY", "password" ), endpoint_url=os.environ.get("ENDPOINT_URL", "http://localhost:9000"), ) elif runtime_environment in ["dev", "staging", "prod"]: logger.info(f"S3 settings for {runtime_environment}") # Key/ token should be in session/lambda for this self.client = boto3.client("s3") else: raise NotImplementedError("No correspnding runtime environment") def determine_model_bucket(self, runtime_environment: str) -> None: """ For the given environment, return the correct bucket for models """ if runtime_environment == "local": logger.info("In local development - no need for s3") elif runtime_environment in ["local-mock", "dev"]: # TODO: get from enironment self.model_bucket = "retrofit-model-directory-dev" elif runtime_environment in ["staging", "prod"]: self.model_bucket = f"retrofit-model-directory-{runtime_environment}" else: raise NotImplementedError("No corresponding runtime environment") def download_model(self, filepath: str, model_folder: str): """ For the file path, download the model locally so that we can load the model """ pass