~netlandish/gobwebs

c231306edca9df5de71b172cb0dcc654d95a355a — Peter Sanchez 1 year, 9 months ago 116b816
Removing DBI from database functions
3 files changed, 32 insertions(+), 27 deletions(-)

M database/db.go
M database/sql.go
M server/server.go
M database/db.go => database/db.go +4 -6
@@ 1,5 1,7 @@
package database

// DO NOT USE THESE TYPES OR INTERFACES

import (
	"context"
	"database/sql"


@@ 42,15 44,11 @@ func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) 
			// Current read only transaction. Commit it and create a new transaction.
			// Should be harmless as no data has been written as of yet.
			cval := d.commit
			if d.commit == false {
				d.commit = true
			}
			d.commit = true
			if err := d.CommitTx(); err != nil {
				return nil, err
			}
			if d.commit != cval {
				d.commit = cval
			}
			d.commit = cval
		} else {
			return d.tx, nil
		}

M database/sql.go => database/sql.go +20 -8
@@ 5,6 5,8 @@ import (
	"database/sql"
	"errors"
	"time"

	"github.com/labstack/echo/v4"
)

// This entire concept is inspired, ripped, and edited, from git.sr.ht/~sircmpwn/core-go


@@ 20,18 22,18 @@ var TxOptionsRO *sql.TxOptions = &sql.TxOptions{Isolation: 0, ReadOnly: true}

// ContextWithTimeout returns a context with query timeout
func ContextWithTimeout(
	ctx context.Context, db DBI, timeout int) (context.Context, context.CancelFunc) {
	ctx context.Context, db *sql.DB, timeout int) (context.Context, context.CancelFunc) {
	return context.WithTimeout(Context(ctx, db), time.Duration(timeout)*time.Second)
}

// Context adds db connection to context for immediate use
func Context(ctx context.Context, db DBI) context.Context {
func Context(ctx context.Context, db *sql.DB) context.Context {
	return context.WithValue(ctx, dbCtxKey, db)
}

// ForContext pulls DBI obj for context
func ForContext(ctx context.Context) DBI {
	db, ok := ctx.Value(dbCtxKey).(DBI)
// ForContext pulls *sql.DB obj for context
func ForContext(ctx context.Context) *sql.DB {
	db, ok := ctx.Value(dbCtxKey).(*sql.DB)
	if !ok {
		panic(errors.New("Invalid database context"))
	}


@@ 47,21 49,31 @@ func WithTx(ctx context.Context, opts *sql.TxOptions, fn func(tx *sql.Tx) error)
	}
	defer func() {
		if r := recover(); r != nil {
			db.RollbackTx()
			tx.Rollback()
			panic(r)
		}
	}()
	err = fn(tx)
	if err != nil {
		err := db.RollbackTx()
		err := tx.Rollback()
		if err != nil && err != sql.ErrTxDone {
			panic(err)
		}
	} else {
		err := db.CommitTx()
		err := tx.Commit()
		if err != nil && err != sql.ErrTxDone {
			panic(err)
		}
	}
	return err
}

// Middleware will place the stripe client in the request context
func Middleware(db *sql.DB) echo.MiddlewareFunc {
	return func(next echo.HandlerFunc) echo.HandlerFunc {
		return func(c echo.Context) error {
			c.SetRequest(c.Request().WithContext(Context(c.Request().Context(), db)))
			return next(c)
		}
	}
}

M server/server.go => server/server.go +8 -13
@@ 20,7 20,6 @@ import (
	"github.com/labstack/echo/v4/middleware"
	"hg.code.netlandish.com/~netlandish/gobwebs"
	"hg.code.netlandish.com/~netlandish/gobwebs/config"
	"hg.code.netlandish.com/~netlandish/gobwebs/database"
	"hg.code.netlandish.com/~netlandish/gobwebs/email"
	"hg.code.netlandish.com/~netlandish/gobwebs/internal/localizer"
	"hg.code.netlandish.com/~netlandish/gobwebs/storage"


@@ 288,18 287,14 @@ func (s *Server) WithDefaultMiddleware() *Server {
	s.e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
		return func(c echo.Context) error {
			ctx := &Context{Context: c, Server: s}
			c.SetRequest(
				c.Request().WithContext(
					database.Context(c.Request().Context(), database.NewDB(s.DB)),
				),
			)
			c.Response().Before(func() {
				db := database.ForContext(c.Request().Context())
				db.EnableCommit()
				if err := db.CommitTx(); err != nil {
					panic(err)
				}
			})
			// XXX Revisit when we have more time to debug the DBI transaction interface
			//c.Response().Before(func() {
			//    db := database.ForContext(c.Request().Context())
			//    db.EnableCommit()
			//    if err := db.CommitTx(); err != nil {
			//        panic(err)
			//    }
			//})
			return next(ctx)
		}
	})