From 584364316d8d881098551601b277278466940a27 Mon Sep 17 00:00:00 2001 From: Peter Sanchez Date: Wed, 29 Nov 2023 18:13:41 -0600 Subject: [PATCH] Adding base refresh_token support. Missing grant/scope modifying support --- grants.go | 12 +-- models.go | 17 ++-- routes.go | 233 ++++++++++++++++++++++++++++++++++++----------------- schema.sql | 1 + 4 files changed, 178 insertions(+), 85 deletions(-) diff --git a/grants.go b/grants.go index 64fdc9f..57161b7 100644 --- a/grants.go +++ b/grants.go @@ -20,7 +20,7 @@ func GetGrants(ctx context.Context, opts *database.FilterOptions) ([]*Grant, err q := opts.GetBuilder(nil) rows, err := q. Columns("g.id", "g.issued", "g.expires", "g.comment", "g.grants", "g.token_hash", - "g.user_id", "g.client_id", "c.name"). + "g.refresh_token_hash", "g.user_id", "g.client_id", "c.name"). From("oauth2_grants g"). LeftJoin("oauth2_clients c ON c.id = g.client_id"). PlaceholderFormat(sq.Dollar). @@ -37,7 +37,7 @@ func GetGrants(ctx context.Context, opts *database.FilterOptions) ([]*Grant, err for rows.Next() { var g Grant if err = rows.Scan(&g.ID, &g.Issued, &g.Expires, &g.Comment, &g.Grants, &g.TokenHash, - &g.UserID, &g.ClientID, &g.ClientName); err != nil { + &g.RefreshTokenHash, &g.UserID, &g.ClientID, &g.ClientName); err != nil { return err } grants = append(grants, &g) @@ -56,9 +56,10 @@ func (g *Grant) Store(ctx context.Context) error { if g.ID == 0 { err = sq. Insert("oauth2_grants"). - Columns("issued", "expires", "comment", "grants", "token_hash", "user_id", - "client_id"). - Values(g.Issued, g.Expires, g.Comment, g.Grants, g.TokenHash, g.UserID, g.ClientID). + Columns("issued", "expires", "comment", "grants", "token_hash", "refresh_token_hash", + "user_id", "client_id"). + Values(g.Issued, g.Expires, g.Comment, g.Grants, g.TokenHash, g.RefreshTokenHash, + g.UserID, g.ClientID). Suffix(`RETURNING (id)`). PlaceholderFormat(sq.Dollar). RunWith(tx). @@ -71,6 +72,7 @@ func (g *Grant) Store(ctx context.Context) error { Set("comment", g.Comment). Set("grants", g.Grants). Set("token_hash", g.TokenHash). + Set("refresh_token_hash", g.RefreshTokenHash). Set("user_id", g.UserID). Set("client_id", g.ClientID). Where("id = ?", g.ID). diff --git a/models.go b/models.go index 26e4624..88f02da 100644 --- a/models.go +++ b/models.go @@ -23,14 +23,15 @@ type Client struct { // Grant ... type Grant struct { - ID int `db:"id"` - Issued time.Time `db:"issued"` - Expires time.Time `db:"expires"` - Comment string `db:"comment"` - Grants string `db:"grants"` - TokenHash string `db:"token_hash"` - UserID int `db:"user_id"` - ClientID sql.NullInt64 `db:"client_id"` + ID int `db:"id"` + Issued time.Time `db:"issued"` + Expires time.Time `db:"expires"` + Comment string `db:"comment"` + Grants string `db:"grants"` + TokenHash string `db:"token_hash"` + RefreshTokenHash string `db:"refresh_token_hash"` + UserID int `db:"user_id"` + ClientID sql.NullInt64 `db:"client_id"` ClientName sql.NullString `db:"-"` } diff --git a/routes.go b/routes.go index 9c88223..f63f50b 100644 --- a/routes.go +++ b/routes.go @@ -347,7 +347,8 @@ func (s *Service) authorizeError(c echo.Context, }) } -// Authorize ... +// Authorize will show an authorization form to the end user seeking their permission to +// grant oauth2 access to the provided scopes by the provided client. func (s *Service) Authorize(c echo.Context) error { respType := c.QueryParam("response_type") clientID := c.QueryParam("client_id") @@ -398,7 +399,7 @@ func (s *Service) Authorize(c echo.Context) error { return gctx.Render(http.StatusOK, "oauth2_authorization.html", gmap) } -// AuthorisePOST ... +// AuthorisePOST will authorize, or reject, oauth2 access request by an outside client. func (s *Service) AuthorizePOST(c echo.Context) error { params, err := c.FormParams() if err != nil { @@ -492,7 +493,8 @@ func (s *Service) accessTokenError(c echo.Context, err, desc string, status int) return c.JSON(status, &retErr) } -// AccessTokenPOST ... +// AccessTokenPOST will be used by remote party after a user has successfully authorized +// oauth2 account access. This is where the access token will be created and returned. func (s *Service) AccessTokenPOST(c echo.Context) error { req := c.Request() ctype := req.Header.Get("Content-Type") @@ -508,6 +510,8 @@ func (s *Service) AccessTokenPOST(c echo.Context) error { grantType := params.Get("grant_type") code := params.Get("code") redirectURI := params.Get("redirect_uri") + inRefreshToken := params.Get("refresh_token") + scope := params.Get("scope") clientID := params.Get("client_id") clientSecret := params.Get("client_secret") @@ -551,42 +555,17 @@ func (s *Service) AccessTokenPOST(c echo.Context) error { return s.accessTokenError(c, "invalid_request", "The grant_type parameter is required", 400) } - if grantType != "authorization_code" { + if grantType != "authorization_code" && grantType != "refresh_token" { return s.accessTokenError(c, "unsupported_grant_type", fmt.Sprintf("Unsupported grant type %s", grantType), 400) } - if code == "" { - return s.accessTokenError(c, "invalid_request", - "The code parameter is required", 400) - } - opts := &database.FilterOptions{ - Filter: sq.Eq{"code": code}, - } - auths, err := GetAuthorizations(c.Request().Context(), opts) - if err != nil { - return s.accessTokenError(c, "server_error", - "server error occurred, try again.", 400) - } - if len(auths) == 0 { - return s.accessTokenError(c, "invalid_request", - "Invalid authorization code", 400) - } - authCode := auths[0] - if authCode.IsExpired() { - authCode.Delete(c.Request().Context()) - return s.accessTokenError(c, "invalid_request", - "Authorization code expired", 400) - } - - var payload AuthorizationPayload - if err = json.Unmarshal([]byte(authCode.Payload), &payload); err != nil { - panic(err) - } - - issued := time.Now().UTC() - expires := issued.Add(366 * 24 * time.Hour) + var ( + grant *Grant + token, tokenHash, refreshToken, refreshTokenHash string + ) + // Fetch client by Client.Key hash client, err := GetClientByID(c.Request().Context(), clientID) if err != nil { return s.accessTokenError(c, "invalid_request", @@ -597,50 +576,160 @@ func (s *Service) AccessTokenPOST(c echo.Context) error { "Invalid client secret", 400) } - if redirectURI != "" && redirectURI != client.RedirectURL { - return s.accessTokenError(c, "invalid_request", - "Invalid redirect URI", 400) - } + buf := make([]byte, 32) + issued := time.Now().UTC() + expires := issued.Add(366 * 24 * time.Hour) - bt := BearerToken{ - Version: TokenVersion, - Type: TypeOAuth2, - Issued: ToTimestamp(issued), - Expires: ToTimestamp(expires), - Grants: payload.Grants, - UserID: payload.UserID, - ClientID: payload.ClientKey, - } + if grantType == "authorization_code" { + // Adding a new grant + if code == "" { + return s.accessTokenError(c, "invalid_request", + "The code parameter is required", 400) + } + + opts := &database.FilterOptions{ + Filter: sq.Eq{"code": code}, + } + auths, err := GetAuthorizations(c.Request().Context(), opts) + if err != nil { + return s.accessTokenError(c, "server_error", + "server error occurred, try again.", 400) + } + if len(auths) == 0 { + return s.accessTokenError(c, "invalid_request", + "Invalid authorization code", 400) + } + authCode := auths[0] + if authCode.IsExpired() { + authCode.Delete(c.Request().Context()) + return s.accessTokenError(c, "invalid_request", + "Authorization code expired", 400) + } - token := bt.Encode(c.Request().Context()) - hash := sha512.Sum512([]byte(token)) - tokenHash := hex.EncodeToString(hash[:]) + var payload AuthorizationPayload + if err = json.Unmarshal([]byte(authCode.Payload), &payload); err != nil { + panic(err) + } - grant := &Grant{ - Issued: issued, - Expires: expires, - Grants: payload.Grants, - TokenHash: tokenHash, - UserID: payload.UserID, - ClientID: sql.NullInt64{Int64: int64(client.ID), Valid: true}, - } - if err := grant.Store(c.Request().Context()); err != nil { - return s.accessTokenError(c, "server_error", - "server error occurred storing token, try again.", 400) - } + if redirectURI != "" && redirectURI != client.RedirectURL { + return s.accessTokenError(c, "invalid_request", + "Invalid redirect URI", 400) + } - authCode.Delete(c.Request().Context()) + bt := BearerToken{ + Version: TokenVersion, + Type: TypeOAuth2, + Issued: ToTimestamp(issued), + Expires: ToTimestamp(expires), + Grants: payload.Grants, + UserID: payload.UserID, + ClientID: payload.ClientKey, + } + + token = bt.Encode(c.Request().Context()) + hash := sha512.Sum512([]byte(token)) + tokenHash = hex.EncodeToString(hash[:]) + + if _, err := rand.Read(buf); err != nil { + panic(err) + } + refreshToken = base64.RawURLEncoding.EncodeToString(buf) + hash = sha512.Sum512([]byte(refreshToken)) + refreshTokenHash = hex.EncodeToString(hash[:]) + + grant = &Grant{ + Issued: issued, + Expires: expires, + Grants: payload.Grants, + TokenHash: tokenHash, + RefreshTokenHash: refreshTokenHash, + UserID: payload.UserID, + ClientID: sql.NullInt64{Int64: int64(client.ID), Valid: true}, + } + if err := grant.Store(c.Request().Context()); err != nil { + return s.accessTokenError(c, "server_error", + "server error occurred storing token, try again.", 400) + } + + authCode.Delete(c.Request().Context()) + + } else if grantType == "refresh_token" { + // Refreshing an existing token + if inRefreshToken == "" { + return s.accessTokenError(c, "invalid_request", + "The refresh_token parameter is required", 400) + } + + hash := sha512.Sum512([]byte(inRefreshToken)) + inRefreshTokenHash := hex.EncodeToString(hash[:]) + + opts := &database.FilterOptions{ + Filter: sq.And{ + sq.Eq{"g.client_id": client.ID}, + sq.Eq{"g.refresh_token_hash": inRefreshTokenHash}, + }, + } + grants, err := GetGrants(c.Request().Context(), opts) + if err != nil { + return s.accessTokenError(c, "server_error", + "server error occurred, try again.", 400) + } + if len(grants) == 0 { + return s.accessTokenError(c, "invalid_request", + "Invalid refresh_token given", 400) + } + grant = grants[0] + + // TODO: Allow modifying grants + if scope != "" && scope != grant.Grants { + return s.accessTokenError(c, "invalid_request", + "Invalid scope value given", 400) + } + + bt := BearerToken{ + Version: TokenVersion, + Type: TypeOAuth2, + Issued: ToTimestamp(issued), + Expires: ToTimestamp(expires), + Grants: grant.Grants, // scope in the future + UserID: grant.UserID, + ClientID: client.Key, + } + + token = bt.Encode(c.Request().Context()) + hash = sha512.Sum512([]byte(token)) + tokenHash = hex.EncodeToString(hash[:]) + + if _, err := rand.Read(buf); err != nil { + panic(err) + } + refreshToken = base64.RawURLEncoding.EncodeToString(buf) + hash = sha512.Sum512([]byte(refreshToken)) + refreshTokenHash = hex.EncodeToString(hash[:]) + + grant.Issued = issued + grant.Expires = expires + grant.TokenHash = tokenHash + grant.RefreshTokenHash = refreshTokenHash + + if err := grant.Store(c.Request().Context()); err != nil { + return s.accessTokenError(c, "server_error", + "server error occurred storing token, try again.", 400) + } + } ret := struct { - Token string `json:"access_token"` - Type string `json:"token_type"` - Expires int `json:"expires_in"` - Scope string `json:"scope"` + Token string `json:"access_token"` + Type string `json:"token_type"` + Expires int `json:"expires_in"` + Scope string `json:"scope"` + RefreshToken string `json:"refresh_token"` }{ - Token: token, - Type: "bearer", - Expires: int(expires.Sub(time.Now().UTC()).Seconds()), - Scope: payload.Grants, + Token: token, + Type: "bearer", + Expires: int(expires.Sub(time.Now().UTC()).Seconds()), + Scope: grant.Grants, + RefreshToken: refreshToken, } return c.JSON(http.StatusOK, &ret) @@ -740,7 +829,7 @@ func (s *Service) OAuthMetadata(c echo.Context) error { TokenEndpoint: tURL, Scopes: scopes, Responses: []string{"code"}, - Grants: []string{"authorization_code"}, + Grants: []string{"authorization_code", "refresh_token"}, Doc: s.config.DocumentationURL, IntroEndpoint: iURL, IntroAuth: []string{"none"}, diff --git a/schema.sql b/schema.sql index 3539e4f..0793bec 100644 --- a/schema.sql +++ b/schema.sql @@ -39,6 +39,7 @@ CREATE TABLE oauth2_grants ( comment character varying, grants character varying, token_hash character varying(128) NOT NULL, + refresh_token_hash character varying(128) NOT NULL, user_id integer NOT NULL REFERENCES users (id) ON DELETE CASCADE, client_id integer REFERENCES oauth2_clients (id) ON DELETE CASCADE ); -- 2.45.2