pull configs to top

This commit is contained in:
Michael Duong 2023-09-29 10:00:56 +00:00
parent b23c395320
commit 240db65698
4 changed files with 159 additions and 94 deletions

View file

@ -16,6 +16,10 @@ from core.Logger import logger
from core.DataClient import dataclient_factory
from core.FeatureProcessor import feature_processor_factory
logger.info("----------------------------")
logger.info(f"--- Initiate Parameters ---")
logger.info("----------------------------")
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "local")
client_path = Path(__file__).parent / "configs" / "client.yaml"
@ -27,6 +31,36 @@ prepare_data_params = yaml.safe_load(open(prepare_data_path))
feature_process_path = Path(__file__).parent / "configs" / "feature_processor.yaml"
feature_process_params = yaml.safe_load(open(feature_process_path))
data_filepath = prepare_data_params["data_filepath"]
train_proportion = prepare_data_params["train_proportion"]
output_train_filepath = prepare_data_params["output_train_filepath"]
output_test_filepath = prepare_data_params["output_test_filepath"]
feature_processor_config = feature_process_params["feature_processor_config"]
logger.info("----------------------------")
logger.info(f"--- Initiate DataClient ---")
logger.info("----------------------------")
input_dataclient_type = prepare_data_params["input_dataclient_type"]
output_dataclient_type = prepare_data_params["output_dataclient_type"]
input_dataclient = dataclient_factory(
dataclient_type=input_dataclient_type,
dataclient_config=client_params[input_dataclient_type],
)
output_dataclient = dataclient_factory(
dataclient_type=output_dataclient_type,
dataclient_config=client_params[output_dataclient_type],
)
logger.info("----------------------------------")
logger.info(f"--- Initiate FeatureProcessor ---")
logger.info("----------------------------------")
feature_processor = feature_processor_factory(
feature_process_params["feature_processor_type"]
)
def prepare_data(
input_dataclient: DataClient,
@ -100,30 +134,6 @@ if __name__ == "__main__":
logger.info(f"--- {__file__} - Start! ---")
logger.info("----------------------------")
logger.info("----------------------------")
logger.info(f"--- Initiate DataClient ---")
logger.info("----------------------------")
input_dataclient_type = prepare_data_params["input_dataclient_type"]
output_dataclient_type = prepare_data_params["output_dataclient_type"]
input_dataclient = dataclient_factory(
dataclient_type=input_dataclient_type,
dataclient_config=client_params[input_dataclient_type],
)
output_dataclient = dataclient_factory(
dataclient_type=output_dataclient_type,
dataclient_config=client_params[output_dataclient_type],
)
logger.info("----------------------------------")
logger.info(f"--- Initiate FeatureProcessor ---")
logger.info("----------------------------------")
feature_processor = feature_processor_factory(
feature_process_params["feature_processor_type"]
)
logger.info("---------------------------")
logger.info(f"--- Prepare Data Stage ---")
logger.info("---------------------------")
@ -132,11 +142,11 @@ if __name__ == "__main__":
input_dataclient=input_dataclient,
output_dataclient=output_dataclient,
feature_processor=feature_processor,
data_filepath=prepare_data_params["data_filepath"],
train_proportion=prepare_data_params["train_proportion"],
output_train_filepath=prepare_data_params["output_train_filepath"],
output_test_filepath=prepare_data_params["output_test_filepath"],
feature_processor_config=feature_process_params["feature_processor_config"],
data_filepath=data_filepath,
train_proportion=train_proportion,
output_train_filepath=output_train_filepath,
output_test_filepath=output_test_filepath,
feature_processor_config=feature_processor_config,
business_logic=business_logic,
new_feature_funcs=new_feature_funcs,
)

View file

