~netlandish/links

a78510736edc1272656bffa9382e285f8794cbf7 — Peter Sanchez 9 months ago ad705f6
Work to allow reverse proxys to check domain validity.

References: https://todo.code.netlandish.com/~netlandish/links/46
4 files changed, 63 insertions(+), 20 deletions(-)

M cmd/links/main.go
M config.example.ini
M core/domains.go
M core/middleware.go
M cmd/links/main.go => cmd/links/main.go +35 -5
@@ 131,6 131,14 @@ func run() error {
		return fmt.Errorf("Unknown storage service configured")
	}

	var domCheck bool
	if domCheckVal, ok := config.File.Get("links", "enable-domain-check"); ok {
		if domCheckVal == "true" {
			domCheck = true
		}
	}
	tlsman := cmd.LoadAutoTLS(config, db, models.DomainServiceLinks)

	e := echo.New()

	// email work queue and service, general task queue


@@ 175,13 183,35 @@ func run() error {
		WithMiddleware(
			database.Middleware(db),
			core.TimezoneContext(),
			crypto.Middleware(entropy),
			core.DomainContext(models.DomainServiceLinks),
			core.DomainRedirect,
			auth.AuthMiddleware(accounts.NewUserFetch()),
		)

	tlsman := cmd.LoadAutoTLS(config, db, models.DomainServiceLinks)
	// Split here to do as little middleware processing as needed
	// to serve the domain check.
	if tlsman == nil && domCheck {
		e.GET("/_check/domain", func(c echo.Context) error {
			domain := c.QueryParam("domain")
			if domain == "" {
				return c.NoContent(http.StatusBadRequest)
			}
			domains, err := core.ValidDomain(c.Request().Context(), domain, -1, true)
			if err != nil {
				return err
			}
			if len(domains) != 1 {
				return c.NoContent(http.StatusBadRequest)
			}
			return c.NoContent(http.StatusOK)
		}).Name = "domain_check"
	}

	// Continue with middlewares...
	srv.WithMiddleware(
		crypto.Middleware(entropy),
		core.DomainContext(models.DomainServiceLinks),
		core.DomainRedirect,
		auth.AuthMiddleware(accounts.NewUserFetch()),
	)

	if tlsman != nil {
		srv = srv.WithCertManager(tlsman)
	}

M config.example.ini => config.example.ini +6 -1
@@ 113,9 113,14 @@ max-upload-size=10737418
api-origin=http://127.0.0.1:8080/query

# Enable AutoTLS / SSL Cert management?
# Default true
auto-tls=true
# Where will SSL certs be stored
# Where will SSL certs be stored. If empty, the value of `./cache` is used.
ssl-cert-cachedir=/var/www/.cache
# Enable domain TLS support check. If set to true then the 
# /_check/domain route will be added.
# Default false
enable-domain-check=false

## DNS CHECKS


M core/domains.go => core/domains.go +20 -8
@@ 17,14 17,30 @@ import (

var ipMap map[string]net.IP

func getDomains(ctx context.Context, host string, service int) ([]*models.Domain, error) {
func ValidDomain(ctx context.Context, host string, service int, active bool) ([]*models.Domain, error) {
	h, err := idna.Lookup.ToASCII(host)
	if err != nil {
		// To account for IP:Port situations we allow the host unaltered.
		h = host
	}
	opts := &database.FilterOptions{
		Filter: sq.And{
			sq.Eq{"d.lookup_name": strings.ToLower(host)},
			sq.Eq{"d.service": service},
			sq.Eq{"d.lookup_name": strings.ToLower(h)},
			sq.Eq{"d.status": models.DomainStatusApproved},
		},
	}
	if service >= 0 {
		opts.Filter = sq.And{
			opts.Filter,
			sq.Eq{"d.service": service},
		}
	}
	if active {
		opts.Filter = sq.And{
			opts.Filter,
			sq.Eq{"d.is_active": active},
		}
	}
	domains, err := models.GetDomains(ctx, opts)
	if err != nil {
		return nil, err


@@ 36,12 52,8 @@ func getDomains(ctx context.Context, host string, service int) ([]*models.Domain
// with confiugred domains for various services
func DomainHostPolicy(db *sql.DB, service int) autocert.HostPolicy {
	return func(ctx context.Context, host string) error {
		h, err := idna.Lookup.ToASCII(host)
		if err != nil {
			return err
		}
		ctx = database.Context(ctx, db)
		domains, err := getDomains(ctx, h, service)
		domains, err := ValidDomain(ctx, host, service, true)
		if err != nil {
			return err
		}

M core/middleware.go => core/middleware.go +2 -6
@@ 41,12 41,8 @@ type contextKey struct {
func DomainContext(service int) echo.MiddlewareFunc {
	return func(next echo.HandlerFunc) echo.HandlerFunc {
		return func(c echo.Context) error {
			_host := c.Request().Host
			if _host == "" {
				return fmt.Errorf("No HOST header sent")
			}
			host, _ := strings.CutSuffix(_host, ".")
			domains, err := getDomains(c.Request().Context(), host, service)
			req := c.Request()
			domains, err := ValidDomain(req.Context(), req.Host, service, false)
			if err != nil {
				return err
			}