mirror of
				https://github.com/NginxProxyManager/nginx-proxy-manager.git
				synced 2025-11-03 17:13:33 +00:00 
			
		
		
		
	Improvements to enforce middleware, linting, returning 404 properly
This commit is contained in:
		@@ -11,6 +11,8 @@ import (
 | 
			
		||||
	"npm/internal/entity/user"
 | 
			
		||||
	"npm/internal/errors"
 | 
			
		||||
	"npm/internal/logger"
 | 
			
		||||
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type setAuthModel struct {
 | 
			
		||||
@@ -41,7 +43,10 @@ func SetAuth() func(http.ResponseWriter, *http.Request) {
 | 
			
		||||
 | 
			
		||||
		// Load user
 | 
			
		||||
		thisUser, thisUserErr := user.GetByID(userID)
 | 
			
		||||
		if thisUserErr != nil {
 | 
			
		||||
		if thisUserErr == gorm.ErrRecordNotFound {
 | 
			
		||||
			h.NotFound(w, r)
 | 
			
		||||
			return
 | 
			
		||||
		} else if thisUserErr != nil {
 | 
			
		||||
			h.ResultErrorJSON(w, r, http.StatusBadRequest, thisUserErr.Error(), nil)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,6 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
@@ -12,6 +11,8 @@ import (
 | 
			
		||||
	"npm/internal/api/middleware"
 | 
			
		||||
	"npm/internal/entity/certificateauthority"
 | 
			
		||||
	"npm/internal/logger"
 | 
			
		||||
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// GetCertificateAuthorities will return a list of Certificate Authorities
 | 
			
		||||
@@ -46,7 +47,7 @@ func GetCertificateAuthority() func(http.ResponseWriter, *http.Request) {
 | 
			
		||||
 | 
			
		||||
		item, err := certificateauthority.GetByID(caID)
 | 
			
		||||
		switch err {
 | 
			
		||||
		case sql.ErrNoRows:
 | 
			
		||||
		case gorm.ErrRecordNotFound:
 | 
			
		||||
			h.NotFound(w, r)
 | 
			
		||||
		case nil:
 | 
			
		||||
			h.ResultResponseJSON(w, r, http.StatusOK, item)
 | 
			
		||||
@@ -100,7 +101,7 @@ func UpdateCertificateAuthority() func(http.ResponseWriter, *http.Request) {
 | 
			
		||||
 | 
			
		||||
		ca, err := certificateauthority.GetByID(caID)
 | 
			
		||||
		switch err {
 | 
			
		||||
		case sql.ErrNoRows:
 | 
			
		||||
		case gorm.ErrRecordNotFound:
 | 
			
		||||
			h.NotFound(w, r)
 | 
			
		||||
		case nil:
 | 
			
		||||
			bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte)
 | 
			
		||||
@@ -140,7 +141,7 @@ func DeleteCertificateAuthority() func(http.ResponseWriter, *http.Request) {
 | 
			
		||||
 | 
			
		||||
		item, err := certificateauthority.GetByID(caID)
 | 
			
		||||
		switch err {
 | 
			
		||||
		case sql.ErrNoRows:
 | 
			
		||||
		case gorm.ErrRecordNotFound:
 | 
			
		||||
			h.NotFound(w, r)
 | 
			
		||||
		case nil:
 | 
			
		||||
			h.ResultResponseJSON(w, r, http.StatusOK, item.Delete())
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,6 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
@@ -14,6 +13,8 @@ import (
 | 
			
		||||
	"npm/internal/entity/host"
 | 
			
		||||
	"npm/internal/jobqueue"
 | 
			
		||||
	"npm/internal/logger"
 | 
			
		||||
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// GetCertificates will return a list of Certificates
 | 
			
		||||
@@ -138,7 +139,7 @@ func getCertificateFromRequest(w http.ResponseWriter, r *http.Request) *certific
 | 
			
		||||
 | 
			
		||||
	certificateObject, err := certificate.GetByID(certificateID)
 | 
			
		||||
	switch err {
 | 
			
		||||
	case sql.ErrNoRows:
 | 
			
		||||
	case gorm.ErrRecordNotFound:
 | 
			
		||||
		h.NotFound(w, r)
 | 
			
		||||
	case nil:
 | 
			
		||||
		return &certificateObject
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,6 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
@@ -12,6 +11,8 @@ import (
 | 
			
		||||
	"npm/internal/dnsproviders"
 | 
			
		||||
	"npm/internal/entity/dnsprovider"
 | 
			
		||||
	"npm/internal/errors"
 | 
			
		||||
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// GetDNSProviders will return a list of DNS Providers
 | 
			
		||||
@@ -46,7 +47,7 @@ func GetDNSProvider() func(http.ResponseWriter, *http.Request) {
 | 
			
		||||
 | 
			
		||||
		item, err := dnsprovider.GetByID(providerID)
 | 
			
		||||
		switch err {
 | 
			
		||||
		case sql.ErrNoRows:
 | 
			
		||||
		case gorm.ErrRecordNotFound:
 | 
			
		||||
			h.NotFound(w, r)
 | 
			
		||||
		case nil:
 | 
			
		||||
			h.ResultResponseJSON(w, r, http.StatusOK, item)
 | 
			
		||||
@@ -95,7 +96,7 @@ func UpdateDNSProvider() func(http.ResponseWriter, *http.Request) {
 | 
			
		||||
 | 
			
		||||
		item, err := dnsprovider.GetByID(providerID)
 | 
			
		||||
		switch err {
 | 
			
		||||
		case sql.ErrNoRows:
 | 
			
		||||
		case gorm.ErrRecordNotFound:
 | 
			
		||||
			h.NotFound(w, r)
 | 
			
		||||
		case nil:
 | 
			
		||||
			bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte)
 | 
			
		||||
@@ -130,7 +131,7 @@ func DeleteDNSProvider() func(http.ResponseWriter, *http.Request) {
 | 
			
		||||
 | 
			
		||||
		item, err := dnsprovider.GetByID(providerID)
 | 
			
		||||
		switch err {
 | 
			
		||||
		case sql.ErrNoRows:
 | 
			
		||||
		case gorm.ErrRecordNotFound:
 | 
			
		||||
			h.NotFound(w, r)
 | 
			
		||||
		case nil:
 | 
			
		||||
			h.ResultResponseJSON(w, r, http.StatusOK, item.Delete())
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,6 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
@@ -14,6 +13,8 @@ import (
 | 
			
		||||
	"npm/internal/logger"
 | 
			
		||||
	"npm/internal/nginx"
 | 
			
		||||
	"npm/internal/validator"
 | 
			
		||||
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// GetHosts will return a list of Hosts
 | 
			
		||||
@@ -48,7 +49,7 @@ func GetHost() func(http.ResponseWriter, *http.Request) {
 | 
			
		||||
 | 
			
		||||
		item, err := host.GetByID(hostID)
 | 
			
		||||
		switch err {
 | 
			
		||||
		case sql.ErrNoRows:
 | 
			
		||||
		case gorm.ErrRecordNotFound:
 | 
			
		||||
			h.NotFound(w, r)
 | 
			
		||||
		case nil:
 | 
			
		||||
			// nolint: errcheck,gosec
 | 
			
		||||
@@ -111,7 +112,7 @@ func UpdateHost() func(http.ResponseWriter, *http.Request) {
 | 
			
		||||
 | 
			
		||||
		hostObject, err := host.GetByID(hostID)
 | 
			
		||||
		switch err {
 | 
			
		||||
		case sql.ErrNoRows:
 | 
			
		||||
		case gorm.ErrRecordNotFound:
 | 
			
		||||
			h.NotFound(w, r)
 | 
			
		||||
		case nil:
 | 
			
		||||
			bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte)
 | 
			
		||||
@@ -156,7 +157,7 @@ func DeleteHost() func(http.ResponseWriter, *http.Request) {
 | 
			
		||||
 | 
			
		||||
		item, err := host.GetByID(hostID)
 | 
			
		||||
		switch err {
 | 
			
		||||
		case sql.ErrNoRows:
 | 
			
		||||
		case gorm.ErrRecordNotFound:
 | 
			
		||||
			h.NotFound(w, r)
 | 
			
		||||
		case nil:
 | 
			
		||||
			h.ResultResponseJSON(w, r, http.StatusOK, item.Delete())
 | 
			
		||||
@@ -181,7 +182,7 @@ func GetHostNginxConfig(format string) func(http.ResponseWriter, *http.Request)
 | 
			
		||||
 | 
			
		||||
		item, err := host.GetByID(hostID)
 | 
			
		||||
		switch err {
 | 
			
		||||
		case sql.ErrNoRows:
 | 
			
		||||
		case gorm.ErrRecordNotFound:
 | 
			
		||||
			h.NotFound(w, r)
 | 
			
		||||
		case nil:
 | 
			
		||||
			// Get the config from disk
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,6 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
@@ -10,6 +9,8 @@ import (
 | 
			
		||||
	h "npm/internal/api/http"
 | 
			
		||||
	"npm/internal/api/middleware"
 | 
			
		||||
	"npm/internal/entity/nginxtemplate"
 | 
			
		||||
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// GetNginxTemplates will return a list of Nginx Templates
 | 
			
		||||
@@ -44,7 +45,7 @@ func GetNginxTemplate() func(http.ResponseWriter, *http.Request) {
 | 
			
		||||
 | 
			
		||||
		item, err := nginxtemplate.GetByID(templateID)
 | 
			
		||||
		switch err {
 | 
			
		||||
		case sql.ErrNoRows:
 | 
			
		||||
		case gorm.ErrRecordNotFound:
 | 
			
		||||
			h.NotFound(w, r)
 | 
			
		||||
		case nil:
 | 
			
		||||
			h.ResultResponseJSON(w, r, http.StatusOK, item)
 | 
			
		||||
@@ -95,7 +96,7 @@ func UpdateNginxTemplate() func(http.ResponseWriter, *http.Request) {
 | 
			
		||||
 | 
			
		||||
		nginxTemplate, err := nginxtemplate.GetByID(templateID)
 | 
			
		||||
		switch err {
 | 
			
		||||
		case sql.ErrNoRows:
 | 
			
		||||
		case gorm.ErrRecordNotFound:
 | 
			
		||||
			h.NotFound(w, r)
 | 
			
		||||
		case nil:
 | 
			
		||||
			bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte)
 | 
			
		||||
@@ -130,7 +131,7 @@ func DeleteNginxTemplate() func(http.ResponseWriter, *http.Request) {
 | 
			
		||||
 | 
			
		||||
		item, err := nginxtemplate.GetByID(templateID)
 | 
			
		||||
		switch err {
 | 
			
		||||
		case sql.ErrNoRows:
 | 
			
		||||
		case gorm.ErrRecordNotFound:
 | 
			
		||||
			h.NotFound(w, r)
 | 
			
		||||
		case nil:
 | 
			
		||||
			h.ResultResponseJSON(w, r, http.StatusOK, item.Delete())
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,6 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
@@ -12,6 +11,7 @@ import (
 | 
			
		||||
	"npm/internal/entity/setting"
 | 
			
		||||
 | 
			
		||||
	"github.com/go-chi/chi/v5"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// GetSettings will return a list of Settings
 | 
			
		||||
@@ -41,7 +41,7 @@ func GetSetting() func(http.ResponseWriter, *http.Request) {
 | 
			
		||||
 | 
			
		||||
		item, err := setting.GetByName(name)
 | 
			
		||||
		switch err {
 | 
			
		||||
		case sql.ErrNoRows:
 | 
			
		||||
		case gorm.ErrRecordNotFound:
 | 
			
		||||
			h.NotFound(w, r)
 | 
			
		||||
		case nil:
 | 
			
		||||
			h.ResultResponseJSON(w, r, http.StatusOK, item)
 | 
			
		||||
@@ -81,7 +81,7 @@ func UpdateSetting() func(http.ResponseWriter, *http.Request) {
 | 
			
		||||
 | 
			
		||||
		setting, err := setting.GetByName(settingName)
 | 
			
		||||
		switch err {
 | 
			
		||||
		case sql.ErrNoRows:
 | 
			
		||||
		case gorm.ErrRecordNotFound:
 | 
			
		||||
			h.NotFound(w, r)
 | 
			
		||||
		case nil:
 | 
			
		||||
			bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte)
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,6 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
@@ -10,6 +9,8 @@ import (
 | 
			
		||||
	h "npm/internal/api/http"
 | 
			
		||||
	"npm/internal/api/middleware"
 | 
			
		||||
	"npm/internal/entity/stream"
 | 
			
		||||
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// GetStreams will return a list of Streams
 | 
			
		||||
@@ -44,7 +45,7 @@ func GetStream() func(http.ResponseWriter, *http.Request) {
 | 
			
		||||
 | 
			
		||||
		item, err := stream.GetByID(hostID)
 | 
			
		||||
		switch err {
 | 
			
		||||
		case sql.ErrNoRows:
 | 
			
		||||
		case gorm.ErrRecordNotFound:
 | 
			
		||||
			h.NotFound(w, r)
 | 
			
		||||
		case nil:
 | 
			
		||||
			h.ResultResponseJSON(w, r, http.StatusOK, item)
 | 
			
		||||
@@ -93,7 +94,7 @@ func UpdateStream() func(http.ResponseWriter, *http.Request) {
 | 
			
		||||
 | 
			
		||||
		host, err := stream.GetByID(hostID)
 | 
			
		||||
		switch err {
 | 
			
		||||
		case sql.ErrNoRows:
 | 
			
		||||
		case gorm.ErrRecordNotFound:
 | 
			
		||||
			h.NotFound(w, r)
 | 
			
		||||
		case nil:
 | 
			
		||||
			bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte)
 | 
			
		||||
@@ -128,7 +129,7 @@ func DeleteStream() func(http.ResponseWriter, *http.Request) {
 | 
			
		||||
 | 
			
		||||
		item, err := stream.GetByID(hostID)
 | 
			
		||||
		switch err {
 | 
			
		||||
		case sql.ErrNoRows:
 | 
			
		||||
		case gorm.ErrRecordNotFound:
 | 
			
		||||
			h.NotFound(w, r)
 | 
			
		||||
		case nil:
 | 
			
		||||
			h.ResultResponseJSON(w, r, http.StatusOK, item.Delete())
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,6 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
@@ -15,6 +14,8 @@ import (
 | 
			
		||||
	"npm/internal/logger"
 | 
			
		||||
	"npm/internal/nginx"
 | 
			
		||||
	"npm/internal/validator"
 | 
			
		||||
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// GetUpstreams will return a list of Upstreams
 | 
			
		||||
@@ -49,7 +50,7 @@ func GetUpstream() func(http.ResponseWriter, *http.Request) {
 | 
			
		||||
 | 
			
		||||
		item, err := upstream.GetByID(upstreamID)
 | 
			
		||||
		switch err {
 | 
			
		||||
		case sql.ErrNoRows:
 | 
			
		||||
		case gorm.ErrRecordNotFound:
 | 
			
		||||
			h.NotFound(w, r)
 | 
			
		||||
		case nil:
 | 
			
		||||
			// nolint: errcheck,gosec
 | 
			
		||||
@@ -149,7 +150,7 @@ func DeleteUpstream() func(http.ResponseWriter, *http.Request) {
 | 
			
		||||
 | 
			
		||||
		item, err := upstream.GetByID(upstreamID)
 | 
			
		||||
		switch err {
 | 
			
		||||
		case sql.ErrNoRows:
 | 
			
		||||
		case gorm.ErrRecordNotFound:
 | 
			
		||||
			h.NotFound(w, r)
 | 
			
		||||
		case nil:
 | 
			
		||||
			// Ensure that this upstream isn't in use by a host
 | 
			
		||||
@@ -180,7 +181,7 @@ func GetUpstreamNginxConfig(format string) func(http.ResponseWriter, *http.Reque
 | 
			
		||||
 | 
			
		||||
		item, err := upstream.GetByID(upstreamID)
 | 
			
		||||
		switch err {
 | 
			
		||||
		case sql.ErrNoRows:
 | 
			
		||||
		case gorm.ErrRecordNotFound:
 | 
			
		||||
			h.NotFound(w, r)
 | 
			
		||||
		case nil:
 | 
			
		||||
			// Get the config from disk
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,6 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"net/http"
 | 
			
		||||
 | 
			
		||||
@@ -15,6 +14,7 @@ import (
 | 
			
		||||
	"npm/internal/logger"
 | 
			
		||||
 | 
			
		||||
	"github.com/go-chi/chi/v5"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// GetUsers returns all users
 | 
			
		||||
@@ -48,7 +48,7 @@ func GetUser() func(http.ResponseWriter, *http.Request) {
 | 
			
		||||
 | 
			
		||||
		item, err := user.GetByID(userID)
 | 
			
		||||
		switch err {
 | 
			
		||||
		case sql.ErrNoRows:
 | 
			
		||||
		case gorm.ErrRecordNotFound:
 | 
			
		||||
			h.NotFound(w, r)
 | 
			
		||||
		case nil:
 | 
			
		||||
			// nolint: errcheck,gosec
 | 
			
		||||
@@ -72,7 +72,7 @@ func UpdateUser() func(http.ResponseWriter, *http.Request) {
 | 
			
		||||
 | 
			
		||||
		userObject, err := user.GetByID(userID)
 | 
			
		||||
		switch err {
 | 
			
		||||
		case sql.ErrNoRows:
 | 
			
		||||
		case gorm.ErrRecordNotFound:
 | 
			
		||||
			h.NotFound(w, r)
 | 
			
		||||
		case nil:
 | 
			
		||||
			// nolint: errcheck,gosec
 | 
			
		||||
@@ -136,7 +136,7 @@ func DeleteUser() func(http.ResponseWriter, *http.Request) {
 | 
			
		||||
 | 
			
		||||
		item, err := user.GetByID(userID)
 | 
			
		||||
		switch err {
 | 
			
		||||
		case sql.ErrNoRows:
 | 
			
		||||
		case gorm.ErrRecordNotFound:
 | 
			
		||||
			h.NotFound(w, r)
 | 
			
		||||
		case nil:
 | 
			
		||||
			if err := item.Delete(); err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"slices"
 | 
			
		||||
 | 
			
		||||
	c "npm/internal/api/context"
 | 
			
		||||
	h "npm/internal/api/http"
 | 
			
		||||
@@ -11,7 +12,6 @@ import (
 | 
			
		||||
	"npm/internal/entity/user"
 | 
			
		||||
	njwt "npm/internal/jwt"
 | 
			
		||||
	"npm/internal/logger"
 | 
			
		||||
	"npm/internal/util"
 | 
			
		||||
 | 
			
		||||
	"github.com/go-chi/jwtauth/v5"
 | 
			
		||||
)
 | 
			
		||||
@@ -35,7 +35,7 @@ func DecodeAuth() func(http.Handler) http.Handler {
 | 
			
		||||
// Enforce is a authentication middleware to enforce access from the
 | 
			
		||||
// jwtauth.Verifier middleware request context values. The Authenticator sends a 401 Unauthorised
 | 
			
		||||
// response for any unverified tokens and passes the good ones through.
 | 
			
		||||
func Enforce(permission string) func(http.Handler) http.Handler {
 | 
			
		||||
func Enforce(permissions ...string) func(http.Handler) http.Handler {
 | 
			
		||||
	return func(next http.Handler) http.Handler {
 | 
			
		||||
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
			ctx := r.Context()
 | 
			
		||||
@@ -56,7 +56,7 @@ func Enforce(permission string) func(http.Handler) http.Handler {
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// Check if permissions exist for this user
 | 
			
		||||
				if permission != "" {
 | 
			
		||||
				if len(permissions) > 0 {
 | 
			
		||||
					// Since the permission that we require is not on the token, we have to get it from the DB
 | 
			
		||||
					// So we don't go crazy with hits, we will use a memory cache
 | 
			
		||||
					cacheKey := fmt.Sprintf("userCapabilties.%v", userID)
 | 
			
		||||
@@ -75,9 +75,16 @@ func Enforce(permission string) func(http.Handler) http.Handler {
 | 
			
		||||
 | 
			
		||||
					// Now check that they have the permission in their admin capabilities
 | 
			
		||||
					// full-admin can do anything
 | 
			
		||||
					if !util.SliceContainsItem(userCapabilities, user.CapabilityFullAdmin) && !util.SliceContainsItem(userCapabilities, permission) {
 | 
			
		||||
					hasOnePermission := false
 | 
			
		||||
					for _, permission := range permissions {
 | 
			
		||||
						if slices.Contains(userCapabilities, user.CapabilityFullAdmin) || slices.Contains(userCapabilities, permission) {
 | 
			
		||||
							hasOnePermission = true
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					if !hasOnePermission {
 | 
			
		||||
						// Access denied
 | 
			
		||||
						logger.Debug("User has: %+v but needs %s", userCapabilities, permission)
 | 
			
		||||
						logger.Debug("Enforce Failed: User has %v but needs %v", userCapabilities, permissions)
 | 
			
		||||
						h.ResultErrorJSON(w, r, http.StatusForbidden, "Forbidden", nil)
 | 
			
		||||
						return
 | 
			
		||||
					}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,6 @@
 | 
			
		||||
package middleware
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
 | 
			
		||||
	h "npm/internal/api/http"
 | 
			
		||||
@@ -9,15 +8,11 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// EnforceSetup will error if the config setup doesn't match what is required
 | 
			
		||||
func EnforceSetup(shouldBeSetup bool) func(http.Handler) http.Handler {
 | 
			
		||||
func EnforceSetup() func(http.Handler) http.Handler {
 | 
			
		||||
	return func(next http.Handler) http.Handler {
 | 
			
		||||
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
			if config.IsSetup != shouldBeSetup {
 | 
			
		||||
				state := "during"
 | 
			
		||||
				if config.IsSetup {
 | 
			
		||||
					state = "after"
 | 
			
		||||
				}
 | 
			
		||||
				h.ResultErrorJSON(w, r, http.StatusForbidden, fmt.Sprintf("Not available %s setup phase", state), nil)
 | 
			
		||||
			if !config.IsSetup {
 | 
			
		||||
				h.ResultErrorJSON(w, r, http.StatusForbidden, "Not available during setup phase", nil)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -17,34 +17,19 @@ func TestEnforceSetup(t *testing.T) {
 | 
			
		||||
	defer goleak.VerifyNone(t, goleak.IgnoreAnyFunction("database/sql.(*DB).connectionOpener"))
 | 
			
		||||
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name          string
 | 
			
		||||
		shouldBeSetup bool
 | 
			
		||||
		isSetup       bool
 | 
			
		||||
		expectedCode  int
 | 
			
		||||
		name         string
 | 
			
		||||
		isSetup      bool
 | 
			
		||||
		expectedCode int
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name:          "should allow request when setup is expected and is setup",
 | 
			
		||||
			shouldBeSetup: true,
 | 
			
		||||
			isSetup:       true,
 | 
			
		||||
			expectedCode:  http.StatusOK,
 | 
			
		||||
			name:         "should allow request when setup is expected and is setup",
 | 
			
		||||
			isSetup:      true,
 | 
			
		||||
			expectedCode: http.StatusOK,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:          "should error when setup is expected but not setup",
 | 
			
		||||
			shouldBeSetup: true,
 | 
			
		||||
			isSetup:       false,
 | 
			
		||||
			expectedCode:  http.StatusForbidden,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:          "should allow request when setup is not expected and not setup",
 | 
			
		||||
			shouldBeSetup: false,
 | 
			
		||||
			isSetup:       false,
 | 
			
		||||
			expectedCode:  http.StatusOK,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:          "should error when setup is not expected but is setup",
 | 
			
		||||
			shouldBeSetup: false,
 | 
			
		||||
			isSetup:       true,
 | 
			
		||||
			expectedCode:  http.StatusForbidden,
 | 
			
		||||
			name:         "should error when setup is expected but not setup",
 | 
			
		||||
			isSetup:      false,
 | 
			
		||||
			expectedCode: http.StatusForbidden,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -52,7 +37,7 @@ func TestEnforceSetup(t *testing.T) {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			config.IsSetup = tt.isSetup
 | 
			
		||||
 | 
			
		||||
			handler := middleware.EnforceSetup(tt.shouldBeSetup)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
			handler := middleware.EnforceSetup()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
				w.WriteHeader(http.StatusOK)
 | 
			
		||||
			}))
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -23,7 +23,7 @@ import (
 | 
			
		||||
// to be used later in other endpoints.
 | 
			
		||||
func ListQuery(obj interface{}) func(http.Handler) http.Handler {
 | 
			
		||||
	schemaData := tags.GetFilterSchema(obj)
 | 
			
		||||
	filterMap := tags.GetFilterMap(obj)
 | 
			
		||||
	filterMap := tags.GetFilterMap(obj, "")
 | 
			
		||||
 | 
			
		||||
	return func(next http.Handler) http.Handler {
 | 
			
		||||
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
 
 | 
			
		||||
@@ -64,48 +64,49 @@ func applyRoutes(r chi.Router) chi.Router {
 | 
			
		||||
	// SSE - requires a sse token as the `jwt` get parameter
 | 
			
		||||
	// Exists inside /api but it's here so that we can skip the Timeout middleware
 | 
			
		||||
	// that applies to other endpoints.
 | 
			
		||||
	r.With(middleware.EnforceSetup(true), middleware.SSEAuth).
 | 
			
		||||
	r.With(middleware.EnforceSetup(), middleware.SSEAuth).
 | 
			
		||||
		Mount("/api/sse", serverevents.Get())
 | 
			
		||||
 | 
			
		||||
	// API
 | 
			
		||||
	r.With(chiMiddleware.Timeout(30*time.Second)).Route("/api", func(r chi.Router) {
 | 
			
		||||
		r.Get("/", handler.Health())
 | 
			
		||||
		r.Get("/schema", handler.Schema())
 | 
			
		||||
		r.With(middleware.EnforceSetup(true), middleware.Enforce("")).
 | 
			
		||||
		r.With(middleware.EnforceSetup(), middleware.Enforce()).
 | 
			
		||||
			Get("/config", handler.Config())
 | 
			
		||||
 | 
			
		||||
		// Tokens
 | 
			
		||||
		r.With(middleware.EnforceSetup(true)).Route("/tokens", func(r chi.Router) {
 | 
			
		||||
		r.With(middleware.EnforceSetup()).Route("/tokens", func(r chi.Router) {
 | 
			
		||||
			r.With(middleware.EnforceRequestSchema(schema.GetToken())).
 | 
			
		||||
				Post("/", handler.NewToken())
 | 
			
		||||
			r.With(middleware.Enforce("")).
 | 
			
		||||
			r.With(middleware.Enforce()).
 | 
			
		||||
				Get("/", handler.RefreshToken())
 | 
			
		||||
			r.With(middleware.Enforce("")).
 | 
			
		||||
			r.With(middleware.Enforce()).
 | 
			
		||||
				Post("/sse", handler.NewSSEToken())
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		// Users
 | 
			
		||||
		r.Route("/users", func(r chi.Router) {
 | 
			
		||||
			// Create - can be done in Setup stage as well
 | 
			
		||||
			r.With(middleware.Enforce(user.CapabilityUsersManage), middleware.EnforceRequestSchema(schema.CreateUser())).
 | 
			
		||||
				Post("/", handler.CreateUser())
 | 
			
		||||
			r.With(
 | 
			
		||||
				middleware.Enforce(user.CapabilityUsersManage),
 | 
			
		||||
				middleware.EnforceRequestSchema(schema.CreateUser()),
 | 
			
		||||
			).Post("/", handler.CreateUser())
 | 
			
		||||
 | 
			
		||||
			// Requires Setup stage to be completed
 | 
			
		||||
			r.With(middleware.EnforceSetup(true)).Route("/", func(r chi.Router) {
 | 
			
		||||
			r.With(middleware.EnforceSetup()).Route("/", func(r chi.Router) {
 | 
			
		||||
				// Get yourself, requires a login but no other permissions
 | 
			
		||||
				r.With(middleware.Enforce("")).
 | 
			
		||||
				r.With(middleware.Enforce()).
 | 
			
		||||
					Get("/{userID:me}", handler.GetUser())
 | 
			
		||||
 | 
			
		||||
				// Update yourself, requires a login but no other permissions
 | 
			
		||||
				r.With(middleware.Enforce(""), middleware.EnforceRequestSchema(schema.UpdateUser())).
 | 
			
		||||
					Put("/{userID:me}", handler.UpdateUser())
 | 
			
		||||
				r.With(
 | 
			
		||||
					middleware.Enforce(),
 | 
			
		||||
					middleware.EnforceRequestSchema(schema.UpdateUser()),
 | 
			
		||||
				).Put("/{userID:me}", handler.UpdateUser())
 | 
			
		||||
 | 
			
		||||
				r.With(middleware.Enforce(user.CapabilityUsersManage)).Route("/", func(r chi.Router) {
 | 
			
		||||
					// List
 | 
			
		||||
					r.With(
 | 
			
		||||
						middleware.Enforce(user.CapabilityUsersManage),
 | 
			
		||||
						middleware.ListQuery(user.Model{}),
 | 
			
		||||
					).Get("/", handler.GetUsers())
 | 
			
		||||
					r.With(middleware.ListQuery(user.Model{})).Get("/", handler.GetUsers())
 | 
			
		||||
 | 
			
		||||
					// Specific Item
 | 
			
		||||
					r.Get("/{userID:[0-9]+}", handler.GetUser())
 | 
			
		||||
@@ -117,10 +118,14 @@ func applyRoutes(r chi.Router) chi.Router {
 | 
			
		||||
				})
 | 
			
		||||
 | 
			
		||||
				// Auth - sets passwords
 | 
			
		||||
				r.With(middleware.Enforce(""), middleware.EnforceRequestSchema(schema.SetAuth())).
 | 
			
		||||
					Post("/{userID:me}/auth", handler.SetAuth())
 | 
			
		||||
				r.With(middleware.Enforce(user.CapabilityUsersManage), middleware.EnforceRequestSchema(schema.SetAuth())).
 | 
			
		||||
					Post("/{userID:[0-9]+}/auth", handler.SetAuth())
 | 
			
		||||
				r.With(
 | 
			
		||||
					middleware.Enforce(),
 | 
			
		||||
					middleware.EnforceRequestSchema(schema.SetAuth()),
 | 
			
		||||
				).Post("/{userID:me}/auth", handler.SetAuth())
 | 
			
		||||
				r.With(
 | 
			
		||||
					middleware.Enforce(user.CapabilityUsersManage),
 | 
			
		||||
					middleware.EnforceRequestSchema(schema.SetAuth()),
 | 
			
		||||
				).Post("/{userID:[0-9]+}/auth", handler.SetAuth())
 | 
			
		||||
			})
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
@@ -133,7 +138,7 @@ func applyRoutes(r chi.Router) chi.Router {
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Settings
 | 
			
		||||
		r.With(middleware.EnforceSetup(true), middleware.Enforce(user.CapabilitySettingsManage)).Route("/settings", func(r chi.Router) {
 | 
			
		||||
		r.With(middleware.EnforceSetup(), middleware.Enforce(user.CapabilitySettingsManage)).Route("/settings", func(r chi.Router) {
 | 
			
		||||
			// List
 | 
			
		||||
			r.With(
 | 
			
		||||
				middleware.ListQuery(setting.Model{}),
 | 
			
		||||
@@ -147,7 +152,7 @@ func applyRoutes(r chi.Router) chi.Router {
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		// Access Lists
 | 
			
		||||
		r.With(middleware.EnforceSetup(true)).Route("/access-lists", func(r chi.Router) {
 | 
			
		||||
		r.With(middleware.EnforceSetup()).Route("/access-lists", func(r chi.Router) {
 | 
			
		||||
			// List
 | 
			
		||||
			r.With(
 | 
			
		||||
				middleware.Enforce(user.CapabilityAccessListsView),
 | 
			
		||||
@@ -171,7 +176,7 @@ func applyRoutes(r chi.Router) chi.Router {
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		// DNS Providers
 | 
			
		||||
		r.With(middleware.EnforceSetup(true)).Route("/dns-providers", func(r chi.Router) {
 | 
			
		||||
		r.With(middleware.EnforceSetup()).Route("/dns-providers", func(r chi.Router) {
 | 
			
		||||
			// List
 | 
			
		||||
			r.With(
 | 
			
		||||
				middleware.Enforce(user.CapabilityDNSProvidersView),
 | 
			
		||||
@@ -201,7 +206,7 @@ func applyRoutes(r chi.Router) chi.Router {
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		// Certificate Authorities
 | 
			
		||||
		r.With(middleware.EnforceSetup(true)).Route("/certificate-authorities", func(r chi.Router) {
 | 
			
		||||
		r.With(middleware.EnforceSetup()).Route("/certificate-authorities", func(r chi.Router) {
 | 
			
		||||
			// List
 | 
			
		||||
			r.With(
 | 
			
		||||
				middleware.Enforce(user.CapabilityCertificateAuthoritiesView),
 | 
			
		||||
@@ -231,7 +236,7 @@ func applyRoutes(r chi.Router) chi.Router {
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		// Certificates
 | 
			
		||||
		r.With(middleware.EnforceSetup(true)).Route("/certificates", func(r chi.Router) {
 | 
			
		||||
		r.With(middleware.EnforceSetup()).Route("/certificates", func(r chi.Router) {
 | 
			
		||||
			// List
 | 
			
		||||
			r.With(
 | 
			
		||||
				middleware.Enforce(user.CapabilityCertificatesView),
 | 
			
		||||
@@ -258,7 +263,7 @@ func applyRoutes(r chi.Router) chi.Router {
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		// Hosts
 | 
			
		||||
		r.With(middleware.EnforceSetup(true)).Route("/hosts", func(r chi.Router) {
 | 
			
		||||
		r.With(middleware.EnforceSetup()).Route("/hosts", func(r chi.Router) {
 | 
			
		||||
			// List
 | 
			
		||||
			r.With(
 | 
			
		||||
				middleware.Enforce(user.CapabilityHostsView),
 | 
			
		||||
@@ -284,7 +289,7 @@ func applyRoutes(r chi.Router) chi.Router {
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		// Nginx Templates
 | 
			
		||||
		r.With(middleware.EnforceSetup(true)).Route("/nginx-templates", func(r chi.Router) {
 | 
			
		||||
		r.With(middleware.EnforceSetup()).Route("/nginx-templates", func(r chi.Router) {
 | 
			
		||||
			// List
 | 
			
		||||
			r.With(
 | 
			
		||||
				middleware.Enforce(user.CapabilityNginxTemplatesView),
 | 
			
		||||
@@ -308,7 +313,7 @@ func applyRoutes(r chi.Router) chi.Router {
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		// Streams
 | 
			
		||||
		r.With(middleware.EnforceSetup(true)).Route("/streams", func(r chi.Router) {
 | 
			
		||||
		r.With(middleware.EnforceSetup()).Route("/streams", func(r chi.Router) {
 | 
			
		||||
			// List
 | 
			
		||||
			r.With(
 | 
			
		||||
				middleware.Enforce(user.CapabilityStreamsView),
 | 
			
		||||
@@ -332,7 +337,7 @@ func applyRoutes(r chi.Router) chi.Router {
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		// Upstreams
 | 
			
		||||
		r.With(middleware.EnforceSetup(true)).Route("/upstreams", func(r chi.Router) {
 | 
			
		||||
		r.With(middleware.EnforceSetup()).Route("/upstreams", func(r chi.Router) {
 | 
			
		||||
			// List
 | 
			
		||||
			r.With(
 | 
			
		||||
				middleware.Enforce(user.CapabilityHostsView),
 | 
			
		||||
 
 | 
			
		||||
@@ -29,6 +29,7 @@ func CreateDataFolders() {
 | 
			
		||||
			path = fmt.Sprintf("%s/%s", Configuration.DataFolder, folder)
 | 
			
		||||
		}
 | 
			
		||||
		logger.Debug("Creating folder: %s", path)
 | 
			
		||||
		// nolint: gosec
 | 
			
		||||
		if err := os.MkdirAll(path, os.ModePerm); err != nil {
 | 
			
		||||
			logger.Error("CreateDataFolderError", err)
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -244,6 +244,7 @@ func (m *Model) Request() error {
 | 
			
		||||
	certKeyFile, certFullchainFile, certFolder := m.GetCertificateLocations()
 | 
			
		||||
 | 
			
		||||
	// ensure certFolder is created
 | 
			
		||||
	// nolint: gosec
 | 
			
		||||
	if err := os.MkdirAll(certFolder, os.ModePerm); err != nil {
 | 
			
		||||
		logger.Error("CreateFolderError", err)
 | 
			
		||||
		return err
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,6 @@
 | 
			
		||||
package jwt
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
@@ -14,6 +13,7 @@ import (
 | 
			
		||||
	"github.com/stretchr/testify/require"
 | 
			
		||||
	"github.com/stretchr/testify/suite"
 | 
			
		||||
	"go.uber.org/goleak"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// +------------+
 | 
			
		||||
@@ -136,7 +136,7 @@ func (s *testsuite) TestLoadKeys() {
 | 
			
		||||
	s.mock.
 | 
			
		||||
		ExpectQuery(regexp.QuoteMeta(`SELECT * FROM "jwt_keys" WHERE "jwt_keys"."is_deleted" = $1`)).
 | 
			
		||||
		WithArgs(0).
 | 
			
		||||
		WillReturnError(sql.ErrNoRows)
 | 
			
		||||
		WillReturnError(gorm.ErrRecordNotFound)
 | 
			
		||||
 | 
			
		||||
	// insert row
 | 
			
		||||
	s.mock.ExpectBegin()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user