@ -17,6 +17,9 @@ from core.MLModels import model_factory
from core.MLMetrics import metrics_factory
from configs.post_prediction_logic import post_prediction_logic
logger.info("----------------------------")
logger.info(f"--- Initiate Parameters ---")
logger.info("----------------------------")
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "local")
@ -32,6 +35,33 @@ feature_process_params = yaml.safe_load(open(feature_process_path))
generate_metrics_path = Path(__file__).parent / "configs" / "generate_metrics.yaml"
generate_metrics_params = yaml.safe_load(open(generate_metrics_path))
model_type = build_model_params["model_type"]
target = feature_process_params["feature_processor_config"]["target"]
model_save_location = build_model_params["model_save_filepath"]
model_hyperparameters = build_model_params[model_type]
train_filepath = prepare_data_params["output_train_filepath"]
test_filepath = prepare_data_params["output_test_filepath"]
fit_metrics_filepath = build_model_params["fit_metrics_filepath"]
logger.info("----------------------------")
logger.info(f"--- Initiate DataClient ---")
logger.info("----------------------------")
# Output of previous prepare data step, will be where the data is
dataclient = dataclient_factory(prepare_data_params["output_dataclient_type"])
logger.info("-------------------------")
logger.info(f"--- Initiate MLModel ---")
logger.info("-------------------------")
model = model_factory(model_type)
logger.info("-------------------------")
logger.info(f"--- Initiate Metrics ---")
logger.info("-------------------------")
metrics = metrics_factory(generate_metrics_params["metrics_type"])
def build_model(
dataclient: DataClient,
@ -109,26 +139,6 @@ if __name__ == "__main__":
logger.info(f"--- {__file__} - Start! ---")
logger.info("----------------------------")
logger.info("----------------------------")
logger.info(f"--- Initiate DataClient ---")
logger.info("----------------------------")
# Output of previous prepare data step, will be where the data is
dataclient = dataclient_factory(prepare_data_params["output_dataclient_type"])
logger.info("-------------------------")
logger.info(f"--- Initiate MLModel ---")
logger.info("-------------------------")
model_type = build_model_params["model_type"]
model = model_factory(model_type)
logger.info("-------------------------")
logger.info(f"--- Initiate Metrics ---")
logger.info("-------------------------")
metrics = metrics_factory(generate_metrics_params["metrics_type"])
logger.info("--------------------------")
logger.info(f"--- Build Model Stage ---")
logger.info("--------------------------")
@ -137,12 +147,12 @@ if __name__ == "__main__":
dataclient=dataclient,
model=model,
metrics=metrics,
target=feature_process_params["feature_processor_config"]["target"],
model_save_location=build_model_params["model_save_filepath"],
target=target,
model_save_location=model_save_location,
model_hyperparameters=build_model_params[model_type],
train_filepath=prepare_data_params["output_train_filepath"],
test_filepath=prepare_data_params["output_test_filepath"],
fit_metrics_filepath=build_model_params["fit_metrics_filepath"],
train_filepath=model_hyperparameters,
test_filepath=test_filepath,
fit_metrics_filepath=fit_metrics_filepath,
)
logger.info("-------------------------------")

View file

