@@ 20,18 20,18 @@ var TxOptionsRO *sql.TxOptions = &sql.TxOptions{Isolation: 0, ReadOnly: true}
// ContextWithTimeout returns a context with query timeout
func ContextWithTimeout(
- ctx context.Context, db *sql.DB, timeout int) (context.Context, context.CancelFunc) {
+ ctx context.Context, db DBI, timeout int) (context.Context, context.CancelFunc) {
return context.WithTimeout(Context(ctx, db), time.Duration(timeout)*time.Second)
}
// Context adds db connection to context for immediate use
-func Context(ctx context.Context, db *sql.DB) context.Context {
+func Context(ctx context.Context, db DBI) context.Context {
return context.WithValue(ctx, dbCtxKey, db)
}
// DBFromContext pulls db pool from context
-func DBFromContext(ctx context.Context) *sql.DB {
- db, ok := ctx.Value(dbCtxKey).(*sql.DB)
+func DBFromContext(ctx context.Context) DBI {
+ db, ok := ctx.Value(dbCtxKey).(DBI)
if !ok {
panic(errors.New("Invalid database context"))
}
@@ 47,18 47,18 @@ func WithTx(ctx context.Context, opts *sql.TxOptions, fn func(tx *sql.Tx) error)
}
defer func() {
if r := recover(); r != nil {
- tx.Rollback()
+ db.RollbackTx()
panic(r)
}
}()
err = fn(tx)
if err != nil {
- err := tx.Rollback()
+ err := db.RollbackTx()
if err != nil && err != sql.ErrTxDone {
panic(err)
}
} else {
- err := tx.Commit()
+ err := db.CommitTx()
if err != nil && err != sql.ErrTxDone {
panic(err)
}