Oauth2 support

This commit is contained in:
Jamie Curnow
2024-11-06 20:33:51 +10:00
parent f23299f793
commit 208037946f
25 changed files with 529 additions and 30 deletions

View File

@ -70,8 +70,6 @@ func NewToken() func(http.ResponseWriter, *http.Request) {
switch payload.Type {
case "ldap":
newTokenLDAP(w, r, payload)
case "oidc":
newTokenOIDC(w, r, payload)
case "local":
newTokenLocal(w, r, payload)
}
@ -199,10 +197,6 @@ func newTokenLDAP(w http.ResponseWriter, r *http.Request, payload tokenPayload)
}
}
func newTokenOIDC(w http.ResponseWriter, r *http.Request, _ tokenPayload) {
h.ResultErrorJSON(w, r, http.StatusInternalServerError, "NOT YET SUPPORTED", nil)
}
// RefreshToken an existing token by given them a new one with the same claims
// Route: POST /auth/refresh
func RefreshToken() func(http.ResponseWriter, *http.Request) {

View File

@ -21,11 +21,22 @@ func getPageInfoFromRequest(r *http.Request) (model.PageInfo, error) {
return pageInfo, err
}
// pageInfo.Sort = middleware.GetSortFromContext(r)
return pageInfo, nil
}
func getQueryVarString(r *http.Request, varName string, required bool, defaultValue string) (string, error) {
queryValues := r.URL.Query()
varValue := queryValues.Get(varName)
if varValue == "" && required {
return "", eris.Errorf("%v was not supplied in the request", varName)
} else if varValue == "" {
return defaultValue, nil
}
return varValue, nil
}
func getQueryVarInt(r *http.Request, varName string, required bool, defaultValue int) (int, error) {
queryValues := r.URL.Query()
varValue := queryValues.Get(varName)

View File

@ -0,0 +1,156 @@
package handler
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
h "npm/internal/api/http"
"npm/internal/entity/auth"
"npm/internal/entity/setting"
"npm/internal/entity/user"
"npm/internal/errors"
njwt "npm/internal/jwt"
"npm/internal/logger"
"gorm.io/gorm"
)
// getRequestIPAddress will use X-FORWARDED-FOR header if it exists
// otherwise it will use RemoteAddr
func getRequestIPAddress(r *http.Request) string {
// this Get is case insensitive
xff := r.Header.Get("X-FORWARDED-FOR")
if xff != "" {
ip, _, _ := strings.Cut(xff, ",")
return strings.TrimSpace(ip)
}
return r.RemoteAddr
}
// OAuthLogin ...
// Route: GET /oauth/login
func OAuthLogin() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if !setting.AuthMethodEnabled(auth.TypeOAuth) {
h.ResultErrorJSON(w, r, http.StatusNotFound, "Not found", nil)
return
}
redirectBase, _ := getQueryVarString(r, "redirect_base", false, "")
url, err := auth.OAuthLogin(redirectBase, getRequestIPAddress(r))
if err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return
}
h.ResultResponseJSON(w, r, http.StatusOK, url)
}
}
// OAuthRedirect ...
// Route: GET /oauth/redirect
func OAuthRedirect() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if !setting.AuthMethodEnabled(auth.TypeOAuth) {
h.ResultErrorJSON(w, r, http.StatusNotFound, "Not found", nil)
return
}
code, err := getQueryVarString(r, "code", true, "")
if err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return
}
ou, err := auth.OAuthReturn(r.Context(), code, getRequestIPAddress(r))
if err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return
}
if ou.Identifier == "" {
h.ResultErrorJSON(w, r, http.StatusBadRequest, "User found, but OAuth identifier seems misconfigured", nil)
return
}
jwt, err := newTokenOAuth(ou)
if err != nil {
h.ResultErrorJSON(w, r, http.StatusInternalServerError, err.Error(), nil)
return
}
// encode jwt to json
j, _ := json.Marshal(jwt)
// Redirect to frontend with success
http.Redirect(w, r, fmt.Sprintf("/?token_response=%s", url.QueryEscape(string(j))), http.StatusSeeOther)
}
}
// newTokenOAuth takes a OAuthUser and creates a new token,
// optionally creating a new user if one does not exist
func newTokenOAuth(ou *auth.OAuthUser) (*njwt.GeneratedResponse, error) {
// Get OAuth settings
oAuthSettings, err := setting.GetOAuthSettings()
if err != nil {
logger.Error("OAuth settings not found", err)
return nil, err
}
// Get Auth by identity
authObj, authErr := auth.GetByIdenityType(ou.GetID(), auth.TypeOAuth)
if authErr == gorm.ErrRecordNotFound {
// Auth is not found for this identity. We can create it
if !oAuthSettings.AutoCreateUser {
// user does not have an auth record
// and auto create is disabled. Showing account disabled error
// for the time being
return nil, errors.ErrUserDisabled
}
// Attempt to find user by email
foundUser, err := user.GetByEmail(ou.GetEmail())
if err == gorm.ErrRecordNotFound {
// User not found, create user
foundUser, err = user.CreateFromOAuthUser(ou)
if err != nil {
logger.Error("user.CreateFromOAuthUser", err)
return nil, err
}
logger.Info("Created user from OAuth: %s, %s", ou.GetID(), foundUser.Email)
} else if err != nil {
logger.Error("user.GetByEmail", err)
return nil, err
}
// Create auth record and attach to this user
authObj = auth.Model{
UserID: foundUser.ID,
Type: auth.TypeOAuth,
Identity: ou.GetID(),
}
if err := authObj.Save(); err != nil {
logger.Error("auth.Save", err)
return nil, err
}
logger.Info("Created OAuth auth for user: %s, %s", ou.GetID(), foundUser.Email)
} else if authErr != nil {
logger.Error("auth.GetByIdenityType", err)
return nil, authErr
}
userObj, userErr := user.GetByID(authObj.UserID)
if userErr != nil {
return nil, userErr
}
if userObj.IsDisabled {
return nil, errors.ErrUserDisabled
}
jwt, err := njwt.Generate(&userObj, false)
return &jwt, err
}