@ -15,6 +15,10 @@ from core.Logger import logger
from configs.post_prediction_logic import post_prediction_logic
logger.info("----------------------------")
logger.info(f"--- Initiate Parameters ---")
logger.info("----------------------------")
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "local")
client_path = Path(__file__).parent / "configs" / "client.yaml"
@ -34,6 +38,37 @@ generate_predictions_params = yaml.safe_load(open(generate_predictions_path))
feature_process_path = Path(__file__).parent / "configs" / "feature_processor.yaml"
feature_process_params = yaml.safe_load(open(feature_process_path))
target = feature_process_params["feature_processor_config"]["target"]
model_filepath = build_model_params["model_save_filepath"]
test_data_filepath = generate_predictions_params["test_data_filepath"]
predictions_output_filepath = generate_predictions_params["predictions_output_filepath"]
predictions_column_name = generate_predictions_params["predictions_column_name"]
logger.info("-------------------------")
logger.info(f"--- Initiate MLModel ---")
logger.info("-------------------------")
model = model_factory(build_model_params["model_type"])
logger.info("----------------------------")
logger.info(f"--- Initiate DataClient ---")
logger.info("----------------------------")
# We may have different locations of loading hence why we use one specified in generate_predictions.yaml
# I.e. for metric runs, this will be a local data client
# For predictions, we will want a cloud data client
input_dataclient_type = generate_predictions_params["input_dataclient_type"]
input_dataclient = dataclient_factory(
dataclient_type=input_dataclient_type,
dataclient_config=client_params[input_dataclient_type],
)
output_dataclient_type = generate_predictions_params["output_dataclient_type"]
output_dataclient = dataclient_factory(
dataclient_type=output_dataclient_type,
dataclient_config=client_params[output_dataclient_type],
)
def generate_predictions(
input_dataclient: DataClient,
@ -93,34 +128,19 @@ if __name__ == "__main__":
logger.info(f"--- {__file__} - Start! ---")
logger.info("----------------------------")
model = model_factory(build_model_params["model_type"])
# We may have different locations of loading hence why we use one specified in generate_predictions.yaml
# I.e. for metric runs, this will be a local data client
# For predictions, we will want a cloud data client
input_dataclient_type = generate_predictions_params["input_dataclient_type"]
input_dataclient = dataclient_factory(
dataclient_type=input_dataclient_type,
dataclient_config=client_params[input_dataclient_type],
)
output_dataclient_type = generate_predictions_params["output_dataclient_type"]
output_dataclient = dataclient_factory(
dataclient_type=output_dataclient_type,
dataclient_config=client_params[output_dataclient_type],
)
logger.info("----------------------------------")
logger.info(f"--- Generate Predictions Stage---")
logger.info("----------------------------------")
generate_predictions(
input_dataclient=input_dataclient,
output_dataclient=output_dataclient,
model=model,
target=feature_process_params["feature_processor_config"]["target"],
model_filepath=build_model_params["model_save_filepath"],
test_data_filepath=generate_predictions_params["test_data_filepath"],
predictions_output_filepath=generate_predictions_params[
"predictions_output_filepath"
],
predictions_column_name=generate_predictions_params["predictions_column_name"],
target=target,
model_filepath=model_filepath,
test_data_filepath=test_data_filepath,
predictions_output_filepath=predictions_output_filepath,
predictions_column_name=predictions_column_name,
)
logger.info("-------------------------------")

View file

@ -16,6 +16,10 @@ from core.MLMetrics import metrics_factory
from core.Logger import logger
logger.info("----------------------------")
logger.info(f"--- Initiate Parameters ---")
logger.info("----------------------------")
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "local")
client_path = Path(__file__).parent / "configs" / "client.yaml"
@ -38,6 +42,36 @@ generate_metrics_params = yaml.safe_load(open(generate_metrics_path))
feature_process_path = Path(__file__).parent / "configs" / "feature_processor.yaml"
feature_process_params = yaml.safe_load(open(feature_process_path))
target = (feature_process_params["feature_processor_config"]["target"],)
test_data_filepath = generate_predictions_params["test_data_filepath"]
predictions_output_filepath = generate_predictions_params["predictions_output_filepath"]
predictions_column_name = generate_predictions_params["predictions_column_name"]
metrics_output_filepath = generate_metrics_params["metrics_output_filepath"]
logger.info("-------------------------")
logger.info(f"--- Initiate MLModel ---")
logger.info("-------------------------")
model = model_factory(build_model_params["model_type"])
logger.info("----------------------------")
logger.info(f"--- Initiate DataClient ---")
logger.info("----------------------------")
# Use data client for input and output, as we use dvc to cache later to the cloud
dataclient_type = generate_metrics_params["dataclient_type"]
dataclient = dataclient_factory(
dataclient_type=dataclient_type,
dataclient_config=client_params[dataclient_type],
)
logger.info("---------------------------")
logger.info(f"--- Initiate MLMetrics ---")
logger.info("---------------------------")
metrics = metrics_factory(generate_metrics_params["metrics_type"])
def generate_metrics(
input_dataclient: DataClient,
@ -94,29 +128,20 @@ if __name__ == "__main__":
logger.info(f"--- {__file__} - Start! ---")
logger.info("----------------------------")
model = model_factory(build_model_params["model_type"])
# Use data client for input and output, as we use dvc to cache later to the cloud
dataclient_type = generate_metrics_params["dataclient_type"]
dataclient = dataclient_factory(
dataclient_type=dataclient_type,
dataclient_config=client_params[dataclient_type],
)
metrics = metrics_factory(generate_metrics_params["metrics_type"])
logger.info("------------------------------")
logger.info(f"--- Generate Metrics Stage---")
logger.info("------------------------------")
generate_metrics(
input_dataclient=dataclient,
output_dataclient=dataclient,
model=model,
metrics=metrics,
target=feature_process_params["feature_processor_config"]["target"],
test_data_filepath=generate_predictions_params["test_data_filepath"],
predictions_output_filepath=generate_predictions_params[
"predictions_output_filepath"
],
predictions_column_name=generate_predictions_params["predictions_column_name"],
metrics_output_filepath=generate_metrics_params["metrics_output_filepath"],
target=target,
test_data_filepath=test_data_filepath,
predictions_output_filepath=predictions_output_filepath,
predictions_column_name=predictions_column_name,
metrics_output_filepath=metrics_output_filepath,
)
logger.info("-------------------------------")