mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
handling concurrency issues:
This commit is contained in:
parent
0932a5b8d9
commit
7ac3833a7c
3 changed files with 21 additions and 23 deletions
|
|
@ -1,6 +1,9 @@
|
|||
import boto3
|
||||
import json
|
||||
import math
|
||||
import asyncio
|
||||
import random
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
|
@ -83,7 +86,6 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest):
|
|||
index_end = min((i + 1) * chunk_size, total_rows)
|
||||
|
||||
message_payload = {**data, "index_start": index_start, "index_end": index_end}
|
||||
|
||||
message_body = json.dumps(message_payload)
|
||||
|
||||
response = sqs_client.send_message(
|
||||
|
|
@ -91,7 +93,10 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest):
|
|||
MessageBody=message_body
|
||||
)
|
||||
logger.info(
|
||||
f"Chunk {i} sent to SQS. Rows {index_start}–{index_end}. Message ID: {response.get('MessageId')}")
|
||||
f"Chunk {i} sent to SQS. Rows {index_start}–{index_end}. Message ID: {response.get('MessageId')}"
|
||||
)
|
||||
|
||||
await asyncio.sleep(random.uniform(0.1, 0.5)) # Delay to reduce spike pressure
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error during Excel file handling: %s", e)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import random
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import pandas as pd
|
||||
|
|
@ -18,9 +19,6 @@ class ModelApi:
|
|||
"sap_change_predictions",
|
||||
"heat_demand_predictions",
|
||||
"carbon_change_predictions",
|
||||
# "lighting_cost_predictions",
|
||||
# "heating_cost_predictions",
|
||||
# "hot_water_cost_predictions",
|
||||
]
|
||||
|
||||
KWH_MODEL_PREFIXES = ["heating_kwh_predictions", "hotwater_kwh_predictions"]
|
||||
|
|
@ -31,9 +29,6 @@ class ModelApi:
|
|||
"carbon_change_predictions": "carbonmodel",
|
||||
"hotwater_kwh_predictions": "hotwaterkwhmodel",
|
||||
"heating_kwh_predictions": "heatingkwhmodel",
|
||||
# "lighting_cost_predictions": "lightingmodel",
|
||||
# "heating_cost_predictions": "heatingmodel",
|
||||
# "hot_water_cost_predictions": "hotwatermodel",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
|
|
@ -44,21 +39,12 @@ class ModelApi:
|
|||
base_url="https://api.dev.hestia.homes",
|
||||
max_retries=2,
|
||||
):
|
||||
"""
|
||||
This class handles the communication with the Model APIs. These models include SAP change, heat demain change
|
||||
and carbon change
|
||||
|
||||
property_id (int, optional): :
|
||||
:param portfolio_id: The portfolio ID to be passed in the request payload. Defaults to 4.
|
||||
:param timestamp: The creation timestamp to be passed in the request payload. Defaults to None.
|
||||
:param base_url:
|
||||
"""
|
||||
self.base_url = base_url
|
||||
self.portfolio_id = portfolio_id
|
||||
self.timestamp = timestamp
|
||||
self.prediction_buckets = prediction_buckets
|
||||
self.max_retries = max_retries
|
||||
self.semaphore = asyncio.Semaphore(5)
|
||||
self.semaphore = asyncio.Semaphore(2)
|
||||
|
||||
@staticmethod
|
||||
def get_aiohttp_session():
|
||||
|
|
@ -131,6 +117,7 @@ class ModelApi:
|
|||
}
|
||||
|
||||
async with self.semaphore:
|
||||
await asyncio.sleep(random.uniform(0.3, 1.2))
|
||||
try:
|
||||
async with session.post(url, json=payload, headers=headers, timeout=120) as response:
|
||||
if response.status != 200:
|
||||
|
|
@ -215,18 +202,21 @@ class ModelApi:
|
|||
async def predict_all_async(self, df, bucket, model_prefixes=None, extract_ids=True) -> dict:
|
||||
model_prefixes = self.MODEL_PREFIXES if model_prefixes is None else model_prefixes
|
||||
predictions = {}
|
||||
tasks = []
|
||||
|
||||
session = self.get_aiohttp_session()
|
||||
|
||||
for model_prefix in model_prefixes:
|
||||
async def run_model(model_prefix):
|
||||
logger.info(f"Scoring for model prefix: {model_prefix}")
|
||||
file_location = self.upload_scoring_data(df, bucket, model_prefix)
|
||||
tasks.append(self.predict_async(f"s3://{bucket}/" + file_location, model_prefix, session=session))
|
||||
response = await self.predict_async(f"s3://{bucket}/" + file_location, model_prefix, session=session)
|
||||
return model_prefix, response
|
||||
|
||||
responses = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
results = []
|
||||
for coro in asyncio.as_completed([run_model(mp) for mp in model_prefixes]):
|
||||
result = await coro
|
||||
results.append(result)
|
||||
|
||||
for model_prefix, response in zip(model_prefixes, responses):
|
||||
for model_prefix, response in results:
|
||||
if response:
|
||||
predictions_bucket = self.prediction_buckets[model_prefix]
|
||||
predictions_df = pd.DataFrame(
|
||||
|
|
|
|||
|
|
@ -62,10 +62,13 @@ functions:
|
|||
timeout: 900
|
||||
memorySize: 3008
|
||||
role: EngineLambdaRole
|
||||
reservedConcurrency: 5
|
||||
events:
|
||||
- sqs:
|
||||
arn: arn:aws:sqs:${self:provider.region}:${aws:accountId}:model-engine-queue
|
||||
batchSize: 1
|
||||
maximumConcurrency: 2
|
||||
|
||||
|
||||
resources:
|
||||
Resources:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue