M analytics/helpers.go => analytics/helpers.go +2 -1
@@ 4,6 4,7 @@ import (
"context"
"database/sql"
"fmt"
+ "links"
"links/models"
"net"
"net/http"
@@ 116,7 117,7 @@ func AddMetaAnalytics(ctx context.Context, req *http.Request, dailyTotalID int,
}
}
- ip := req.Header.Get("X-FORWARDED-FOR")
+ ip := links.IPForContext(req.Context())
if ip != "" {
db, err := geoip2.Open(geoPath)
if err != nil {
M analytics/routes_test.go => analytics/routes_test.go +11 -5
@@ 148,16 148,22 @@ func TestAPI(t *testing.T) {
c.NoError(err)
req := httptest.NewRequest(http.MethodGet, "/", nil)
+ path, ok := srv.Config.File.Get("geo", "path")
+ if ok {
+ // test config has geodb path set. Let's set an IP so we can test the db integration
+ req = req.WithContext(links.IPContext(request.Context(), "142.250.217.196")) // www.google.com
+ }
+
// Add fake analytics entries
- err = analytics.AddAnalytics(dbCtx, req, short.ID, analytics.LinkShortAnalyticsFilter, "")
+ err = analytics.AddAnalytics(dbCtx, req, short.ID, analytics.LinkShortAnalyticsFilter, path)
c.NoError(err)
- err = analytics.AddAnalytics(dbCtx, req, short.ID, analytics.LinkShortAnalyticsFilter, "")
+ err = analytics.AddAnalytics(dbCtx, req, short.ID, analytics.LinkShortAnalyticsFilter, path)
c.NoError(err)
- err = analytics.AddAnalytics(dbCtx, req, short.ID, analytics.LinkShortAnalyticsFilter, "")
+ err = analytics.AddAnalytics(dbCtx, req, short.ID, analytics.LinkShortAnalyticsFilter, path)
c.NoError(err)
- err = analytics.AddAnalytics(dbCtx, req, short.ID, analytics.LinkShortAnalyticsFilter, "")
+ err = analytics.AddAnalytics(dbCtx, req, short.ID, analytics.LinkShortAnalyticsFilter, path)
c.NoError(err)
- err = analytics.AddAnalytics(dbCtx, req, short.ID, analytics.LinkShortAnalyticsFilter, "")
+ err = analytics.AddAnalytics(dbCtx, req, short.ID, analytics.LinkShortAnalyticsFilter, path)
c.NoError(err)
today := time.Now().UTC()
M cmd/api/main.go => cmd/api/main.go +1 -0
@@ 122,6 122,7 @@ func run() error {
WithQueues(eq, wq).
WithMiddleware(
database.Middleware(db),
+ core.RemoteIPMiddleware,
loaders.Middleware(),
core.TimezoneContext(),
crypto.Middleware(entropy),
M cmd/links/main.go => cmd/links/main.go +1 -0
@@ 213,6 213,7 @@ func run() error {
WithQueues(eq, wq, wqi).
WithMiddleware(
database.Middleware(db),
+ core.RemoteIPMiddleware,
core.TimezoneContext(),
crypto.Middleware(entropy),
domain.DomainContext(models.DomainServiceLinks),
M cmd/list/main.go => cmd/list/main.go +1 -0
@@ 84,6 84,7 @@ func run() error {
WithQueues(eq, wq).
WithMiddleware(
database.Middleware(db),
+ core.RemoteIPMiddleware,
core.TimezoneContext(),
crypto.Middleware(entropy),
domain.DomainContext(models.DomainServiceList),
M cmd/short/main.go => cmd/short/main.go +1 -0
@@ 82,6 82,7 @@ func run() error {
WithQueues(eq, wq).
WithMiddleware(
database.Middleware(db),
+ core.RemoteIPMiddleware,
core.TimezoneContext(),
crypto.Middleware(entropy),
domain.DomainContext(models.DomainServiceShort),
M cmd/test/helpers.go => cmd/test/helpers.go +2 -0
@@ 170,6 170,7 @@ func NewAPITestServer(t *testing.T) (*server.Server, *echo.Echo, string) {
DefaultMiddlewareWithConfig(mwConf).
WithMiddleware(
database.Middleware(db),
+ core.RemoteIPMiddleware,
core.TimezoneContext(),
crypto.Middleware(entropy),
core.InternalAuthMiddleware(accounts.NewUserFetch()),
@@ 245,6 246,7 @@ func getMWChain(s *server.Server, f echo.HandlerFunc, user *models.User) echo.Ha
entropy, _ := s.Config.File.Get("access", "entropy")
cryptoMiddleware := crypto.Middleware(entropy)
handlerFunc := cryptoMiddleware(f)
+ handlerFunc = core.RemoteIPMiddleware(handlerFunc)
timezoneMiddleware := core.TimezoneContext()
handlerFunc = timezoneMiddleware(handlerFunc)
serverMiddleware := server.Middleware(s)
M core/middleware.go => core/middleware.go +11 -0
@@ 91,3 91,14 @@ func CORSReadOnlyMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
return next(c)
}
}
+
+func RemoteIPMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ c.SetRequest(
+ c.Request().WithContext(
+ links.IPContext(c.Request().Context(), c.RealIP()),
+ ),
+ )
+ return next(c)
+ }
+}
M helpers.go => helpers.go +21 -4
@@ 5,6 5,7 @@ import (
"context"
"encoding/json"
"encoding/xml"
+ "errors"
"fmt"
"html/template"
"io"
@@ 43,6 44,10 @@ import (
"netlandish.com/x/gobwebs/validate"
)
+type contextKey struct {
+ name string
+}
+
type errorExtension struct {
Code int
Field string
@@ 686,10 691,6 @@ func ArchiveURLSnapshot(ctx context.Context, orgLink *models.OrgLink) error {
return nil
}
-type contextKey struct {
- name string
-}
-
var langCtxKey = &contextKey{"userLang"}
func LangContext(c echo.Context) context.Context {
@@ 1135,3 1136,19 @@ func NewRateLimiterConfig(conf *config.Config) (middleware.RateLimiterConfig, er
}
return rlConfig, nil
}
+
+var IPCtxKey = &contextKey{"remote_ip"}
+
+// IPContext adds a domain model to context for immediate use
+func IPContext(ctx context.Context, ip string) context.Context {
+ return context.WithValue(ctx, IPCtxKey, ip)
+}
+
+// IPForContext fetches current domain from the request context
+func IPForContext(ctx context.Context) string {
+ ip, ok := ctx.Value(IPCtxKey).(string)
+ if !ok {
+ panic(errors.New("Invalid IP context"))
+ }
+ return ip
+}