mirror of
				https://github.com/NginxProxyManager/nginx-proxy-manager.git
				synced 2025-10-30 15:23:34 +00:00 
			
		
		
		
	Improvements to enforce middleware, linting, returning 404 properly
This commit is contained in:
		| @@ -11,6 +11,8 @@ import ( | ||||
| 	"npm/internal/entity/user" | ||||
| 	"npm/internal/errors" | ||||
| 	"npm/internal/logger" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| type setAuthModel struct { | ||||
| @@ -41,7 +43,10 @@ func SetAuth() func(http.ResponseWriter, *http.Request) { | ||||
|  | ||||
| 		// Load user | ||||
| 		thisUser, thisUserErr := user.GetByID(userID) | ||||
| 		if thisUserErr != nil { | ||||
| 		if thisUserErr == gorm.ErrRecordNotFound { | ||||
| 			h.NotFound(w, r) | ||||
| 			return | ||||
| 		} else if thisUserErr != nil { | ||||
| 			h.ResultErrorJSON(w, r, http.StatusBadRequest, thisUserErr.Error(), nil) | ||||
| 			return | ||||
| 		} | ||||
|   | ||||
| @@ -1,7 +1,6 @@ | ||||
| package handler | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| @@ -12,6 +11,8 @@ import ( | ||||
| 	"npm/internal/api/middleware" | ||||
| 	"npm/internal/entity/certificateauthority" | ||||
| 	"npm/internal/logger" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| // GetCertificateAuthorities will return a list of Certificate Authorities | ||||
| @@ -46,7 +47,7 @@ func GetCertificateAuthority() func(http.ResponseWriter, *http.Request) { | ||||
|  | ||||
| 		item, err := certificateauthority.GetByID(caID) | ||||
| 		switch err { | ||||
| 		case sql.ErrNoRows: | ||||
| 		case gorm.ErrRecordNotFound: | ||||
| 			h.NotFound(w, r) | ||||
| 		case nil: | ||||
| 			h.ResultResponseJSON(w, r, http.StatusOK, item) | ||||
| @@ -100,7 +101,7 @@ func UpdateCertificateAuthority() func(http.ResponseWriter, *http.Request) { | ||||
|  | ||||
| 		ca, err := certificateauthority.GetByID(caID) | ||||
| 		switch err { | ||||
| 		case sql.ErrNoRows: | ||||
| 		case gorm.ErrRecordNotFound: | ||||
| 			h.NotFound(w, r) | ||||
| 		case nil: | ||||
| 			bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte) | ||||
| @@ -140,7 +141,7 @@ func DeleteCertificateAuthority() func(http.ResponseWriter, *http.Request) { | ||||
|  | ||||
| 		item, err := certificateauthority.GetByID(caID) | ||||
| 		switch err { | ||||
| 		case sql.ErrNoRows: | ||||
| 		case gorm.ErrRecordNotFound: | ||||
| 			h.NotFound(w, r) | ||||
| 		case nil: | ||||
| 			h.ResultResponseJSON(w, r, http.StatusOK, item.Delete()) | ||||
|   | ||||
| @@ -1,7 +1,6 @@ | ||||
| package handler | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| @@ -14,6 +13,8 @@ import ( | ||||
| 	"npm/internal/entity/host" | ||||
| 	"npm/internal/jobqueue" | ||||
| 	"npm/internal/logger" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| // GetCertificates will return a list of Certificates | ||||
| @@ -138,7 +139,7 @@ func getCertificateFromRequest(w http.ResponseWriter, r *http.Request) *certific | ||||
|  | ||||
| 	certificateObject, err := certificate.GetByID(certificateID) | ||||
| 	switch err { | ||||
| 	case sql.ErrNoRows: | ||||
| 	case gorm.ErrRecordNotFound: | ||||
| 		h.NotFound(w, r) | ||||
| 	case nil: | ||||
| 		return &certificateObject | ||||
|   | ||||
| @@ -1,7 +1,6 @@ | ||||
| package handler | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| @@ -12,6 +11,8 @@ import ( | ||||
| 	"npm/internal/dnsproviders" | ||||
| 	"npm/internal/entity/dnsprovider" | ||||
| 	"npm/internal/errors" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| // GetDNSProviders will return a list of DNS Providers | ||||
| @@ -46,7 +47,7 @@ func GetDNSProvider() func(http.ResponseWriter, *http.Request) { | ||||
|  | ||||
| 		item, err := dnsprovider.GetByID(providerID) | ||||
| 		switch err { | ||||
| 		case sql.ErrNoRows: | ||||
| 		case gorm.ErrRecordNotFound: | ||||
| 			h.NotFound(w, r) | ||||
| 		case nil: | ||||
| 			h.ResultResponseJSON(w, r, http.StatusOK, item) | ||||
| @@ -95,7 +96,7 @@ func UpdateDNSProvider() func(http.ResponseWriter, *http.Request) { | ||||
|  | ||||
| 		item, err := dnsprovider.GetByID(providerID) | ||||
| 		switch err { | ||||
| 		case sql.ErrNoRows: | ||||
| 		case gorm.ErrRecordNotFound: | ||||
| 			h.NotFound(w, r) | ||||
| 		case nil: | ||||
| 			bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte) | ||||
| @@ -130,7 +131,7 @@ func DeleteDNSProvider() func(http.ResponseWriter, *http.Request) { | ||||
|  | ||||
| 		item, err := dnsprovider.GetByID(providerID) | ||||
| 		switch err { | ||||
| 		case sql.ErrNoRows: | ||||
| 		case gorm.ErrRecordNotFound: | ||||
| 			h.NotFound(w, r) | ||||
| 		case nil: | ||||
| 			h.ResultResponseJSON(w, r, http.StatusOK, item.Delete()) | ||||
|   | ||||
| @@ -1,7 +1,6 @@ | ||||
| package handler | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| @@ -14,6 +13,8 @@ import ( | ||||
| 	"npm/internal/logger" | ||||
| 	"npm/internal/nginx" | ||||
| 	"npm/internal/validator" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| // GetHosts will return a list of Hosts | ||||
| @@ -48,7 +49,7 @@ func GetHost() func(http.ResponseWriter, *http.Request) { | ||||
|  | ||||
| 		item, err := host.GetByID(hostID) | ||||
| 		switch err { | ||||
| 		case sql.ErrNoRows: | ||||
| 		case gorm.ErrRecordNotFound: | ||||
| 			h.NotFound(w, r) | ||||
| 		case nil: | ||||
| 			// nolint: errcheck,gosec | ||||
| @@ -111,7 +112,7 @@ func UpdateHost() func(http.ResponseWriter, *http.Request) { | ||||
|  | ||||
| 		hostObject, err := host.GetByID(hostID) | ||||
| 		switch err { | ||||
| 		case sql.ErrNoRows: | ||||
| 		case gorm.ErrRecordNotFound: | ||||
| 			h.NotFound(w, r) | ||||
| 		case nil: | ||||
| 			bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte) | ||||
| @@ -156,7 +157,7 @@ func DeleteHost() func(http.ResponseWriter, *http.Request) { | ||||
|  | ||||
| 		item, err := host.GetByID(hostID) | ||||
| 		switch err { | ||||
| 		case sql.ErrNoRows: | ||||
| 		case gorm.ErrRecordNotFound: | ||||
| 			h.NotFound(w, r) | ||||
| 		case nil: | ||||
| 			h.ResultResponseJSON(w, r, http.StatusOK, item.Delete()) | ||||
| @@ -181,7 +182,7 @@ func GetHostNginxConfig(format string) func(http.ResponseWriter, *http.Request) | ||||
|  | ||||
| 		item, err := host.GetByID(hostID) | ||||
| 		switch err { | ||||
| 		case sql.ErrNoRows: | ||||
| 		case gorm.ErrRecordNotFound: | ||||
| 			h.NotFound(w, r) | ||||
| 		case nil: | ||||
| 			// Get the config from disk | ||||
|   | ||||
| @@ -1,7 +1,6 @@ | ||||
| package handler | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| @@ -10,6 +9,8 @@ import ( | ||||
| 	h "npm/internal/api/http" | ||||
| 	"npm/internal/api/middleware" | ||||
| 	"npm/internal/entity/nginxtemplate" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| // GetNginxTemplates will return a list of Nginx Templates | ||||
| @@ -44,7 +45,7 @@ func GetNginxTemplate() func(http.ResponseWriter, *http.Request) { | ||||
|  | ||||
| 		item, err := nginxtemplate.GetByID(templateID) | ||||
| 		switch err { | ||||
| 		case sql.ErrNoRows: | ||||
| 		case gorm.ErrRecordNotFound: | ||||
| 			h.NotFound(w, r) | ||||
| 		case nil: | ||||
| 			h.ResultResponseJSON(w, r, http.StatusOK, item) | ||||
| @@ -95,7 +96,7 @@ func UpdateNginxTemplate() func(http.ResponseWriter, *http.Request) { | ||||
|  | ||||
| 		nginxTemplate, err := nginxtemplate.GetByID(templateID) | ||||
| 		switch err { | ||||
| 		case sql.ErrNoRows: | ||||
| 		case gorm.ErrRecordNotFound: | ||||
| 			h.NotFound(w, r) | ||||
| 		case nil: | ||||
| 			bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte) | ||||
| @@ -130,7 +131,7 @@ func DeleteNginxTemplate() func(http.ResponseWriter, *http.Request) { | ||||
|  | ||||
| 		item, err := nginxtemplate.GetByID(templateID) | ||||
| 		switch err { | ||||
| 		case sql.ErrNoRows: | ||||
| 		case gorm.ErrRecordNotFound: | ||||
| 			h.NotFound(w, r) | ||||
| 		case nil: | ||||
| 			h.ResultResponseJSON(w, r, http.StatusOK, item.Delete()) | ||||
|   | ||||
| @@ -1,7 +1,6 @@ | ||||
| package handler | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| @@ -12,6 +11,7 @@ import ( | ||||
| 	"npm/internal/entity/setting" | ||||
|  | ||||
| 	"github.com/go-chi/chi/v5" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| // GetSettings will return a list of Settings | ||||
| @@ -41,7 +41,7 @@ func GetSetting() func(http.ResponseWriter, *http.Request) { | ||||
|  | ||||
| 		item, err := setting.GetByName(name) | ||||
| 		switch err { | ||||
| 		case sql.ErrNoRows: | ||||
| 		case gorm.ErrRecordNotFound: | ||||
| 			h.NotFound(w, r) | ||||
| 		case nil: | ||||
| 			h.ResultResponseJSON(w, r, http.StatusOK, item) | ||||
| @@ -81,7 +81,7 @@ func UpdateSetting() func(http.ResponseWriter, *http.Request) { | ||||
|  | ||||
| 		setting, err := setting.GetByName(settingName) | ||||
| 		switch err { | ||||
| 		case sql.ErrNoRows: | ||||
| 		case gorm.ErrRecordNotFound: | ||||
| 			h.NotFound(w, r) | ||||
| 		case nil: | ||||
| 			bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte) | ||||
|   | ||||
| @@ -1,7 +1,6 @@ | ||||
| package handler | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| @@ -10,6 +9,8 @@ import ( | ||||
| 	h "npm/internal/api/http" | ||||
| 	"npm/internal/api/middleware" | ||||
| 	"npm/internal/entity/stream" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| // GetStreams will return a list of Streams | ||||
| @@ -44,7 +45,7 @@ func GetStream() func(http.ResponseWriter, *http.Request) { | ||||
|  | ||||
| 		item, err := stream.GetByID(hostID) | ||||
| 		switch err { | ||||
| 		case sql.ErrNoRows: | ||||
| 		case gorm.ErrRecordNotFound: | ||||
| 			h.NotFound(w, r) | ||||
| 		case nil: | ||||
| 			h.ResultResponseJSON(w, r, http.StatusOK, item) | ||||
| @@ -93,7 +94,7 @@ func UpdateStream() func(http.ResponseWriter, *http.Request) { | ||||
|  | ||||
| 		host, err := stream.GetByID(hostID) | ||||
| 		switch err { | ||||
| 		case sql.ErrNoRows: | ||||
| 		case gorm.ErrRecordNotFound: | ||||
| 			h.NotFound(w, r) | ||||
| 		case nil: | ||||
| 			bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte) | ||||
| @@ -128,7 +129,7 @@ func DeleteStream() func(http.ResponseWriter, *http.Request) { | ||||
|  | ||||
| 		item, err := stream.GetByID(hostID) | ||||
| 		switch err { | ||||
| 		case sql.ErrNoRows: | ||||
| 		case gorm.ErrRecordNotFound: | ||||
| 			h.NotFound(w, r) | ||||
| 		case nil: | ||||
| 			h.ResultResponseJSON(w, r, http.StatusOK, item.Delete()) | ||||
|   | ||||
| @@ -1,7 +1,6 @@ | ||||
| package handler | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| @@ -15,6 +14,8 @@ import ( | ||||
| 	"npm/internal/logger" | ||||
| 	"npm/internal/nginx" | ||||
| 	"npm/internal/validator" | ||||
|  | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| // GetUpstreams will return a list of Upstreams | ||||
| @@ -49,7 +50,7 @@ func GetUpstream() func(http.ResponseWriter, *http.Request) { | ||||
|  | ||||
| 		item, err := upstream.GetByID(upstreamID) | ||||
| 		switch err { | ||||
| 		case sql.ErrNoRows: | ||||
| 		case gorm.ErrRecordNotFound: | ||||
| 			h.NotFound(w, r) | ||||
| 		case nil: | ||||
| 			// nolint: errcheck,gosec | ||||
| @@ -149,7 +150,7 @@ func DeleteUpstream() func(http.ResponseWriter, *http.Request) { | ||||
|  | ||||
| 		item, err := upstream.GetByID(upstreamID) | ||||
| 		switch err { | ||||
| 		case sql.ErrNoRows: | ||||
| 		case gorm.ErrRecordNotFound: | ||||
| 			h.NotFound(w, r) | ||||
| 		case nil: | ||||
| 			// Ensure that this upstream isn't in use by a host | ||||
| @@ -180,7 +181,7 @@ func GetUpstreamNginxConfig(format string) func(http.ResponseWriter, *http.Reque | ||||
|  | ||||
| 		item, err := upstream.GetByID(upstreamID) | ||||
| 		switch err { | ||||
| 		case sql.ErrNoRows: | ||||
| 		case gorm.ErrRecordNotFound: | ||||
| 			h.NotFound(w, r) | ||||
| 		case nil: | ||||
| 			// Get the config from disk | ||||
|   | ||||
| @@ -1,7 +1,6 @@ | ||||
| package handler | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"encoding/json" | ||||
| 	"net/http" | ||||
|  | ||||
| @@ -15,6 +14,7 @@ import ( | ||||
| 	"npm/internal/logger" | ||||
|  | ||||
| 	"github.com/go-chi/chi/v5" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| // GetUsers returns all users | ||||
| @@ -48,7 +48,7 @@ func GetUser() func(http.ResponseWriter, *http.Request) { | ||||
|  | ||||
| 		item, err := user.GetByID(userID) | ||||
| 		switch err { | ||||
| 		case sql.ErrNoRows: | ||||
| 		case gorm.ErrRecordNotFound: | ||||
| 			h.NotFound(w, r) | ||||
| 		case nil: | ||||
| 			// nolint: errcheck,gosec | ||||
| @@ -72,7 +72,7 @@ func UpdateUser() func(http.ResponseWriter, *http.Request) { | ||||
|  | ||||
| 		userObject, err := user.GetByID(userID) | ||||
| 		switch err { | ||||
| 		case sql.ErrNoRows: | ||||
| 		case gorm.ErrRecordNotFound: | ||||
| 			h.NotFound(w, r) | ||||
| 		case nil: | ||||
| 			// nolint: errcheck,gosec | ||||
| @@ -136,7 +136,7 @@ func DeleteUser() func(http.ResponseWriter, *http.Request) { | ||||
|  | ||||
| 		item, err := user.GetByID(userID) | ||||
| 		switch err { | ||||
| 		case sql.ErrNoRows: | ||||
| 		case gorm.ErrRecordNotFound: | ||||
| 			h.NotFound(w, r) | ||||
| 		case nil: | ||||
| 			if err := item.Delete(); err != nil { | ||||
|   | ||||
| @@ -4,6 +4,7 @@ import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"slices" | ||||
|  | ||||
| 	c "npm/internal/api/context" | ||||
| 	h "npm/internal/api/http" | ||||
| @@ -11,7 +12,6 @@ import ( | ||||
| 	"npm/internal/entity/user" | ||||
| 	njwt "npm/internal/jwt" | ||||
| 	"npm/internal/logger" | ||||
| 	"npm/internal/util" | ||||
|  | ||||
| 	"github.com/go-chi/jwtauth/v5" | ||||
| ) | ||||
| @@ -35,7 +35,7 @@ func DecodeAuth() func(http.Handler) http.Handler { | ||||
| // Enforce is a authentication middleware to enforce access from the | ||||
| // jwtauth.Verifier middleware request context values. The Authenticator sends a 401 Unauthorised | ||||
| // response for any unverified tokens and passes the good ones through. | ||||
| func Enforce(permission string) func(http.Handler) http.Handler { | ||||
| func Enforce(permissions ...string) func(http.Handler) http.Handler { | ||||
| 	return func(next http.Handler) http.Handler { | ||||
| 		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 			ctx := r.Context() | ||||
| @@ -56,7 +56,7 @@ func Enforce(permission string) func(http.Handler) http.Handler { | ||||
| 				} | ||||
|  | ||||
| 				// Check if permissions exist for this user | ||||
| 				if permission != "" { | ||||
| 				if len(permissions) > 0 { | ||||
| 					// Since the permission that we require is not on the token, we have to get it from the DB | ||||
| 					// So we don't go crazy with hits, we will use a memory cache | ||||
| 					cacheKey := fmt.Sprintf("userCapabilties.%v", userID) | ||||
| @@ -75,9 +75,16 @@ func Enforce(permission string) func(http.Handler) http.Handler { | ||||
|  | ||||
| 					// Now check that they have the permission in their admin capabilities | ||||
| 					// full-admin can do anything | ||||
| 					if !util.SliceContainsItem(userCapabilities, user.CapabilityFullAdmin) && !util.SliceContainsItem(userCapabilities, permission) { | ||||
| 					hasOnePermission := false | ||||
| 					for _, permission := range permissions { | ||||
| 						if slices.Contains(userCapabilities, user.CapabilityFullAdmin) || slices.Contains(userCapabilities, permission) { | ||||
| 							hasOnePermission = true | ||||
| 						} | ||||
| 					} | ||||
|  | ||||
| 					if !hasOnePermission { | ||||
| 						// Access denied | ||||
| 						logger.Debug("User has: %+v but needs %s", userCapabilities, permission) | ||||
| 						logger.Debug("Enforce Failed: User has %v but needs %v", userCapabilities, permissions) | ||||
| 						h.ResultErrorJSON(w, r, http.StatusForbidden, "Forbidden", nil) | ||||
| 						return | ||||
| 					} | ||||
|   | ||||
| @@ -1,7 +1,6 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
|  | ||||
| 	h "npm/internal/api/http" | ||||
| @@ -9,15 +8,11 @@ import ( | ||||
| ) | ||||
|  | ||||
| // EnforceSetup will error if the config setup doesn't match what is required | ||||
| func EnforceSetup(shouldBeSetup bool) func(http.Handler) http.Handler { | ||||
| func EnforceSetup() func(http.Handler) http.Handler { | ||||
| 	return func(next http.Handler) http.Handler { | ||||
| 		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 			if config.IsSetup != shouldBeSetup { | ||||
| 				state := "during" | ||||
| 				if config.IsSetup { | ||||
| 					state = "after" | ||||
| 				} | ||||
| 				h.ResultErrorJSON(w, r, http.StatusForbidden, fmt.Sprintf("Not available %s setup phase", state), nil) | ||||
| 			if !config.IsSetup { | ||||
| 				h.ResultErrorJSON(w, r, http.StatusForbidden, "Not available during setup phase", nil) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
|   | ||||
| @@ -17,34 +17,19 @@ func TestEnforceSetup(t *testing.T) { | ||||
| 	defer goleak.VerifyNone(t, goleak.IgnoreAnyFunction("database/sql.(*DB).connectionOpener")) | ||||
|  | ||||
| 	tests := []struct { | ||||
| 		name          string | ||||
| 		shouldBeSetup bool | ||||
| 		isSetup       bool | ||||
| 		expectedCode  int | ||||
| 		name         string | ||||
| 		isSetup      bool | ||||
| 		expectedCode int | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name:          "should allow request when setup is expected and is setup", | ||||
| 			shouldBeSetup: true, | ||||
| 			isSetup:       true, | ||||
| 			expectedCode:  http.StatusOK, | ||||
| 			name:         "should allow request when setup is expected and is setup", | ||||
| 			isSetup:      true, | ||||
| 			expectedCode: http.StatusOK, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:          "should error when setup is expected but not setup", | ||||
| 			shouldBeSetup: true, | ||||
| 			isSetup:       false, | ||||
| 			expectedCode:  http.StatusForbidden, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:          "should allow request when setup is not expected and not setup", | ||||
| 			shouldBeSetup: false, | ||||
| 			isSetup:       false, | ||||
| 			expectedCode:  http.StatusOK, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:          "should error when setup is not expected but is setup", | ||||
| 			shouldBeSetup: false, | ||||
| 			isSetup:       true, | ||||
| 			expectedCode:  http.StatusForbidden, | ||||
| 			name:         "should error when setup is expected but not setup", | ||||
| 			isSetup:      false, | ||||
| 			expectedCode: http.StatusForbidden, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| @@ -52,7 +37,7 @@ func TestEnforceSetup(t *testing.T) { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			config.IsSetup = tt.isSetup | ||||
|  | ||||
| 			handler := middleware.EnforceSetup(tt.shouldBeSetup)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 			handler := middleware.EnforceSetup()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 				w.WriteHeader(http.StatusOK) | ||||
| 			})) | ||||
|  | ||||
|   | ||||
| @@ -23,7 +23,7 @@ import ( | ||||
| // to be used later in other endpoints. | ||||
| func ListQuery(obj interface{}) func(http.Handler) http.Handler { | ||||
| 	schemaData := tags.GetFilterSchema(obj) | ||||
| 	filterMap := tags.GetFilterMap(obj) | ||||
| 	filterMap := tags.GetFilterMap(obj, "") | ||||
|  | ||||
| 	return func(next http.Handler) http.Handler { | ||||
| 		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
|   | ||||
| @@ -64,48 +64,49 @@ func applyRoutes(r chi.Router) chi.Router { | ||||
| 	// SSE - requires a sse token as the `jwt` get parameter | ||||
| 	// Exists inside /api but it's here so that we can skip the Timeout middleware | ||||
| 	// that applies to other endpoints. | ||||
| 	r.With(middleware.EnforceSetup(true), middleware.SSEAuth). | ||||
| 	r.With(middleware.EnforceSetup(), middleware.SSEAuth). | ||||
| 		Mount("/api/sse", serverevents.Get()) | ||||
|  | ||||
| 	// API | ||||
| 	r.With(chiMiddleware.Timeout(30*time.Second)).Route("/api", func(r chi.Router) { | ||||
| 		r.Get("/", handler.Health()) | ||||
| 		r.Get("/schema", handler.Schema()) | ||||
| 		r.With(middleware.EnforceSetup(true), middleware.Enforce("")). | ||||
| 		r.With(middleware.EnforceSetup(), middleware.Enforce()). | ||||
| 			Get("/config", handler.Config()) | ||||
|  | ||||
| 		// Tokens | ||||
| 		r.With(middleware.EnforceSetup(true)).Route("/tokens", func(r chi.Router) { | ||||
| 		r.With(middleware.EnforceSetup()).Route("/tokens", func(r chi.Router) { | ||||
| 			r.With(middleware.EnforceRequestSchema(schema.GetToken())). | ||||
| 				Post("/", handler.NewToken()) | ||||
| 			r.With(middleware.Enforce("")). | ||||
| 			r.With(middleware.Enforce()). | ||||
| 				Get("/", handler.RefreshToken()) | ||||
| 			r.With(middleware.Enforce("")). | ||||
| 			r.With(middleware.Enforce()). | ||||
| 				Post("/sse", handler.NewSSEToken()) | ||||
| 		}) | ||||
|  | ||||
| 		// Users | ||||
| 		r.Route("/users", func(r chi.Router) { | ||||
| 			// Create - can be done in Setup stage as well | ||||
| 			r.With(middleware.Enforce(user.CapabilityUsersManage), middleware.EnforceRequestSchema(schema.CreateUser())). | ||||
| 				Post("/", handler.CreateUser()) | ||||
| 			r.With( | ||||
| 				middleware.Enforce(user.CapabilityUsersManage), | ||||
| 				middleware.EnforceRequestSchema(schema.CreateUser()), | ||||
| 			).Post("/", handler.CreateUser()) | ||||
|  | ||||
| 			// Requires Setup stage to be completed | ||||
| 			r.With(middleware.EnforceSetup(true)).Route("/", func(r chi.Router) { | ||||
| 			r.With(middleware.EnforceSetup()).Route("/", func(r chi.Router) { | ||||
| 				// Get yourself, requires a login but no other permissions | ||||
| 				r.With(middleware.Enforce("")). | ||||
| 				r.With(middleware.Enforce()). | ||||
| 					Get("/{userID:me}", handler.GetUser()) | ||||
|  | ||||
| 				// Update yourself, requires a login but no other permissions | ||||
| 				r.With(middleware.Enforce(""), middleware.EnforceRequestSchema(schema.UpdateUser())). | ||||
| 					Put("/{userID:me}", handler.UpdateUser()) | ||||
| 				r.With( | ||||
| 					middleware.Enforce(), | ||||
| 					middleware.EnforceRequestSchema(schema.UpdateUser()), | ||||
| 				).Put("/{userID:me}", handler.UpdateUser()) | ||||
|  | ||||
| 				r.With(middleware.Enforce(user.CapabilityUsersManage)).Route("/", func(r chi.Router) { | ||||
| 					// List | ||||
| 					r.With( | ||||
| 						middleware.Enforce(user.CapabilityUsersManage), | ||||
| 						middleware.ListQuery(user.Model{}), | ||||
| 					).Get("/", handler.GetUsers()) | ||||
| 					r.With(middleware.ListQuery(user.Model{})).Get("/", handler.GetUsers()) | ||||
|  | ||||
| 					// Specific Item | ||||
| 					r.Get("/{userID:[0-9]+}", handler.GetUser()) | ||||
| @@ -117,10 +118,14 @@ func applyRoutes(r chi.Router) chi.Router { | ||||
| 				}) | ||||
|  | ||||
| 				// Auth - sets passwords | ||||
| 				r.With(middleware.Enforce(""), middleware.EnforceRequestSchema(schema.SetAuth())). | ||||
| 					Post("/{userID:me}/auth", handler.SetAuth()) | ||||
| 				r.With(middleware.Enforce(user.CapabilityUsersManage), middleware.EnforceRequestSchema(schema.SetAuth())). | ||||
| 					Post("/{userID:[0-9]+}/auth", handler.SetAuth()) | ||||
| 				r.With( | ||||
| 					middleware.Enforce(), | ||||
| 					middleware.EnforceRequestSchema(schema.SetAuth()), | ||||
| 				).Post("/{userID:me}/auth", handler.SetAuth()) | ||||
| 				r.With( | ||||
| 					middleware.Enforce(user.CapabilityUsersManage), | ||||
| 					middleware.EnforceRequestSchema(schema.SetAuth()), | ||||
| 				).Post("/{userID:[0-9]+}/auth", handler.SetAuth()) | ||||
| 			}) | ||||
| 		}) | ||||
|  | ||||
| @@ -133,7 +138,7 @@ func applyRoutes(r chi.Router) chi.Router { | ||||
| 		} | ||||
|  | ||||
| 		// Settings | ||||
| 		r.With(middleware.EnforceSetup(true), middleware.Enforce(user.CapabilitySettingsManage)).Route("/settings", func(r chi.Router) { | ||||
| 		r.With(middleware.EnforceSetup(), middleware.Enforce(user.CapabilitySettingsManage)).Route("/settings", func(r chi.Router) { | ||||
| 			// List | ||||
| 			r.With( | ||||
| 				middleware.ListQuery(setting.Model{}), | ||||
| @@ -147,7 +152,7 @@ func applyRoutes(r chi.Router) chi.Router { | ||||
| 		}) | ||||
|  | ||||
| 		// Access Lists | ||||
| 		r.With(middleware.EnforceSetup(true)).Route("/access-lists", func(r chi.Router) { | ||||
| 		r.With(middleware.EnforceSetup()).Route("/access-lists", func(r chi.Router) { | ||||
| 			// List | ||||
| 			r.With( | ||||
| 				middleware.Enforce(user.CapabilityAccessListsView), | ||||
| @@ -171,7 +176,7 @@ func applyRoutes(r chi.Router) chi.Router { | ||||
| 		}) | ||||
|  | ||||
| 		// DNS Providers | ||||
| 		r.With(middleware.EnforceSetup(true)).Route("/dns-providers", func(r chi.Router) { | ||||
| 		r.With(middleware.EnforceSetup()).Route("/dns-providers", func(r chi.Router) { | ||||
| 			// List | ||||
| 			r.With( | ||||
| 				middleware.Enforce(user.CapabilityDNSProvidersView), | ||||
| @@ -201,7 +206,7 @@ func applyRoutes(r chi.Router) chi.Router { | ||||
| 		}) | ||||
|  | ||||
| 		// Certificate Authorities | ||||
| 		r.With(middleware.EnforceSetup(true)).Route("/certificate-authorities", func(r chi.Router) { | ||||
| 		r.With(middleware.EnforceSetup()).Route("/certificate-authorities", func(r chi.Router) { | ||||
| 			// List | ||||
| 			r.With( | ||||
| 				middleware.Enforce(user.CapabilityCertificateAuthoritiesView), | ||||
| @@ -231,7 +236,7 @@ func applyRoutes(r chi.Router) chi.Router { | ||||
| 		}) | ||||
|  | ||||
| 		// Certificates | ||||
| 		r.With(middleware.EnforceSetup(true)).Route("/certificates", func(r chi.Router) { | ||||
| 		r.With(middleware.EnforceSetup()).Route("/certificates", func(r chi.Router) { | ||||
| 			// List | ||||
| 			r.With( | ||||
| 				middleware.Enforce(user.CapabilityCertificatesView), | ||||
| @@ -258,7 +263,7 @@ func applyRoutes(r chi.Router) chi.Router { | ||||
| 		}) | ||||
|  | ||||
| 		// Hosts | ||||
| 		r.With(middleware.EnforceSetup(true)).Route("/hosts", func(r chi.Router) { | ||||
| 		r.With(middleware.EnforceSetup()).Route("/hosts", func(r chi.Router) { | ||||
| 			// List | ||||
| 			r.With( | ||||
| 				middleware.Enforce(user.CapabilityHostsView), | ||||
| @@ -284,7 +289,7 @@ func applyRoutes(r chi.Router) chi.Router { | ||||
| 		}) | ||||
|  | ||||
| 		// Nginx Templates | ||||
| 		r.With(middleware.EnforceSetup(true)).Route("/nginx-templates", func(r chi.Router) { | ||||
| 		r.With(middleware.EnforceSetup()).Route("/nginx-templates", func(r chi.Router) { | ||||
| 			// List | ||||
| 			r.With( | ||||
| 				middleware.Enforce(user.CapabilityNginxTemplatesView), | ||||
| @@ -308,7 +313,7 @@ func applyRoutes(r chi.Router) chi.Router { | ||||
| 		}) | ||||
|  | ||||
| 		// Streams | ||||
| 		r.With(middleware.EnforceSetup(true)).Route("/streams", func(r chi.Router) { | ||||
| 		r.With(middleware.EnforceSetup()).Route("/streams", func(r chi.Router) { | ||||
| 			// List | ||||
| 			r.With( | ||||
| 				middleware.Enforce(user.CapabilityStreamsView), | ||||
| @@ -332,7 +337,7 @@ func applyRoutes(r chi.Router) chi.Router { | ||||
| 		}) | ||||
|  | ||||
| 		// Upstreams | ||||
| 		r.With(middleware.EnforceSetup(true)).Route("/upstreams", func(r chi.Router) { | ||||
| 		r.With(middleware.EnforceSetup()).Route("/upstreams", func(r chi.Router) { | ||||
| 			// List | ||||
| 			r.With( | ||||
| 				middleware.Enforce(user.CapabilityHostsView), | ||||
|   | ||||
| @@ -29,6 +29,7 @@ func CreateDataFolders() { | ||||
| 			path = fmt.Sprintf("%s/%s", Configuration.DataFolder, folder) | ||||
| 		} | ||||
| 		logger.Debug("Creating folder: %s", path) | ||||
| 		// nolint: gosec | ||||
| 		if err := os.MkdirAll(path, os.ModePerm); err != nil { | ||||
| 			logger.Error("CreateDataFolderError", err) | ||||
| 		} | ||||
|   | ||||
| @@ -244,6 +244,7 @@ func (m *Model) Request() error { | ||||
| 	certKeyFile, certFullchainFile, certFolder := m.GetCertificateLocations() | ||||
|  | ||||
| 	// ensure certFolder is created | ||||
| 	// nolint: gosec | ||||
| 	if err := os.MkdirAll(certFolder, os.ModePerm); err != nil { | ||||
| 		logger.Error("CreateFolderError", err) | ||||
| 		return err | ||||
|   | ||||
| @@ -1,7 +1,6 @@ | ||||
| package jwt | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"regexp" | ||||
| 	"testing" | ||||
|  | ||||
| @@ -14,6 +13,7 @@ import ( | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"github.com/stretchr/testify/suite" | ||||
| 	"go.uber.org/goleak" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| // +------------+ | ||||
| @@ -136,7 +136,7 @@ func (s *testsuite) TestLoadKeys() { | ||||
| 	s.mock. | ||||
| 		ExpectQuery(regexp.QuoteMeta(`SELECT * FROM "jwt_keys" WHERE "jwt_keys"."is_deleted" = $1`)). | ||||
| 		WithArgs(0). | ||||
| 		WillReturnError(sql.ErrNoRows) | ||||
| 		WillReturnError(gorm.ErrRecordNotFound) | ||||
|  | ||||
| 	// insert row | ||||
| 	s.mock.ExpectBegin() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user