From 406d91afa142c493f74ee51228b6d17d1d6686fd Mon Sep 17 00:00:00 2001 From: Peter Sanchez Date: Tue, 14 Feb 2023 16:13:02 -0600 Subject: [PATCH] Updating accounts for DBI scheme --- accounts/helpers.go | 3 +-- accounts/input.go | 3 +-- accounts/middleware.go | 2 +- accounts/routes.go | 60 ++++++++++++++++++++---------------------- 4 files changed, 31 insertions(+), 37 deletions(-) diff --git a/accounts/helpers.go b/accounts/helpers.go index 679ca86..13f0318 100644 --- a/accounts/helpers.go +++ b/accounts/helpers.go @@ -30,8 +30,7 @@ func UserLogout(c echo.Context) error { } // UpdateLastLogin updates the last login time for a user -func UpdateLastLogin(db *sql.DB, user gobwebs.User) error { - ctx := database.Context(context.Background(), db) +func UpdateLastLogin(ctx context.Context, user gobwebs.User) error { err := database.WithTx(ctx, nil, func(tx *sql.Tx) error { // Null any current pending email changes _, err := tx.ExecContext(ctx, ` diff --git a/accounts/input.go b/accounts/input.go index c6b24a5..a8b953f 100644 --- a/accounts/input.go +++ b/accounts/input.go @@ -41,8 +41,7 @@ func (l *LoginForm) Validate(c echo.Context, fetch gobwebs.UserFetch) error { } lt := core.GetSessionLocalizer(c) - ctx := c.(*server.Context) - user, err := fetch.FromDB(ctx.Server.DB, l.Email, false) + user, err := fetch.FromDB(c.Request().Context(), l.Email, false) if err != nil { if err == sql.ErrNoRows { err = fmt.Errorf(lt.Translate("Invalid email and/or password")) diff --git a/accounts/middleware.go b/accounts/middleware.go index ce9cc17..dbe0144 100644 --- a/accounts/middleware.go +++ b/accounts/middleware.go @@ -30,7 +30,7 @@ func AuthMiddleware(fetch gobwebs.UserFetch) echo.MiddlewareFunc { return c.Redirect(http.StatusMovedPermanently, c.Echo().Reverse("accounts:login")) } - user, err := fetch.FromDB(gctx.Server.DB, uint(userID), true) + user, err := fetch.FromDB(ctx, uint(userID), true) if err != nil { return err } diff --git a/accounts/routes.go b/accounts/routes.go index 5038e17..da3d998 100644 --- a/accounts/routes.go +++ b/accounts/routes.go @@ -129,7 +129,7 @@ func (s *Service) LoginAuthPOST(c echo.Context) error { } UserLogin(c, form.user.GetID()) - UpdateLastLogin(gctx.Server.DB, form.user) + UpdateLastLogin(c.Request().Context(), form.user) if err := s.fetch.ProcessSuccessfulLogin(c, form.user); err != nil { return err @@ -160,7 +160,7 @@ func (s *Service) LoginEmailPOST(c echo.Context) error { } } - user, err := s.fetch.FromDB(gctx.Server.DB, form.Email, false) + user, err := s.fetch.FromDB(gctx.Request().Context(), form.Email, false) if err != nil { // Don't let the end user know whether or not the user exists if err != sql.ErrNoRows { @@ -172,7 +172,7 @@ func (s *Service) LoginEmailPOST(c echo.Context) error { user, time.Now().In(time.UTC).Add(1*time.Hour), ) - if err = conf.Store(gctx.Server.DB); err != nil { + if err = conf.Store(c.Request().Context()); err != nil { return err } @@ -211,8 +211,7 @@ func (s *Service) LoginEmailConf(c echo.Context) error { return echo.NotFoundHandler(c) } - gctx := c.(*server.Context) - conf, err := GetConfirmation(gctx.Server.DB, key) + conf, err := GetConfirmation(c.Request().Context(), key) if err != nil { if err == sql.ErrNoRows { return echo.NotFoundHandler(c) @@ -225,7 +224,7 @@ func (s *Service) LoginEmailConf(c echo.Context) error { } // Everything good, log user in - user, err := s.fetch.FromDB(gctx.Server.DB, conf.UserID, false) + user, err := s.fetch.FromDB(c.Request().Context(), conf.UserID, false) if err != nil { if err == sql.ErrNoRows { // Should never be reached @@ -238,12 +237,12 @@ func (s *Service) LoginEmailConf(c echo.Context) error { return err } - if err := conf.Confirm(gctx.Server.DB); err != nil { + if err := conf.Confirm(c.Request().Context()); err != nil { return err } UserLogin(c, user.GetID()) - UpdateLastLogin(gctx.Server.DB, user) + UpdateLastLogin(c.Request().Context(), user) if err := s.fetch.ProcessSuccessfulLogin(c, user); err != nil { return err @@ -277,7 +276,7 @@ func (s *Service) ChangePasswordPOST(c echo.Context) error { user := gctx.User user.SetPassword(form.NPassword) - if err := s.fetch.WritePassword(gctx.Server.DB, user); err != nil { + if err := s.fetch.WritePassword(c.Request().Context(), user); err != nil { return err } @@ -333,7 +332,7 @@ func (s *Service) ForgotPasswordPOST(c echo.Context) error { } } - user, err := s.fetch.FromDB(gctx.Server.DB, form.Email, false) + user, err := s.fetch.FromDB(c.Request().Context(), form.Email, false) if err != nil { // Don't let the end user know whether or not the user exists if err != sql.ErrNoRows { @@ -341,7 +340,7 @@ func (s *Service) ForgotPasswordPOST(c echo.Context) error { } } else { var pid uint - ctxt := database.Context(c.Request().Context(), gctx.Server.DB) + ctxt := c.Request().Context() if err := database.WithTx(ctxt, nil, func(tx *sql.Tx) error { row := tx.QueryRowContext(ctxt, ` SELECT id @@ -374,7 +373,7 @@ func (s *Service) ForgotPasswordPOST(c echo.Context) error { user, time.Now().In(time.UTC).Add(1*time.Hour), ) - if err = conf.Store(gctx.Server.DB); err != nil { + if err = conf.Store(c.Request().Context()); err != nil { return err } @@ -414,7 +413,7 @@ func (s *Service) ForgotPasswordConf(c echo.Context) error { } gctx := c.(*server.Context) - conf, err := GetConfirmation(gctx.Server.DB, key) + conf, err := GetConfirmation(c.Request().Context(), key) if err != nil { if err == sql.ErrNoRows { return echo.NotFoundHandler(c) @@ -427,7 +426,7 @@ func (s *Service) ForgotPasswordConf(c echo.Context) error { } // Everything good, load user for verification - user, err := s.fetch.FromDB(gctx.Server.DB, conf.UserID, false) + user, err := s.fetch.FromDB(c.Request().Context(), conf.UserID, false) if err != nil { if err == sql.ErrNoRows { // Should never be reached @@ -455,7 +454,7 @@ func (s *Service) ForgotPasswordConfPOST(c echo.Context) error { } gctx := c.(*server.Context) - conf, err := GetConfirmation(gctx.Server.DB, key) + conf, err := GetConfirmation(c.Request().Context(), key) if err != nil { if err == sql.ErrNoRows { return echo.NotFoundHandler(c) @@ -468,7 +467,7 @@ func (s *Service) ForgotPasswordConfPOST(c echo.Context) error { } // Everything good, log user in - user, err := s.fetch.FromDB(gctx.Server.DB, conf.UserID, false) + user, err := s.fetch.FromDB(c.Request().Context(), conf.UserID, false) if err != nil { if err == sql.ErrNoRows { // Should never be reached @@ -499,11 +498,11 @@ func (s *Service) ForgotPasswordConfPOST(c echo.Context) error { } user.SetPassword(form.NPassword) - if err := s.fetch.WritePassword(gctx.Server.DB, user); err != nil { + if err := s.fetch.WritePassword(c.Request().Context(), user); err != nil { return err } - if err := conf.Confirm(gctx.Server.DB); err != nil { + if err := conf.Confirm(c.Request().Context()); err != nil { return err } @@ -528,8 +527,7 @@ func (s *Service) ConfirmEmailConf(c echo.Context) error { return echo.NotFoundHandler(c) } - gctx := c.(*server.Context) - conf, err := GetConfirmation(gctx.Server.DB, key) + conf, err := GetConfirmation(c.Request().Context(), key) if err != nil { if err == sql.ErrNoRows { return echo.NotFoundHandler(c) @@ -542,7 +540,7 @@ func (s *Service) ConfirmEmailConf(c echo.Context) error { } // Everything good, log user in - user, err := s.fetch.FromDB(gctx.Server.DB, conf.UserID, false) + user, err := s.fetch.FromDB(c.Request().Context(), conf.UserID, false) if err != nil { if err == sql.ErrNoRows { // Should never be reached @@ -555,7 +553,7 @@ func (s *Service) ConfirmEmailConf(c echo.Context) error { return err } - if err := conf.Confirm(gctx.Server.DB); err != nil { + if err := conf.Confirm(c.Request().Context()); err != nil { return err } @@ -568,8 +566,7 @@ func (s *Service) ConfirmEmailConf(c echo.Context) error { func verifyUser(c echo.Context, user gobwebs.User) error { if !user.IsVerified() { - gctx := c.(*server.Context) - ctx := database.Context(c.Request().Context(), gctx.Server.DB) + ctx := c.Request().Context() return database.WithTx(ctx, nil, func(tx *sql.Tx) error { // Null any current pending email changes _, err := tx.ExecContext(ctx, ` @@ -617,7 +614,7 @@ func (s *Service) UpdateEmailPOST(c echo.Context) error { } } - _, err := s.fetch.FromDB(gctx.Server.DB, form.Email, false) + _, err := s.fetch.FromDB(c.Request().Context(), form.Email, false) if err != nil { if err != sql.ErrNoRows { return err @@ -636,7 +633,7 @@ func (s *Service) UpdateEmailPOST(c echo.Context) error { var pid uint - ctxt := database.Context(c.Request().Context(), gctx.Server.DB) + ctxt := c.Request().Context() if err := database.WithTx(ctxt, nil, func(tx *sql.Tx) error { row := tx.QueryRowContext(ctxt, ` SELECT id @@ -687,7 +684,7 @@ func (s *Service) UpdateEmailPOST(c echo.Context) error { time.Now().In(time.UTC).Add(2*time.Hour), ) conf.ConfirmationTarget = sql.NullString{String: form.Email, Valid: true} - if err = conf.Store(gctx.Server.DB); err != nil { + if err = conf.Store(c.Request().Context()); err != nil { return err } @@ -725,8 +722,7 @@ func (s *Service) UpdateEmailConf(c echo.Context) error { return echo.NotFoundHandler(c) } - gctx := c.(*server.Context) - conf, err := GetConfirmation(gctx.Server.DB, key) + conf, err := GetConfirmation(c.Request().Context(), key) if err != nil { if err == sql.ErrNoRows { return echo.NotFoundHandler(c) @@ -739,7 +735,7 @@ func (s *Service) UpdateEmailConf(c echo.Context) error { } // Everything good, load user for verification - user, err := s.fetch.FromDB(gctx.Server.DB, conf.UserID, false) + user, err := s.fetch.FromDB(c.Request().Context(), conf.UserID, false) if err != nil { if err == sql.ErrNoRows { // Should never be reached @@ -748,7 +744,7 @@ func (s *Service) UpdateEmailConf(c echo.Context) error { return err } - ctxt := database.Context(c.Request().Context(), gctx.Server.DB) + ctxt := c.Request().Context() if err := database.WithTx(ctxt, nil, func(tx *sql.Tx) error { _, err := tx.ExecContext(ctxt, ` UPDATE "users" @@ -764,7 +760,7 @@ func (s *Service) UpdateEmailConf(c echo.Context) error { return err } - if err := conf.Confirm(gctx.Server.DB); err != nil { + if err := conf.Confirm(c.Request().Context()); err != nil { return err } -- 2.45.2