fixed buffer bug and add id

This commit is contained in:
Michael Duong 2023-10-10 12:35:34 +01:00
parent 70b3008dc5
commit 57934d0ae3
4 changed files with 17 additions and 2 deletions

View file

@ -107,6 +107,7 @@ def handler(event, context):
predictions_column_name=generate_predictions_params[
"predictions_column_name"
],
identifier_column=generate_predictions_params["identifier_column"],
)
return {

View file

@ -43,6 +43,7 @@ default:
test_data_filepath: ./data/prepared_data/test.parquet
predictions_output_filepath: ./data/predictions/predictions.parquet
predictions_column_name: predictions
identifier_column: id
generate_metrics:
dataclient_type: local

View file

@ -142,9 +142,15 @@ class AWSS3Client:
buffer = BytesIO()
obj.to_parquet(buffer, index=False)
# Reset the buffer position to the beginning
buffer.seek(0)
bucket, key = location.strip("s3://").split("/", 1)
self.client.upload_fileobj(buffer, bucket, key)
# Close the buffer
buffer.close()
def _load_parquet(self, location: str, load_config: dict) -> pd.DataFrame:
"""
Load a parquet file

View file

@ -14,6 +14,7 @@ def generate_predictions(
test_data_filepath: str,
predictions_output_filepath: str,
predictions_column_name: str,
identifier_column: str = "id",
):
"""
For a given model, we generate prediction and evaluate this against the true target
@ -52,6 +53,12 @@ def generate_predictions(
predictions_df = pd.DataFrame(predictions)
predictions_df.columns = [predictions_column_name]
output_dataclient.save_data(
obj=predictions_df, location=predictions_output_filepath, save_config=None
output_df = (
pd.concat([test_data[identifier_column], predictions_df], axis=1)
if identifier_column in test_data.columns
else predictions_df
)
output_dataclient.save_data(
obj=output_df, location=predictions_output_filepath, save_config=None
)