From 14ecc44f91a2ca88598dfb2a9bab8f44b2278da9 Mon Sep 17 00:00:00 2001 From: Peter Sanchez Date: Mon, 8 May 2023 17:25:00 -0600 Subject: [PATCH] Adding oauth2 authorization handlers --- bearer.go | 40 +++++++++ logic.go | 8 +- models.go | 8 ++ oauth2_authorizations.go | 92 ++++++++++++++++++++ oauth2_clients.go | 15 ++++ oauth2_grants.go | 6 ++ routes.go | 177 ++++++++++++++++++++++++++++++++++++++- schema.sql | 11 +++ 8 files changed, 350 insertions(+), 7 deletions(-) create mode 100644 oauth2_authorizations.go diff --git a/bearer.go b/bearer.go index 0cafb06..98bcff5 100644 --- a/bearer.go +++ b/bearer.go @@ -136,6 +136,9 @@ func DecodeGrants(grants string) Grants { access = "RO" } else { access = parts[1] + if access != "RW" && access != "RO" { + access = "RO" + } } accessMap[scope] = access } @@ -174,6 +177,35 @@ func (g *Grants) Encode() string { return g.encoded } +// Validate ... +func (g *Grants) Validate(scopes []string) []string { + var errors []string + for k := range g.grants { + if !contains(scopes, k) { + errors = append(errors, fmt.Sprintf("Invalid scope: %s", k)) + } + } + return errors +} + +// List ... +func (g *Grants) List() []string { + var grants []string + for k, v := range g.grants { + grants = append(grants, fmt.Sprintf("%s:%s", k, v)) + } + return grants +} + +func contains(values []string, str string) bool { + for _, v := range values { + if v == str { + return true + } + } + return false +} + // TokenUser wrapper for gobwebs.User and token grants type TokenUser struct { User gobwebs.User @@ -181,3 +213,11 @@ type TokenUser struct { Grants *Grants TokenHash [64]byte } + +// AuthorizationPayload holds temporary approval while the oauth2 cycle is +// in process +type AuthorizationPayload struct { + Grants string + ClientKey string + UserID int +} diff --git a/logic.go b/logic.go index affec8f..c58216c 100644 --- a/logic.go +++ b/logic.go @@ -25,7 +25,10 @@ func OAuth2(ctx context.Context, token string, fetch gobwebs.UserFetch) (*TokenU hash := sha512.Sum512([]byte(token)) hashStr := hex.EncodeToString(hash[:]) opts := &database.FilterOptions{ - Filter: sq.Eq{"token_hash": hashStr}, + Filter: sq.And{ + sq.Eq{"token_hash": hashStr}, + sq.Expr("expires > NOW() at time zone 'UTC'"), + }, } grants, err := GetGrants(ctx, opts) if err != nil { @@ -39,9 +42,6 @@ func OAuth2(ctx context.Context, token string, fetch gobwebs.UserFetch) (*TokenU return nil, fmt.Errorf("Error with provided OAuth 2.0 bearer token") } grant := grants[0] - if grant.IsExpired() { - return nil, fmt.Errorf("Invalid or expired OAuth 2.0 bearer token") - } bt.Issued = ToTimestamp(grant.Issued) gt := DecodeGrants(bt.Grants) diff --git a/models.go b/models.go index 916874a..bb515b1 100644 --- a/models.go +++ b/models.go @@ -31,3 +31,11 @@ type Grant struct { UserID int `db:"user_id"` ClientID sql.NullInt64 `db:"client_id"` } + +// Authorization ... +type Authorization struct { + ID int `db:"id"` + Code string `db:"code"` + Payload string `db:"payload"` + CreatedOn time.Time `db:"created_on"` +} diff --git a/oauth2_authorizations.go b/oauth2_authorizations.go new file mode 100644 index 0000000..4aff011 --- /dev/null +++ b/oauth2_authorizations.go @@ -0,0 +1,92 @@ +package oauth2 + +import ( + "context" + "database/sql" + "fmt" + + sq "github.com/Masterminds/squirrel" + "hg.code.netlandish.com/~netlandish/gobwebs/database" +) + +// GetAuthorizations retuns oauth2 authorizations using the given filters +func GetAuthorizations(ctx context.Context, opts *database.FilterOptions) ([]*Authorization, error) { + if opts == nil { + opts = &database.FilterOptions{} + } + auths := make([]*Authorization, 0) + if err := database.WithTx(ctx, database.TxOptionsRO, func(tx *sql.Tx) error { + q := opts.GetBuilder(nil) + rows, err := q. + Columns("id", "code", "payload", "created_on"). + From("oauth2_authorizations"). + PlaceholderFormat(sq.Dollar). + RunWith(tx). + QueryContext(ctx) + if err != nil { + if err == sql.ErrNoRows { + return nil + } + return err + } + defer rows.Close() + + for rows.Next() { + var a Authorization + if err = rows.Scan(&a.ID, &a.Code, &a.Payload, &a.CreatedOn); err != nil { + return err + } + auths = append(auths, &a) + } + return nil + }); err != nil { + return nil, err + } + return auths, nil +} + +// Store will save a client +func (a *Authorization) Store(ctx context.Context) error { + err := database.WithTx(ctx, nil, func(tx *sql.Tx) error { + var err error + if a.ID == 0 { + err = sq. + Insert("oauth2_authorizations"). + Columns("code", "payload"). + Values(a.Code, a.Payload). + Suffix(`RETURNING (id)`). + PlaceholderFormat(sq.Dollar). + RunWith(tx). + ScanContext(ctx, &a.ID) + } else { + err = sq. + Update("oauth2_authorizations"). + Set("code", a.Code). + Set("payload", a.Payload). + Where("id = ?", a.ID). + Suffix(`RETURNING (id)`). + PlaceholderFormat(sq.Dollar). + RunWith(tx). + ScanContext(ctx, &a.ID) + } + return err + }) + return err +} + +// Delete will delete this rate +func (a *Authorization) Delete(ctx context.Context) error { + if a.ID == 0 { + return fmt.Errorf("Authorization object is not populated") + } + err := database.WithTx(ctx, nil, func(tx *sql.Tx) error { + _, err := sq. + Delete("oauth2_authorizations"). + Where("id = ?", a.ID). + PlaceholderFormat(sq.Dollar). + RunWith(tx). + ExecContext(ctx) + return err + }) + return err +} diff --git a/oauth2_clients.go b/oauth2_clients.go index 4530064..3f217e4 100644 --- a/oauth2_clients.go +++ b/oauth2_clients.go @@ -54,6 +54,21 @@ func GetClients(ctx context.Context, opts *database.FilterOptions) ([]*Client, e return clients, nil } +// GetClientByID will fetch a client by given client id +func GetClientByID(ctx context.Context, clientID string) (*Client, error) { + opts := &database.FilterOptions{ + Filter: sq.Eq{"key": clientID}, + } + clients, err := GetClients(ctx, opts) + if err != nil { + return nil, err + } + if len(clients) == 0 { + return nil, nil + } + return clients[0], nil +} + // Store will save a client func (c *Client) Store(ctx context.Context) error { err := database.WithTx(ctx, nil, func(tx *sql.Tx) error { diff --git a/oauth2_grants.go b/oauth2_grants.go index 2c93cb5..561964a 100644 --- a/oauth2_grants.go +++ b/oauth2_grants.go @@ -85,6 +85,12 @@ func (g *Grant) IsExpired() bool { return time.Now().UTC().After(g.Expires.UTC()) } +// Revoke ... +func (g *Grant) Revoke(ctx context.Context) error { + g.Expires = time.Now().UTC() + return g.Store(ctx) +} + // Delete will delete this rate func (g *Grant) Delete(ctx context.Context) error { if g.ID == 0 { diff --git a/routes.go b/routes.go index 99f24cf..5fb1234 100644 --- a/routes.go +++ b/routes.go @@ -1,10 +1,14 @@ package oauth2 import ( + "crypto/rand" "crypto/sha512" "encoding/hex" + "encoding/json" "fmt" "net/http" + "net/url" + "strings" "time" sq "github.com/Masterminds/squirrel" @@ -15,11 +19,18 @@ import ( "hg.code.netlandish.com/~netlandish/gobwebs/server" ) +// ServiceConfig let's you add basic config variables to service +type ServiceConfig struct { + DocumentationURL string + Scopes []string +} + // Service is the base accounts service struct type Service struct { name string eg *echo.Group helper Helper + config *ServiceConfig } // RegisterRoutes ... @@ -33,6 +44,8 @@ func (s *Service) RegisterRoutes() { s.eg.GET("/clients", s.ListClients).Name = s.RouteName("list_clients") s.eg.GET("/clients/add", s.AddClient).Name = s.RouteName("add_client") s.eg.POST("/clients/add", s.AddClient).Name = s.RouteName("add_client_post") + s.eg.GET("/authorize", s.Authorize).Name = s.RouteName("authorize") + s.eg.POST("/authorize", s.AuthorizePOST).Name = s.RouteName("authorize_post") } // ListPersonal ... @@ -154,6 +167,163 @@ func (s *Service) AddClient(c echo.Context) error { return gctx.Render(http.StatusOK, "oauth2_add_client.html", gmap) } +func oauth2Redirect(c echo.Context, redirectURI string, params gobwebs.Map) error { + parts, err := url.Parse(redirectURI) + if err != nil { + return err + } + qs := parts.Query() + for k, v := range params { + qs.Set(k, v.(string)) + } + parts.RawQuery = qs.Encode() + return c.Redirect(http.StatusMovedPermanently, parts.String()) +} + +func (s *Service) authorizeError(c echo.Context, + redirectURI, state, errorCode, errorDescription string) error { + if redirectURI == "" { + gctx := c.(*server.Context) + return gctx.Render(http.StatusOK, "oauth2_error.html", gobwebs.Map{ + "code": errorCode, + "description": errorDescription, + }) + } + return oauth2Redirect(c, redirectURI, gobwebs.Map{ + "error": errorCode, + "error_description": errorDescription, + "error_uri": s.config.DocumentationURL, + "state": state, + }) +} + +// Authorize ... +func (s *Service) Authorize(c echo.Context) error { + respType := c.QueryParam("response_type") + clientID := c.QueryParam("client_id") + scope := c.QueryParam("scope") + state := c.QueryParam("state") + redirectURL := c.QueryParam("redirect_uri") + + if clientID == "" { + return s.authorizeError(c, "", state, "invalid_request", + "The client_id parameter is required") + } + + client, err := GetClientByID(c.Request().Context(), clientID) + if err != nil { + return s.authorizeError(c, "", state, "server_error", err.Error()) + } + if client == nil { + return s.authorizeError(c, "", state, "invalid_request", "Invalid client ID") + } + + if redirectURL != "" && redirectURL != client.RedirectURL { + return s.authorizeError(c, "", state, "invalid_request", + "The redirect_uri parameter doesn't match the registered client's") + } + if respType != "code" { + return s.authorizeError(c, redirectURL, state, "unsupported_response_type", + "The response_type parameter must be set to 'code'") + } + if scope == "" { + return s.authorizeError(c, redirectURL, state, "invalid_scope", + "The scope parameter is required") + } + + grants := DecodeGrants(scope) + errors := grants.Validate(s.config.Scopes) + if len(errors) > 0 { + return s.authorizeError(c, redirectURL, state, + "invalid_scope", strings.Join(errors, ", ")) + } + + gctx := c.(*server.Context) + return gctx.Render(http.StatusOK, "oauth2_authorization.html", gobwebs.Map{ + "client": client, + "grants": grants.List(), + "client_id": clientID, + "redirect_uri": redirectURL, + "state": state, + }) +} + +// AuthorisePOST ... +func (s *Service) AuthorizePOST(c echo.Context) error { + params, err := c.FormParams() + if err != nil { + return err + } + clientID := params.Get("client_id") + redirectURL := params.Get("redirect_uri") + state := params.Get("state") + + if params.Has("reject") { + return s.authorizeError(c, redirectURL, state, "access_denied", + "The resource owner denied the request.") + } + + subgrants := []string{} + // XXX csrf shouldn't be hard coded here + skip := []string{"accept", "client_id", "redirect_uri", "state", "csrf"} + for grant := range params { + if contains(skip, grant) { + continue + } + subgrants = append(subgrants, grant) + } + + grants := DecodeGrants(strings.Join(subgrants, " ")) + errors := grants.Validate(s.config.Scopes) + if len(errors) > 0 { + return s.authorizeError(c, redirectURL, state, + "invalid_scope", strings.Join(errors, ", ")) + } + + client, err := GetClientByID(c.Request().Context(), clientID) + if err != nil { + return s.authorizeError(c, "", state, "server_error", err.Error()) + } + if client == nil { + return s.authorizeError(c, "", state, "invalid_request", "Invalid client ID") + } + + var seed [64]byte + gctx := c.(*server.Context) + n, err := rand.Read(seed[:]) + if err != nil || n != len(seed) { + panic(err) + } + hash := sha512.Sum512(seed[:]) + code := hex.EncodeToString(hash[:])[:32] + + payload := AuthorizationPayload{ + Grants: grants.encoded, + ClientKey: clientID, + UserID: int(gctx.User.GetID()), + } + data, err := json.Marshal(&payload) + if err != nil { + panic(err) + } + + auth := &Authorization{ + Code: code, + Payload: string(data), + } + if err := auth.Store(c.Request().Context()); err != nil { + return s.authorizeError(c, "", state, "server_error", err.Error()) + } + + gmap := gobwebs.Map{ + "code": code, + } + if state != "" { + gmap["state"] = state + } + return oauth2Redirect(c, client.RedirectURL, gmap) +} + // Introspect ... func (s *Service) Introspect(c echo.Context) error { req := c.Request() @@ -162,10 +332,11 @@ func (s *Service) Introspect(c echo.Context) error { retErr := struct { Err string `json:"error"` Desc string `json:"error_description"` - URI string `json:"error_url"` // TODO Make this customizable + URI string `json:"error_url"` }{ Err: "invalid request", Desc: "Content-Type must be application/x-www-form-urlencoded", + URI: s.config.DocumentationURL, } return c.JSON(http.StatusBadRequest, &retErr) } @@ -207,11 +378,11 @@ func (s *Service) RouteName(value string) string { } // NewService return service -func NewService(eg *echo.Group, name string, helper Helper) *Service { +func NewService(eg *echo.Group, name string, helper Helper, config *ServiceConfig) *Service { if name == "" { name = "oauth2" } - service := &Service{name: name, eg: eg, helper: helper} + service := &Service{name: name, eg: eg, helper: helper, config: config} service.RegisterRoutes() return service } diff --git a/schema.sql b/schema.sql index 201513c..0a4d4da 100644 --- a/schema.sql +++ b/schema.sql @@ -45,3 +45,14 @@ CREATE TABLE oauth2_grants ( CREATE INDEX oauth2_grants_id_idx ON oauth2_grants (id); CREATE INDEX oauth2_grants_user_id_idx ON oauth2_grants (user_id); CREATE INDEX oauth2_grants_client_id_idx ON oauth2_grants (client_id); + + +CREATE TABLE oauth2_authorizations ( + id serial PRIMARY KEY, + code character varying(128) NOT NULL, + payload character varying NOT NULL, + created_on TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX oauth2_authorizations_id_idx ON oauth2_authorizations (id); +CREATE INDEX oauth2_authorizations_code_idx ON oauth2_authorizations (code); -- 2.45.2