Skip to content

Instantly share code, notes, and snippets.

@stevenferrer
Forked from pseudomuto/main_1.go
Created May 5, 2022 10:22
Show Gist options
  • Select an option

  • Save stevenferrer/a4d776c9e134d17a5dae398ce5845791 to your computer and use it in GitHub Desktop.

Select an option

Save stevenferrer/a4d776c9e134d17a5dae398ce5845791 to your computer and use it in GitHub Desktop.

Revisions

  1. @pseudomuto pseudomuto revised this gist Jan 28, 2018. 2 changed files with 15 additions and 5 deletions.
    10 changes: 10 additions & 0 deletions pipeline.go
    Original file line number Diff line number Diff line change
    @@ -6,6 +6,8 @@ import (
    "strings"
    )

    // A PipelineStmt is a simple wrapper for creating a statement consisting of
    // a query and a set of arguments to be passed to that query.
    type PipelineStmt struct {
    query string
    args []interface{}
    @@ -15,11 +17,19 @@ func NewPipelineStmt(query string, args ...interface{}) *PipelineStmt {
    return &PipelineStmt{query, args}
    }

    // Executes the statement within supplied transaction. The literal string `{LAST_INS_ID}`
    // will be replaced with the supplied value to make chaining `PipelineStmt` objects together
    // simple.
    func (ps *PipelineStmt) Exec(tx Transaction, lastInsertId int64) (sql.Result, error) {
    query := strings.Replace(ps.query, "{LAST_INS_ID}", strconv.Itoa(int(lastInsertId)), -1)
    return tx.Exec(query, ps.args...)
    }

    // Runs the supplied statements within the transaction. If any statement fails, the transaction
    // is rolled back, and the original error is returned.
    //
    // The `LastInsertId` from the previous statement will be passed to `Exec`. The zero-value (0) is
    // used initially.
    func RunPipeline(tx Transaction, stmts ...*PipelineStmt) (sql.Result, error) {
    var res sql.Result
    var err error
    10 changes: 5 additions & 5 deletions transaction.go
    Original file line number Diff line number Diff line change
    @@ -7,7 +7,7 @@ import (
    // Transaction is an interface that models the standard transaction in
    // `database/sql`.
    //
    // To ensure `innerTxFn` funcs cannot commit or rollback a transaction (which is
    // To ensure `TxFn` funcs cannot commit or rollback a transaction (which is
    // handled by `WithTransaction`), those methods are not included here.
    type Transaction interface {
    Exec(query string, args ...interface{}) (sql.Result, error)
    @@ -16,12 +16,12 @@ type Transaction interface {
    QueryRow(query string, args ...interface{}) *sql.Row
    }

    // A Txfn is a function that will be called with an initialized `Transaction` object that can be used for executing
    // statements and queries against a database.
    // A Txfn is a function that will be called with an initialized `Transaction` object
    // that can be used for executing statements and queries against a database.
    type TxFn func(Transaction) error

    // WithTransaction creates a new transaction and handles rollback/commit based on the error object returned by the
    // `TxFn`
    // WithTransaction creates a new transaction and handles rollback/commit based on the
    // error object returned by the `TxFn`
    func WithTransaction(db *sql.DB, fn TxFn) (err error) {
    tx, err := db.Begin()
    if err != nil {
  2. @pseudomuto pseudomuto revised this gist Jan 28, 2018. 2 changed files with 88 additions and 0 deletions.
    41 changes: 41 additions & 0 deletions pipeline.go
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,41 @@
    package main

    import (
    "database/sql"
    "strconv"
    "strings"
    )

    type PipelineStmt struct {
    query string
    args []interface{}
    }

    func NewPipelineStmt(query string, args ...interface{}) *PipelineStmt {
    return &PipelineStmt{query, args}
    }

    func (ps *PipelineStmt) Exec(tx Transaction, lastInsertId int64) (sql.Result, error) {
    query := strings.Replace(ps.query, "{LAST_INS_ID}", strconv.Itoa(int(lastInsertId)), -1)
    return tx.Exec(query, ps.args...)
    }

    func RunPipeline(tx Transaction, stmts ...*PipelineStmt) (sql.Result, error) {
    var res sql.Result
    var err error
    var lastInsId int64

    for _, ps := range stmts {
    res, err = ps.Exec(tx, lastInsId)
    if err != nil {
    return nil, err
    }

    lastInsId, err = res.LastInsertId()
    if err != nil {
    return nil, err
    }
    }

    return res, nil
    }
    47 changes: 47 additions & 0 deletions transaction.go
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,47 @@
    package main

    import (
    "database/sql"
    )

    // Transaction is an interface that models the standard transaction in
    // `database/sql`.
    //
    // To ensure `innerTxFn` funcs cannot commit or rollback a transaction (which is
    // handled by `WithTransaction`), those methods are not included here.
    type Transaction interface {
    Exec(query string, args ...interface{}) (sql.Result, error)
    Prepare(query string) (*sql.Stmt, error)
    Query(query string, args ...interface{}) (*sql.Rows, error)
    QueryRow(query string, args ...interface{}) *sql.Row
    }

    // A Txfn is a function that will be called with an initialized `Transaction` object that can be used for executing
    // statements and queries against a database.
    type TxFn func(Transaction) error

    // WithTransaction creates a new transaction and handles rollback/commit based on the error object returned by the
    // `TxFn`
    func WithTransaction(db *sql.DB, fn TxFn) (err error) {
    tx, err := db.Begin()
    if err != nil {
    return
    }

    defer func() {
    if p := recover(); p != nil {
    // a panic occurred, rollback and repanic
    tx.Rollback()
    panic(p)
    } else if err != nil {
    // something went wrong, rollback
    tx.Rollback()
    } else {
    // all good, commit
    err = tx.Commit()
    }
    }()

    err = fn(tx)
    return err
    }
  3. @pseudomuto pseudomuto revised this gist Jan 28, 2018. 2 changed files with 72 additions and 0 deletions.
    40 changes: 40 additions & 0 deletions main_2.go
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,40 @@
    package main

    import (
    "database/sql"
    "log"
    )

    func main() {
    db, err := sql.Open("VENDOR_HERE", "YOUR_DSN_HERE")
    handleError(err)

    defer db.Close()

    err = WithTransaction(db, func(tx Transaction) error {
    // insert a record into table1
    res, err := tx.Exec("INSERT INTO table1(name) VALUES(?)", "some name")
    if err != nil {
    return err
    }

    id, err := res.LastInsertId()
    if err != nil {
    return err
    }

    res, err = tx.Exec("INSERT INTO table2(table1_id, name) VALUES(?, ?)", id, "other name")
    if err != nil {
    return err
    }
    })

    handleError(err)
    log.Println("Done.")
    }

    func handleError(err error) {
    if err != nil {
    log.Fatal(err)
    }
    }
    32 changes: 32 additions & 0 deletions main_3.go
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,32 @@
    package main

    import (
    "database/sql"
    "log"
    )

    func main() {
    db, err := sql.Open("VENDOR_HERE", "YOUR_DSN_HERE")
    handleError(err)

    defer db.Close()

    stmts := []*PipelineStmt{
    NewPipelineStmt("INSERT INTO table1(name) VALUES(?)", "some name"),
    NewPipelineStmt("INSERT INTO table2(table1_id, name) VALUES({LAST_INS_ID}, ?)", "other name"),
    }

    err = WithTransaction(db, func(tx Transaction) error {
    _, err := RunPipeline(tx, stmts...)
    return err
    })

    handleError(err)
    log.Println("Done.")
    }

    func handleError(err error) {
    if err != nil {
    log.Fatal(err)
    }
    }
  4. @pseudomuto pseudomuto revised this gist Jan 28, 2018. No changes.
  5. @pseudomuto pseudomuto created this gist Jan 28, 2018.
    45 changes: 45 additions & 0 deletions main_1.go
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,45 @@
    package main

    import (
    "database/sql"
    "log"
    )

    func main() {
    db, err := sql.Open("VENDOR_HERE", "YOUR_DSN_HERE")
    handleError(err)

    defer db.Close()

    tx, err := db.Begin()
    handleError(err)

    // insert a record into table1
    res, err := tx.Exec("INSERT INTO table1(name) VALUES(?)", "some name")
    if err != nil {
    tx.Rollback()
    log.Fatal(err)
    }

    // fetch the auto incremented id
    id, err := res.LastInsertId()
    handleError(err)

    // insert record into table2, referencing the first record from table1
    res, err = tx.Exec("INSERT INTO table2(table1_id, name) VALUES(?, ?)", id, "other name")
    if err != nil {
    tx.Rollback()
    log.Fatal(err)
    }

    // commit the transaction
    handleError(tx.Commit())

    log.Println("Done.")
    }

    func handleError(err error) {
    if err != nil {
    log.Fatal(err)
    }
    }