mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
62 lines
2.2 KiB
Python
62 lines
2.2 KiB
Python
import csv
|
|
from io import StringIO
|
|
|
|
from infrastructure.s3.s3_client import S3Client
|
|
from infrastructure.s3.s3_uri import parse_s3_uri
|
|
|
|
|
|
def _dedupe_fieldnames(fieldnames: list[str]) -> list[str]:
|
|
"""Disambiguate repeated CSV headers by appending an index.
|
|
|
|
The first occurrence keeps its name; each later one becomes
|
|
``name_1``, ``name_2`` … so duplicate columns survive as distinct
|
|
keys instead of collapsing onto one (last-wins) dict entry.
|
|
"""
|
|
deduped: list[str] = []
|
|
counts: dict[str, int] = {}
|
|
for name in fieldnames:
|
|
if name not in counts:
|
|
counts[name] = 0
|
|
deduped.append(name)
|
|
continue
|
|
counts[name] += 1
|
|
candidate = f"{name}_{counts[name]}"
|
|
while candidate in counts:
|
|
counts[name] += 1
|
|
candidate = f"{name}_{counts[name]}"
|
|
counts[candidate] = 0
|
|
deduped.append(candidate)
|
|
return deduped
|
|
|
|
|
|
class CsvS3Client(S3Client):
|
|
def read_rows(self, s3_uri: str) -> list[dict[str, str]]:
|
|
bucket, key = parse_s3_uri(s3_uri)
|
|
if bucket != self.bucket:
|
|
raise ValueError(
|
|
f"s3_uri bucket {bucket!r} does not match client bucket {self.bucket!r}"
|
|
)
|
|
raw = self.get_object(key)
|
|
try:
|
|
text = raw.decode("utf-8-sig")
|
|
except UnicodeDecodeError:
|
|
# Some uploads are Windows-1252 (e.g. £ as byte 0xA3), not UTF-8.
|
|
text = raw.decode("cp1252")
|
|
|
|
buffer = StringIO(text)
|
|
header = next(csv.reader(buffer), None)
|
|
if header is None:
|
|
return []
|
|
fieldnames = _dedupe_fieldnames(header)
|
|
reader = csv.DictReader(buffer, fieldnames=fieldnames)
|
|
return [dict(row) for row in reader]
|
|
|
|
def save_rows(self, rows: list[dict[str, str]], key: str) -> str:
|
|
if not rows:
|
|
raise ValueError("Cannot save an empty rows list: header is unknown")
|
|
buffer = StringIO()
|
|
fieldnames = list(rows[0].keys())
|
|
writer = csv.DictWriter(buffer, fieldnames=fieldnames)
|
|
writer.writeheader()
|
|
writer.writerows(rows)
|
|
return self.put_object(key, buffer.getvalue().encode("utf-8"))
|