diff --git a/backend/app/db/functions/energy_assessment_functions.py b/backend/app/db/functions/energy_assessment_functions.py index e810c168..ca2f721c 100644 --- a/backend/app/db/functions/energy_assessment_functions.py +++ b/backend/app/db/functions/energy_assessment_functions.py @@ -1,19 +1,26 @@ from backend.app.db.models.energy_assessments import ( - EnergyAssessment, EnergyAssessmentScenarios, EnergyAssessmentDocuments + EnergyAssessment, EnergyAssessmentScenarios, EnergyAssessmentDocuments, DocumentTypeEnum ) from sqlalchemy.orm import Session from sqlalchemy.exc import IntegrityError -from typing import Optional, List +from typing import Optional, List, Dict from sqlalchemy import desc +from utils.logger import setup_logger + +logger = setup_logger() -def bulk_insert_energy_assessments(session: Session, data_list): +def bulk_insert_energy_assessments(session: Session, data_list: List[dict]) -> Dict[int, int]: """ - This function inserts or updates multiple energy assessment records into the database. + This function inserts or updates multiple energy assessment records into the database and returns a mapping of + uprn to energy_assessment_id. :param session: The SQLAlchemy session. :param data_list: A list of dictionaries containing energy assessment data. + :return: A dictionary mapping each uprn to its corresponding energy_assessment_id. """ + uprn_to_assessment_id = {} + try: for data in data_list: uprn = data.get('uprn') @@ -30,19 +37,30 @@ def bulk_insert_energy_assessments(session: Session, data_list): for key, value in data.items(): setattr(existing_record, key, value) session.add(existing_record) + + # Map the uprn to the existing record's ID + uprn_to_assessment_id[uprn] = existing_record.id else: # Insert a new record new_assessment = EnergyAssessment(**data) session.add(new_assessment) + # Flush the session to get the newly created ID before commit + session.flush() + + # Map the uprn to the new record's ID + uprn_to_assessment_id[uprn] = new_assessment.id + # Commit the transaction session.commit() - print("All records inserted or updated successfully.") + logger.info("All records inserted or updated successfully.") except IntegrityError as e: # Rollback the session in case of error session.rollback() - print(f"Error occurred: {e}") + logger.info(f"Error occurred: {e}") + + return uprn_to_assessment_id def get_latest_assessment_by_uprn(session: Session, uprn: int) -> Optional[EnergyAssessment]: @@ -60,77 +78,81 @@ def get_latest_assessment_by_uprn(session: Session, uprn: int) -> Optional[Energ return latest_assessment.to_dict() if latest_assessment else EnergyAssessment.empty_response() except Exception as e: - print(f"An error occurred: {e}") + logger.info(f"An error occurred: {e}") return None -def create_energy_assessment_scenario(session: Session, data_list: List[dict], energy_assessment_id: int): +def create_scenarios_for_documents(session: Session, document_list: List[dict], uprn_to_assessment_id: dict): """ - This function creates the necessary energy assessment scenarios if they don't already exist. + Creates scenarios for documents by UPRN and links them to the energy assessments. :param session: The SQLAlchemy session. - :param data_list: A list of dictionaries containing document data with scenario information. - :param energy_assessment_id: The ID of the energy assessment. + :param document_list: A list of dictionaries containing document data. + :param uprn_to_assessment_id: A dictionary mapping UPRN to energy_assessment_id. """ try: - # Extract unique scenario names from the data - scenario_names = {item['scenario_id'] for item in data_list if item['scenario_id'] is not None} + for document in document_list: + uprn = document.get('uprn') + scenario_name = document.get('scenario_id') - for scenario_name in scenario_names: - # Check if the scenario already exists in the database - existing_scenario = session.query(EnergyAssessmentScenarios).filter_by(scenario_name=scenario_name).first() + if scenario_name: + # Get the associated energy_assessment_id for the UPRN + energy_assessment_id = uprn_to_assessment_id.get(uprn) - if not existing_scenario: - # Create a new scenario - new_scenario = EnergyAssessmentScenarios( - scenario_name=scenario_name, energy_assessment_id=energy_assessment_id - ) - session.add(new_scenario) + # Check if the scenario already exists + existing_scenario = session.query(EnergyAssessmentScenarios).filter_by( + scenario_name=scenario_name, + energy_assessment_id=energy_assessment_id + ).first() - # Commit all scenario creations + if not existing_scenario: + # Create the scenario + new_scenario = EnergyAssessmentScenarios( + scenario_name=scenario_name, + energy_assessment_id=energy_assessment_id + ) + session.add(new_scenario) + session.flush() # Get the new scenario ID + + # Update document with new scenario ID + document['scenario_id'] = new_scenario.id + else: + # If the scenario already exists, just use its ID + document['scenario_id'] = existing_scenario.id + + # Commit the scenarios session.commit() - print("Scenarios created successfully.") + logger.info("Scenarios created successfully.") except IntegrityError as e: session.rollback() - print(f"Error occurred: {e}") + logger.info(f"Error occurred: {e}") -def create_scenario_documents(session: Session, data_list: List[dict]): +def create_documents(session: Session, document_list: List[dict]): """ - This function creates documents in the energy_assessment_documents table, linking them to scenarios if applicable. - For usage in the energy assessment upload router + Inserts documents into the energy_assessment_documents table, linking them to scenarios and assessments. :param session: The SQLAlchemy session. - :param data_list: A list of dictionaries containing document data. + :param document_list: A list of dictionaries containing document data. """ try: - for data in data_list: - scenario_name = data.get('scenario_id') - - if scenario_name: - # Get the scenario ID from the scenario name - scenario = session.query(EnergyAssessmentScenarios).filter_by(scenario_name=scenario_name).first() - - if scenario: - data['scenario_id'] = scenario.id - else: - print(f"Scenario '{scenario_name}' not found. Skipping document.") - - # Create the new document + for document in document_list: + # Ensure the document_type is cast to Enum new_document = EnergyAssessmentDocuments( - uprn=data['uprn'], - document_type=data['document_type'], - document_location=data['document_location'], - scenario_id=data['scenario_id'] # Might be None + uprn=document['uprn'], + document_type=DocumentTypeEnum(document['document_type']).value, + document_location=document['document_location'], + energy_assessment_id=document['energy_assessment_id'], + scenario_id=document.get('scenario_id') # Might be None if no scenario ) session.add(new_document) # Commit all document insertions session.commit() - print("Documents created successfully.") + logger.info("Documents created successfully.") except IntegrityError as e: session.rollback() - print(f"Error occurred: {e}") + logger.info(f"Error occurred: {e}") diff --git a/backend/app/db/models/energy_assessments.py b/backend/app/db/models/energy_assessments.py index a5667a0a..46912c9b 100644 --- a/backend/app/db/models/energy_assessments.py +++ b/backend/app/db/models/energy_assessments.py @@ -1,5 +1,8 @@ from sqlalchemy import Column, Integer, BigInteger, Text, Float, DateTime, Boolean, Date, ForeignKey from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.dialects.postgresql import ENUM as PgEnum +import enum +from datetime import datetime Base = declarative_base() @@ -172,19 +175,33 @@ class EnergyAssessmentScenarios(Base): energy_assessment_id = Column(BigInteger, ForeignKey('energy_assessments.id'), nullable=False) +class DocumentTypeEnum(enum.Enum): + EPR = "EPR" + ConditionReport = "Condition Report" + EvidenceReport = "Evidence Report" + SummaryInformation = "Summary Information" + FloorPlan = "Floor Plan" + ScenarioDraftEPC = "Scenario Draft EPC" + ScenarioSiteNotes = "Scenario Site Notes" + + class EnergyAssessmentDocuments(Base): __tablename__ = 'energy_assessment_documents' id = Column(BigInteger, primary_key=True, autoincrement=True) uprn = Column(BigInteger, nullable=False) energy_assessment_id = Column(BigInteger, ForeignKey('energy_assessments.id'), nullable=False) - document_type = Column(Text, nullable=False) # You can handle this using an enum if needed + document_type = Column(PgEnum(DocumentTypeEnum, name="document_type", create_type=False), nullable=False) document_location = Column(Text, nullable=False) - uploaded_at = Column(DateTime(timezone=True), nullable=False) + uploaded_at = Column(DateTime(timezone=True), nullable=False, default=datetime.utcnow) scenario_id = Column(BigInteger, ForeignKey('energy_assessment_scenarios.id'), nullable=True) @staticmethod def empty_response(): return { - "id": None, "uprn": None, "document_type": None, "document_location": None, "uploaded_at": None, + "id": None, + "uprn": None, + "document_type": None, + "document_location": None, + "uploaded_at": None, "scenario_id": None } diff --git a/backend/app/energy_assessments/router.py b/backend/app/energy_assessments/router.py index 1c55f005..0f5fcf1b 100644 --- a/backend/app/energy_assessments/router.py +++ b/backend/app/energy_assessments/router.py @@ -1,5 +1,6 @@ import os from io import BytesIO +from typing import List from fastapi import APIRouter, Depends from starlette.responses import Response @@ -11,7 +12,9 @@ from backend.app.energy_assessments.schemas import EnergyAssessmentUploadPayload from sqlalchemy.orm import sessionmaker from sqlalchemy.exc import IntegrityError, OperationalError from backend.app.db.connection import db_engine -from backend.app.db.functions.energy_assessment_functions import bulk_insert_energy_assessments +from backend.app.db.functions.energy_assessment_functions import ( + bulk_insert_energy_assessments, create_scenarios_for_documents, create_documents +) from etl.xml_survey_extraction.XmlParser import XmlParser @@ -23,6 +26,29 @@ from utils.logger import setup_logger logger = setup_logger() + +def insert_energy_assessment_documents(document_list: List[dict], uprn_to_assessment_id: dict): + """ + Inserts or updates energy assessment documents, assigning the correct energy_assessment_id. + + :param document_list: A list of dictionaries containing document data. + :param uprn_to_assessment_id: A dictionary mapping UPRN to energy_assessment_id. + """ + for document in document_list: + uprn = document['uprn'] + # Assign the energy_assessment_id based on uprn + energy_assessment_id = uprn_to_assessment_id.get(uprn) + + if not energy_assessment_id: + logger.info(f"No energy_assessment_id found for UPRN: {uprn}. Skipping document.") + continue + + # Attach energy_assessment_id to each document + document['energy_assessment_id'] = energy_assessment_id + + logger.info("Energy Assessment IDs assigned to documents.") + + router = APIRouter( prefix="/energy-assessments", tags=["energy-assessments"], @@ -214,9 +240,16 @@ async def upload(body: EnergyAssessmentUploadPayload): xml_data_to_store.append(extracted_data) logger.info("Storing energy assessment xml data to database") - bulk_insert_energy_assessments(session, xml_data_to_store) + uprn_to_assessment_id = bulk_insert_energy_assessments(session, xml_data_to_store) - # TODO: Store energy_assessment_documents + # Insert energy assessment id into the documents data + insert_energy_assessment_documents(energy_assessment_documents, uprn_to_assessment_id) + + create_scenarios_for_documents(session, energy_assessment_documents, uprn_to_assessment_id) + + create_documents(session, energy_assessment_documents) + + session.close() except IntegrityError: logger.error("Database integrity error occurred", exc_info=True)