~netlandish/gobwebs-oauth2

14ecc44f91a2ca88598dfb2a9bab8f44b2278da9 — Peter Sanchez 1 year, 7 months ago 2e8c79b
Adding oauth2 authorization handlers
8 files changed, 350 insertions(+), 7 deletions(-)

M bearer.go
M logic.go
M models.go
A oauth2_authorizations.go
M oauth2_clients.go
M oauth2_grants.go
M routes.go
M schema.sql
M bearer.go => bearer.go +40 -0
@@ 136,6 136,9 @@ func DecodeGrants(grants string) Grants {
			access = "RO"
		} else {
			access = parts[1]
			if access != "RW" && access != "RO" {
				access = "RO"
			}
		}
		accessMap[scope] = access
	}


@@ 174,6 177,35 @@ func (g *Grants) Encode() string {
	return g.encoded
}

// Validate ...
func (g *Grants) Validate(scopes []string) []string {
	var errors []string
	for k := range g.grants {
		if !contains(scopes, k) {
			errors = append(errors, fmt.Sprintf("Invalid scope: %s", k))
		}
	}
	return errors
}

// List ...
func (g *Grants) List() []string {
	var grants []string
	for k, v := range g.grants {
		grants = append(grants, fmt.Sprintf("%s:%s", k, v))
	}
	return grants
}

func contains(values []string, str string) bool {
	for _, v := range values {
		if v == str {
			return true
		}
	}
	return false
}