View File

@ -0,0 +1,16 @@
package middleware
import (
"net/http"
"npm/internal/logger"
)
// Log will print out route information to the logger
// only when debug is enabled
func Log(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger.Debug("Request: %s %s", r.Method, r.URL.Path)
next.ServeHTTP(w, r)
})
}

View File

@ -50,6 +50,7 @@ func NewRouter() http.Handler {
middleware.Expansion,
middleware.DecodeAuth(),
middleware.BodyContext(),
middleware.Log,
)
return applyRoutes(r)
@ -61,6 +62,12 @@ func applyRoutes(r chi.Router) chi.Router {
r.NotFound(handler.NotFound())
r.MethodNotAllowed(handler.NotAllowed())
// OAuth endpoints aren't technically API endpoints
r.With(middleware.EnforceSetup()).Route("/oauth", func(r chi.Router) {
r.Get("/login", handler.OAuthLogin())
r.Get("/redirect", handler.OAuthRedirect())
})
// SSE - requires a sse token as the `jwt` get parameter
// Exists inside /api but it's here so that we can skip the Timeout middleware
// that applies to other endpoints.

View File

@ -18,7 +18,7 @@ func GetToken() string {
"properties": {
"type": {
"type": "string",
"enum": ["local", "ldap", "oidc"]
"enum": ["local", "ldap"]
},
"identity": %s,
"secret": %s

View File

@ -70,7 +70,7 @@ func ldapSearchUser(l *ldap3.Conn, ldapSettings setting.LDAPSettings, username s
0,
false,
strings.Replace(ldapSettings.SelfFilter, "{{USERNAME}}", username, 1),
nil, // []string{"name"},
nil,
nil,
)

View File

@ -12,7 +12,7 @@ import (
const (
TypeLocal = "local"
TypeLDAP = "ldap"
TypeOIDC = "oidc"
TypeOAuth = "oauth"
)
// Model is the model

View File

@ -0,0 +1,216 @@
package auth
import (
"context"
"encoding/json"
"fmt"
"io"
"strings"
"time"
"npm/internal/entity/setting"
"npm/internal/logger"
cache "github.com/patrickmn/go-cache"
"github.com/rotisserie/eris"
"golang.org/x/oauth2"
)
// AuthCache is a cache item that stores the Admin API data for each admin that has been requesting endpoints
var OAuthCache *cache.Cache
// OAuthCacheInit will create a new Memory Cache
func OAuthCacheInit() {
if OAuthCache == nil {
logger.Debug("Creating a new OAuthCache")
OAuthCache = cache.New(5*time.Minute, 5*time.Minute)
}
}
// OAuthUser is the OAuth User
type OAuthUser struct {
Identifier string `json:"identifier"`
Token string `json:"token"`
Resource map[string]interface{} `json:"resource"`
}
// GetEmail will return an email address even if it can't be known in the
// Resource
func (m *OAuthUser) GetResourceField(field string) string {
if m.Resource != nil {
if value, ok := m.Resource[field]; ok {
return value.(string)
}
}
return ""
}
// GetEmail will return an email address even if it can't be known in the
// Resource
func (m *OAuthUser) GetID() string {
if m.Identifier != "" {
return m.Identifier
}
fields := []string{
"uid",
"user_id",
"username",
"preferred_username",
"email",
"mail",
}
for _, field := range fields {
if val := m.GetResourceField(field); val != "" {
return val
}
}
return ""
}
// GetName attempts to get a name from the resource
// using different fields
func (m *OAuthUser) GetName() string {
fields := []string{
"nickname",
"given_name",
"name",
"preferred_username",
"username",
}
for _, field := range fields {
if name := m.GetResourceField(field); name != "" {
return name
}
}
// Fallback:
return m.Identifier
}
// GetEmail will return an email address even if it can't be known in the
// Resource
func (m *OAuthUser) GetEmail() string {
// See if there's an email field first
if email := m.GetResourceField("email"); email != "" {
return email
}
// Return the identifier if it looks like an email
if m.Identifier != "" {
if strings.Contains(m.Identifier, "@") {
return m.Identifier
}
return fmt.Sprintf("%s@oauth", m.Identifier)
}
return ""
}
func getOAuth2Config() (*oauth2.Config, *setting.OAuthSettings, error) {
oauthSettings, err := setting.GetOAuthSettings()
if err != nil {
return nil, nil, err
}
if oauthSettings.ClientID == "" || oauthSettings.ClientSecret == "" || oauthSettings.AuthURL == "" || oauthSettings.TokenURL == "" {
return nil, nil, eris.New("oauth-settings-incorrect")
}
return &oauth2.Config{
ClientID: oauthSettings.ClientID,
ClientSecret: oauthSettings.ClientSecret,
Scopes: oauthSettings.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: oauthSettings.AuthURL,
TokenURL: oauthSettings.TokenURL,
},
}, &oauthSettings, nil
}
// OAuthLogin ...
func OAuthLogin(redirectBase, ipAddress string) (string, error) {
OAuthCacheInit()
conf, _, err := getOAuth2Config()
if err != nil {
return "", err
}
verifier := oauth2.GenerateVerifier()
OAuthCache.Set(getCacheKey(ipAddress), verifier, cache.DefaultExpiration)
// todo: state should be unique to the incoming IP address of the requester, I guess
url := conf.AuthCodeURL("state", oauth2.AccessTypeOnline, oauth2.S256ChallengeOption(verifier))
if redirectBase != "" {
url = url + "&redirect_uri=" + redirectBase + "/oauth/redirect"
}
logger.Debug("URL: %s", url)
return url, nil
}
// OAuthReturn ...
func OAuthReturn(ctx context.Context, code, ipAddress string) (*OAuthUser, error) {
// Just in case...
OAuthCacheInit()
conf, oauthSettings, err := getOAuth2Config()
if err != nil {
return nil, err
}
verifier, found := OAuthCache.Get(getCacheKey(ipAddress))
if !found {
return nil, eris.New("oauth-verifier-not-found")
}
// Use the authorization code that is pushed to the redirect
// URL. Exchange will do the handshake to retrieve the
// initial access token. The HTTP Client returned by
// conf.Client will refresh the token as necessary.
tok, err := conf.Exchange(ctx, code, oauth2.VerifierOption(verifier.(string)))
if err != nil {
return nil, err
}
// At this stage, the token is the JWT as given by the oauth server.
// we need to use that to get more info about this user,
// and then we'll create our own jwt for use later.
client := conf.Client(ctx, tok)
resp, err := client.Get(oauthSettings.ResourceURL)
if err != nil {
return nil, err
}
// nolint: errcheck, gosec
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
ou := OAuthUser{
Token: tok.AccessToken,
}
// unmarshal the body into a interface
if err := json.Unmarshal(body, &ou.Resource); err != nil {
return nil, err
}
// Attempt to get the identifier from the resource
if oauthSettings.Identifier != "" {
ou.Identifier = ou.GetResourceField(oauthSettings.Identifier)
}
return &ou, nil
}
func getCacheKey(ipAddress string) string {
return fmt.Sprintf("oauth-%s", ipAddress)
}

View File

@ -2,19 +2,31 @@ package setting
import (
"encoding/json"
"slices"
)
// GetAuthMethods returns the authentication methods enabled for this site
func GetAuthMethods() ([]string, error) {
var l []string
var m Model
if err := m.LoadByName("auth-methods"); err != nil {
return l, err
return nil, err
}
if err := json.Unmarshal([]byte(m.Value.String()), &l); err != nil {
return l, err
var r []string
if err := json.Unmarshal([]byte(m.Value.String()), &r); err != nil {
return nil, err
}
return l, nil
return r, nil
}
// AuthMethodEnabled checks that the auth method given is
// enabled in the db setting
func AuthMethodEnabled(method string) bool {
r, err := GetAuthMethods()
if err != nil {
return false
}
return slices.Contains(r, method)
}

View File

@ -0,0 +1,42 @@
package setting
import (
"encoding/json"
)
// OAuthSettings are the settings for OAuth that come from
// the `oauth-auth` setting value
type OAuthSettings struct {
AutoCreateUser bool `json:"auto_create_user"`
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
AuthURL string `json:"authorization_url"`
TokenURL string `json:"token_url"`
Identifier string `json:"identifier"`
LogoutURL string `json:"logout_url"`
Scopes []string `json:"scopes"`
ResourceURL string `json:"resource_url"`
}
// GetOAuthSettings will return the OAuth settings
func GetOAuthSettings() (OAuthSettings, error) {
var o OAuthSettings
var m Model
if err := m.LoadByName("oauth-auth"); err != nil {
return o, err
}
if err := json.Unmarshal([]byte(m.Value.String()), &o); err != nil {
return o, err
}
o.ApplyDefaults()
return o, nil
}
// ApplyDefaults will ensure there are defaults set
func (m *OAuthSettings) ApplyDefaults() {
if m.Identifier == "" {
m.Identifier = "email"
}
}

View File

@ -117,3 +117,14 @@ func CreateFromLDAPUser(ldapUser *auth.LDAPUser) (Model, error) {
user.generateGravatar()
return user, err
}
// CreateFromOAuthUser will create a user from an OAuth user object
func CreateFromOAuthUser(ou *auth.OAuthUser) (Model, error) {
user := Model{
Email: ou.GetEmail(),
Name: ou.GetName(),
}
err := user.Save()
user.generateGravatar()
return user, err
}