diff --git a/etl/surveyedData/surveryedData.py b/etl/surveyedData/surveryedData.py index ec1acb6..3bfb418 100644 --- a/etl/surveyedData/surveryedData.py +++ b/etl/surveyedData/surveryedData.py @@ -1,7 +1,7 @@ from etl.pdfReader.pdfReaderToText import pdfReaderToText from etl.pdfReader.reportType import ReportType import math -from etl.transform.preSiteNoteTypes import AssessorInfo +from etl.transform.preSiteNoteTypes import AssessorInfo, CompanyInfo class surveyedDataProcessor(): def __init__(self, address, files): @@ -23,8 +23,28 @@ class surveyedDataProcessor(): elif pdf.type == ReportType.CHARTED_SURVEYOR_REPORT: self.csr = pdf.get_reader() + def load_company_table(self, db_session): + company_data = self.pre_site_note.company_information.__dict__ + + company_name = company_data.get('trading_name') + + existing_company = db_session.query(CompanyInfo).filter_by( + trading_name=company_name + ).first() + + if existing_company: + return existing_company + else: + new_company = CompanyInfo(**company_data) + db_session.add(new_company) + db_session.commit() + return new_company + def load_assessor_table(self, db_session): + company = self.load_company_table(db_session) assessor_data = self.pre_site_note.assessor_information.__dict__ + assessor_data['company_id'] = company.id + accreditation_number = assessor_data.get('accreditation_number') existing_assessor = db_session.query(AssessorInfo).filter_by( diff --git a/etl/transform/preSiteNoteTypes.py b/etl/transform/preSiteNoteTypes.py index 7e25682..278e153 100644 --- a/etl/transform/preSiteNoteTypes.py +++ b/etl/transform/preSiteNoteTypes.py @@ -9,7 +9,7 @@ from sqlalchemy.dialects.postgresql import UUID class BaseModel(SQLModel): id: uuid.UUID = Field( default_factory=uuid.uuid4, - sa_column=Column(UUID(as_uuid=True), primary_key=True) + primary_key=True, ) class Dimension(BaseModel): @@ -18,11 +18,13 @@ class Dimension(BaseModel): loss_perimeter_m: float party_wall_length_m: float -class CompanyInfo(BaseModel): +class CompanyInfo(BaseModel, table=True): trading_name: str post_code: str fax_number: Optional[str] = None related_party_disclosure: Optional[str] = None + + __table_args__ = {"extend_existing": True} @field_validator('related_party_disclosure') def set_none_if_none_of_the_above(cls, v): @@ -84,6 +86,10 @@ class AssessorInfo(BaseModel, table=True): name: str phone_number: Optional[str] = None email_address: Optional[EmailStr] = None + company_id: uuid.UUID = Field( + foreign_key="companyinfo.id", # Referencing CompanyInfo + nullable=False + ) class VentilationAndCooling(BaseModel): no_of_open_fireplaces: int