mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
cleaned up docker file for local testing
This commit is contained in:
parent
9445c65701
commit
69a5e98619
10 changed files with 102 additions and 99 deletions
|
|
@ -6,6 +6,7 @@ ARG GID=100
|
|||
|
||||
# Install patches
|
||||
RUN apt-get update && apt-get upgrade -y \
|
||||
&& apt-get install libgomp1 -y \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ ARG GID=100
|
|||
|
||||
# Install patches
|
||||
RUN apt-get update && apt-get upgrade -y \
|
||||
&& apt-get install libgomp1 -y \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists
|
||||
|
||||
|
|
|
|||
|
|
@ -3,97 +3,96 @@ Set up the client to be used for downloading and uploading model files
|
|||
"""
|
||||
|
||||
import os
|
||||
import s3fs
|
||||
import boto3
|
||||
from core.Logger import logger
|
||||
|
||||
|
||||
class S3FSClient:
|
||||
"""
|
||||
Set up the correct client to upload files to s3
|
||||
"""
|
||||
# class S3FSClient:
|
||||
# """
|
||||
# Set up the correct client to upload files to s3
|
||||
# """
|
||||
|
||||
def __init__(self, runtime_environment: str = "local") -> None:
|
||||
self.client: s3fs.S3FileSystem | None = None
|
||||
self.model_bucket: str
|
||||
# def __init__(self, runtime_environment: str = "local") -> None:
|
||||
# self.client: s3fs.S3FileSystem | None = None
|
||||
# self.model_bucket: str
|
||||
|
||||
self.client_factory(runtime_environment)
|
||||
self.determine_model_bucket(runtime_environment)
|
||||
# self.client_factory(runtime_environment)
|
||||
# self.determine_model_bucket(runtime_environment)
|
||||
|
||||
def client_factory(self, runtime_environment: str = "local"):
|
||||
"""
|
||||
Select the correct s3 client to use
|
||||
"""
|
||||
# def client_factory(self, runtime_environment: str = "local"):
|
||||
# """
|
||||
# Select the correct s3 client to use
|
||||
# """
|
||||
|
||||
if runtime_environment == "local":
|
||||
logger.info("No S3 client setup required")
|
||||
elif runtime_environment == "local-mock":
|
||||
logger.info(f"S3 settings for {runtime_environment}")
|
||||
self.client = s3fs.S3FileSystem(
|
||||
key=os.environ.get("AWS_ACCESS_KEY_ID", "admin"),
|
||||
secret=os.environ.get("AWS_SECRET_ACCESS_KEY", "password"),
|
||||
client_kwargs={
|
||||
"endpoint_url": os.environ.get(
|
||||
"ENDPOINT_URL", "http://localhost:9000"
|
||||
)
|
||||
},
|
||||
)
|
||||
elif runtime_environment in ["dev", "staging", "prod"]:
|
||||
logger.info(f"S3 settings for {runtime_environment}")
|
||||
# Key/ token should be in session/lambda for this
|
||||
self.client = s3fs.S3FileSystem()
|
||||
else:
|
||||
raise NotImplementedError("No correspnding runtime environment")
|
||||
# if runtime_environment == "local":
|
||||
# logger.info("No S3 client setup required")
|
||||
# elif runtime_environment == "local-mock":
|
||||
# logger.info(f"S3 settings for {runtime_environment}")
|
||||
# self.client = s3fs.S3FileSystem(
|
||||
# key=os.environ.get("AWS_ACCESS_KEY_ID", "admin"),
|
||||
# secret=os.environ.get("AWS_SECRET_ACCESS_KEY", "password"),
|
||||
# client_kwargs={
|
||||
# "endpoint_url": os.environ.get(
|
||||
# "ENDPOINT_URL", "http://localhost:9000"
|
||||
# )
|
||||
# },
|
||||
# )
|
||||
# elif runtime_environment in ["dev", "staging", "prod"]:
|
||||
# logger.info(f"S3 settings for {runtime_environment}")
|
||||
# # Key/ token should be in session/lambda for this
|
||||
# self.client = s3fs.S3FileSystem()
|
||||
# else:
|
||||
# raise NotImplementedError("No correspnding runtime environment")
|
||||
|
||||
def determine_model_bucket(self, runtime_environment: str) -> None:
|
||||
"""
|
||||
For the given environment, return the correct bucket for models
|
||||
"""
|
||||
if runtime_environment == "local":
|
||||
logger.info("In local development - no need for s3")
|
||||
elif runtime_environment in ["local-mock", "dev"]:
|
||||
# TODO: get from enironment
|
||||
self.model_bucket = "retrofit-model-directory-dev"
|
||||
elif runtime_environment in ["staging", "prod"]:
|
||||
self.model_bucket = f"retrofit-model-directory-{runtime_environment}"
|
||||
else:
|
||||
raise NotImplementedError("No corresponding runtime environment")
|
||||
# def determine_model_bucket(self, runtime_environment: str) -> None:
|
||||
# """
|
||||
# For the given environment, return the correct bucket for models
|
||||
# """
|
||||
# if runtime_environment == "local":
|
||||
# logger.info("In local development - no need for s3")
|
||||
# elif runtime_environment in ["local-mock", "dev"]:
|
||||
# # TODO: get from enironment
|
||||
# self.model_bucket = "retrofit-model-directory-dev"
|
||||
# elif runtime_environment in ["staging", "prod"]:
|
||||
# self.model_bucket = f"retrofit-model-directory-{runtime_environment}"
|
||||
# else:
|
||||
# raise NotImplementedError("No corresponding runtime environment")
|
||||
|
||||
def download_model(self, filepath: str, model_folder: str):
|
||||
"""
|
||||
For the file path, download the model locally so that we can load the model
|
||||
"""
|
||||
# def download_model(self, filepath: str, model_folder: str):
|
||||
# """
|
||||
# For the file path, download the model locally so that we can load the model
|
||||
# """
|
||||
|
||||
if self.client is None:
|
||||
logger.info("No need to download model as local development")
|
||||
else:
|
||||
# if self.client is None:
|
||||
# logger.info("No need to download model as local development")
|
||||
# else:
|
||||
|
||||
def list_files_recursively(folder_path, client):
|
||||
all_files = []
|
||||
for root, dirs, files in client.walk(folder_path):
|
||||
for file in files:
|
||||
s3_path = os.path.join(root, file)
|
||||
all_files.append(s3_path)
|
||||
return all_files
|
||||
# def list_files_recursively(folder_path, client):
|
||||
# all_files = []
|
||||
# for root, dirs, files in client.walk(folder_path):
|
||||
# for file in files:
|
||||
# s3_path = os.path.join(root, file)
|
||||
# all_files.append(s3_path)
|
||||
# return all_files
|
||||
|
||||
# List all files in the specified S3 folder and its subfolders
|
||||
files = list_files_recursively(
|
||||
f"{self.model_bucket}/{filepath}", client=self.client
|
||||
)
|
||||
# # List all files in the specified S3 folder and its subfolders
|
||||
# files = list_files_recursively(
|
||||
# f"{self.model_bucket}/{filepath}", client=self.client
|
||||
# )
|
||||
|
||||
# Download each file
|
||||
for file in files:
|
||||
# Extract the filename from the S3 path
|
||||
filename = file.split(filepath)[-1]
|
||||
# # Download each file
|
||||
# for file in files:
|
||||
# # Extract the filename from the S3 path
|
||||
# filename = file.split(filepath)[-1]
|
||||
|
||||
# Define the local path where you want to save the file
|
||||
local_path = os.path.join(model_folder, filename)
|
||||
# # Define the local path where you want to save the file
|
||||
# local_path = os.path.join(model_folder, filename)
|
||||
|
||||
# Download the file from S3 to the local directory
|
||||
self.client.get(file, local_path)
|
||||
print(f"Downloaded {filename} to {local_path}")
|
||||
# # Download the file from S3 to the local directory
|
||||
# self.client.get(file, local_path)
|
||||
# print(f"Downloaded {filename} to {local_path}")
|
||||
|
||||
print("Download completed.")
|
||||
# print("Download completed.")
|
||||
|
||||
|
||||
class BotoClient:
|
||||
|
|
|
|||
|
|
@ -18,20 +18,20 @@ services:
|
|||
timeout: 20s
|
||||
retries: 3
|
||||
|
||||
# simulation_system_training:
|
||||
# build:
|
||||
# context: ./
|
||||
# dockerfile: ./Dockerfiles/Dockerfile.training
|
||||
# image: simulation_system_training
|
||||
# environment:
|
||||
# RUNTIME_ENVIRONMENT: local-mock
|
||||
# ENDPOINT_URL: http://minio:9000/
|
||||
# AWS_ACCESS_KEY_ID: *MINIO_USER
|
||||
# AWS_SECRET_ACCESS_KEY: *MINIO_PASS
|
||||
# tty: true
|
||||
# depends_on:
|
||||
# minio:
|
||||
# condition: service_healthy
|
||||
simulation_system_training:
|
||||
build:
|
||||
context: ./
|
||||
dockerfile: ./Dockerfiles/Dockerfile.training
|
||||
image: simulation_system_training
|
||||
environment:
|
||||
RUNTIME_ENVIRONMENT: local-mock
|
||||
ENDPOINT_URL: http://minio:9000/
|
||||
AWS_ACCESS_KEY_ID: *MINIO_USER
|
||||
AWS_SECRET_ACCESS_KEY: *MINIO_PASS
|
||||
tty: true
|
||||
depends_on:
|
||||
minio:
|
||||
condition: service_healthy
|
||||
# command:
|
||||
# ["bash"]
|
||||
|
||||
|
|
@ -41,14 +41,15 @@ services:
|
|||
dockerfile: ./Dockerfiles/Dockerfile.prediction
|
||||
image: simulation_system_prediction
|
||||
environment:
|
||||
RUNTIME_ENVIRONMENT: local-mock
|
||||
ENDPOINT_URL: http://minio:9000/
|
||||
AWS_ACCESS_KEY_ID: *MINIO_USER
|
||||
AWS_SECRET_ACCESS_KEY: *MINIO_PASS
|
||||
tty: true
|
||||
# depends_on:
|
||||
# simulation_system_training:
|
||||
# condition: service_completed_successfully
|
||||
# command:
|
||||
depends_on:
|
||||
simulation_system_training:
|
||||
condition: service_completed_successfully
|
||||
# command:
|
||||
# ["bash"]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -155,6 +155,8 @@ def prediction(
|
|||
logger.info("--- Generating Predictions ---")
|
||||
prediction = model.generate_predictions(data=data)
|
||||
|
||||
# logger.info(pd.concat([data["id"], prediction], axis=1))
|
||||
|
||||
return pd.concat([data["id"], prediction], axis=1)
|
||||
|
||||
# Save prediction some where?
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
boto3
|
||||
autogluon==0.8.2
|
||||
pandas==1.5.3
|
||||
s3fs==2023.6.0
|
||||
seaborn==0.12.2
|
||||
matplotlib==3.7.2
|
||||
pre-commit==3.3.3
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
boto3
|
||||
autogluon==0.8.2
|
||||
pandas==1.5.3
|
||||
s3fs
|
||||
seaborn==0.12.2
|
||||
matplotlib==3.7.2
|
||||
matplotlib==3.7.2
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
autogluon==0.8.2
|
||||
pandas==1.5.3
|
||||
seaborn==0.12.2
|
||||
s3fs==2023.6.0
|
||||
pre-commit==3.3.3
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
autogluon==0.8.2
|
||||
pandas==1.5.3
|
||||
seaborn==0.12.2
|
||||
s3fs==2023.6.0
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from core.Logger import logger
|
|||
from core.Metrics import Metrics, sort_by_metric
|
||||
from core.DataLoader import dataloader_factory
|
||||
from core.FeatureProcessor import FeatureProcessor
|
||||
from core.CloudClient import S3FSClient, BotoClient
|
||||
from core.CloudClient import BotoClient
|
||||
from core.RegistryHandler import RegistryHandler
|
||||
from core.Settings import (
|
||||
MODEL_DIRECTORY,
|
||||
|
|
@ -28,7 +28,7 @@ from core.Settings import (
|
|||
|
||||
TIMESTAMP = datetime.now().strftime(TIMESTAMP_FORMAT)
|
||||
|
||||
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "local")
|
||||
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "local-mock")
|
||||
|
||||
CLIENT = BotoClient(runtime_environment=RUNTIME_ENVIRONMENT)
|
||||
# CLIENT = S3FSClient(runtime_environment=RUNTIME_ENVIRONMENT)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue