Insert / upsert sql example persist strategy
September 17, 2024
def _persist_to_db(
session_context: SessionContext, records: List[dict], strategy: PersistStrategy
) -> None:
if strategy == PersistStrategy.INSERT:
_persist_to_db_via_insert(session_context, records)
elif strategy == PersistStrategy.UPSERT:
_persist_to_db_via_upsert(session_context, records)
else:
raise NotImplementedError(f"Persist strategy {strategy} not implemented")
return None
def _persist_to_db_via_insert(session_context: SessionContext, records: List[dict]) -> None:
if not records:
return
table = tables.tbl_fundamental_estimates
with session_context.get_session() as session:
engine = session.bind
with engine.begin() as conn:
conn.execute(table.insert(), records)
return None
def _persist_to_db_via_update(session_context: SessionContext, records: List[dict]) -> None:
if not records:
return
table = tables.tbl_fundamental_estimates
with session_context.get_session() as session:
engine = session.bind
bind_params = {key: sa.bindparam(key) for key in records[0].keys()}
stmt = (
sa.update(table).where(table.c.hash_id == sa.bindparam("hash_id")).values(bind_params)
)
with engine.begin() as conn:
conn.execute(stmt, records)
return None
def _persist_to_db_via_upsert(session_context: SessionContext, records: List[dict]) -> None:
if not records:
return
table = tables.tbl_fundamental_estimates
hash_ids = [record["hash_id"] for record in records]
with session_context.get_session() as session:
stmt = sa.select(table.c.hash_id).where(table.c.hash_id.in_(hash_ids))
existing_hash_ids = {row[0] for row in session.execute(stmt).fetchall()}
records_to_update = [record for record in records if record["hash_id"] in existing_hash_ids]
records_to_insert = [record for record in records if record["hash_id"] not in existing_hash_ids]
if records_to_update:
_persist_to_db_via_update(session_context, records_to_update)
if records_to_insert:
_persist_to_db_via_insert(session_context, records_to_insert)
return None