From ff44212f9844d0de405401e28be2fe97934800e1 Mon Sep 17 00:00:00 2001 From: Khalim Conn-Kowlessar Date: Tue, 22 Oct 2024 11:17:57 +0100 Subject: [PATCH] making model api async so we can run from within fastapi endpoint and locally --- backend/ml_models/api.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/backend/ml_models/api.py b/backend/ml_models/api.py index 6e0a4162..c2f2dcd9 100644 --- a/backend/ml_models/api.py +++ b/backend/ml_models/api.py @@ -289,8 +289,7 @@ class ModelApi: logger.info("Lambda functions are warmed up and ready to go!") - def async_paginated_predictions(self, data, bucket, batch_size, model_prefixes=None, extract_ids=True): - + async def async_paginated_predictions(self, data, bucket, batch_size, model_prefixes=None, extract_ids=True): all_predictions = self.predictions_template() to_loop_over = range(0, data.shape[0], batch_size) @@ -306,7 +305,13 @@ class ModelApi: for key, scored in predictions_dict.items(): all_predictions[key] = pd.concat([all_predictions[key], scored]) - # Run the async function - asyncio.run(run_batches()) + # Check if there is an existing event loop + try: + # If there is an existing event loop, await the coroutine directly + loop = asyncio.get_running_loop() + await run_batches() + except RuntimeError: # No running event loop + # If no event loop is running, use asyncio.run() + asyncio.run(run_batches()) return all_predictions