Model/backend/ml_models/sap_change_model/api.py
2023-10-10 13:29:05 +08:00

83 lines
3 KiB
Python

import pandas as pd
import requests
from requests.exceptions import RequestException
from utils.logger import setup_logger
from utils.s3 import save_dataframe_to_s3_parquet
logger = setup_logger()
class SAPChangeModelAPI:
def __init__(
self,
portfolio_id,
timestamp,
base_url="https://api.dev.hestia.homes",
):
"""
property_id (int, optional): :
:param portfolio_id: The portfolio ID to be passed in the request payload. Defaults to 4.
:param timestamp: The creation timestamp to be passed in the request payload. Defaults to None.
:param base_url:
"""
self.base_url = base_url
self.portfolio_id = portfolio_id
self.timestamp = timestamp
def upload_scoring_data(self, df: pd.DataFrame, bucket: str) -> str:
"""
The sap model api needs a scoring data that is sitting in s3 to use as a dataset to score on
This method allows the user to upload a table as a parquet file. This method will return the file
location, which can be used as the file location in the predict() method
:param df: Pandas dataframe with scoring data to be uploaded to s3
:param bucket: Name of the bucket in s3 to upload to
:return:
"""
# Store parquet file in s3 for scoring
file_location = "sap_change_predictions/{portfolio_id}/{timestamp}.parquet".format(
portfolio_id=self.portfolio_id,
timestamp=self.timestamp
)
logger.info("Storing scoring data to s3")
save_dataframe_to_s3_parquet(
df=df,
bucket_name=bucket,
file_key=file_location
)
return file_location
def predict(self, file_location):
"""Makes a POST request to the SAP Change Model API with the provided parameters.
Args:
file_location (str): The file location to be passed in the request payload.
Returns:
dict: The API response as a dictionary if the request was successful, None otherwise.
"""
logger.info("Making request to sap change api")
url = f"{self.base_url}/sapmodel/predict"
payload = {
"file_location": file_location,
"property_id": "", # This should get removed
"portfolio_id": self.portfolio_id,
"created_at": self.timestamp
}
try:
response = requests.post(url, json=payload, headers={"Content-Type": "application/json"})
# Check if the response status code is 2xx (success)
response.raise_for_status()
# Return the JSON response as a Python dictionary
return response.json()
except RequestException as e:
logger.error(f"An error occurred: {e}")
# In case of an error, you might want to return None or raise the exception
# depending on how you want to handle errors in your application
return None