M database/db.go => database/db.go +8 -3
@@ 32,13 32,17 @@ func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
if d.tx != nil {
return d.tx, nil
}
- return d.db.BeginTx(ctx, opts)
+
+ var err error
+ d.tx, err = d.db.BeginTx(ctx, opts)
+ return d.tx, err
}
// CommitTx commits active transaction
func (d *DB) CommitTx() error {
if d.tx == nil {
- return errors.New("You have no active db transaction")
+ // No errors here
+ return nil
}
if d.commit {
err := d.tx.Commit()
@@ 51,7 55,8 @@ func (d *DB) CommitTx() error {
// RollbackTx rollsback active transaction
func (d *DB) RollbackTx() error {
if d.tx == nil {
- return errors.New("You have no active db transaction")
+ // No errors here
+ return nil
}
err := d.tx.Rollback()
d.tx = nil
M server/server.go => server/server.go +7 -0
@@ 293,6 293,13 @@ func (s *Server) WithDefaultMiddleware() *Server {
database.Context(c.Request().Context(), database.NewDB(s.DB)),
),
)
+ c.Response().Before(func() {
+ db := database.DBFromContext(c.Request().Context())
+ db.EnableCommit()
+ if err := db.CommitTx(); err != nil {
+ panic(err)
+ }
+ })
return next(ctx)
}
})