~netlandish/gobwebs-oauth2

584364316d8d881098551601b277278466940a27 — Peter Sanchez 11 months ago 55b7b69
Adding base refresh_token support. Missing grant/scope modifying support
4 files changed, 178 insertions(+), 85 deletions(-)

M grants.go
M models.go
M routes.go
M schema.sql
M grants.go => grants.go +7 -5
@@ 20,7 20,7 @@ func GetGrants(ctx context.Context, opts *database.FilterOptions) ([]*Grant, err
		q := opts.GetBuilder(nil)
		rows, err := q.
			Columns("g.id", "g.issued", "g.expires", "g.comment", "g.grants", "g.token_hash",
				"g.user_id", "g.client_id", "c.name").
				"g.refresh_token_hash", "g.user_id", "g.client_id", "c.name").
			From("oauth2_grants g").
			LeftJoin("oauth2_clients c ON c.id = g.client_id").
			PlaceholderFormat(sq.Dollar).


@@ 37,7 37,7 @@ func GetGrants(ctx context.Context, opts *database.FilterOptions) ([]*Grant, err
		for rows.Next() {
			var g Grant
			if err = rows.Scan(&g.ID, &g.Issued, &g.Expires, &g.Comment, &g.Grants, &g.TokenHash,
				&g.UserID, &g.ClientID, &g.ClientName); err != nil {
				&g.RefreshTokenHash, &g.UserID, &g.ClientID, &g.ClientName); err != nil {
				return err
			}
			grants = append(grants, &g)


@@ 56,9 56,10 @@ func (g *Grant) Store(ctx context.Context) error {
		if g.ID == 0 {
			err = sq.
				Insert("oauth2_grants").
				Columns("issued", "expires", "comment", "grants", "token_hash", "user_id",
					"client_id").
				Values(g.Issued, g.Expires, g.Comment, g.Grants, g.TokenHash, g.UserID, g.ClientID).
				Columns("issued", "expires", "comment", "grants", "token_hash", "refresh_token_hash",
					"user_id", "client_id").
				Values(g.Issued, g.Expires, g.Comment, g.Grants, g.TokenHash, g.RefreshTokenHash,
					g.UserID, g.ClientID).
				Suffix(`RETURNING (id)`).
				PlaceholderFormat(sq.Dollar).
				RunWith(tx).


@@ 71,6 72,7 @@ func (g *Grant) Store(ctx context.Context) error {
				Set("comment", g.Comment).
				Set("grants", g.Grants).
				Set("token_hash", g.TokenHash).
				Set("refresh_token_hash", g.RefreshTokenHash).
				Set("user_id", g.UserID).
				Set("client_id", g.ClientID).
				Where("id = ?", g.ID).

M models.go => models.go +9 -8
@@ 23,14 23,15 @@ type Client struct {

// Grant ...
type Grant struct {
	ID        int           `db:"id"`
	Issued    time.Time     `db:"issued"`
	Expires   time.Time     `db:"expires"`
	Comment   string        `db:"comment"`
	Grants    string        `db:"grants"`
	TokenHash string        `db:"token_hash"`
	UserID    int           `db:"user_id"`
	ClientID  sql.NullInt64 `db:"client_id"`
	ID               int           `db:"id"`
	Issued           time.Time     `db:"issued"`
	Expires          time.Time     `db:"expires"`
	Comment          string        `db:"comment"`
	Grants           string        `db:"grants"`
	TokenHash        string        `db:"token_hash"`
	RefreshTokenHash string        `db:"refresh_token_hash"`
	UserID           int           `db:"user_id"`
	ClientID         sql.NullInt64 `db:"client_id"`

	ClientName sql.NullString `db:"-"`
}

M routes.go => routes.go +161 -72
@@ 347,7 347,8 @@ func (s *Service) authorizeError(c echo.Context,
	})
}

// Authorize ...
// Authorize will show an authorization form to the end user seeking their permission to
// grant oauth2 access to the provided scopes by the provided client.
func (s *Service) Authorize(c echo.Context) error {
	respType := c.QueryParam("response_type")
	clientID := c.QueryParam("client_id")


@@ 398,7 399,7 @@ func (s *Service) Authorize(c echo.Context) error {
	return gctx.Render(http.StatusOK, "oauth2_authorization.html", gmap)
}

// AuthorisePOST ...
// AuthorisePOST will authorize, or reject, oauth2 access request by an outside client.
func (s *Service) AuthorizePOST(c echo.Context) error {
	params, err := c.FormParams()
	if err != nil {


@@ 492,7 493,8 @@ func (s *Service) accessTokenError(c echo.Context, err, desc string, status int)
	return c.JSON(status, &retErr)
}

// AccessTokenPOST ...
// AccessTokenPOST will be used by remote party after a user has successfully authorized
// oauth2 account access. This is where the access token will be created and returned.
func (s *Service) AccessTokenPOST(c echo.Context) error {
	req := c.Request()
	ctype := req.Header.Get("Content-Type")


@@ 508,6 510,8 @@ func (s *Service) AccessTokenPOST(c echo.Context) error {
	grantType := params.Get("grant_type")
	code := params.Get("code")
	redirectURI := params.Get("redirect_uri")
	inRefreshToken := params.Get("refresh_token")
	scope := params.Get("scope")
	clientID := params.Get("client_id")
	clientSecret := params.Get("client_secret")



@@ 551,42 555,17 @@ func (s *Service) AccessTokenPOST(c echo.Context) error {
		return s.accessTokenError(c, "invalid_request",
			"The grant_type parameter is required", 400)
	}
	if grantType != "authorization_code" {
	if grantType != "authorization_code" && grantType != "refresh_token" {
		return s.accessTokenError(c, "unsupported_grant_type",
			fmt.Sprintf("Unsupported grant type %s", grantType), 400)
	}
	if code == "" {
		return s.accessTokenError(c, "invalid_request",
			"The code parameter is required", 400)
	}

	opts := &database.FilterOptions{
		Filter: sq.Eq{"code": code},
	}
	auths, err := GetAuthorizations(c.Request().Context(), opts)
	if err != nil {
		return s.accessTokenError(c, "server_error",
			"server error occurred, try again.", 400)
	}
	if len(auths) == 0 {
		return s.accessTokenError(c, "invalid_request",
			"Invalid authorization code", 400)
	}
	authCode := auths[0]
	if authCode.IsExpired() {
		authCode.Delete(c.Request().Context())
		return s.accessTokenError(c, "invalid_request",
			"Authorization code expired", 400)
	}

	var payload AuthorizationPayload
	if err = json.Unmarshal([]byte(authCode.Payload), &payload); err != nil {
		panic(err)
	}

	issued := time.Now().UTC()
	expires := issued.Add(366 * 24 * time.Hour)
	var (
		grant                                            *Grant
		token, tokenHash, refreshToken, refreshTokenHash string
	)

	// Fetch client by Client.Key hash
	client, err := GetClientByID(c.Request().Context(), clientID)
	if err != nil {
		return s.accessTokenError(c, "invalid_request",


@@ 597,50 576,160 @@ func (s *Service) AccessTokenPOST(c echo.Context) error {
			"Invalid client secret", 400)
	}

	if redirectURI != "" && redirectURI != client.RedirectURL {
		return s.accessTokenError(c, "invalid_request",
			"Invalid redirect URI", 400)
	}
	buf := make([]byte, 32)
	issued := time.Now().UTC()
	expires := issued.Add(366 * 24 * time.Hour)

	bt := BearerToken{
		Version:  TokenVersion,
		Type:     TypeOAuth2,
		Issued:   ToTimestamp(issued),
		Expires:  ToTimestamp(expires),
		Grants:   payload.Grants,
		UserID:   payload.UserID,
		ClientID: payload.ClientKey,
	}
	if grantType == "authorization_code" {
		// Adding a new grant
		if code == "" {
			return s.accessTokenError(c, "invalid_request",
				"The code parameter is required", 400)
		}

		opts := &database.FilterOptions{
			Filter: sq.Eq{"code": code},
		}
		auths, err := GetAuthorizations(c.Request().Context(), opts)
		if err != nil {
			return s.accessTokenError(c, "server_error",
				"server error occurred, try again.", 400)
		}
		if len(auths) == 0 {
			return s.accessTokenError(c, "invalid_request",
				"Invalid authorization code", 400)
		}
		authCode := auths[0]
		if authCode.IsExpired() {
			authCode.Delete(c.Request().Context())
			return s.accessTokenError(c, "invalid_request",
				"Authorization code expired", 400)
		}

	token := bt.Encode(c.Request().Context())
	hash := sha512.Sum512([]byte(token))
	tokenHash := hex.EncodeToString(hash[:])
		var payload AuthorizationPayload
		if err = json.Unmarshal([]byte(authCode.Payload), &payload); err != nil {
			panic(err)
		}

	grant := &Grant{
		Issued:    issued,
		Expires:   expires,
		Grants:    payload.Grants,
		TokenHash: tokenHash,
		UserID:    payload.UserID,
		ClientID:  sql.NullInt64{Int64: int64(client.ID), Valid: true},
	}
	if err := grant.Store(c.Request().Context()); err != nil {
		return s.accessTokenError(c, "server_error",
			"server error occurred storing token, try again.", 400)
	}
		if redirectURI != "" && redirectURI != client.RedirectURL {
			return s.accessTokenError(c, "invalid_request",
				"Invalid redirect URI", 400)
		}

	authCode.Delete(c.Request().Context())
		bt := BearerToken{
			Version:  TokenVersion,
			Type:     TypeOAuth2,
			Issued:   ToTimestamp(issued),
			Expires:  ToTimestamp(expires),
			Grants:   payload.Grants,
			UserID:   payload.UserID,
			ClientID: payload.ClientKey,
		}

		token = bt.Encode(c.Request().Context())
		hash := sha512.Sum512([]byte(token))
		tokenHash = hex.EncodeToString(hash[:])

		if _, err := rand.Read(buf); err != nil {
			panic(err)
		}
		refreshToken = base64.RawURLEncoding.EncodeToString(buf)
		hash = sha512.Sum512([]byte(refreshToken))
		refreshTokenHash = hex.EncodeToString(hash[:])

		grant = &Grant{
			Issued:           issued,
			Expires:          expires,
			Grants:           payload.Grants,
			TokenHash:        tokenHash,
			RefreshTokenHash: refreshTokenHash,
			UserID:           payload.UserID,
			ClientID:         sql.NullInt64{Int64: int64(client.ID), Valid: true},
		}
		if err := grant.Store(c.Request().Context()); err != nil {
			return s.accessTokenError(c, "server_error",
				"server error occurred storing token, try again.", 400)
		}

		authCode.Delete(c.Request().Context())

	} else if grantType == "refresh_token" {
		// Refreshing an existing token
		if inRefreshToken == "" {
			return s.accessTokenError(c, "invalid_request",
				"The refresh_token parameter is required", 400)
		}

		hash := sha512.Sum512([]byte(inRefreshToken))
		inRefreshTokenHash := hex.EncodeToString(hash[:])

		opts := &database.FilterOptions{
			Filter: sq.And{
				sq.Eq{"g.client_id": client.ID},
				sq.Eq{"g.refresh_token_hash": inRefreshTokenHash},
			},
		}
		grants, err := GetGrants(c.Request().Context(), opts)
		if err != nil {
			return s.accessTokenError(c, "server_error",
				"server error occurred, try again.", 400)
		}
		if len(grants) == 0 {
			return s.accessTokenError(c, "invalid_request",
				"Invalid refresh_token given", 400)
		}
		grant = grants[0]

		// TODO: Allow modifying grants
		if scope != "" && scope != grant.Grants {
			return s.accessTokenError(c, "invalid_request",
				"Invalid scope value given", 400)
		}

		bt := BearerToken{
			Version:  TokenVersion,
			Type:     TypeOAuth2,
			Issued:   ToTimestamp(issued),
			Expires:  ToTimestamp(expires),
			Grants:   grant.Grants, // scope in the future
			UserID:   grant.UserID,
			ClientID: client.Key,
		}

		token = bt.Encode(c.Request().Context())
		hash = sha512.Sum512([]byte(token))
		tokenHash = hex.EncodeToString(hash[:])

		if _, err := rand.Read(buf); err != nil {
			panic(err)
		}
		refreshToken = base64.RawURLEncoding.EncodeToString(buf)
		hash = sha512.Sum512([]byte(refreshToken))
		refreshTokenHash = hex.EncodeToString(hash[:])

		grant.Issued = issued
		grant.Expires = expires
		grant.TokenHash = tokenHash
		grant.RefreshTokenHash = refreshTokenHash

		if err := grant.Store(c.Request().Context()); err != nil {
			return s.accessTokenError(c, "server_error",
				"server error occurred storing token, try again.", 400)
		}
	}

	ret := struct {
		Token   string `json:"access_token"`
		Type    string `json:"token_type"`
		Expires int    `json:"expires_in"`
		Scope   string `json:"scope"`
		Token        string `json:"access_token"`
		Type         string `json:"token_type"`
		Expires      int    `json:"expires_in"`
		Scope        string `json:"scope"`
		RefreshToken string `json:"refresh_token"`
	}{
		Token:   token,
		Type:    "bearer",
		Expires: int(expires.Sub(time.Now().UTC()).Seconds()),
		Scope:   payload.Grants,
		Token:        token,
		Type:         "bearer",
		Expires:      int(expires.Sub(time.Now().UTC()).Seconds()),
		Scope:        grant.Grants,
		RefreshToken: refreshToken,
	}

	return c.JSON(http.StatusOK, &ret)


@@ 740,7 829,7 @@ func (s *Service) OAuthMetadata(c echo.Context) error {
		TokenEndpoint: tURL,
		Scopes:        scopes,
		Responses:     []string{"code"},
		Grants:        []string{"authorization_code"},
		Grants:        []string{"authorization_code", "refresh_token"},
		Doc:           s.config.DocumentationURL,
		IntroEndpoint: iURL,
		IntroAuth:     []string{"none"},

M schema.sql => schema.sql +1 -0
@@ 39,6 39,7 @@ CREATE TABLE oauth2_grants (
        comment character varying,
        grants character varying,
        token_hash character varying(128) NOT NULL,
        refresh_token_hash character varying(128) NOT NULL,
        user_id integer NOT NULL REFERENCES users (id) ON DELETE CASCADE,
        client_id integer REFERENCES oauth2_clients (id) ON DELETE CASCADE
);