From 6e97085cb48d4d9fbcb279ca0c20e64058e64a0e Mon Sep 17 00:00:00 2001 From: Peter Sanchez Date: Tue, 3 Sep 2024 18:42:34 -0600 Subject: [PATCH] Making creating rate limit configs functional --- accounts/routes.go | 12 +++++- cmd/api/main.go | 87 ++++++++----------------------------------- cmd/links/main.go | 8 ++++ helpers.go | 93 ++++++++++++++++++++++++++++++++++++++-------- 4 files changed, 112 insertions(+), 88 deletions(-) diff --git a/accounts/routes.go b/accounts/routes.go index a252c84..560c073 100644 --- a/accounts/routes.go +++ b/accounts/routes.go @@ -5,6 +5,7 @@ import ( "errors" "links" "net/http" + "time" "links/internal/localizer" "links/models" @@ -13,6 +14,7 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" + "golang.org/x/time/rate" "netlandish.com/x/gobwebs" formguard "netlandish.com/x/gobwebs-formguard" "netlandish.com/x/gobwebs/accounts" @@ -34,7 +36,15 @@ type Service struct { // RegisterRoutes ... func (s *Service) RegisterRoutes() { - rlConfig := links.RLConfig + rlConfig, _ := links.NewRateLimiterConfig(nil) + rlConfig.Store = middleware.NewRateLimiterMemoryStoreWithConfig( + middleware.RateLimiterMemoryStoreConfig{ + Rate: rate.Limit(3), + Burst: 5, + ExpiresIn: 3 * time.Minute, + }, + ) + s.Group.GET("/register", s.Register).Name = s.RouteName("register") s.Group.POST("/register", s.Register, middleware.RateLimiterWithConfig(rlConfig)).Name = s.RouteName("register_post") s.Group.GET("/register/:key", s.Register).Name = s.RouteName("register_invitation") diff --git a/cmd/api/main.go b/cmd/api/main.go index 0d1a3a7..3f6de7b 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -13,18 +13,15 @@ import ( "links/api/loaders" "links/cmd" "links/core" - "net" "net/url" "os" "strconv" - "strings" "time" work "git.sr.ht/~sircmpwn/dowork" "github.com/99designs/gqlgen/graphql" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" - "golang.org/x/time/rate" gobwebsgql "netlandish.com/x/gobwebs-graphql" oauth2 "netlandish.com/x/gobwebs-oauth2" feedback "netlandish.com/x/gobwebs-ses-feedback" @@ -75,46 +72,6 @@ func run() error { } } - var wlnets []*net.IPNet - if val, ok := config.File.Get("links", "rate-limit-whitelist"); ok { - for _, nr := range strings.Split(val, ",") { - nr = strings.TrimSpace(nr) - _, subnet, err := net.ParseCIDR(nr) - if err != nil { - return fmt.Errorf("links:rate-limit-whitelist %s is invalid", nr) - } - wlnets = append(wlnets, subnet) - } - } else { - _, subnet, _ := net.ParseCIDR("127.0.0.0/8") - wlnets = append(wlnets, subnet) - } - - rlnums := struct { - Limit int - Burst int - Expire time.Duration - }{20, 40, 3 * time.Minute} - if val, ok := config.File.Get("links", "rate-limit-limit"); ok { - rlnums.Limit, err = strconv.Atoi(val) - if err != nil { - return fmt.Errorf("links:rate-limit-limit must be an integer value") - } - } - if val, ok := config.File.Get("links", "rate-limit-burst"); ok { - rlnums.Burst, err = strconv.Atoi(val) - if err != nil { - return fmt.Errorf("links:rate-limit-burst must be an integer value") - } - } - if val, ok := config.File.Get("links", "rate-limit-expire"); ok { - expire, err := strconv.Atoi(val) - if err != nil { - return fmt.Errorf("links:rate-limit-expire must be an integer value") - } - rlnums.Expire = time.Duration(expire) * time.Minute - } - esvc, err := cmd.LoadEmailService(config) if err != nil { return fmt.Errorf("unable to load email service: %v", err) @@ -125,6 +82,21 @@ func run() error { return fmt.Errorf("unable to load storage service: %v", err) } + rlConfig, err := links.NewRateLimiterConfig(config) + if err != nil { + return err + } + rlConfig.IdentifierExtractor = func(c echo.Context) (string, error) { + tuser := oauth2.ForContext(c.Request().Context()) + if tuser != nil { + hashStr := hex.EncodeToString(tuser.TokenHash[:]) + if hashStr != "" { + return hashStr, nil + } + } + return c.RealIP(), nil + } + e := echo.New() // https://echo.labstack.com/docs/ip-address // Deployed via Caddy at the moment which uses X-Forwarded-For header by default @@ -145,35 +117,6 @@ func run() error { ServerContext: true, } - rlConfig := middleware.RateLimiterConfig{ - Skipper: func(c echo.Context) bool { - ip := net.ParseIP(c.RealIP()) - for _, subnet := range wlnets { - if subnet.Contains(ip) { - return true - } - } - return false - }, - Store: middleware.NewRateLimiterMemoryStoreWithConfig( - middleware.RateLimiterMemoryStoreConfig{ - Rate: rate.Limit(rlnums.Limit), - Burst: rlnums.Burst, - ExpiresIn: rlnums.Expire, - }, - ), - IdentifierExtractor: func(c echo.Context) (string, error) { - tuser := oauth2.ForContext(c.Request().Context()) - if tuser != nil { - hashStr := hex.EncodeToString(tuser.TokenHash[:]) - if hashStr != "" { - return hashStr, nil - } - } - return c.RealIP(), nil - }, - } - srv := server.New(e, db, config). Initialize(). WithAppInfo("links-api", Version). diff --git a/cmd/links/main.go b/cmd/links/main.go index 4e40302..640a5f5 100644 --- a/cmd/links/main.go +++ b/cmd/links/main.go @@ -29,6 +29,7 @@ import ( work "git.sr.ht/~sircmpwn/dowork" "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" formguard "netlandish.com/x/gobwebs-formguard" gobwebsgql "netlandish.com/x/gobwebs-graphql" oauth2 "netlandish.com/x/gobwebs-oauth2" @@ -161,6 +162,11 @@ func run() error { tlsman := cmd.LoadAutoTLS(config, db, models.DomainServiceLinks) + rlConfig, err := links.NewRateLimiterConfig(config) + if err != nil { + return err + } + e := echo.New() // https://echo.labstack.com/docs/ip-address // Deployed via Caddy at the moment which uses X-Forwarded-For header by default @@ -312,6 +318,8 @@ func run() error { } loadGQLDefaults(config, gqlConfig) gqlGroup := e.Group("") + // Rate limit the /graphql end point to avoid abuse + gqlGroup.Use(middleware.RateLimiterWithConfig(rlConfig)) gobwebsgql.NewService(gqlGroup, gqlConfig) slackService := e.Group("/slack") diff --git a/helpers.go b/helpers.go index 1949958..0824962 100644 --- a/helpers.go +++ b/helpers.go @@ -13,6 +13,7 @@ import ( "links/models" "links/valid" "mime/multipart" + "net" "net/http" "net/url" "path/filepath" @@ -1031,24 +1032,86 @@ func GetSEOData(c echo.Context) *SEOData { return seoData } -// RLConfig is a base rate limit config struct that can be used and altered -// in handlers as needed. -var RLConfig = middleware.RateLimiterConfig{ - Skipper: func(c echo.Context) bool { - gctx := c.(*server.Context) - if gctx.User.IsAuthenticated() && gctx.User.IsSuperUser() { - return true +// NewRateLimiterConfig will return a base rate limit config struct that can be used +// and altered in handlers as needed. +func NewRateLimiterConfig(conf *config.Config) (middleware.RateLimiterConfig, error) { + var ( + rlConfig middleware.RateLimiterConfig + wlnets []*net.IPNet + err error + ) + + rlnums := struct { + Limit int + Burst int + Expire time.Duration + }{20, 40, 3 * time.Minute} + + if conf != nil { + if val, ok := conf.File.Get("links", "rate-limit-whitelist"); ok { + for _, nr := range strings.Split(val, ",") { + nr = strings.TrimSpace(nr) + _, subnet, err := net.ParseCIDR(nr) + if err != nil { + return rlConfig, fmt.Errorf("links:rate-limit-whitelist %s is invalid", nr) + } + wlnets = append(wlnets, subnet) + } + } else { + _, subnet, _ := net.ParseCIDR("127.0.0.0/8") + wlnets = append(wlnets, subnet) + } + + if val, ok := conf.File.Get("links", "rate-limit-limit"); ok { + rlnums.Limit, err = strconv.Atoi(val) + if err != nil { + return rlConfig, fmt.Errorf("links:rate-limit-limit must be an integer value") + } + } + if val, ok := conf.File.Get("links", "rate-limit-burst"); ok { + rlnums.Burst, err = strconv.Atoi(val) + if err != nil { + return rlConfig, fmt.Errorf("links:rate-limit-burst must be an integer value") + } + } + if val, ok := conf.File.Get("links", "rate-limit-expire"); ok { + expire, err := strconv.Atoi(val) + if err != nil { + return rlConfig, fmt.Errorf("links:rate-limit-expire must be an integer value") + } + rlnums.Expire = time.Duration(expire) * time.Minute + } + } + + rlConfig.Skipper = func(c echo.Context) bool { + ip := net.ParseIP(c.RealIP()) + for _, subnet := range wlnets { + if subnet.Contains(ip) { + return true + } + } + + if gctx, ok := c.(*server.Context); ok { + if gctx.User.IsAuthenticated() && gctx.User.IsSuperUser() { + return true + } } return false - }, - Store: middleware.NewRateLimiterMemoryStoreWithConfig( + } + rlConfig.Store = middleware.NewRateLimiterMemoryStoreWithConfig( middleware.RateLimiterMemoryStoreConfig{ - Rate: rate.Limit(3), - Burst: 5, - ExpiresIn: 3 * time.Minute, + Rate: rate.Limit(rlnums.Limit), + Burst: rlnums.Burst, + ExpiresIn: rlnums.Expire, }, - ), - IdentifierExtractor: func(c echo.Context) (string, error) { + ) + rlConfig.IdentifierExtractor = func(c echo.Context) (string, error) { + if gctx, ok := c.(*server.Context); ok { + if gctx.User.IsAuthenticated() { + return fmt.Sprintf("user:%d", gctx.User.GetID()), nil + } + } return c.RealIP(), nil - }, + } + return rlConfig, nil } -- 2.45.2