~netlandish/links

6e97085cb48d4d9fbcb279ca0c20e64058e64a0e — Peter Sanchez a month ago c2d2a2d
Making creating rate limit configs functional
4 files changed, 112 insertions(+), 88 deletions(-)

M accounts/routes.go
M cmd/api/main.go
M cmd/links/main.go
M helpers.go
M accounts/routes.go => accounts/routes.go +11 -1
@@ 5,6 5,7 @@ import (
	"errors"
	"links"
	"net/http"
	"time"

	"links/internal/localizer"
	"links/models"


@@ 13,6 14,7 @@ import (
	sq "github.com/Masterminds/squirrel"
	"github.com/labstack/echo/v4"
	"github.com/labstack/echo/v4/middleware"
	"golang.org/x/time/rate"
	"netlandish.com/x/gobwebs"
	formguard "netlandish.com/x/gobwebs-formguard"
	"netlandish.com/x/gobwebs/accounts"


@@ 34,7 36,15 @@ type Service struct {

// RegisterRoutes ...
func (s *Service) RegisterRoutes() {
	rlConfig := links.RLConfig
	rlConfig, _ := links.NewRateLimiterConfig(nil)
	rlConfig.Store = middleware.NewRateLimiterMemoryStoreWithConfig(
		middleware.RateLimiterMemoryStoreConfig{
			Rate:      rate.Limit(3),
			Burst:     5,
			ExpiresIn: 3 * time.Minute,
		},
	)

	s.Group.GET("/register", s.Register).Name = s.RouteName("register")
	s.Group.POST("/register", s.Register, middleware.RateLimiterWithConfig(rlConfig)).Name = s.RouteName("register_post")
	s.Group.GET("/register/:key", s.Register).Name = s.RouteName("register_invitation")

M cmd/api/main.go => cmd/api/main.go +15 -72
@@ 13,18 13,15 @@ import (
	"links/api/loaders"
	"links/cmd"
	"links/core"
	"net"
	"net/url"
	"os"
	"strconv"
	"strings"
	"time"

	work "git.sr.ht/~sircmpwn/dowork"
	"github.com/99designs/gqlgen/graphql"
	"github.com/labstack/echo/v4"
	"github.com/labstack/echo/v4/middleware"
	"golang.org/x/time/rate"
	gobwebsgql "netlandish.com/x/gobwebs-graphql"
	oauth2 "netlandish.com/x/gobwebs-oauth2"
	feedback "netlandish.com/x/gobwebs-ses-feedback"


@@ 75,46 72,6 @@ func run() error {
		}
	}

	var wlnets []*net.IPNet
	if val, ok := config.File.Get("links", "rate-limit-whitelist"); ok {
		for _, nr := range strings.Split(val, ",") {
			nr = strings.TrimSpace(nr)
			_, subnet, err := net.ParseCIDR(nr)
			if err != nil {
				return fmt.Errorf("links:rate-limit-whitelist %s is invalid", nr)
			}
			wlnets = append(wlnets, subnet)
		}
	} else {
		_, subnet, _ := net.ParseCIDR("127.0.0.0/8")
		wlnets = append(wlnets, subnet)
	}

	rlnums := struct {
		Limit  int
		Burst  int
		Expire time.Duration
	}{20, 40, 3 * time.Minute}
	if val, ok := config.File.Get("links", "rate-limit-limit"); ok {
		rlnums.Limit, err = strconv.Atoi(val)
		if err != nil {
			return fmt.Errorf("links:rate-limit-limit must be an integer value")
		}
	}
	if val, ok := config.File.Get("links", "rate-limit-burst"); ok {
		rlnums.Burst, err = strconv.Atoi(val)
		if err != nil {
			return fmt.Errorf("links:rate-limit-burst must be an integer value")
		}
	}
	if val, ok := config.File.Get("links", "rate-limit-expire"); ok {
		expire, err := strconv.Atoi(val)
		if err != nil {
			return fmt.Errorf("links:rate-limit-expire must be an integer value")
		}
		rlnums.Expire = time.Duration(expire) * time.Minute
	}

	esvc, err := cmd.LoadEmailService(config)
	if err != nil {
		return fmt.Errorf("unable to load email service: %v", err)


@@ 125,6 82,21 @@ func run() error {
		return fmt.Errorf("unable to load storage service: %v", err)
	}

	rlConfig, err := links.NewRateLimiterConfig(config)
	if err != nil {
		return err
	}
	rlConfig.IdentifierExtractor = func(c echo.Context) (string, error) {
		tuser := oauth2.ForContext(c.Request().Context())
		if tuser != nil {
			hashStr := hex.EncodeToString(tuser.TokenHash[:])
			if hashStr != "" {
				return hashStr, nil
			}
		}
		return c.RealIP(), nil
	}

	e := echo.New()
	// https://echo.labstack.com/docs/ip-address
	// Deployed via Caddy at the moment which uses X-Forwarded-For header by default


@@ 145,35 117,6 @@ func run() error {
		ServerContext: true,
	}

	rlConfig := middleware.RateLimiterConfig{
		Skipper: func(c echo.Context) bool {
			ip := net.ParseIP(c.RealIP())
			for _, subnet := range wlnets {
				if subnet.Contains(ip) {
					return true
				}
			}
			return false
		},
		Store: middleware.NewRateLimiterMemoryStoreWithConfig(
			middleware.RateLimiterMemoryStoreConfig{
				Rate:      rate.Limit(rlnums.Limit),
				Burst:     rlnums.Burst,
				ExpiresIn: rlnums.Expire,
			},
		),
		IdentifierExtractor: func(c echo.Context) (string, error) {
			tuser := oauth2.ForContext(c.Request().Context())
			if tuser != nil {
				hashStr := hex.EncodeToString(tuser.TokenHash[:])
				if hashStr != "" {
					return hashStr, nil
				}
			}
			return c.RealIP(), nil
		},
	}

	srv := server.New(e, db, config).
		Initialize().
		WithAppInfo("links-api", Version).

M cmd/links/main.go => cmd/links/main.go +8 -0
@@ 29,6 29,7 @@ import (

	work "git.sr.ht/~sircmpwn/dowork"
	"github.com/labstack/echo/v4"
	"github.com/labstack/echo/v4/middleware"
	formguard "netlandish.com/x/gobwebs-formguard"
	gobwebsgql "netlandish.com/x/gobwebs-graphql"
	oauth2 "netlandish.com/x/gobwebs-oauth2"


@@ 161,6 162,11 @@ func run() error {

	tlsman := cmd.LoadAutoTLS(config, db, models.DomainServiceLinks)

	rlConfig, err := links.NewRateLimiterConfig(config)
	if err != nil {
		return err
	}

	e := echo.New()
	// https://echo.labstack.com/docs/ip-address
	// Deployed via Caddy at the moment which uses X-Forwarded-For header by default


@@ 312,6 318,8 @@ func run() error {
	}
	loadGQLDefaults(config, gqlConfig)
	gqlGroup := e.Group("")
	// Rate limit the /graphql end point to avoid abuse
	gqlGroup.Use(middleware.RateLimiterWithConfig(rlConfig))
	gobwebsgql.NewService(gqlGroup, gqlConfig)

	slackService := e.Group("/slack")

M helpers.go => helpers.go +78 -15
@@ 13,6 13,7 @@ import (
	"links/models"
	"links/valid"
	"mime/multipart"
	"net"
	"net/http"
	"net/url"
	"path/filepath"


@@ 1031,24 1032,86 @@ func GetSEOData(c echo.Context) *SEOData {
	return seoData
}

// RLConfig is a base rate limit config struct that can be used and altered
// in handlers as needed.
var RLConfig = middleware.RateLimiterConfig{
	Skipper: func(c echo.Context) bool {
		gctx := c.(*server.Context)
		if gctx.User.IsAuthenticated() && gctx.User.IsSuperUser() {
			return true
// NewRateLimiterConfig will return a base rate limit config struct that can be used
// and altered in handlers as needed.
func NewRateLimiterConfig(conf *config.Config) (middleware.RateLimiterConfig, error) {
	var (
		rlConfig middleware.RateLimiterConfig
		wlnets   []*net.IPNet
		err      error
	)

	rlnums := struct {
		Limit  int
		Burst  int
		Expire time.Duration
	}{20, 40, 3 * time.Minute}

	if conf != nil {
		if val, ok := conf.File.Get("links", "rate-limit-whitelist"); ok {
			for _, nr := range strings.Split(val, ",") {
				nr = strings.TrimSpace(nr)
				_, subnet, err := net.ParseCIDR(nr)
				if err != nil {
					return rlConfig, fmt.Errorf("links:rate-limit-whitelist %s is invalid", nr)
				}
				wlnets = append(wlnets, subnet)
			}
		} else {
			_, subnet, _ := net.ParseCIDR("127.0.0.0/8")
			wlnets = append(wlnets, subnet)
		}

		if val, ok := conf.File.Get("links", "rate-limit-limit"); ok {
			rlnums.Limit, err = strconv.Atoi(val)
			if err != nil {
				return rlConfig, fmt.Errorf("links:rate-limit-limit must be an integer value")
			}
		}
		if val, ok := conf.File.Get("links", "rate-limit-burst"); ok {
			rlnums.Burst, err = strconv.Atoi(val)
			if err != nil {
				return rlConfig, fmt.Errorf("links:rate-limit-burst must be an integer value")
			}
		}
		if val, ok := conf.File.Get("links", "rate-limit-expire"); ok {
			expire, err := strconv.Atoi(val)
			if err != nil {
				return rlConfig, fmt.Errorf("links:rate-limit-expire must be an integer value")
			}
			rlnums.Expire = time.Duration(expire) * time.Minute
		}
	}

	rlConfig.Skipper = func(c echo.Context) bool {
		ip := net.ParseIP(c.RealIP())
		for _, subnet := range wlnets {
			if subnet.Contains(ip) {
				return true
			}
		}

		if gctx, ok := c.(*server.Context); ok {
			if gctx.User.IsAuthenticated() && gctx.User.IsSuperUser() {
				return true
			}
		}
		return false
	},
	Store: middleware.NewRateLimiterMemoryStoreWithConfig(
	}
	rlConfig.Store = middleware.NewRateLimiterMemoryStoreWithConfig(
		middleware.RateLimiterMemoryStoreConfig{
			Rate:      rate.Limit(3),
			Burst:     5,
			ExpiresIn: 3 * time.Minute,
			Rate:      rate.Limit(rlnums.Limit),
			Burst:     rlnums.Burst,
			ExpiresIn: rlnums.Expire,
		},
	),
	IdentifierExtractor: func(c echo.Context) (string, error) {
	)
	rlConfig.IdentifierExtractor = func(c echo.Context) (string, error) {
		if gctx, ok := c.(*server.Context); ok {
			if gctx.User.IsAuthenticated() {
				return fmt.Sprintf("user:%d", gctx.User.GetID()), nil
			}
		}
		return c.RealIP(), nil
	},
	}
	return rlConfig, nil
}