Insert / upsert sql example persist strategy

September 17, 2024

def _persist_to_db(
    session_context: SessionContext, records: List[dict], strategy: PersistStrategy
) -> None:
    if strategy == PersistStrategy.INSERT:
        _persist_to_db_via_insert(session_context, records)
    elif strategy == PersistStrategy.UPSERT:
        _persist_to_db_via_upsert(session_context, records)
    else:
        raise NotImplementedError(f"Persist strategy {strategy} not implemented")
 
    return None
 
 
def _persist_to_db_via_insert(session_context: SessionContext, records: List[dict]) -> None:
    if not records:
        return
 
    table = tables.tbl_fundamental_estimates
 
    with session_context.get_session() as session:
        engine = session.bind
 
        with engine.begin() as conn:
            conn.execute(table.insert(), records)
 
    return None
 
 
def _persist_to_db_via_update(session_context: SessionContext, records: List[dict]) -> None:
    if not records:
        return
 
    table = tables.tbl_fundamental_estimates
 
    with session_context.get_session() as session:
        engine = session.bind
 
        bind_params = {key: sa.bindparam(key) for key in records[0].keys()}
        stmt = (
            sa.update(table).where(table.c.hash_id == sa.bindparam("hash_id")).values(bind_params)
        )
 
        with engine.begin() as conn:
            conn.execute(stmt, records)
 
    return None
 
 
def _persist_to_db_via_upsert(session_context: SessionContext, records: List[dict]) -> None:
    if not records:
        return
 
    table = tables.tbl_fundamental_estimates
 
    hash_ids = [record["hash_id"] for record in records]
 
    with session_context.get_session() as session:
        stmt = sa.select(table.c.hash_id).where(table.c.hash_id.in_(hash_ids))
        existing_hash_ids = {row[0] for row in session.execute(stmt).fetchall()}
 
    records_to_update = [record for record in records if record["hash_id"] in existing_hash_ids]
    records_to_insert = [record for record in records if record["hash_id"] not in existing_hash_ids]
 
    if records_to_update:
        _persist_to_db_via_update(session_context, records_to_update)
    if records_to_insert:
        _persist_to_db_via_insert(session_context, records_to_insert)
 
    return None