// TokenUser wrapper for gobwebs.User and token grants
type TokenUser struct {
	User      gobwebs.User


@@ 181,3 213,11 @@ type TokenUser struct {
	Grants    *Grants
	TokenHash [64]byte
}

// AuthorizationPayload holds temporary approval while the oauth2 cycle is
// in process
type AuthorizationPayload struct {
	Grants    string
	ClientKey string
	UserID    int
}

M logic.go => logic.go +4 -4
@@ 25,7 25,10 @@ func OAuth2(ctx context.Context, token string, fetch gobwebs.UserFetch) (*TokenU
	hash := sha512.Sum512([]byte(token))
	hashStr := hex.EncodeToString(hash[:])
	opts := &database.FilterOptions{
		Filter: sq.Eq{"token_hash": hashStr},
		Filter: sq.And{
			sq.Eq{"token_hash": hashStr},
			sq.Expr("expires > NOW() at time zone 'UTC'"),
		},
	}
	grants, err := GetGrants(ctx, opts)
	if err != nil {


@@ 39,9 42,6 @@ func OAuth2(ctx context.Context, token string, fetch gobwebs.UserFetch) (*TokenU
		return nil, fmt.Errorf("Error with provided OAuth 2.0 bearer token")
	}
	grant := grants[0]
	if grant.IsExpired() {
		return nil, fmt.Errorf("Invalid or expired OAuth 2.0 bearer token")
	}

	bt.Issued = ToTimestamp(grant.Issued)
	gt := DecodeGrants(bt.Grants)

M models.go => models.go +8 -0
@@ 31,3 31,11 @@ type Grant struct {
	UserID    int           `db:"user_id"`
	ClientID  sql.NullInt64 `db:"client_id"`
}

// Authorization ...
type Authorization struct {
	ID        int       `db:"id"`
	Code      string    `db:"code"`
	Payload   string    `db:"payload"`
	CreatedOn time.Time `db:"created_on"`
}

A oauth2_authorizations.go => oauth2_authorizations.go +92 -0
@@ 0,0 1,92 @@
package oauth2

import (
	"context"
	"database/sql"
	"fmt"

	sq "github.com/Masterminds/squirrel"
	"hg.code.netlandish.com/~netlandish/gobwebs/database"
)

// GetAuthorizations retuns oauth2 authorizations using the given filters
func GetAuthorizations(ctx context.Context, opts *database.FilterOptions) ([]*Authorization, error) {
	if opts == nil {
		opts = &database.FilterOptions{}
	}
	auths := make([]*Authorization, 0)
	if err := database.WithTx(ctx, database.TxOptionsRO, func(tx *sql.Tx) error {
		q := opts.GetBuilder(nil)
		rows, err := q.
			Columns("id", "code", "payload", "created_on").
			From("oauth2_authorizations").
			PlaceholderFormat(sq.Dollar).
			RunWith(tx).
			QueryContext(ctx)
		if err != nil {
			if err == sql.ErrNoRows {
				return nil
			}
			return err
		}
		defer rows.Close()

		for rows.Next() {
			var a Authorization
			if err = rows.Scan(&a.ID, &a.Code, &a.Payload, &a.CreatedOn); err != nil {
				return err
			}
			auths = append(auths, &a)
		}
		return nil
	}); err != nil {
		return nil, err
	}
	return auths, nil
}

// Store will save a client
func (a *Authorization) Store(ctx context.Context) error {
	err := database.WithTx(ctx, nil, func(tx *sql.Tx) error {
		var err error
		if a.ID == 0 {
			err = sq.
				Insert("oauth2_authorizations").
				Columns("code", "payload").
				Values(a.Code, a.Payload).
				Suffix(`RETURNING (id)`).
				PlaceholderFormat(sq.Dollar).
				RunWith(tx).
				ScanContext(ctx, &a.ID)
		} else {
			err = sq.
				Update("oauth2_authorizations").
				Set("code", a.Code).
				Set("payload", a.Payload).
				Where("id = ?", a.ID).
				Suffix(`RETURNING (id)`).
				PlaceholderFormat(sq.Dollar).
				RunWith(tx).
				ScanContext(ctx, &a.ID)
		}
		return err
	})
	return err
}

// Delete will delete this rate
func (a *Authorization) Delete(ctx context.Context) error {
	if a.ID == 0 {
		return fmt.Errorf("Authorization object is not populated")
	}
	err := database.WithTx(ctx, nil, func(tx *sql.Tx) error {
		_, err := sq.
			Delete("oauth2_authorizations").
			Where("id = ?", a.ID).
			PlaceholderFormat(sq.Dollar).
			RunWith(tx).
			ExecContext(ctx)
		return err
	})
	return err
}

M oauth2_clients.go => oauth2_clients.go +15 -0
@@ 54,6 54,21 @@ func GetClients(ctx context.Context, opts *database.FilterOptions) ([]*Client, e
	return clients, nil
}

// GetClientByID will fetch a client by given client id
func GetClientByID(ctx context.Context, clientID string) (*Client, error) {
	opts := &database.FilterOptions{
		Filter: sq.Eq{"key": clientID},
	}
	clients, err := GetClients(ctx, opts)
	if err != nil {
		return nil, err
	}
	if len(clients) == 0 {
		return nil, nil
	}
	return clients[0], nil
}

// Store will save a client
func (c *Client) Store(ctx context.Context) error {
	err := database.WithTx(ctx, nil, func(tx *sql.Tx) error {

M oauth2_grants.go => oauth2_grants.go +6 -0
@@ 85,6 85,12 @@ func (g *Grant) IsExpired() bool {
	return time.Now().UTC().After(g.Expires.UTC())
}

// Revoke ...
func (g *Grant) Revoke(ctx context.Context) error {
	g.Expires = time.Now().UTC()
	return g.Store(ctx)
}

// Delete will delete this rate
func (g *Grant) Delete(ctx context.Context) error {
	if g.ID == 0 {

M routes.go => routes.go +174 -3
@@ 1,10 1,14 @@
package oauth2

import (
	"crypto/rand"
	"crypto/sha512"
	"encoding/hex"
	"encoding/json"
	"fmt"
	"net/http"
	"net/url"
	"strings"
	"time"

	sq "github.com/Masterminds/squirrel"


@@ 15,11 19,18 @@ import (
	"hg.code.netlandish.com/~netlandish/gobwebs/server"
)

// ServiceConfig let's you add basic config variables to service
type ServiceConfig struct {
	DocumentationURL string
	Scopes           []string
}

// Service is the base accounts service struct
type Service struct {
	name   string
	eg     *echo.Group
	helper Helper
	config *ServiceConfig
}

// RegisterRoutes ...


@@ 33,6 44,8 @@ func (s *Service) RegisterRoutes() {
	s.eg.GET("/clients", s.ListClients).Name = s.RouteName("list_clients")
	s.eg.GET("/clients/add", s.AddClient).Name = s.RouteName("add_client")
	s.eg.POST("/clients/add", s.AddClient).Name = s.RouteName("add_client_post")
	s.eg.GET("/authorize", s.Authorize).Name = s.RouteName("authorize")
	s.eg.POST("/authorize", s.AuthorizePOST).Name = s.RouteName("authorize_post")
}

// ListPersonal ...


@@ 154,6 167,163 @@ func (s *Service) AddClient(c echo.Context) error {
	return gctx.Render(http.StatusOK, "oauth2_add_client.html", gmap)
}

func oauth2Redirect(c echo.Context, redirectURI string, params gobwebs.Map) error {
	parts, err := url.Parse(redirectURI)
	if err != nil {
		return err
	}
	qs := parts.Query()
	for k, v := range params {
		qs.Set(k, v.(string))
	}
	parts.RawQuery = qs.Encode()
	return c.Redirect(http.StatusMovedPermanently, parts.String())
}

func (s *Service) authorizeError(c echo.Context,
	redirectURI, state, errorCode, errorDescription string) error {
	if redirectURI == "" {
		gctx := c.(*server.Context)
		return gctx.Render(http.StatusOK, "oauth2_error.html", gobwebs.Map{
			"code":        errorCode,
			"description": errorDescription,
		})
	}
	return oauth2Redirect(c, redirectURI, gobwebs.Map{
		"error":             errorCode,
		"error_description": errorDescription,
		"error_uri":         s.config.DocumentationURL,
		"state":             state,
	})
}

// Authorize ...
func (s *Service) Authorize(c echo.Context) error {
	respType := c.QueryParam("response_type")
	clientID := c.QueryParam("client_id")
	scope := c.QueryParam("scope")
	state := c.QueryParam("state")
	redirectURL := c.QueryParam("redirect_uri")

	if clientID == "" {
		return s.authorizeError(c, "", state, "invalid_request",
			"The client_id parameter is required")
	}

	client, err := GetClientByID(c.Request().Context(), clientID)
	if err != nil {
		return s.authorizeError(c, "", state, "server_error", err.Error())
	}
	if client == nil {
		return s.authorizeError(c, "", state, "invalid_request", "Invalid client ID")
	}

	if redirectURL != "" && redirectURL != client.RedirectURL {
		return s.authorizeError(c, "", state, "invalid_request",
			"The redirect_uri parameter doesn't match the registered client's")
	}
	if respType != "code" {
		return s.authorizeError(c, redirectURL, state, "unsupported_response_type",
			"The response_type parameter must be set to 'code'")
	}
	if scope == "" {
		return s.authorizeError(c, redirectURL, state, "invalid_scope",
			"The scope parameter is required")
	}

	grants := DecodeGrants(scope)
	errors := grants.Validate(s.config.Scopes)
	if len(errors) > 0 {
		return s.authorizeError(c, redirectURL, state,
			"invalid_scope", strings.Join(errors, ", "))
	}

	gctx := c.(*server.Context)
	return gctx.Render(http.StatusOK, "oauth2_authorization.html", gobwebs.Map{
		"client":       client,
		"grants":       grants.List(),
		"client_id":    clientID,
		"redirect_uri": redirectURL,
		"state":        state,
	})
}

// AuthorisePOST ...
func (s *Service) AuthorizePOST(c echo.Context) error {
	params, err := c.FormParams()
	if err != nil {
		return err
	}
	clientID := params.Get("client_id")
	redirectURL := params.Get("redirect_uri")
	state := params.Get("state")

	if params.Has("reject") {
		return s.authorizeError(c, redirectURL, state, "access_denied",
			"The resource owner denied the request.")
	}

	subgrants := []string{}
	// XXX csrf shouldn't be hard coded here
	skip := []string{"accept", "client_id", "redirect_uri", "state", "csrf"}
	for grant := range params {
		if contains(skip, grant) {
			continue
		}
		subgrants = append(subgrants, grant)
	}

	grants := DecodeGrants(strings.Join(subgrants, " "))
	errors := grants.Validate(s.config.Scopes)
	if len(errors) > 0 {
		return s.authorizeError(c, redirectURL, state,
			"invalid_scope", strings.Join(errors, ", "))
	}

	client, err := GetClientByID(c.Request().Context(), clientID)
	if err != nil {
		return s.authorizeError(c, "", state, "server_error", err.Error())
	}
	if client == nil {
		return s.authorizeError(c, "", state, "invalid_request", "Invalid client ID")
	}

	var seed [64]byte
	gctx := c.(*server.Context)
	n, err := rand.Read(seed[:])
	if err != nil || n != len(seed) {
		panic(err)
	}
	hash := sha512.Sum512(seed[:])
	code := hex.EncodeToString(hash[:])[:32]

	payload := AuthorizationPayload{
		Grants:    grants.encoded,
		ClientKey: clientID,
		UserID:    int(gctx.User.GetID()),
	}
	data, err := json.Marshal(&payload)
	if err != nil {
		panic(err)
	}

	auth := &Authorization{
		Code:    code,
		Payload: string(data),
	}
	if err := auth.Store(c.Request().Context()); err != nil {
		return s.authorizeError(c, "", state, "server_error", err.Error())
	}

	gmap := gobwebs.Map{
		"code": code,
	}
	if state != "" {
		gmap["state"] = state
	}
	return oauth2Redirect(c, client.RedirectURL, gmap)
}

// Introspect ...
func (s *Service) Introspect(c echo.Context) error {
	req := c.Request()


@@ 162,10 332,11 @@ func (s *Service) Introspect(c echo.Context) error {
		retErr := struct {
			Err  string `json:"error"`
			Desc string `json:"error_description"`
			URI  string `json:"error_url"` // TODO Make this customizable
			URI  string `json:"error_url"`
		}{
			Err:  "invalid request",
			Desc: "Content-Type must be application/x-www-form-urlencoded",
			URI:  s.config.DocumentationURL,
		}
		return c.JSON(http.StatusBadRequest, &retErr)
	}


@@ 207,11 378,11 @@ func (s *Service) RouteName(value string) string {
}

// NewService return service
func NewService(eg *echo.Group, name string, helper Helper) *Service {
func NewService(eg *echo.Group, name string, helper Helper, config *ServiceConfig) *Service {
	if name == "" {
		name = "oauth2"
	}
	service := &Service{name: name, eg: eg, helper: helper}
	service := &Service{name: name, eg: eg, helper: helper, config: config}
	service.RegisterRoutes()
	return service
}

M schema.sql => schema.sql +11 -0
@@ 45,3 45,14 @@ CREATE TABLE oauth2_grants (
CREATE INDEX oauth2_grants_id_idx ON oauth2_grants (id);
CREATE INDEX oauth2_grants_user_id_idx ON oauth2_grants (user_id);
CREATE INDEX oauth2_grants_client_id_idx ON oauth2_grants (client_id);


CREATE TABLE oauth2_authorizations (
        id serial PRIMARY KEY,
	code character varying(128) NOT NULL,
	payload character varying NOT NULL,
	created_on TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);

CREATE INDEX oauth2_authorizations_id_idx ON oauth2_authorizations (id);
CREATE INDEX oauth2_authorizations_code_idx ON oauth2_authorizations (code);