""" Set up the client to be used for downloading and uploading model files """ import os import boto3 from core.Logger import logger # class S3FSClient: # """ # Set up the correct client to upload files to s3 # """ # def __init__(self, runtime_environment: str = "local") -> None: # self.client: s3fs.S3FileSystem | None = None # self.model_bucket: str # self.client_factory(runtime_environment) # self.determine_model_bucket(runtime_environment) # def client_factory(self, runtime_environment: str = "local"): # """ # Select the correct s3 client to use # """ # if runtime_environment == "local": # logger.info("No S3 client setup required") # elif runtime_environment == "local-mock": # logger.info(f"S3 settings for {runtime_environment}") # self.client = s3fs.S3FileSystem( # key=os.environ.get("AWS_ACCESS_KEY_ID", "admin"), # secret=os.environ.get("AWS_SECRET_ACCESS_KEY", "password"), # client_kwargs={ # "endpoint_url": os.environ.get( # "ENDPOINT_URL", "http://localhost:9000" # ) # }, # ) # elif runtime_environment in ["dev", "staging", "prod"]: # logger.info(f"S3 settings for {runtime_environment}") # # Key/ token should be in session/lambda for this # self.client = s3fs.S3FileSystem() # else: # raise NotImplementedError("No correspnding runtime environment") # def determine_model_bucket(self, runtime_environment: str) -> None: # """ # For the given environment, return the correct bucket for models # """ # if runtime_environment == "local": # logger.info("In local development - no need for s3") # elif runtime_environment in ["local-mock", "dev"]: # # TODO: get from enironment # self.model_bucket = "retrofit-model-directory-dev" # elif runtime_environment in ["staging", "prod"]: # self.model_bucket = f"retrofit-model-directory-{runtime_environment}" # else: # raise NotImplementedError("No corresponding runtime environment") # def download_model(self, filepath: str, model_folder: str): # """ # For the file path, download the model locally so that we can load the model # """ # if self.client is None: # logger.info("No need to download model as local development") # else: # def list_files_recursively(folder_path, client): # all_files = [] # for root, dirs, files in client.walk(folder_path): # for file in files: # s3_path = os.path.join(root, file) # all_files.append(s3_path) # return all_files # # List all files in the specified S3 folder and its subfolders # files = list_files_recursively( # f"{self.model_bucket}/{filepath}", client=self.client # ) # # Download each file # for file in files: # # Extract the filename from the S3 path # filename = file.split(filepath)[-1] # # Define the local path where you want to save the file # local_path = os.path.join(model_folder, filename) # # Download the file from S3 to the local directory # self.client.get(file, local_path) # print(f"Downloaded {filename} to {local_path}") # print("Download completed.") class BotoClient: """ Using boto3 to access the different aws storage configurations """ def __init__(self, runtime_environment: str = "local") -> None: self.client = None self.model_bucket: str self.client_factory(runtime_environment) self.determine_model_bucket(runtime_environment) def client_factory(self, runtime_environment: str = "local"): """ Select the correct s3 client to use """ if runtime_environment == "local": logger.info("No S3 client setup required") elif runtime_environment == "local-mock": logger.info(f"S3 settings for {runtime_environment}") session = boto3.Session() self.client = session.client( service_name="s3", aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID", "admin"), aws_secret_access_key=os.environ.get( "AWS_SECRET_ACCESS_KEY", "password" ), endpoint_url=os.environ.get("ENDPOINT_URL", "http://localhost:9000"), ) elif runtime_environment in ["dev", "staging", "prod"]: logger.info(f"S3 settings for {runtime_environment}") # Key/ token should be in session/lambda for this self.client = boto3.client("s3") else: raise NotImplementedError("No correspnding runtime environment") def determine_model_bucket(self, runtime_environment: str) -> None: """ For the given environment, return the correct bucket for models """ if runtime_environment == "local": logger.info("In local development - no need for s3") elif runtime_environment in ["local-mock", "dev"]: # TODO: get from enironment self.model_bucket = "retrofit-model-directory-dev" elif runtime_environment in ["staging", "prod"]: self.model_bucket = f"retrofit-model-directory-{runtime_environment}" else: raise NotImplementedError("No corresponding runtime environment") def download_model(self, filepath: str, model_folder: str): """ For the file path, download the model locally so that we can load the model """ # List all objects with the specified prefix in the bucket if self.client is None: raise ValueError("SHould not be in here!") objects = self.client.list_objects_v2(Bucket=self.model_bucket, Prefix=filepath) # Ensure the local directory for downloads exists if not os.path.exists(model_folder): os.makedirs(model_folder) # Download each object with the specified prefix for obj in objects.get("Contents", []): # Get the object key (file path) object_key = obj["Key"] # Determine the local file path to save the object local_file_path = os.path.join( model_folder, object_key.split(f"{filepath}/")[-1] ) # Create the local directory if it doesn't exist local_directory = os.path.dirname(local_file_path) if not os.path.exists(local_directory): os.makedirs(local_directory) # Download the object from S3 to the local directory self.client.download_file(self.model_bucket, object_key, local_file_path) print(f"Downloaded {object_key} to {local_file_path}") print("Download completed.")