mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
65 lines
2.4 KiB
Python
65 lines
2.4 KiB
Python
"""
|
|
Set up the client to be used for downloading and uploading model files
|
|
"""
|
|
|
|
import os
|
|
import s3fs
|
|
from core.Logger import logger
|
|
|
|
|
|
class S3FSClient:
|
|
"""
|
|
Set up the correct client to upload files to s3
|
|
"""
|
|
|
|
def __init__(self, runtime_environment: str = "local") -> None:
|
|
self.client: s3fs.S3FileSystem | None = None
|
|
self.model_bucket: str
|
|
|
|
self.client_factory(runtime_environment)
|
|
self.determine_model_bucket(runtime_environment)
|
|
|
|
def client_factory(self, runtime_environment: str = "local"):
|
|
"""
|
|
Select the correct s3 client to use
|
|
"""
|
|
|
|
if runtime_environment == "local":
|
|
logger.info("No S3 client setup required")
|
|
elif runtime_environment == "local-mock":
|
|
logger.info(f"S3 settings for {runtime_environment}")
|
|
self.client = s3fs.S3FileSystem(
|
|
key=os.environ.get("AWS_ACCESS_KEY_ID", "admin"),
|
|
secret=os.environ.get("AWS_SECRET_ACCESS_KEY", "password"),
|
|
client_kwargs={
|
|
"endpoint_url": os.environ.get(
|
|
"ENDPOINT_URL", "http://localhost:9000"
|
|
)
|
|
},
|
|
)
|
|
elif runtime_environment in ["dev", "staging", "prod"]:
|
|
logger.info(f"S3 settings for {runtime_environment}")
|
|
# Key/ token should be in session/lambda for this
|
|
self.client = s3fs.S3FileSystem()
|
|
else:
|
|
raise NotImplementedError("No correspnding runtime environment")
|
|
|
|
def determine_model_bucket(self, runtime_environment: str) -> None:
|
|
"""
|
|
For the given environment, return the correct bucket for models
|
|
"""
|
|
if runtime_environment == "local":
|
|
logger.info("In local development - no need for s3")
|
|
elif runtime_environment in ["local-mock", "dev"]:
|
|
self.model_bucket = "retrofit-model-directory-dev"
|
|
elif runtime_environment in ["staging", "prod"]:
|
|
self.model_bucket = f"retrofit-model-directory-{runtime_environment}"
|
|
else:
|
|
raise NotImplementedError("No corresponding runtime environment")
|
|
|
|
# def download_model(self, filepath: str, local_filepath: str = "."):
|
|
# """
|
|
# For the file path, download the model locally so that we can load the model
|
|
# """
|
|
# if local_filepath is None:
|
|
# self.local_filepath = filepath
|