From cb155f1cacddbb75f813885c1bff0e869167f800 Mon Sep 17 00:00:00 2001 From: Khalim Conn-Kowlessar Date: Mon, 31 Jul 2023 11:02:22 +0100 Subject: [PATCH] Adding database connection to fastapi --- .gitignore | 2 + backend/app/config.py | 5 ++ backend/app/db/__init__.py | 0 backend/app/db/connection.py | 17 ++++ backend/app/db/models/users.py | 16 ++++ backend/app/dependencies.py | 25 ++++-- backend/app/exceptions.py | 105 ++++++++++++++++++++++ backend/requirements/base.txt | 4 +- model_data/simulation_system/__init__.py | 0 model_data/simulation_system/app.py | 108 +++++++++++++++++++++++ 10 files changed, 276 insertions(+), 6 deletions(-) create mode 100644 backend/app/db/__init__.py create mode 100644 backend/app/db/connection.py create mode 100644 backend/app/db/models/users.py create mode 100644 backend/app/exceptions.py create mode 100644 model_data/simulation_system/__init__.py create mode 100644 model_data/simulation_system/app.py diff --git a/.gitignore b/.gitignore index 95bb5d87..cb17846e 100644 --- a/.gitignore +++ b/.gitignore @@ -253,3 +253,5 @@ open_uprn/.idea/ conservation_areas/.idea/ model_data/.idea/ +model_data/simulation_system/data* + diff --git a/backend/app/config.py b/backend/app/config.py index cfd87ec4..03296ae0 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -9,6 +9,11 @@ class Settings(BaseSettings): ENVIRONMENT: str PLAN_TRIGGER_BUCKET: str EPC_AUTH_TOKEN: str + DB_HOST: str + DB_PASSWORD: str + DB_USERNAME: str + DB_PORT: str + DB_NAME: str class Config: env_file = "backend/.env" diff --git a/backend/app/db/__init__.py b/backend/app/db/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/app/db/connection.py b/backend/app/db/connection.py new file mode 100644 index 00000000..273a9b9c --- /dev/null +++ b/backend/app/db/connection.py @@ -0,0 +1,17 @@ +from sqlalchemy import create_engine +from backend.app.config import get_settings + +connection_string = "postgresql+{drivername}://{username}:{password}@{server}:{port}/{dbname}" +db_string = connection_string.format( + drivername="psycopg2", # You'll need to use psycopg2 driver for PostgreSQL + username=get_settings().DB_USERNAME, + password=get_settings().DB_PASSWORD, + server=get_settings().DB_HOST, + port=get_settings().DB_PORT, + dbname=get_settings().DB_NAME, +) + +db_engine = create_engine( + db_string, + echo=True +) diff --git a/backend/app/db/models/users.py b/backend/app/db/models/users.py new file mode 100644 index 00000000..6e243815 --- /dev/null +++ b/backend/app/db/models/users.py @@ -0,0 +1,16 @@ +from sqlalchemy import Column, Integer, String, DateTime +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.sql import func + +Base = declarative_base() + + +class UserModel(Base): + __tablename__ = 'user' + id = Column(Integer, primary_key=True, autoincrement=True) + firstName = Column(String) + email = Column(String, nullable=False, unique=True) + oauth_id = Column(String) + oauth_provider = Column(String, nullable=False) + created_at = Column(DateTime(timezone=True), nullable=False, default=func.now()) + updated_at = Column(DateTime(timezone=True), nullable=False, default=func.now(), onupdate=func.now()) diff --git a/backend/app/dependencies.py b/backend/app/dependencies.py index 05bc0e27..027cfe40 100644 --- a/backend/app/dependencies.py +++ b/backend/app/dependencies.py @@ -6,8 +6,12 @@ from cryptography.hazmat.primitives import hashes from cryptography.hazmat.backends import default_backend from typing import Any import json +from sqlalchemy.orm import sessionmaker from backend.app.config import get_settings from backend.app.utils import setup_logger +from backend.app.db.connection import db_engine +from backend.app.db.models.users import UserModel +from backend.app import exceptions logger = setup_logger() @@ -23,6 +27,16 @@ async def validate_api_key(api_key_header: str = Depends(api_key_header)): return api_key_header +def validate_user(user_id) -> (bool, int): + Session = sessionmaker(bind=db_engine) + with Session() as session: + user = session.query(UserModel).filter_by(id=user_id).first() + if user is not None: + # This methodology is temp but we'll just take the first batch + return True, user + return False, None + + def get_user(user_id: str): # Define here how to fetch a user from your database # using the user_id. Here's a simple placeholder implementation: @@ -30,11 +44,12 @@ def get_user(user_id: str): if get_settings().ENVIRONMENT == "local": return {"id": user_id, "name": "Dummy User"} else: - if user_id == "known_id": - user = {"id": user_id, "name": "Known User"} - else: - print("IMPLEMENT ME! - fetch user from database") - user = {"id": user_id, "name": "Dummy User"} + + is_valid, user = validate_user(user_id) + + if not is_valid or user is None: + exceptions.manage_exception(status_code=401) + return None return user diff --git a/backend/app/exceptions.py b/backend/app/exceptions.py new file mode 100644 index 00000000..031caa71 --- /dev/null +++ b/backend/app/exceptions.py @@ -0,0 +1,105 @@ +def manage_exception(status_code, response=None): + """ + Given the returned status code, this function will raise the relevant exception + This function does not handle 200 responses, it just returns None + :param status_code: + :param response: + :return: + """ + + if response is None: + response = EmptyResponse(status_code=status_code) + + if status_code == 400: + raise AppBadRequest(response=response) + + if status_code == 401: + raise AppUnauthorized(response=response) + + if status_code == 403: + raise AppForbidden(response=response) + + if status_code == 404: + raise AppNotFound(response=response) + + if status_code == 409: + raise AppConflict(response=response) + + if status_code == 415: + raise AppUnsupportedMediaType(response=response) + + if status_code == 500: + raise AppInternalError(response=response) + + +class EmptyResponse: + def __init__(self, status_code): + self.status_code = status_code + self.text = "Generic Error" + + +class AppException(Exception): + def __init__(self, response, msg=None): + self.response = response + self.status_code = response.status_code + super().__init__(msg) + + +class AppBadRequest(AppException): + # HTTP 400: Bad Request + def __init__(self, response): + super().__init__(response, response.text) + + +class AppUnauthorized(AppException): + # HTTP 401: Unauthorized + def __init__(self, response): + super().__init__(response, response.text) + + +class AppForbidden(AppException): + # HTTP 403: Forbidden + def __init__(self, response): + super().__init__(response, response.text) + + +class AppNotFound(AppException): + # HTTP 404: Not Found + def __init__(self, response): + super().__init__(response, response.text) + + +class AppConflict(AppException): + # HTTP 409: Conflict + def __init__(self, response): + super().__init__(response, response.text) + + +class AppUnsupportedMediaType(AppException): + # HTTP 415: UnsupportedMediaType + def __init__(self, response): + super().__init__(response, response.text) + + +class AppInternalError(AppException): + # HTTP 500: Internal Error + def __init__(self, response): + super().__init__(response, response.text) + + +class AppNotImplemented(AppException): + # HTTP 501 + def __init__(self, response): + super().__init__(response, response.text) + + +class AppExceptionUnknown(AppException): + # HTTP Unknown + def __init__(self, response): + super().__init__(response, response.text) + + +class AppNotAuthenticated(AppException): + # Not Authenticated + def __init__(self): + super().__init__(None, "Not Authenticated") diff --git a/backend/requirements/base.txt b/backend/requirements/base.txt index 9588009b..ff9a74a2 100644 --- a/backend/requirements/base.txt +++ b/backend/requirements/base.txt @@ -26,4 +26,6 @@ uvicorn==0.22.0 uvloop==0.17.0 urllib3<2 watchfiles==0.19.0 -websockets==11.0.3 \ No newline at end of file +websockets==11.0.3 +sqlalchemy==2.0.19 +psycopg2-binary \ No newline at end of file diff --git a/model_data/simulation_system/__init__.py b/model_data/simulation_system/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/model_data/simulation_system/app.py b/model_data/simulation_system/app.py new file mode 100644 index 00000000..15902d19 --- /dev/null +++ b/model_data/simulation_system/app.py @@ -0,0 +1,108 @@ +import numpy as np +import os +import pandas as pd +from tqdm import tqdm +from model_data.BaseUtility import BaseUtility + + +def list_subdirectories(directory_path): + return [d for d in os.listdir(directory_path) if os.path.isdir(os.path.join(directory_path, d))] + + +DATA_DIRECTORY = os.getcwd() + '/model_data/simulation_system/data/all-domestic-certificates' + +FIXED_FEATURES = [ + 'PROPERTY_TYPE', + 'BUILT_FORM', + 'CONSTRUCTION_AGE_BAND', + 'NUMBER_HABITABLE_ROOMS', + 'CONSTITUENCY', + 'NUMBER_HEATED_ROOMS', + 'FIXED_LIGHTING_OUTLETS_COUNT', + 'GLAZED_AREA', + 'FLOOR_HEIGHT', + 'FLOOR_LEVEL', + 'TOTAL_FLOOR_AREA', +] + +COMPONENT_FEATURES = [ + 'TRANSACTION_TYPE', + 'WALLS_DESCRIPTION', + 'FLOOR_DESCRIPTION', + 'LIGHTING_DESCRIPTION', + 'ROOF_DESCRIPTION', + 'MAINHEAT_DESCRIPTION', + 'HOTWATER_DESCRIPTION', + 'MAIN_FUEL', + 'MECHANICAL_VENTILATION', + 'SECONDHEAT_DESCRIPTION', + 'ENERGY_TARIFF', # Not sure if this is relevant + 'SOLAR_WATER_HEATING_FLAG', + 'PHOTO_SUPPLY', + 'WINDOWS_DESCRIPTION', + 'GLAZED_TYPE', + 'MULTI_GLAZE_PROPORTION', + 'LIGHTING_DESCRIPTION', + 'LOW_ENERGY_LIGHTING', + 'NUMBER_OPEN_FIREPLACES', + 'MAINHEATCONT_DESCRIPTION', + 'EXTENSION_COUNT' +] + +AVERAGE_FIXED_FEATURES = [ + "TOTAL_FLOOR_AREA" +] + + +def app(): + # Get all the files in the directory + + directories = list_subdirectories(DATA_DIRECTORY) + + for directory in tqdm(directories): + filepath = os.path.join(DATA_DIRECTORY, directory, "certificates.csv") + df = pd.read_csv(filepath, low_memory=False) + df = df[~pd.isnull(df["UPRN"])] + df["UPRN"] = df["UPRN"].astype(int).astype(str) + counts = df.groupby("UPRN").size().reset_index() + counts.columns = ["UPRN", "count"] + counts = counts.sort_values("count", ascending=False) + + # take UPRNS with multiple EPCs + counts = counts[counts["count"] > 1] + df = df[df["UPRN"].isin(counts["UPRN"])] + df = df.sort_values(["UPRN", "LODGEMENT_DATE"], ascending=True) + + for uprn, property_data in df.groupby("UPRN"): + + # Fixed features - these are property attributes that shouldn't change over time + + fixed_data = {} + for field in FIXED_FEATURES: + vals = property_data[field].dropna().unique() + # Remove invalid values + vals = [v for v in vals if v not in BaseUtility.DATA_ANOMALY_MATCHES] + + if len(vals) > 1: + raise ValueError("Fixed feature {} has more than one value - fix me".format(field)) + + if field in AVERAGE_FIXED_FEATURES: + # Check the values are too far apart + if abs(vals[0] - vals[1]) / vals[0] > 0.1: + raise ValueError("Large deviation in fixed feature {} - fix me".format(field)) + + field_value = np.mean(vals) + else: + field_value = vals[0] if vals else None + + fixed_data[field] = field_value + + variable_data = property_data[COMPONENT_FEATURES] + + for idx in range(0, property_data.shape[0] - 1): + + if idx >= property_data.shape[0] - 1: + break + + starting_record = variable_data.iloc[idx] + ending_record = variable_data.iloc[idx + 1]