added test modification to us postgres

This commit is contained in:
Jun-te Kim 2026-04-22 14:17:06 +00:00
parent 8c92eee448
commit f220c792e4

View file

@ -13,20 +13,34 @@ from backend.app.db.models.bulk_address_uploads import (
)
@pytest.fixture
def sqlite_session(monkeypatch):
engine = create_engine("sqlite:///:memory:")
@pytest.fixture(scope="function")
def pg_engine(postgresql):
connection_string = (
f"postgresql+psycopg://"
f"{postgresql.info.user}:"
f"{postgresql.info.password}@"
f"{postgresql.info.host}:"
f"{postgresql.info.port}/"
f"{postgresql.info.dbname}"
)
engine = create_engine(connection_string)
SQLModel.metadata.create_all(engine)
yield engine
SQLModel.metadata.drop_all(engine)
engine.dispose()
@pytest.fixture
def patched_session(pg_engine, monkeypatch):
sessions = []
def factory():
s = Session(engine)
s = Session(pg_engine)
sessions.append(s)
return s
monkeypatch.setattr(module, "get_db_session", factory)
yield engine
yield pg_engine
for s in sessions:
s.close()
@ -57,33 +71,33 @@ def _fetch(engine, task_id):
).first()
def test_set_combining_status_updates_row(sqlite_session):
def test_set_combining_status_updates_row(patched_session):
task_id = uuid4()
_insert_row(sqlite_session, task_id, status="processing")
_insert_row(patched_session, task_id, status="processing")
set_combining_status(task_id)
row = _fetch(sqlite_session, task_id)
row = _fetch(patched_session, task_id)
assert row.status == "combining"
assert row.combined_output_s3_uri is None
def test_set_combined_output_s3_uri_writes_uri_and_awaiting_review(sqlite_session):
def test_set_combined_output_s3_uri_writes_uri_and_awaiting_review(patched_session):
task_id = uuid4()
_insert_row(sqlite_session, task_id, status="combining")
_insert_row(patched_session, task_id, status="combining")
set_combined_output_s3_uri(task_id, "s3://bucket/bulk_final_outputs/abc/combined.csv")
row = _fetch(sqlite_session, task_id)
row = _fetch(patched_session, task_id)
assert row.status == "awaiting_review"
assert row.combined_output_s3_uri == "s3://bucket/bulk_final_outputs/abc/combined.csv"
def test_set_combining_status_missing_row_raises(sqlite_session):
def test_set_combining_status_missing_row_raises(patched_session):
with pytest.raises(ValueError, match="No bulk_address_uploads row"):
set_combining_status(uuid4())
def test_set_combined_output_s3_uri_missing_row_raises(sqlite_session):
def test_set_combined_output_s3_uri_missing_row_raises(patched_session):
with pytest.raises(ValueError, match="No bulk_address_uploads row"):
set_combined_output_s3_uri(uuid4(), "s3://x/y.csv")