""" Set up the client to be used for downloading and uploading model files """ import os import s3fs from core.Logger import logger class S3FSClient: """ Set up the correct client to upload files to s3 """ def __init__(self, runtime_environment: str = "local") -> None: self.client: s3fs.S3FileSystem | None = None self.model_bucket: str self.client_factory(runtime_environment) self.determine_model_bucket(runtime_environment) def client_factory(self, runtime_environment: str = "local"): """ Select the correct s3 client to use """ if runtime_environment == "local": logger.info("No S3 client setup required") elif runtime_environment == "local-mock": logger.info(f"S3 settings for {runtime_environment}") self.client = s3fs.S3FileSystem( key=os.environ.get("AWS_ACCESS_KEY_ID", "admin"), secret=os.environ.get("AWS_SECRET_ACCESS_KEY", "password"), client_kwargs={ "endpoint_url": os.environ.get( "ENDPOINT_URL", "http://localhost:9000" ) }, ) elif runtime_environment in ["dev", "staging", "prod"]: logger.info(f"S3 settings for {runtime_environment}") # Key/ token should be in session/lambda for this self.client = s3fs.S3FileSystem() else: raise NotImplementedError("No correspnding runtime environment") def determine_model_bucket(self, runtime_environment: str) -> None: """ For the given environment, return the correct bucket for models """ if runtime_environment == "local": logger.info("In local development - no need for s3") elif runtime_environment in ["local-mock", "dev"]: self.model_bucket = "retrofit-model-directory-dev" elif runtime_environment in ["staging", "prod"]: self.model_bucket = f"retrofit-model-directory-{runtime_environment}" else: raise NotImplementedError("No corresponding runtime environment") # def download_model(self, filepath: str, local_filepath: str = "."): # """ # For the file path, download the model locally so that we can load the model # """ # if local_filepath is None: # self.local_filepath = filepath