From fe8ba157801efb8bdbd973beb9a4b0f96d8c979b Mon Sep 17 00:00:00 2001 From: Peter Sanchez Date: Tue, 14 Feb 2023 17:25:36 -0600 Subject: [PATCH] Fixing transaction handling --- database/db.go | 11 ++++++++--- server/server.go | 7 +++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/database/db.go b/database/db.go index 6bccaa3..64d8c1b 100644 --- a/database/db.go +++ b/database/db.go @@ -32,13 +32,17 @@ func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) if d.tx != nil { return d.tx, nil } - return d.db.BeginTx(ctx, opts) + + var err error + d.tx, err = d.db.BeginTx(ctx, opts) + return d.tx, err } // CommitTx commits active transaction func (d *DB) CommitTx() error { if d.tx == nil { - return errors.New("You have no active db transaction") + // No errors here + return nil } if d.commit { err := d.tx.Commit() @@ -51,7 +55,8 @@ func (d *DB) CommitTx() error { // RollbackTx rollsback active transaction func (d *DB) RollbackTx() error { if d.tx == nil { - return errors.New("You have no active db transaction") + // No errors here + return nil } err := d.tx.Rollback() d.tx = nil diff --git a/server/server.go b/server/server.go index a361034..6cb8d5f 100644 --- a/server/server.go +++ b/server/server.go @@ -293,6 +293,13 @@ func (s *Server) WithDefaultMiddleware() *Server { database.Context(c.Request().Context(), database.NewDB(s.DB)), ), ) + c.Response().Before(func() { + db := database.DBFromContext(c.Request().Context()) + db.EnableCommit() + if err := db.CommitTx(); err != nil { + panic(err) + } + }) return next(ctx) } }) -- 2.45.2