M accounts/routes.go => accounts/routes.go +11 -1
@@ 5,6 5,7 @@ import (
"errors"
"links"
"net/http"
+ "time"
"links/internal/localizer"
"links/models"
@@ 13,6 14,7 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
+ "golang.org/x/time/rate"
"netlandish.com/x/gobwebs"
formguard "netlandish.com/x/gobwebs-formguard"
"netlandish.com/x/gobwebs/accounts"
@@ 34,7 36,15 @@ type Service struct {
// RegisterRoutes ...
func (s *Service) RegisterRoutes() {
- rlConfig := links.RLConfig
+ rlConfig, _ := links.NewRateLimiterConfig(nil)
+ rlConfig.Store = middleware.NewRateLimiterMemoryStoreWithConfig(
+ middleware.RateLimiterMemoryStoreConfig{
+ Rate: rate.Limit(3),
+ Burst: 5,
+ ExpiresIn: 3 * time.Minute,
+ },
+ )
+
s.Group.GET("/register", s.Register).Name = s.RouteName("register")
s.Group.POST("/register", s.Register, middleware.RateLimiterWithConfig(rlConfig)).Name = s.RouteName("register_post")
s.Group.GET("/register/:key", s.Register).Name = s.RouteName("register_invitation")
M cmd/api/main.go => cmd/api/main.go +15 -72
@@ 13,18 13,15 @@ import (
"links/api/loaders"
"links/cmd"
"links/core"
- "net"
"net/url"
"os"
"strconv"
- "strings"
"time"
work "git.sr.ht/~sircmpwn/dowork"
"github.com/99designs/gqlgen/graphql"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
- "golang.org/x/time/rate"
gobwebsgql "netlandish.com/x/gobwebs-graphql"
oauth2 "netlandish.com/x/gobwebs-oauth2"
feedback "netlandish.com/x/gobwebs-ses-feedback"
@@ 75,46 72,6 @@ func run() error {
}
}
- var wlnets []*net.IPNet
- if val, ok := config.File.Get("links", "rate-limit-whitelist"); ok {
- for _, nr := range strings.Split(val, ",") {
- nr = strings.TrimSpace(nr)
- _, subnet, err := net.ParseCIDR(nr)
- if err != nil {
- return fmt.Errorf("links:rate-limit-whitelist %s is invalid", nr)
- }
- wlnets = append(wlnets, subnet)
- }
- } else {
- _, subnet, _ := net.ParseCIDR("127.0.0.0/8")
- wlnets = append(wlnets, subnet)
- }
-
- rlnums := struct {
- Limit int
- Burst int
- Expire time.Duration
- }{20, 40, 3 * time.Minute}
- if val, ok := config.File.Get("links", "rate-limit-limit"); ok {
- rlnums.Limit, err = strconv.Atoi(val)
- if err != nil {
- return fmt.Errorf("links:rate-limit-limit must be an integer value")
- }
- }
- if val, ok := config.File.Get("links", "rate-limit-burst"); ok {
- rlnums.Burst, err = strconv.Atoi(val)
- if err != nil {
- return fmt.Errorf("links:rate-limit-burst must be an integer value")
- }
- }
- if val, ok := config.File.Get("links", "rate-limit-expire"); ok {
- expire, err := strconv.Atoi(val)
- if err != nil {
- return fmt.Errorf("links:rate-limit-expire must be an integer value")
- }
- rlnums.Expire = time.Duration(expire) * time.Minute
- }
-
esvc, err := cmd.LoadEmailService(config)
if err != nil {
return fmt.Errorf("unable to load email service: %v", err)
@@ 125,6 82,21 @@ func run() error {
return fmt.Errorf("unable to load storage service: %v", err)
}
+ rlConfig, err := links.NewRateLimiterConfig(config)
+ if err != nil {
+ return err
+ }
+ rlConfig.IdentifierExtractor = func(c echo.Context) (string, error) {
+ tuser := oauth2.ForContext(c.Request().Context())
+ if tuser != nil {
+ hashStr := hex.EncodeToString(tuser.TokenHash[:])
+ if hashStr != "" {
+ return hashStr, nil
+ }
+ }
+ return c.RealIP(), nil
+ }
+
e := echo.New()
// https://echo.labstack.com/docs/ip-address
// Deployed via Caddy at the moment which uses X-Forwarded-For header by default
@@ 145,35 117,6 @@ func run() error {
ServerContext: true,
}
- rlConfig := middleware.RateLimiterConfig{
- Skipper: func(c echo.Context) bool {
- ip := net.ParseIP(c.RealIP())
- for _, subnet := range wlnets {
- if subnet.Contains(ip) {
- return true
- }
- }
- return false
- },
- Store: middleware.NewRateLimiterMemoryStoreWithConfig(
- middleware.RateLimiterMemoryStoreConfig{
- Rate: rate.Limit(rlnums.Limit),
- Burst: rlnums.Burst,
- ExpiresIn: rlnums.Expire,
- },
- ),
- IdentifierExtractor: func(c echo.Context) (string, error) {
- tuser := oauth2.ForContext(c.Request().Context())
- if tuser != nil {
- hashStr := hex.EncodeToString(tuser.TokenHash[:])
- if hashStr != "" {
- return hashStr, nil
- }
- }
- return c.RealIP(), nil
- },
- }
-
srv := server.New(e, db, config).
Initialize().
WithAppInfo("links-api", Version).
M cmd/links/main.go => cmd/links/main.go +8 -0
@@ 29,6 29,7 @@ import (
work "git.sr.ht/~sircmpwn/dowork"
"github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v4/middleware"
formguard "netlandish.com/x/gobwebs-formguard"
gobwebsgql "netlandish.com/x/gobwebs-graphql"
oauth2 "netlandish.com/x/gobwebs-oauth2"
@@ 161,6 162,11 @@ func run() error {
tlsman := cmd.LoadAutoTLS(config, db, models.DomainServiceLinks)
+ rlConfig, err := links.NewRateLimiterConfig(config)
+ if err != nil {
+ return err
+ }
+
e := echo.New()
// https://echo.labstack.com/docs/ip-address
// Deployed via Caddy at the moment which uses X-Forwarded-For header by default
@@ 312,6 318,8 @@ func run() error {
}
loadGQLDefaults(config, gqlConfig)
gqlGroup := e.Group("")
+ // Rate limit the /graphql end point to avoid abuse
+ gqlGroup.Use(middleware.RateLimiterWithConfig(rlConfig))
gobwebsgql.NewService(gqlGroup, gqlConfig)
slackService := e.Group("/slack")
M helpers.go => helpers.go +78 -15
@@ 13,6 13,7 @@ import (
"links/models"
"links/valid"
"mime/multipart"
+ "net"
"net/http"
"net/url"
"path/filepath"
@@ 1031,24 1032,86 @@ func GetSEOData(c echo.Context) *SEOData {
return seoData
}
-// RLConfig is a base rate limit config struct that can be used and altered
-// in handlers as needed.
-var RLConfig = middleware.RateLimiterConfig{
- Skipper: func(c echo.Context) bool {
- gctx := c.(*server.Context)
- if gctx.User.IsAuthenticated() && gctx.User.IsSuperUser() {
- return true
+// NewRateLimiterConfig will return a base rate limit config struct that can be used
+// and altered in handlers as needed.
+func NewRateLimiterConfig(conf *config.Config) (middleware.RateLimiterConfig, error) {
+ var (
+ rlConfig middleware.RateLimiterConfig
+ wlnets []*net.IPNet
+ err error
+ )
+
+ rlnums := struct {
+ Limit int
+ Burst int
+ Expire time.Duration
+ }{20, 40, 3 * time.Minute}
+
+ if conf != nil {
+ if val, ok := conf.File.Get("links", "rate-limit-whitelist"); ok {
+ for _, nr := range strings.Split(val, ",") {
+ nr = strings.TrimSpace(nr)
+ _, subnet, err := net.ParseCIDR(nr)
+ if err != nil {
+ return rlConfig, fmt.Errorf("links:rate-limit-whitelist %s is invalid", nr)
+ }
+ wlnets = append(wlnets, subnet)
+ }
+ } else {
+ _, subnet, _ := net.ParseCIDR("127.0.0.0/8")
+ wlnets = append(wlnets, subnet)
+ }
+
+ if val, ok := conf.File.Get("links", "rate-limit-limit"); ok {
+ rlnums.Limit, err = strconv.Atoi(val)
+ if err != nil {
+ return rlConfig, fmt.Errorf("links:rate-limit-limit must be an integer value")
+ }
+ }
+ if val, ok := conf.File.Get("links", "rate-limit-burst"); ok {
+ rlnums.Burst, err = strconv.Atoi(val)
+ if err != nil {
+ return rlConfig, fmt.Errorf("links:rate-limit-burst must be an integer value")
+ }
+ }
+ if val, ok := conf.File.Get("links", "rate-limit-expire"); ok {
+ expire, err := strconv.Atoi(val)
+ if err != nil {
+ return rlConfig, fmt.Errorf("links:rate-limit-expire must be an integer value")
+ }
+ rlnums.Expire = time.Duration(expire) * time.Minute
+ }
+ }
+
+ rlConfig.Skipper = func(c echo.Context) bool {
+ ip := net.ParseIP(c.RealIP())
+ for _, subnet := range wlnets {
+ if subnet.Contains(ip) {
+ return true
+ }
+ }
+
+ if gctx, ok := c.(*server.Context); ok {
+ if gctx.User.IsAuthenticated() && gctx.User.IsSuperUser() {
+ return true
+ }
}
return false
- },
- Store: middleware.NewRateLimiterMemoryStoreWithConfig(
+ }
+ rlConfig.Store = middleware.NewRateLimiterMemoryStoreWithConfig(
middleware.RateLimiterMemoryStoreConfig{
- Rate: rate.Limit(3),
- Burst: 5,
- ExpiresIn: 3 * time.Minute,
+ Rate: rate.Limit(rlnums.Limit),
+ Burst: rlnums.Burst,
+ ExpiresIn: rlnums.Expire,
},
- ),
- IdentifierExtractor: func(c echo.Context) (string, error) {
+ )
+ rlConfig.IdentifierExtractor = func(c echo.Context) (string, error) {
+ if gctx, ok := c.(*server.Context); ok {
+ if gctx.User.IsAuthenticated() {
+ return fmt.Sprintf("user:%d", gctx.User.GetID()), nil
+ }
+ }
return c.RealIP(), nil
- },
+ }
+ return rlConfig, nil
}