~netlandish/links

9d22bf19a851c434a3578711c8bf29e963d6330c — Peter Sanchez 9 months ago 3f43ebf
Moving domain middleware and logic to it's own module
M cmd/global.go => cmd/global.go +2 -1
@@ 13,6 13,7 @@ import (
	"links/api/graph"
	"links/api/graph/model"
	"links/core"
	ldomain "links/domain"
	"links/models"
	"net/http"
	"net/url"


@@ 293,7 294,7 @@ func MakeRequestWithDomain(s *server.Server, f echo.HandlerFunc, ctx echo.Contex
	gctx := ctx.(*server.Context)
	user := gctx.User.(*models.User)
	gctx.User = nil
	domainMW := core.DomainContext(domain.Service)
	domainMW := ldomain.DomainContext(domain.Service)
	handlerFunc := domainMW(f)
	mwChain := getMWChain(s, handlerFunc, user)
	ctx.Request().Host = domain.LookupName

M cmd/links/main.go => cmd/links/main.go +3 -2
@@ 16,6 16,7 @@ import (
	"links/billing"
	"links/cmd"
	"links/core"
	"links/domain"
	"links/internal/localizer"
	"links/list"
	"links/mattermost"


@@ 178,8 179,8 @@ func run() error {
			database.Middleware(db),
			core.TimezoneContext(),
			crypto.Middleware(entropy),
			core.DomainContext(models.DomainServiceLinks),
			core.DomainRedirect,
			domain.DomainContext(models.DomainServiceLinks),
			domain.DomainRedirect,
			auth.AuthMiddleware(accounts.NewUserFetch()),
		)


M cmd/list/main.go => cmd/list/main.go +2 -1
@@ 5,6 5,7 @@ import (
	"links"
	"links/cmd"
	"links/core"
	"links/domain"
	"links/list"
	"links/models"
	"net/http"


@@ 81,7 82,7 @@ func run() error {
			database.Middleware(db),
			core.TimezoneContext(),
			crypto.Middleware(entropy),
			core.DomainContext(models.DomainServiceList),
			domain.DomainContext(models.DomainServiceList),
			core.CORSReadOnlyMiddleware,
		)


M cmd/short/main.go => cmd/short/main.go +2 -1
@@ 5,6 5,7 @@ import (
	"links"
	"links/cmd"
	"links/core"
	"links/domain"
	"links/models"
	"links/short"
	"net/http"


@@ 79,7 80,7 @@ func run() error {
			database.Middleware(db),
			core.TimezoneContext(),
			crypto.Middleware(entropy),
			core.DomainContext(models.DomainServiceShort),
			domain.DomainContext(models.DomainServiceShort),
			core.CORSReadOnlyMiddleware,
		)


M core/middleware.go => core/middleware.go +0 -96
@@ 1,14 1,9 @@
package core

import (
	"context"
	"errors"
	"fmt"
	"links"
	"links/internal/localizer"
	"links/models"
	"net/http"
	"net/url"
	"strings"

	"github.com/labstack/echo/v4"


@@ 16,7 11,6 @@ import (
	"netlandish.com/x/gobwebs"
	oauth2 "netlandish.com/x/gobwebs-oauth2"
	"netlandish.com/x/gobwebs/auth"
	"netlandish.com/x/gobwebs/messages"
	"netlandish.com/x/gobwebs/server"
	"netlandish.com/x/gobwebs/timezone"
)


@@ 31,58 25,6 @@ func TimezoneContext() echo.MiddlewareFunc {
	}
}

var domainCtxKey = &contextKey{"domain"}

type contextKey struct {
	name string
}

// DomainContext adds the current domain to request context.
func DomainContext(service int) echo.MiddlewareFunc {
	return func(next echo.HandlerFunc) echo.HandlerFunc {
		return func(c echo.Context) error {
			req := c.Request()
			domains, err := ValidDomain(req.Context(), req.Host, service, false)
			if err != nil {
				return err
			}
			if len(domains) != 1 {
				return fmt.Errorf("Invalid domain")
			}
			domain := domains[0]
			if domain.IsActive {
				ctx := context.WithValue(c.Request().Context(), domainCtxKey, domain)
				c.SetRequest(c.Request().WithContext(ctx))
				return next(c)
			}
			// If the domain is disabled
			gctx := c.(*server.Context)
			mainSchema, ok := gctx.Server.Config.File.Get("gobwebs", "scheme")
			if !ok {
				return fmt.Errorf("schema not found")
			}
			mainDomain := gctx.Server.Config.Domain
			nextURL := &url.URL{
				Scheme: mainSchema,
				Host:   mainDomain,
			}
			lt := localizer.GetSessionLocalizer(c)
			messages.Error(
				c, lt.Translate("The domain is currently inactive. Please subscribe to activate this domain"))
			return c.Redirect(http.StatusMovedPermanently, nextURL.String())
		}
	}
}

// ForDomainContext fetches current domain from the request context
func ForDomainContext(ctx context.Context) *models.Domain {
	domain, ok := ctx.Value(domainCtxKey).(*models.Domain)
	if !ok {
		panic(errors.New("Invalid domain context"))
	}
	return domain
}

func authError(c echo.Context, reason string, code int) error {
	gqlerr := gqlerror.Errorf("Authentication error: %s", reason)
	ret := struct {


@@ 140,44 82,6 @@ func InternalAuthMiddleware(fetch gobwebs.UserFetch) echo.MiddlewareFunc {
	}
}

func DomainRedirect(next echo.HandlerFunc) echo.HandlerFunc {
	return func(c echo.Context) error {
		gctx := c.(*server.Context)
		mainDomain, ok := gctx.Server.Config.File.Get("links", "links-service-domain")
		if !ok {
			return fmt.Errorf("links-service-domain not found")
		}
		mainDomain = strings.ToLower(mainDomain)
		domain := ForDomainContext(c.Request().Context())
		if strings.ToLower(domain.LookupName) != strings.ToLower(mainDomain) {
			redirectPaths := [3]string{
				"/accounts",
				"/popular",
				"/recent",
			}
			req := c.Request()
			rPath := req.URL.Path
			for _, path := range redirectPaths {
				if strings.HasPrefix(rPath, path) {
					mainSchema, ok := gctx.Server.Config.File.Get("gobwebs", "scheme")
					if !ok {
						return fmt.Errorf("schema not found")
					}
					nextURL := &url.URL{
						Scheme: mainSchema,
						Host:   mainDomain,
						Path:   rPath,
					}
					qs := req.URL.Query()
					nextURL.RawQuery = qs.Encode()
					return c.Redirect(http.StatusMovedPermanently, nextURL.String())
				}
			}
		}
		return next(c)
	}
}

func CORSReadOnlyMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
	return func(c echo.Context) error {
		c.Response().Before(func() {

A domain/logic.go => domain/logic.go +127 -0
@@ 0,0 1,127 @@
package domain

import (
	"context"
	"database/sql"
	"fmt"
	"links/models"
	"net"
	"strings"

	sq "github.com/Masterminds/squirrel"
	"golang.org/x/crypto/acme/autocert"
	"golang.org/x/net/idna"
	"netlandish.com/x/gobwebs/database"
	"netlandish.com/x/gobwebs/server"
)

var ipMap map[string]net.IP

func ValidDomain(ctx context.Context, host string, service int, active bool) ([]*models.Domain, error) {
	h, err := idna.Lookup.ToASCII(host)
	if err != nil {
		// To account for IP:Port situations we allow the host unaltered.
		h = host
	}
	opts := &database.FilterOptions{
		Filter: sq.And{
			sq.Eq{"d.lookup_name": strings.ToLower(h)},
			sq.Eq{"d.status": models.DomainStatusApproved},
		},
	}
	if service >= 0 {
		opts.Filter = sq.And{
			opts.Filter,
			sq.Eq{"d.service": service},
		}
	}
	if active {
		opts.Filter = sq.And{
			opts.Filter,
			sq.Eq{"d.is_active": active},
		}
	}
	domains, err := models.GetDomains(ctx, opts)
	if err != nil {
		return nil, err
	}
	return domains, nil
}

// DomainHostPolicy returns a autocert manager HostPolicy instance to work
// with confiugred domains for various services
func DomainHostPolicy(db *sql.DB, service int) autocert.HostPolicy {
	return func(ctx context.Context, host string) error {
		ctx = database.Context(ctx, db)
		domains, err := ValidDomain(ctx, host, service, true)
		if err != nil {
			return err
		}
		if len(domains) != 1 {
			return fmt.Errorf("Invalid domain")
		}
		return nil
	}
}

// CheckDomainDNS will verify a domain has a proper CNAME set
func CheckDomainDNS(ctx context.Context, domain string, service int) (bool, error) {
	if service < models.DomainServiceLinks || service > models.DomainServiceList {
		return false, fmt.Errorf("invalid service type given")
	}

	var (
		expName, cval, chk string
		ok                 bool
	)

	srv := server.ForContext(ctx)

	chk, ok = srv.Config.File.Get("links", "domain-check-cname")
	if ok && strings.ToLower(chk) == "false" {
		return true, nil
	}

	switch service {
	case models.DomainServiceLinks:
		cval = "links-cname-domain"
	case models.DomainServiceShort:
		cval = "short-cname-domain"
	case models.DomainServiceList:
		cval = "list-cname-domain"
	}

	expName, ok = srv.Config.File.Get("links", cval)
	if !ok {
		return false, fmt.Errorf("No config.ini value set for [links] %s", cval)
	}

	if ipMap == nil {
		ipMap = make(map[string]net.IP)
	}

	_, ok = ipMap[cval]
	if !ok {
		ips, err := net.LookupIP(expName)
		if err != nil {
			return false, err
		}
		ipMap[cval] = ips[0]
	}

	ips, err := net.LookupIP(domain)
	if err != nil {
		return false, err
	}

	if len(ips) > 1 {
		return false, fmt.Errorf("Multiple IP's defined for %s", domain)
	}

	if !ips[0].Equal(ipMap[cval]) {
		return false, fmt.Errorf(
			"Domain %s IP is incorrect. It should be %s", domain, ipMap[cval].String())
	}

	return true, nil
}

A domain/middleware.go => domain/middleware.go +117 -0
@@ 0,0 1,117 @@
package domain

import (
	"context"
	"errors"
	"fmt"
	"links/internal/localizer"
	"links/models"
	"net/http"
	"net/url"
	"strings"

	"github.com/labstack/echo/v4"
	"netlandish.com/x/gobwebs/messages"
	"netlandish.com/x/gobwebs/server"
)

var domainCtxKey = &contextKey{"domain"}

type contextKey struct {
	name string
}

// Context adds a domain model to context for immediate use
func Context(ctx context.Context, domain *models.Domain) context.Context {
	return context.WithValue(ctx, domainCtxKey, domain)
}

// ForContext fetches current domain from the request context
func ForContext(ctx context.Context) *models.Domain {
	domain, ok := ctx.Value(domainCtxKey).(*models.Domain)
	if !ok {
		panic(errors.New("Invalid domain context"))
	}
	return domain
}

// DomainContext adds the current domain to request context.
func DomainContext(service int) echo.MiddlewareFunc {
	return func(next echo.HandlerFunc) echo.HandlerFunc {
		return func(c echo.Context) error {
			req := c.Request()
			domains, err := ValidDomain(req.Context(), req.Host, service, false)
			if err != nil {
				return err
			}
			if len(domains) != 1 {
				return fmt.Errorf("Invalid domain")
			}
			domain := domains[0]
			if domain.IsActive {
				c.SetRequest(
					c.Request().WithContext(
						Context(c.Request().Context(), domain),
					),
				)
				return next(c)
			}
			// If the domain is disabled
			gctx := c.(*server.Context)
			mainSchema, ok := gctx.Server.Config.File.Get("gobwebs", "scheme")
			if !ok {
				return fmt.Errorf("schema not found")
			}
			mainDomain := gctx.Server.Config.Domain
			nextURL := &url.URL{
				Scheme: mainSchema,
				Host:   mainDomain,
			}
			lt := localizer.GetSessionLocalizer(c)
			messages.Error(
				c, lt.Translate("The domain is currently inactive. Please subscribe to activate this domain"))
			return c.Redirect(http.StatusMovedPermanently, nextURL.String())
		}
	}
}

// DomainRedirect will redirect to the main links service domain if, for some reason,
// a user is trying to view service specific pages on a user domain.
func DomainRedirect(next echo.HandlerFunc) echo.HandlerFunc {
	return func(c echo.Context) error {
		gctx := c.(*server.Context)
		mainDomain, ok := gctx.Server.Config.File.Get("links", "links-service-domain")
		if !ok {
			return fmt.Errorf("links-service-domain not found")
		}
		mainDomain = strings.ToLower(mainDomain)
		domain := ForContext(c.Request().Context())
		if strings.ToLower(domain.LookupName) != strings.ToLower(mainDomain) {
			redirectPaths := []string{
				"/accounts",
				"/popular",
				"/recent",
				"/tour",
			}
			req := c.Request()
			rPath := req.URL.Path
			for _, path := range redirectPaths {
				if rPath == path {
					mainSchema, ok := gctx.Server.Config.File.Get("gobwebs", "scheme")
					if !ok {
						return fmt.Errorf("schema not found")
					}
					nextURL := &url.URL{
						Scheme: mainSchema,
						Host:   mainDomain,
						Path:   rPath,
					}
					qs := req.URL.Query()
					nextURL.RawQuery = qs.Encode()
					return c.Redirect(http.StatusMovedPermanently, nextURL.String())
				}
			}
		}
		return next(c)
	}
}

M helpers.go => helpers.go +7 -0
@@ 6,6 6,7 @@ import (
	"fmt"
	"html/template"
	"io"
	"links/domain"
	"links/internal/localizer"
	"links/models"
	"links/valid"


@@ 210,10 211,16 @@ func GetOrgSelection(c echo.Context) string {

// PullOrgSlug will check url and session for the current org slug
func PullOrgSlug(c echo.Context) string {
	dom := domain.ForContext(c.Request().Context())
	if dom.OrgID.Valid {
		return dom.OrgSlug.String
	}

	slug := c.Param("slug")
	if slug != "" {
		return slug
	}

	return GetOrgSelection(c)
}


M list/routes.go => list/routes.go +3 -3
@@ 6,7 6,7 @@ import (
	"html/template"
	"links"
	"links/analytics"
	"links/core"
	"links/domain"
	"links/internal/localizer"
	"links/models"
	"net/http"


@@ 1386,7 1386,7 @@ func (r *DetailService) ListLink(c echo.Context) error {
	if err != nil {
		return echo.NotFoundHandler(c)
	}
	domain := core.ForDomainContext(c.Request().Context())
	domain := domain.ForContext(c.Request().Context())
	type GraphQLResponse struct {
		Link models.ListingLink `json:"getListingLink"`
	}


@@ 1435,7 1435,7 @@ func (r *DetailService) ListLink(c echo.Context) error {
func (r *DetailService) ListDetail(c echo.Context) error {
	slug := c.Param("slug")
	gctx := c.(*server.Context)
	domain := core.ForDomainContext(c.Request().Context())
	domain := domain.ForContext(c.Request().Context())

	if slug == "" && domain.Level == models.DomainLevelSystem {
		mainDomain := gctx.Server.Config.Domain

M short/routes.go => short/routes.go +2 -2
@@ 5,7 5,7 @@ import (
	"fmt"
	"links"
	"links/analytics"
	"links/core"
	"links/domain"
	"links/internal/localizer"
	"links/models"
	"net/http"


@@ 576,7 576,7 @@ func (r *RedirectService) LinkShort(c echo.Context) error {
		}`)

	op.Var("code", code)
	domain := core.ForDomainContext(c.Request().Context())
	domain := domain.ForContext(c.Request().Context())
	op.Var("domain", domain.ID)

	err := links.Execute(c.Request().Context(), op, &result)