M bearer.go => bearer.go +40 -0
@@ 136,6 136,9 @@ func DecodeGrants(grants string) Grants {
access = "RO"
} else {
access = parts[1]
+ if access != "RW" && access != "RO" {
+ access = "RO"
+ }
}
accessMap[scope] = access
}
@@ 174,6 177,35 @@ func (g *Grants) Encode() string {
return g.encoded
}
+// Validate ...
+func (g *Grants) Validate(scopes []string) []string {
+ var errors []string
+ for k := range g.grants {
+ if !contains(scopes, k) {
+ errors = append(errors, fmt.Sprintf("Invalid scope: %s", k))
+ }
+ }
+ return errors
+}
+
+// List ...
+func (g *Grants) List() []string {
+ var grants []string
+ for k, v := range g.grants {
+ grants = append(grants, fmt.Sprintf("%s:%s", k, v))
+ }
+ return grants
+}
+
+func contains(values []string, str string) bool {
+ for _, v := range values {
+ if v == str {
+ return true
+ }
+ }
+ return false
+}
+
// TokenUser wrapper for gobwebs.User and token grants
type TokenUser struct {
User gobwebs.User
@@ 181,3 213,11 @@ type TokenUser struct {
Grants *Grants
TokenHash [64]byte
}
+
+// AuthorizationPayload holds temporary approval while the oauth2 cycle is
+// in process
+type AuthorizationPayload struct {
+ Grants string
+ ClientKey string
+ UserID int
+}
M logic.go => logic.go +4 -4
@@ 25,7 25,10 @@ func OAuth2(ctx context.Context, token string, fetch gobwebs.UserFetch) (*TokenU
hash := sha512.Sum512([]byte(token))
hashStr := hex.EncodeToString(hash[:])
opts := &database.FilterOptions{
- Filter: sq.Eq{"token_hash": hashStr},
+ Filter: sq.And{
+ sq.Eq{"token_hash": hashStr},
+ sq.Expr("expires > NOW() at time zone 'UTC'"),
+ },
}
grants, err := GetGrants(ctx, opts)
if err != nil {
@@ 39,9 42,6 @@ func OAuth2(ctx context.Context, token string, fetch gobwebs.UserFetch) (*TokenU
return nil, fmt.Errorf("Error with provided OAuth 2.0 bearer token")
}
grant := grants[0]
- if grant.IsExpired() {
- return nil, fmt.Errorf("Invalid or expired OAuth 2.0 bearer token")
- }
bt.Issued = ToTimestamp(grant.Issued)
gt := DecodeGrants(bt.Grants)
M models.go => models.go +8 -0
@@ 31,3 31,11 @@ type Grant struct {
UserID int `db:"user_id"`
ClientID sql.NullInt64 `db:"client_id"`
}
+
+// Authorization ...
+type Authorization struct {
+ ID int `db:"id"`
+ Code string `db:"code"`
+ Payload string `db:"payload"`
+ CreatedOn time.Time `db:"created_on"`
+}
A oauth2_authorizations.go => oauth2_authorizations.go +92 -0
@@ 0,0 1,92 @@
+package oauth2
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ sq "github.com/Masterminds/squirrel"
+ "hg.code.netlandish.com/~netlandish/gobwebs/database"
+)
+
+// GetAuthorizations retuns oauth2 authorizations using the given filters
+func GetAuthorizations(ctx context.Context, opts *database.FilterOptions) ([]*Authorization, error) {
+ if opts == nil {
+ opts = &database.FilterOptions{}
+ }
+ auths := make([]*Authorization, 0)
+ if err := database.WithTx(ctx, database.TxOptionsRO, func(tx *sql.Tx) error {
+ q := opts.GetBuilder(nil)
+ rows, err := q.
+ Columns("id", "code", "payload", "created_on").
+ From("oauth2_authorizations").
+ PlaceholderFormat(sq.Dollar).
+ RunWith(tx).
+ QueryContext(ctx)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil
+ }
+ return err
+ }
+ defer rows.Close()
+
+ for rows.Next() {
+ var a Authorization
+ if err = rows.Scan(&a.ID, &a.Code, &a.Payload, &a.CreatedOn); err != nil {
+ return err
+ }
+ auths = append(auths, &a)
+ }
+ return nil
+ }); err != nil {
+ return nil, err
+ }
+ return auths, nil
+}
+
+// Store will save a client
+func (a *Authorization) Store(ctx context.Context) error {
+ err := database.WithTx(ctx, nil, func(tx *sql.Tx) error {
+ var err error
+ if a.ID == 0 {
+ err = sq.
+ Insert("oauth2_authorizations").
+ Columns("code", "payload").
+ Values(a.Code, a.Payload).
+ Suffix(`RETURNING (id)`).
+ PlaceholderFormat(sq.Dollar).
+ RunWith(tx).
+ ScanContext(ctx, &a.ID)
+ } else {
+ err = sq.
+ Update("oauth2_authorizations").
+ Set("code", a.Code).
+ Set("payload", a.Payload).
+ Where("id = ?", a.ID).
+ Suffix(`RETURNING (id)`).
+ PlaceholderFormat(sq.Dollar).
+ RunWith(tx).
+ ScanContext(ctx, &a.ID)
+ }
+ return err
+ })
+ return err
+}
+
+// Delete will delete this rate
+func (a *Authorization) Delete(ctx context.Context) error {
+ if a.ID == 0 {
+ return fmt.Errorf("Authorization object is not populated")
+ }
+ err := database.WithTx(ctx, nil, func(tx *sql.Tx) error {
+ _, err := sq.
+ Delete("oauth2_authorizations").
+ Where("id = ?", a.ID).
+ PlaceholderFormat(sq.Dollar).
+ RunWith(tx).
+ ExecContext(ctx)
+ return err
+ })
+ return err
+}
M oauth2_clients.go => oauth2_clients.go +15 -0
@@ 54,6 54,21 @@ func GetClients(ctx context.Context, opts *database.FilterOptions) ([]*Client, e
return clients, nil
}
+// GetClientByID will fetch a client by given client id
+func GetClientByID(ctx context.Context, clientID string) (*Client, error) {
+ opts := &database.FilterOptions{
+ Filter: sq.Eq{"key": clientID},
+ }
+ clients, err := GetClients(ctx, opts)
+ if err != nil {
+ return nil, err
+ }
+ if len(clients) == 0 {
+ return nil, nil
+ }
+ return clients[0], nil
+}
+
// Store will save a client
func (c *Client) Store(ctx context.Context) error {
err := database.WithTx(ctx, nil, func(tx *sql.Tx) error {
M oauth2_grants.go => oauth2_grants.go +6 -0
@@ 85,6 85,12 @@ func (g *Grant) IsExpired() bool {
return time.Now().UTC().After(g.Expires.UTC())
}
+// Revoke ...
+func (g *Grant) Revoke(ctx context.Context) error {
+ g.Expires = time.Now().UTC()
+ return g.Store(ctx)
+}
+
// Delete will delete this rate
func (g *Grant) Delete(ctx context.Context) error {
if g.ID == 0 {
M routes.go => routes.go +174 -3
@@ 1,10 1,14 @@
package oauth2
import (
+ "crypto/rand"
"crypto/sha512"
"encoding/hex"
+ "encoding/json"
"fmt"
"net/http"
+ "net/url"
+ "strings"
"time"
sq "github.com/Masterminds/squirrel"
@@ 15,11 19,18 @@ import (
"hg.code.netlandish.com/~netlandish/gobwebs/server"
)
+// ServiceConfig let's you add basic config variables to service
+type ServiceConfig struct {
+ DocumentationURL string
+ Scopes []string
+}
+
// Service is the base accounts service struct
type Service struct {
name string
eg *echo.Group
helper Helper
+ config *ServiceConfig
}
// RegisterRoutes ...
@@ 33,6 44,8 @@ func (s *Service) RegisterRoutes() {
s.eg.GET("/clients", s.ListClients).Name = s.RouteName("list_clients")
s.eg.GET("/clients/add", s.AddClient).Name = s.RouteName("add_client")
s.eg.POST("/clients/add", s.AddClient).Name = s.RouteName("add_client_post")
+ s.eg.GET("/authorize", s.Authorize).Name = s.RouteName("authorize")
+ s.eg.POST("/authorize", s.AuthorizePOST).Name = s.RouteName("authorize_post")
}
// ListPersonal ...
@@ 154,6 167,163 @@ func (s *Service) AddClient(c echo.Context) error {
return gctx.Render(http.StatusOK, "oauth2_add_client.html", gmap)
}
+func oauth2Redirect(c echo.Context, redirectURI string, params gobwebs.Map) error {
+ parts, err := url.Parse(redirectURI)
+ if err != nil {
+ return err
+ }
+ qs := parts.Query()
+ for k, v := range params {
+ qs.Set(k, v.(string))
+ }
+ parts.RawQuery = qs.Encode()
+ return c.Redirect(http.StatusMovedPermanently, parts.String())
+}
+
+func (s *Service) authorizeError(c echo.Context,
+ redirectURI, state, errorCode, errorDescription string) error {
+ if redirectURI == "" {
+ gctx := c.(*server.Context)
+ return gctx.Render(http.StatusOK, "oauth2_error.html", gobwebs.Map{
+ "code": errorCode,
+ "description": errorDescription,
+ })
+ }
+ return oauth2Redirect(c, redirectURI, gobwebs.Map{
+ "error": errorCode,
+ "error_description": errorDescription,
+ "error_uri": s.config.DocumentationURL,
+ "state": state,
+ })
+}
+
+// Authorize ...
+func (s *Service) Authorize(c echo.Context) error {
+ respType := c.QueryParam("response_type")
+ clientID := c.QueryParam("client_id")
+ scope := c.QueryParam("scope")
+ state := c.QueryParam("state")
+ redirectURL := c.QueryParam("redirect_uri")
+
+ if clientID == "" {
+ return s.authorizeError(c, "", state, "invalid_request",
+ "The client_id parameter is required")
+ }
+
+ client, err := GetClientByID(c.Request().Context(), clientID)
+ if err != nil {
+ return s.authorizeError(c, "", state, "server_error", err.Error())
+ }
+ if client == nil {
+ return s.authorizeError(c, "", state, "invalid_request", "Invalid client ID")
+ }
+
+ if redirectURL != "" && redirectURL != client.RedirectURL {
+ return s.authorizeError(c, "", state, "invalid_request",
+ "The redirect_uri parameter doesn't match the registered client's")
+ }
+ if respType != "code" {
+ return s.authorizeError(c, redirectURL, state, "unsupported_response_type",
+ "The response_type parameter must be set to 'code'")
+ }
+ if scope == "" {
+ return s.authorizeError(c, redirectURL, state, "invalid_scope",
+ "The scope parameter is required")
+ }
+
+ grants := DecodeGrants(scope)
+ errors := grants.Validate(s.config.Scopes)
+ if len(errors) > 0 {
+ return s.authorizeError(c, redirectURL, state,
+ "invalid_scope", strings.Join(errors, ", "))
+ }
+
+ gctx := c.(*server.Context)
+ return gctx.Render(http.StatusOK, "oauth2_authorization.html", gobwebs.Map{
+ "client": client,
+ "grants": grants.List(),
+ "client_id": clientID,
+ "redirect_uri": redirectURL,
+ "state": state,
+ })
+}
+
+// AuthorisePOST ...
+func (s *Service) AuthorizePOST(c echo.Context) error {
+ params, err := c.FormParams()
+ if err != nil {
+ return err
+ }
+ clientID := params.Get("client_id")
+ redirectURL := params.Get("redirect_uri")
+ state := params.Get("state")
+
+ if params.Has("reject") {
+ return s.authorizeError(c, redirectURL, state, "access_denied",
+ "The resource owner denied the request.")
+ }
+
+ subgrants := []string{}
+ // XXX csrf shouldn't be hard coded here
+ skip := []string{"accept", "client_id", "redirect_uri", "state", "csrf"}
+ for grant := range params {
+ if contains(skip, grant) {
+ continue
+ }
+ subgrants = append(subgrants, grant)
+ }
+
+ grants := DecodeGrants(strings.Join(subgrants, " "))
+ errors := grants.Validate(s.config.Scopes)
+ if len(errors) > 0 {
+ return s.authorizeError(c, redirectURL, state,
+ "invalid_scope", strings.Join(errors, ", "))
+ }
+
+ client, err := GetClientByID(c.Request().Context(), clientID)
+ if err != nil {
+ return s.authorizeError(c, "", state, "server_error", err.Error())
+ }
+ if client == nil {
+ return s.authorizeError(c, "", state, "invalid_request", "Invalid client ID")
+ }
+
+ var seed [64]byte
+ gctx := c.(*server.Context)
+ n, err := rand.Read(seed[:])
+ if err != nil || n != len(seed) {
+ panic(err)
+ }
+ hash := sha512.Sum512(seed[:])
+ code := hex.EncodeToString(hash[:])[:32]
+
+ payload := AuthorizationPayload{
+ Grants: grants.encoded,
+ ClientKey: clientID,
+ UserID: int(gctx.User.GetID()),
+ }
+ data, err := json.Marshal(&payload)
+ if err != nil {
+ panic(err)
+ }
+
+ auth := &Authorization{
+ Code: code,
+ Payload: string(data),
+ }
+ if err := auth.Store(c.Request().Context()); err != nil {
+ return s.authorizeError(c, "", state, "server_error", err.Error())
+ }
+
+ gmap := gobwebs.Map{
+ "code": code,
+ }
+ if state != "" {
+ gmap["state"] = state
+ }
+ return oauth2Redirect(c, client.RedirectURL, gmap)
+}
+
// Introspect ...
func (s *Service) Introspect(c echo.Context) error {
req := c.Request()
@@ 162,10 332,11 @@ func (s *Service) Introspect(c echo.Context) error {
retErr := struct {
Err string `json:"error"`
Desc string `json:"error_description"`
- URI string `json:"error_url"` // TODO Make this customizable
+ URI string `json:"error_url"`
}{
Err: "invalid request",
Desc: "Content-Type must be application/x-www-form-urlencoded",
+ URI: s.config.DocumentationURL,
}
return c.JSON(http.StatusBadRequest, &retErr)
}
@@ 207,11 378,11 @@ func (s *Service) RouteName(value string) string {
}
// NewService return service
-func NewService(eg *echo.Group, name string, helper Helper) *Service {
+func NewService(eg *echo.Group, name string, helper Helper, config *ServiceConfig) *Service {
if name == "" {
name = "oauth2"
}
- service := &Service{name: name, eg: eg, helper: helper}
+ service := &Service{name: name, eg: eg, helper: helper, config: config}
service.RegisterRoutes()
return service
}
M schema.sql => schema.sql +11 -0
@@ 45,3 45,14 @@ CREATE TABLE oauth2_grants (
CREATE INDEX oauth2_grants_id_idx ON oauth2_grants (id);
CREATE INDEX oauth2_grants_user_id_idx ON oauth2_grants (user_id);
CREATE INDEX oauth2_grants_client_id_idx ON oauth2_grants (client_id);
+
+
+CREATE TABLE oauth2_authorizations (
+ id serial PRIMARY KEY,
+ code character varying(128) NOT NULL,
+ payload character varying NOT NULL,
+ created_on TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
+);
+
+CREATE INDEX oauth2_authorizations_id_idx ON oauth2_authorizations (id);
+CREATE INDEX oauth2_authorizations_code_idx ON oauth2_authorizations (code);