Compare commits

..

9 Commits

Author SHA1 Message Date
Jamie Curnow
4d3d37eaed Update s6-overlay 2024-09-11 20:26:22 +10:00
Jamie Curnow
0cd2e07fd3 Fix docker warnings 2024-09-11 20:24:31 +10:00
Jamie Curnow
0a18a565c7 Remove cache from filters, was duplicating incorrect data 2024-09-11 20:21:15 +10:00
Jamie Curnow
c4db4a2647 Fix quotes 2024-09-11 19:38:41 +10:00
Jamie Curnow
e78dd069f1 Quote filter fields 2024-09-11 15:22:10 +10:00
Jamie Curnow
9a2e5c92d5 Improvements to enforce middleware, linting, returning 404 properly 2024-09-11 15:03:00 +10:00
Jamie Curnow
833dd23dce Support table names in filter tags 2024-09-11 15:01:28 +10:00
Jamie Curnow
514520ce1a Linter hacks 2024-09-11 15:01:05 +10:00
Jamie Curnow
21e3bce95d Adds tests for schema json 2024-09-11 15:00:49 +10:00
31 changed files with 324 additions and 158 deletions

View File

@@ -21,6 +21,13 @@ linters:
- unconvert - unconvert
- unparam - unparam
linters-settings: linters-settings:
gosec:
excludes:
- G115
errcheck:
exclude-functions:
- fmt.Fprint
- fmt.Fprintf
goconst: goconst:
# minimal length of string constant # minimal length of string constant
# default: 3 # default: 3

View File

@@ -11,6 +11,8 @@ import (
"npm/internal/entity/user" "npm/internal/entity/user"
"npm/internal/errors" "npm/internal/errors"
"npm/internal/logger" "npm/internal/logger"
"gorm.io/gorm"
) )
type setAuthModel struct { type setAuthModel struct {
@@ -41,7 +43,10 @@ func SetAuth() func(http.ResponseWriter, *http.Request) {
// Load user // Load user
thisUser, thisUserErr := user.GetByID(userID) thisUser, thisUserErr := user.GetByID(userID)
if thisUserErr != nil { if thisUserErr == gorm.ErrRecordNotFound {
h.NotFound(w, r)
return
} else if thisUserErr != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, thisUserErr.Error(), nil) h.ResultErrorJSON(w, r, http.StatusBadRequest, thisUserErr.Error(), nil)
return return
} }

View File

@@ -1,7 +1,6 @@
package handler package handler
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@@ -12,6 +11,8 @@ import (
"npm/internal/api/middleware" "npm/internal/api/middleware"
"npm/internal/entity/certificateauthority" "npm/internal/entity/certificateauthority"
"npm/internal/logger" "npm/internal/logger"
"gorm.io/gorm"
) )
// GetCertificateAuthorities will return a list of Certificate Authorities // GetCertificateAuthorities will return a list of Certificate Authorities
@@ -46,7 +47,7 @@ func GetCertificateAuthority() func(http.ResponseWriter, *http.Request) {
item, err := certificateauthority.GetByID(caID) item, err := certificateauthority.GetByID(caID)
switch err { switch err {
case sql.ErrNoRows: case gorm.ErrRecordNotFound:
h.NotFound(w, r) h.NotFound(w, r)
case nil: case nil:
h.ResultResponseJSON(w, r, http.StatusOK, item) h.ResultResponseJSON(w, r, http.StatusOK, item)
@@ -100,7 +101,7 @@ func UpdateCertificateAuthority() func(http.ResponseWriter, *http.Request) {
ca, err := certificateauthority.GetByID(caID) ca, err := certificateauthority.GetByID(caID)
switch err { switch err {
case sql.ErrNoRows: case gorm.ErrRecordNotFound:
h.NotFound(w, r) h.NotFound(w, r)
case nil: case nil:
bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte) bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte)
@@ -140,7 +141,7 @@ func DeleteCertificateAuthority() func(http.ResponseWriter, *http.Request) {
item, err := certificateauthority.GetByID(caID) item, err := certificateauthority.GetByID(caID)
switch err { switch err {
case sql.ErrNoRows: case gorm.ErrRecordNotFound:
h.NotFound(w, r) h.NotFound(w, r)
case nil: case nil:
h.ResultResponseJSON(w, r, http.StatusOK, item.Delete()) h.ResultResponseJSON(w, r, http.StatusOK, item.Delete())

View File

@@ -1,7 +1,6 @@
package handler package handler
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@@ -14,6 +13,8 @@ import (
"npm/internal/entity/host" "npm/internal/entity/host"
"npm/internal/jobqueue" "npm/internal/jobqueue"
"npm/internal/logger" "npm/internal/logger"
"gorm.io/gorm"
) )
// GetCertificates will return a list of Certificates // GetCertificates will return a list of Certificates
@@ -138,7 +139,7 @@ func getCertificateFromRequest(w http.ResponseWriter, r *http.Request) *certific
certificateObject, err := certificate.GetByID(certificateID) certificateObject, err := certificate.GetByID(certificateID)
switch err { switch err {
case sql.ErrNoRows: case gorm.ErrRecordNotFound:
h.NotFound(w, r) h.NotFound(w, r)
case nil: case nil:
return &certificateObject return &certificateObject

View File

@@ -1,7 +1,6 @@
package handler package handler
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@@ -12,6 +11,8 @@ import (
"npm/internal/dnsproviders" "npm/internal/dnsproviders"
"npm/internal/entity/dnsprovider" "npm/internal/entity/dnsprovider"
"npm/internal/errors" "npm/internal/errors"
"gorm.io/gorm"
) )
// GetDNSProviders will return a list of DNS Providers // GetDNSProviders will return a list of DNS Providers
@@ -46,7 +47,7 @@ func GetDNSProvider() func(http.ResponseWriter, *http.Request) {
item, err := dnsprovider.GetByID(providerID) item, err := dnsprovider.GetByID(providerID)
switch err { switch err {
case sql.ErrNoRows: case gorm.ErrRecordNotFound:
h.NotFound(w, r) h.NotFound(w, r)
case nil: case nil:
h.ResultResponseJSON(w, r, http.StatusOK, item) h.ResultResponseJSON(w, r, http.StatusOK, item)
@@ -95,7 +96,7 @@ func UpdateDNSProvider() func(http.ResponseWriter, *http.Request) {
item, err := dnsprovider.GetByID(providerID) item, err := dnsprovider.GetByID(providerID)
switch err { switch err {
case sql.ErrNoRows: case gorm.ErrRecordNotFound:
h.NotFound(w, r) h.NotFound(w, r)
case nil: case nil:
bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte) bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte)
@@ -130,7 +131,7 @@ func DeleteDNSProvider() func(http.ResponseWriter, *http.Request) {
item, err := dnsprovider.GetByID(providerID) item, err := dnsprovider.GetByID(providerID)
switch err { switch err {
case sql.ErrNoRows: case gorm.ErrRecordNotFound:
h.NotFound(w, r) h.NotFound(w, r)
case nil: case nil:
h.ResultResponseJSON(w, r, http.StatusOK, item.Delete()) h.ResultResponseJSON(w, r, http.StatusOK, item.Delete())

View File

@@ -1,7 +1,6 @@
package handler package handler
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@@ -14,6 +13,8 @@ import (
"npm/internal/logger" "npm/internal/logger"
"npm/internal/nginx" "npm/internal/nginx"
"npm/internal/validator" "npm/internal/validator"
"gorm.io/gorm"
) )
// GetHosts will return a list of Hosts // GetHosts will return a list of Hosts
@@ -48,7 +49,7 @@ func GetHost() func(http.ResponseWriter, *http.Request) {
item, err := host.GetByID(hostID) item, err := host.GetByID(hostID)
switch err { switch err {
case sql.ErrNoRows: case gorm.ErrRecordNotFound:
h.NotFound(w, r) h.NotFound(w, r)
case nil: case nil:
// nolint: errcheck,gosec // nolint: errcheck,gosec
@@ -111,7 +112,7 @@ func UpdateHost() func(http.ResponseWriter, *http.Request) {
hostObject, err := host.GetByID(hostID) hostObject, err := host.GetByID(hostID)
switch err { switch err {
case sql.ErrNoRows: case gorm.ErrRecordNotFound:
h.NotFound(w, r) h.NotFound(w, r)
case nil: case nil:
bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte) bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte)
@@ -156,7 +157,7 @@ func DeleteHost() func(http.ResponseWriter, *http.Request) {
item, err := host.GetByID(hostID) item, err := host.GetByID(hostID)
switch err { switch err {
case sql.ErrNoRows: case gorm.ErrRecordNotFound:
h.NotFound(w, r) h.NotFound(w, r)
case nil: case nil:
h.ResultResponseJSON(w, r, http.StatusOK, item.Delete()) h.ResultResponseJSON(w, r, http.StatusOK, item.Delete())
@@ -181,7 +182,7 @@ func GetHostNginxConfig(format string) func(http.ResponseWriter, *http.Request)
item, err := host.GetByID(hostID) item, err := host.GetByID(hostID)
switch err { switch err {
case sql.ErrNoRows: case gorm.ErrRecordNotFound:
h.NotFound(w, r) h.NotFound(w, r)
case nil: case nil:
// Get the config from disk // Get the config from disk

View File

@@ -1,7 +1,6 @@
package handler package handler
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@@ -10,6 +9,8 @@ import (
h "npm/internal/api/http" h "npm/internal/api/http"
"npm/internal/api/middleware" "npm/internal/api/middleware"
"npm/internal/entity/nginxtemplate" "npm/internal/entity/nginxtemplate"
"gorm.io/gorm"
) )
// GetNginxTemplates will return a list of Nginx Templates // GetNginxTemplates will return a list of Nginx Templates
@@ -44,7 +45,7 @@ func GetNginxTemplate() func(http.ResponseWriter, *http.Request) {
item, err := nginxtemplate.GetByID(templateID) item, err := nginxtemplate.GetByID(templateID)
switch err { switch err {
case sql.ErrNoRows: case gorm.ErrRecordNotFound:
h.NotFound(w, r) h.NotFound(w, r)
case nil: case nil:
h.ResultResponseJSON(w, r, http.StatusOK, item) h.ResultResponseJSON(w, r, http.StatusOK, item)
@@ -95,7 +96,7 @@ func UpdateNginxTemplate() func(http.ResponseWriter, *http.Request) {
nginxTemplate, err := nginxtemplate.GetByID(templateID) nginxTemplate, err := nginxtemplate.GetByID(templateID)
switch err { switch err {
case sql.ErrNoRows: case gorm.ErrRecordNotFound:
h.NotFound(w, r) h.NotFound(w, r)
case nil: case nil:
bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte) bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte)
@@ -130,7 +131,7 @@ func DeleteNginxTemplate() func(http.ResponseWriter, *http.Request) {
item, err := nginxtemplate.GetByID(templateID) item, err := nginxtemplate.GetByID(templateID)
switch err { switch err {
case sql.ErrNoRows: case gorm.ErrRecordNotFound:
h.NotFound(w, r) h.NotFound(w, r)
case nil: case nil:
h.ResultResponseJSON(w, r, http.StatusOK, item.Delete()) h.ResultResponseJSON(w, r, http.StatusOK, item.Delete())

View File

@@ -1,7 +1,6 @@
package handler package handler
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@@ -12,6 +11,7 @@ import (
"npm/internal/entity/setting" "npm/internal/entity/setting"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"gorm.io/gorm"
) )
// GetSettings will return a list of Settings // GetSettings will return a list of Settings
@@ -41,7 +41,7 @@ func GetSetting() func(http.ResponseWriter, *http.Request) {
item, err := setting.GetByName(name) item, err := setting.GetByName(name)
switch err { switch err {
case sql.ErrNoRows: case gorm.ErrRecordNotFound:
h.NotFound(w, r) h.NotFound(w, r)
case nil: case nil:
h.ResultResponseJSON(w, r, http.StatusOK, item) h.ResultResponseJSON(w, r, http.StatusOK, item)
@@ -81,7 +81,7 @@ func UpdateSetting() func(http.ResponseWriter, *http.Request) {
setting, err := setting.GetByName(settingName) setting, err := setting.GetByName(settingName)
switch err { switch err {
case sql.ErrNoRows: case gorm.ErrRecordNotFound:
h.NotFound(w, r) h.NotFound(w, r)
case nil: case nil:
bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte) bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte)

View File

@@ -1,7 +1,6 @@
package handler package handler
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@@ -10,6 +9,8 @@ import (
h "npm/internal/api/http" h "npm/internal/api/http"
"npm/internal/api/middleware" "npm/internal/api/middleware"
"npm/internal/entity/stream" "npm/internal/entity/stream"
"gorm.io/gorm"
) )
// GetStreams will return a list of Streams // GetStreams will return a list of Streams
@@ -44,7 +45,7 @@ func GetStream() func(http.ResponseWriter, *http.Request) {
item, err := stream.GetByID(hostID) item, err := stream.GetByID(hostID)
switch err { switch err {
case sql.ErrNoRows: case gorm.ErrRecordNotFound:
h.NotFound(w, r) h.NotFound(w, r)
case nil: case nil:
h.ResultResponseJSON(w, r, http.StatusOK, item) h.ResultResponseJSON(w, r, http.StatusOK, item)
@@ -93,7 +94,7 @@ func UpdateStream() func(http.ResponseWriter, *http.Request) {
host, err := stream.GetByID(hostID) host, err := stream.GetByID(hostID)
switch err { switch err {
case sql.ErrNoRows: case gorm.ErrRecordNotFound:
h.NotFound(w, r) h.NotFound(w, r)
case nil: case nil:
bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte) bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte)
@@ -128,7 +129,7 @@ func DeleteStream() func(http.ResponseWriter, *http.Request) {
item, err := stream.GetByID(hostID) item, err := stream.GetByID(hostID)
switch err { switch err {
case sql.ErrNoRows: case gorm.ErrRecordNotFound:
h.NotFound(w, r) h.NotFound(w, r)
case nil: case nil:
h.ResultResponseJSON(w, r, http.StatusOK, item.Delete()) h.ResultResponseJSON(w, r, http.StatusOK, item.Delete())

View File

@@ -1,7 +1,6 @@
package handler package handler
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@@ -15,6 +14,8 @@ import (
"npm/internal/logger" "npm/internal/logger"
"npm/internal/nginx" "npm/internal/nginx"
"npm/internal/validator" "npm/internal/validator"
"gorm.io/gorm"
) )
// GetUpstreams will return a list of Upstreams // GetUpstreams will return a list of Upstreams
@@ -49,7 +50,7 @@ func GetUpstream() func(http.ResponseWriter, *http.Request) {
item, err := upstream.GetByID(upstreamID) item, err := upstream.GetByID(upstreamID)
switch err { switch err {
case sql.ErrNoRows: case gorm.ErrRecordNotFound:
h.NotFound(w, r) h.NotFound(w, r)
case nil: case nil:
// nolint: errcheck,gosec // nolint: errcheck,gosec
@@ -149,7 +150,7 @@ func DeleteUpstream() func(http.ResponseWriter, *http.Request) {
item, err := upstream.GetByID(upstreamID) item, err := upstream.GetByID(upstreamID)
switch err { switch err {
case sql.ErrNoRows: case gorm.ErrRecordNotFound:
h.NotFound(w, r) h.NotFound(w, r)
case nil: case nil:
// Ensure that this upstream isn't in use by a host // Ensure that this upstream isn't in use by a host
@@ -180,7 +181,7 @@ func GetUpstreamNginxConfig(format string) func(http.ResponseWriter, *http.Reque
item, err := upstream.GetByID(upstreamID) item, err := upstream.GetByID(upstreamID)
switch err { switch err {
case sql.ErrNoRows: case gorm.ErrRecordNotFound:
h.NotFound(w, r) h.NotFound(w, r)
case nil: case nil:
// Get the config from disk // Get the config from disk

View File

@@ -1,7 +1,6 @@
package handler package handler
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"net/http" "net/http"
@@ -15,6 +14,7 @@ import (
"npm/internal/logger" "npm/internal/logger"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"gorm.io/gorm"
) )
// GetUsers returns all users // GetUsers returns all users
@@ -48,7 +48,7 @@ func GetUser() func(http.ResponseWriter, *http.Request) {
item, err := user.GetByID(userID) item, err := user.GetByID(userID)
switch err { switch err {
case sql.ErrNoRows: case gorm.ErrRecordNotFound:
h.NotFound(w, r) h.NotFound(w, r)
case nil: case nil:
// nolint: errcheck,gosec // nolint: errcheck,gosec
@@ -72,7 +72,7 @@ func UpdateUser() func(http.ResponseWriter, *http.Request) {
userObject, err := user.GetByID(userID) userObject, err := user.GetByID(userID)
switch err { switch err {
case sql.ErrNoRows: case gorm.ErrRecordNotFound:
h.NotFound(w, r) h.NotFound(w, r)
case nil: case nil:
// nolint: errcheck,gosec // nolint: errcheck,gosec
@@ -136,7 +136,7 @@ func DeleteUser() func(http.ResponseWriter, *http.Request) {
item, err := user.GetByID(userID) item, err := user.GetByID(userID)
switch err { switch err {
case sql.ErrNoRows: case gorm.ErrRecordNotFound:
h.NotFound(w, r) h.NotFound(w, r)
case nil: case nil:
if err := item.Delete(); err != nil { if err := item.Delete(); err != nil {

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
"slices"
c "npm/internal/api/context" c "npm/internal/api/context"
h "npm/internal/api/http" h "npm/internal/api/http"
@@ -11,7 +12,6 @@ import (
"npm/internal/entity/user" "npm/internal/entity/user"
njwt "npm/internal/jwt" njwt "npm/internal/jwt"
"npm/internal/logger" "npm/internal/logger"
"npm/internal/util"
"github.com/go-chi/jwtauth/v5" "github.com/go-chi/jwtauth/v5"
) )
@@ -35,7 +35,7 @@ func DecodeAuth() func(http.Handler) http.Handler {
// Enforce is a authentication middleware to enforce access from the // Enforce is a authentication middleware to enforce access from the
// jwtauth.Verifier middleware request context values. The Authenticator sends a 401 Unauthorised // jwtauth.Verifier middleware request context values. The Authenticator sends a 401 Unauthorised
// response for any unverified tokens and passes the good ones through. // response for any unverified tokens and passes the good ones through.
func Enforce(permission string) func(http.Handler) http.Handler { func Enforce(permissions ...string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
@@ -56,7 +56,7 @@ func Enforce(permission string) func(http.Handler) http.Handler {
} }
// Check if permissions exist for this user // Check if permissions exist for this user
if permission != "" { if len(permissions) > 0 {
// Since the permission that we require is not on the token, we have to get it from the DB // Since the permission that we require is not on the token, we have to get it from the DB
// So we don't go crazy with hits, we will use a memory cache // So we don't go crazy with hits, we will use a memory cache
cacheKey := fmt.Sprintf("userCapabilties.%v", userID) cacheKey := fmt.Sprintf("userCapabilties.%v", userID)
@@ -75,9 +75,16 @@ func Enforce(permission string) func(http.Handler) http.Handler {
// Now check that they have the permission in their admin capabilities // Now check that they have the permission in their admin capabilities
// full-admin can do anything // full-admin can do anything
if !util.SliceContainsItem(userCapabilities, user.CapabilityFullAdmin) && !util.SliceContainsItem(userCapabilities, permission) { hasOnePermission := false
for _, permission := range permissions {
if slices.Contains(userCapabilities, user.CapabilityFullAdmin) || slices.Contains(userCapabilities, permission) {
hasOnePermission = true
}
}
if !hasOnePermission {
// Access denied // Access denied
logger.Debug("User has: %+v but needs %s", userCapabilities, permission) logger.Debug("Enforce Failed: User has %v but needs %v", userCapabilities, permissions)
h.ResultErrorJSON(w, r, http.StatusForbidden, "Forbidden", nil) h.ResultErrorJSON(w, r, http.StatusForbidden, "Forbidden", nil)
return return
} }

View File

@@ -1,7 +1,6 @@
package middleware package middleware
import ( import (
"fmt"
"net/http" "net/http"
h "npm/internal/api/http" h "npm/internal/api/http"
@@ -9,15 +8,11 @@ import (
) )
// EnforceSetup will error if the config setup doesn't match what is required // EnforceSetup will error if the config setup doesn't match what is required
func EnforceSetup(shouldBeSetup bool) func(http.Handler) http.Handler { func EnforceSetup() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if config.IsSetup != shouldBeSetup { if !config.IsSetup {
state := "during" h.ResultErrorJSON(w, r, http.StatusForbidden, "Not available during setup phase", nil)
if config.IsSetup {
state = "after"
}
h.ResultErrorJSON(w, r, http.StatusForbidden, fmt.Sprintf("Not available %s setup phase", state), nil)
return return
} }

View File

@@ -18,41 +18,26 @@ func TestEnforceSetup(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
shouldBeSetup bool
isSetup bool isSetup bool
expectedCode int expectedCode int
}{ }{
{ {
name: "should allow request when setup is expected and is setup", name: "should allow request when setup is expected and is setup",
shouldBeSetup: true,
isSetup: true, isSetup: true,
expectedCode: http.StatusOK, expectedCode: http.StatusOK,
}, },
{ {
name: "should error when setup is expected but not setup", name: "should error when setup is expected but not setup",
shouldBeSetup: true,
isSetup: false, isSetup: false,
expectedCode: http.StatusForbidden, expectedCode: http.StatusForbidden,
}, },
{
name: "should allow request when setup is not expected and not setup",
shouldBeSetup: false,
isSetup: false,
expectedCode: http.StatusOK,
},
{
name: "should error when setup is not expected but is setup",
shouldBeSetup: false,
isSetup: true,
expectedCode: http.StatusForbidden,
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
config.IsSetup = tt.isSetup config.IsSetup = tt.isSetup
handler := middleware.EnforceSetup(tt.shouldBeSetup)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := middleware.EnforceSetup()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
})) }))

View File

@@ -23,7 +23,7 @@ import (
// to be used later in other endpoints. // to be used later in other endpoints.
func ListQuery(obj interface{}) func(http.Handler) http.Handler { func ListQuery(obj interface{}) func(http.Handler) http.Handler {
schemaData := tags.GetFilterSchema(obj) schemaData := tags.GetFilterSchema(obj)
filterMap := tags.GetFilterMap(obj) filterMap := tags.GetFilterMap(obj, "")
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

View File

@@ -64,48 +64,49 @@ func applyRoutes(r chi.Router) chi.Router {
// SSE - requires a sse token as the `jwt` get parameter // SSE - requires a sse token as the `jwt` get parameter
// Exists inside /api but it's here so that we can skip the Timeout middleware // Exists inside /api but it's here so that we can skip the Timeout middleware
// that applies to other endpoints. // that applies to other endpoints.
r.With(middleware.EnforceSetup(true), middleware.SSEAuth). r.With(middleware.EnforceSetup(), middleware.SSEAuth).
Mount("/api/sse", serverevents.Get()) Mount("/api/sse", serverevents.Get())
// API // API
r.With(chiMiddleware.Timeout(30*time.Second)).Route("/api", func(r chi.Router) { r.With(chiMiddleware.Timeout(30*time.Second)).Route("/api", func(r chi.Router) {
r.Get("/", handler.Health()) r.Get("/", handler.Health())
r.Get("/schema", handler.Schema()) r.Get("/schema", handler.Schema())
r.With(middleware.EnforceSetup(true), middleware.Enforce("")). r.With(middleware.EnforceSetup(), middleware.Enforce()).
Get("/config", handler.Config()) Get("/config", handler.Config())
// Tokens // Tokens
r.With(middleware.EnforceSetup(true)).Route("/tokens", func(r chi.Router) { r.With(middleware.EnforceSetup()).Route("/tokens", func(r chi.Router) {
r.With(middleware.EnforceRequestSchema(schema.GetToken())). r.With(middleware.EnforceRequestSchema(schema.GetToken())).
Post("/", handler.NewToken()) Post("/", handler.NewToken())
r.With(middleware.Enforce("")). r.With(middleware.Enforce()).
Get("/", handler.RefreshToken()) Get("/", handler.RefreshToken())
r.With(middleware.Enforce("")). r.With(middleware.Enforce()).
Post("/sse", handler.NewSSEToken()) Post("/sse", handler.NewSSEToken())
}) })
// Users // Users
r.Route("/users", func(r chi.Router) { r.Route("/users", func(r chi.Router) {
// Create - can be done in Setup stage as well // Create - can be done in Setup stage as well
r.With(middleware.Enforce(user.CapabilityUsersManage), middleware.EnforceRequestSchema(schema.CreateUser())). r.With(
Post("/", handler.CreateUser()) middleware.Enforce(user.CapabilityUsersManage),
middleware.EnforceRequestSchema(schema.CreateUser()),
).Post("/", handler.CreateUser())
// Requires Setup stage to be completed // Requires Setup stage to be completed
r.With(middleware.EnforceSetup(true)).Route("/", func(r chi.Router) { r.With(middleware.EnforceSetup()).Route("/", func(r chi.Router) {
// Get yourself, requires a login but no other permissions // Get yourself, requires a login but no other permissions
r.With(middleware.Enforce("")). r.With(middleware.Enforce()).
Get("/{userID:me}", handler.GetUser()) Get("/{userID:me}", handler.GetUser())
// Update yourself, requires a login but no other permissions // Update yourself, requires a login but no other permissions
r.With(middleware.Enforce(""), middleware.EnforceRequestSchema(schema.UpdateUser())). r.With(
Put("/{userID:me}", handler.UpdateUser()) middleware.Enforce(),
middleware.EnforceRequestSchema(schema.UpdateUser()),
).Put("/{userID:me}", handler.UpdateUser())
r.With(middleware.Enforce(user.CapabilityUsersManage)).Route("/", func(r chi.Router) { r.With(middleware.Enforce(user.CapabilityUsersManage)).Route("/", func(r chi.Router) {
// List // List
r.With( r.With(middleware.ListQuery(user.Model{})).Get("/", handler.GetUsers())
middleware.Enforce(user.CapabilityUsersManage),
middleware.ListQuery(user.Model{}),
).Get("/", handler.GetUsers())
// Specific Item // Specific Item
r.Get("/{userID:[0-9]+}", handler.GetUser()) r.Get("/{userID:[0-9]+}", handler.GetUser())
@@ -117,10 +118,14 @@ func applyRoutes(r chi.Router) chi.Router {
}) })
// Auth - sets passwords // Auth - sets passwords
r.With(middleware.Enforce(""), middleware.EnforceRequestSchema(schema.SetAuth())). r.With(
Post("/{userID:me}/auth", handler.SetAuth()) middleware.Enforce(),
r.With(middleware.Enforce(user.CapabilityUsersManage), middleware.EnforceRequestSchema(schema.SetAuth())). middleware.EnforceRequestSchema(schema.SetAuth()),
Post("/{userID:[0-9]+}/auth", handler.SetAuth()) ).Post("/{userID:me}/auth", handler.SetAuth())
r.With(
middleware.Enforce(user.CapabilityUsersManage),
middleware.EnforceRequestSchema(schema.SetAuth()),
).Post("/{userID:[0-9]+}/auth", handler.SetAuth())
}) })
}) })
@@ -133,7 +138,7 @@ func applyRoutes(r chi.Router) chi.Router {
} }
// Settings // Settings
r.With(middleware.EnforceSetup(true), middleware.Enforce(user.CapabilitySettingsManage)).Route("/settings", func(r chi.Router) { r.With(middleware.EnforceSetup(), middleware.Enforce(user.CapabilitySettingsManage)).Route("/settings", func(r chi.Router) {
// List // List
r.With( r.With(
middleware.ListQuery(setting.Model{}), middleware.ListQuery(setting.Model{}),
@@ -147,7 +152,7 @@ func applyRoutes(r chi.Router) chi.Router {
}) })
// Access Lists // Access Lists
r.With(middleware.EnforceSetup(true)).Route("/access-lists", func(r chi.Router) { r.With(middleware.EnforceSetup()).Route("/access-lists", func(r chi.Router) {
// List // List
r.With( r.With(
middleware.Enforce(user.CapabilityAccessListsView), middleware.Enforce(user.CapabilityAccessListsView),
@@ -171,7 +176,7 @@ func applyRoutes(r chi.Router) chi.Router {
}) })
// DNS Providers // DNS Providers
r.With(middleware.EnforceSetup(true)).Route("/dns-providers", func(r chi.Router) { r.With(middleware.EnforceSetup()).Route("/dns-providers", func(r chi.Router) {
// List // List
r.With( r.With(
middleware.Enforce(user.CapabilityDNSProvidersView), middleware.Enforce(user.CapabilityDNSProvidersView),
@@ -201,7 +206,7 @@ func applyRoutes(r chi.Router) chi.Router {
}) })
// Certificate Authorities // Certificate Authorities
r.With(middleware.EnforceSetup(true)).Route("/certificate-authorities", func(r chi.Router) { r.With(middleware.EnforceSetup()).Route("/certificate-authorities", func(r chi.Router) {
// List // List
r.With( r.With(
middleware.Enforce(user.CapabilityCertificateAuthoritiesView), middleware.Enforce(user.CapabilityCertificateAuthoritiesView),
@@ -231,7 +236,7 @@ func applyRoutes(r chi.Router) chi.Router {
}) })
// Certificates // Certificates
r.With(middleware.EnforceSetup(true)).Route("/certificates", func(r chi.Router) { r.With(middleware.EnforceSetup()).Route("/certificates", func(r chi.Router) {
// List // List
r.With( r.With(
middleware.Enforce(user.CapabilityCertificatesView), middleware.Enforce(user.CapabilityCertificatesView),
@@ -258,7 +263,7 @@ func applyRoutes(r chi.Router) chi.Router {
}) })
// Hosts // Hosts
r.With(middleware.EnforceSetup(true)).Route("/hosts", func(r chi.Router) { r.With(middleware.EnforceSetup()).Route("/hosts", func(r chi.Router) {
// List // List
r.With( r.With(
middleware.Enforce(user.CapabilityHostsView), middleware.Enforce(user.CapabilityHostsView),
@@ -284,7 +289,7 @@ func applyRoutes(r chi.Router) chi.Router {
}) })
// Nginx Templates // Nginx Templates
r.With(middleware.EnforceSetup(true)).Route("/nginx-templates", func(r chi.Router) { r.With(middleware.EnforceSetup()).Route("/nginx-templates", func(r chi.Router) {
// List // List
r.With( r.With(
middleware.Enforce(user.CapabilityNginxTemplatesView), middleware.Enforce(user.CapabilityNginxTemplatesView),
@@ -308,7 +313,7 @@ func applyRoutes(r chi.Router) chi.Router {
}) })
// Streams // Streams
r.With(middleware.EnforceSetup(true)).Route("/streams", func(r chi.Router) { r.With(middleware.EnforceSetup()).Route("/streams", func(r chi.Router) {
// List // List
r.With( r.With(
middleware.Enforce(user.CapabilityStreamsView), middleware.Enforce(user.CapabilityStreamsView),
@@ -332,7 +337,7 @@ func applyRoutes(r chi.Router) chi.Router {
}) })
// Upstreams // Upstreams
r.With(middleware.EnforceSetup(true)).Route("/upstreams", func(r chi.Router) { r.With(middleware.EnforceSetup()).Route("/upstreams", func(r chi.Router) {
// List // List
r.With( r.With(
middleware.Enforce(user.CapabilityHostsView), middleware.Enforce(user.CapabilityHostsView),

View File

@@ -0,0 +1,132 @@
package schema
import (
"bytes"
"encoding/json"
"npm/internal/entity/certificate"
"testing"
"github.com/stretchr/testify/assert"
)
func TestSchemas(t *testing.T) {
tests := []struct {
name string
schema string
}{
{
name: "CreateCertificate",
schema: CreateCertificate(),
},
{
name: "UpdateCertificate TypeHTTP",
schema: UpdateCertificate(certificate.TypeHTTP),
},
{
name: "UpdateCertificate TypeDNS",
schema: UpdateCertificate(certificate.TypeDNS),
},
{
name: "UpdateCertificate TypeCustom",
schema: UpdateCertificate(certificate.TypeCustom),
},
{
name: "UpdateCertificate TypeMkcert",
schema: UpdateCertificate(certificate.TypeMkcert),
},
{
name: "UpdateCertificate default",
schema: UpdateCertificate(""),
},
{
name: "CreateAccessList",
schema: CreateAccessList(),
},
{
name: "CreateCertificateAuthority",
schema: CreateCertificateAuthority(),
},
{
name: "CreateDNSProvider",
schema: CreateDNSProvider(),
},
{
name: "CreateHost",
schema: CreateHost(),
},
{
name: "CreateNginxTemplate",
schema: CreateNginxTemplate(),
},
{
name: "CreateSetting",
schema: CreateSetting(),
},
{
name: "CreateStream",
schema: CreateStream(),
},
{
name: "CreateUpstream",
schema: CreateUpstream(),
},
{
name: "CreateUser",
schema: CreateUser(),
},
{
name: "GetToken",
schema: GetToken(),
},
{
name: "SetAuth",
schema: SetAuth(),
},
{
name: "UpdateAccessList",
schema: UpdateAccessList(),
},
{
name: "UpdateCertificateAuthority",
schema: UpdateCertificateAuthority(),
},
{
name: "UpdateDNSProvider",
schema: UpdateDNSProvider(),
},
{
name: "UpdateHost",
schema: UpdateHost(),
},
{
name: "UpdateNginxTemplate",
schema: UpdateNginxTemplate(),
},
{
name: "UpdateSetting",
schema: UpdateSetting(),
},
{
name: "UpdateStream",
schema: UpdateStream(),
},
{
name: "UpdateUpstream",
schema: UpdateUpstream(),
},
{
name: "UpdateUser",
schema: UpdateUser(),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
byt := []byte(tt.schema)
var prettyJSON bytes.Buffer
err := json.Indent(&prettyJSON, byt, "", " ")
assert.NoError(t, err)
assert.Greater(t, len(prettyJSON.String()), 0)
})
}
}

View File

@@ -29,6 +29,7 @@ func CreateDataFolders() {
path = fmt.Sprintf("%s/%s", Configuration.DataFolder, folder) path = fmt.Sprintf("%s/%s", Configuration.DataFolder, folder)
} }
logger.Debug("Creating folder: %s", path) logger.Debug("Creating folder: %s", path)
// nolint: gosec
if err := os.MkdirAll(path, os.ModePerm); err != nil { if err := os.MkdirAll(path, os.ModePerm); err != nil {
logger.Error("CreateDataFolderError", err) logger.Error("CreateDataFolderError", err)
} }

View File

@@ -18,11 +18,12 @@ const (
// is for special cases where we run raw sql // is for special cases where we run raw sql
func QuoteTableName(tbl string) string { func QuoteTableName(tbl string) string {
switch strings.ToLower(config.Configuration.DB.Driver) { switch strings.ToLower(config.Configuration.DB.Driver) {
case config.DatabasePostgres: case config.DatabaseMysql:
return fmt.Sprintf(`"%s"`, tbl) // backticks for mysql
default:
// This is the same for Mysql and Sqlite
return fmt.Sprintf("`%s`", tbl) return fmt.Sprintf("`%s`", tbl)
default:
// double quotes for everything else
return fmt.Sprintf(`"%s"`, tbl)
} }
} }

View File

@@ -244,6 +244,7 @@ func (m *Model) Request() error {
certKeyFile, certFullchainFile, certFolder := m.GetCertificateLocations() certKeyFile, certFullchainFile, certFolder := m.GetCertificateLocations()
// ensure certFolder is created // ensure certFolder is created
// nolint: gosec
if err := os.MkdirAll(certFolder, os.ModePerm); err != nil { if err := os.MkdirAll(certFolder, os.ModePerm); err != nil {
logger.Error("CreateFolderError", err) logger.Error("CreateFolderError", err)
return err return err

View File

@@ -122,12 +122,12 @@ func (s *testsuite) TestList() {
defer goleak.VerifyNone(s.T(), goleak.IgnoreAnyFunction("database/sql.(*DB).connectionOpener")) defer goleak.VerifyNone(s.T(), goleak.IgnoreAnyFunction("database/sql.(*DB).connectionOpener"))
s.mock. s.mock.
ExpectQuery(regexp.QuoteMeta(`SELECT count(*) FROM "certificate_authority" WHERE name LIKE $1 AND "certificate_authority"."is_deleted" = $2`)). ExpectQuery(regexp.QuoteMeta(`SELECT count(*) FROM "certificate_authority" WHERE "certificate_authority"."name" LIKE $1 AND "certificate_authority"."is_deleted" = $2`)).
WithArgs("%test%", 0). WithArgs("%test%", 0).
WillReturnRows(s.listCountRows) WillReturnRows(s.listCountRows)
s.mock. s.mock.
ExpectQuery(regexp.QuoteMeta(`SELECT * FROM "certificate_authority" WHERE name LIKE $1 AND "certificate_authority"."is_deleted" = $2 ORDER BY name asc LIMIT $3`)). ExpectQuery(regexp.QuoteMeta(`SELECT * FROM "certificate_authority" WHERE "certificate_authority"."name" LIKE $1 AND "certificate_authority"."is_deleted" = $2 ORDER BY name asc LIMIT $3`)).
WithArgs("%test%", 0, 8). WithArgs("%test%", 0, 8).
WillReturnRows(s.listRows) WillReturnRows(s.listRows)

View File

@@ -204,12 +204,12 @@ func (s *testsuite) TestList() {
defer goleak.VerifyNone(s.T(), goleak.IgnoreAnyFunction("database/sql.(*DB).connectionOpener")) defer goleak.VerifyNone(s.T(), goleak.IgnoreAnyFunction("database/sql.(*DB).connectionOpener"))
s.mock. s.mock.
ExpectQuery(regexp.QuoteMeta(`SELECT count(*) FROM "dns_provider" WHERE acmesh_name LIKE $1 AND "dns_provider"."is_deleted" = $2`)). ExpectQuery(regexp.QuoteMeta(`SELECT count(*) FROM "dns_provider" WHERE "dns_provider"."acmesh_name" LIKE $1 AND "dns_provider"."is_deleted" = $2`)).
WithArgs("dns%", 0). WithArgs("dns%", 0).
WillReturnRows(s.listCountRows) WillReturnRows(s.listCountRows)
s.mock. s.mock.
ExpectQuery(regexp.QuoteMeta(`SELECT * FROM "dns_provider" WHERE acmesh_name LIKE $1 AND "dns_provider"."is_deleted" = $2 ORDER BY name asc LIMIT $3`)). ExpectQuery(regexp.QuoteMeta(`SELECT * FROM "dns_provider" WHERE "dns_provider"."acmesh_name" LIKE $1 AND "dns_provider"."is_deleted" = $2 ORDER BY name asc LIMIT $3`)).
WithArgs("dns%", 0, 8). WithArgs("dns%", 0, 8).
WillReturnRows(s.listRows) WillReturnRows(s.listRows)

View File

@@ -7,21 +7,23 @@ import (
// GetFilterMap returns the filter map // GetFilterMap returns the filter map
func GetFilterMap(m interface{}, includeBaseEntity bool) map[string]model.FilterMapValue { func GetFilterMap(m interface{}, includeBaseEntity bool) map[string]model.FilterMapValue {
filterMap := tags.GetFilterMap(m) filterMap := tags.GetFilterMap(m, "")
if includeBaseEntity {
return mergeFilterMaps(tags.GetFilterMap(model.ModelBase{}), filterMap) // TODO: this is done in GetFilterMap isn't it?
} // if includeBaseEntity {
// return mergeFilterMaps(tags.GetFilterMap(model.ModelBase{}, ""), filterMap)
// }
return filterMap return filterMap
} }
func mergeFilterMaps(m1 map[string]model.FilterMapValue, m2 map[string]model.FilterMapValue) map[string]model.FilterMapValue { // func mergeFilterMaps(m1 map[string]model.FilterMapValue, m2 map[string]model.FilterMapValue) map[string]model.FilterMapValue {
merged := make(map[string]model.FilterMapValue, 0) // merged := make(map[string]model.FilterMapValue, 0)
for k, v := range m1 { // for k, v := range m1 {
merged[k] = v // merged[k] = v
} // }
for key, value := range m2 { // for key, value := range m2 {
merged[key] = value // merged[key] = value
} // }
return merged // return merged
} // }

View File

@@ -258,7 +258,7 @@ func (s *testsuite) TestDeleteAll() {
defer goleak.VerifyNone(s.T(), goleak.IgnoreAnyFunction("database/sql.(*DB).connectionOpener")) defer goleak.VerifyNone(s.T(), goleak.IgnoreAnyFunction("database/sql.(*DB).connectionOpener"))
s.mock. s.mock.
ExpectExec(regexp.QuoteMeta("DELETE FROM `user` WHERE is_system = $1")). ExpectExec(regexp.QuoteMeta(`DELETE FROM "user" WHERE is_system = $1`)).
WithArgs(false). WithArgs(false).
WillReturnResult(sqlmock.NewResult(0, 1)) WillReturnResult(sqlmock.NewResult(0, 1))
@@ -307,12 +307,12 @@ func (s *testsuite) TestList() {
defer goleak.VerifyNone(s.T(), goleak.IgnoreAnyFunction("database/sql.(*DB).connectionOpener")) defer goleak.VerifyNone(s.T(), goleak.IgnoreAnyFunction("database/sql.(*DB).connectionOpener"))
s.mock. s.mock.
ExpectQuery(regexp.QuoteMeta(`SELECT count(*) FROM "user" WHERE name LIKE $1 AND "user"."is_deleted" = $2`)). ExpectQuery(regexp.QuoteMeta(`SELECT count(*) FROM "user" WHERE "user"."name" LIKE $1 AND "user"."is_deleted" = $2`)).
WithArgs("%jon%", 0). WithArgs("%jon%", 0).
WillReturnRows(s.listCountRows) WillReturnRows(s.listCountRows)
s.mock. s.mock.
ExpectQuery(regexp.QuoteMeta(`SELECT * FROM "user" WHERE name LIKE $1 AND "user"."is_deleted" = $2 ORDER BY name asc LIMIT $3`)). ExpectQuery(regexp.QuoteMeta(`SELECT * FROM "user" WHERE "user"."name" LIKE $1 AND "user"."is_deleted" = $2 ORDER BY name asc LIMIT $3`)).
WithArgs("%jon%", 0, 8). WithArgs("%jon%", 0, 8).
WillReturnRows(s.listRows) WillReturnRows(s.listRows)

View File

@@ -86,7 +86,6 @@ func List(pageInfo model.PageInfo, filters []model.Filter, expand []string) (ent
// DeleteAll will do just that, and should only be used for testing purposes. // DeleteAll will do just that, and should only be used for testing purposes.
func DeleteAll() error { func DeleteAll() error {
db := database.GetDB() db := database.GetDB()
// nolint errcheck
result := db.Exec(fmt.Sprintf(`DELETE FROM %s WHERE is_system = ?`, database.QuoteTableName("user")), false) result := db.Exec(fmt.Sprintf(`DELETE FROM %s WHERE is_system = ?`, database.QuoteTableName("user")), false)
return result.Error return result.Error
} }

View File

@@ -1,7 +1,6 @@
package jwt package jwt
import ( import (
"database/sql"
"regexp" "regexp"
"testing" "testing"
@@ -14,6 +13,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"go.uber.org/goleak" "go.uber.org/goleak"
"gorm.io/gorm"
) )
// +------------+ // +------------+
@@ -136,7 +136,7 @@ func (s *testsuite) TestLoadKeys() {
s.mock. s.mock.
ExpectQuery(regexp.QuoteMeta(`SELECT * FROM "jwt_keys" WHERE "jwt_keys"."is_deleted" = $1`)). ExpectQuery(regexp.QuoteMeta(`SELECT * FROM "jwt_keys" WHERE "jwt_keys"."is_deleted" = $1`)).
WithArgs(0). WithArgs(0).
WillReturnError(sql.ErrNoRows) WillReturnError(gorm.ErrRecordNotFound)
// insert row // insert row
s.mock.ExpectBegin() s.mock.ExpectBegin()

View File

@@ -6,6 +6,7 @@ import (
"regexp" "regexp"
"strings" "strings"
"npm/internal/database"
"npm/internal/logger" "npm/internal/logger"
"npm/internal/model" "npm/internal/model"
"npm/internal/util" "npm/internal/util"
@@ -13,24 +14,35 @@ import (
"github.com/rotisserie/eris" "github.com/rotisserie/eris"
) )
func GetFilterMap(m interface{}) map[string]model.FilterMapValue { func GetFilterMap(m interface{}, globalTablePrefix string) map[string]model.FilterMapValue {
name := getName(m) name := getName(m)
if val, exists := getCache(name); exists { filterMap := make(map[string]model.FilterMapValue)
return val
}
var filterMap = make(map[string]model.FilterMapValue)
// If this is an entity model (and it probably is)
// then include the base model as well
if strings.Contains(name, ".Model") && !strings.Contains(name, "ModelBase") {
filterMap = GetFilterMap(model.ModelBase{})
}
// TypeOf returns the reflection Type that represents the dynamic type of variable. // TypeOf returns the reflection Type that represents the dynamic type of variable.
// If variable is a nil interface value, TypeOf returns nil. // If variable is a nil interface value, TypeOf returns nil.
t := reflect.TypeOf(m) t := reflect.TypeOf(m)
// Get the table name from the model function, if it exists
if globalTablePrefix == "" {
v := reflect.ValueOf(m)
tableNameFunc, ok := t.MethodByName("TableName")
if ok {
n := tableNameFunc.Func.Call([]reflect.Value{v})
if len(n) > 0 {
globalTablePrefix = fmt.Sprintf(
`%s.`,
database.QuoteTableName(n[0].String()),
)
}
}
}
// If this is an entity model (and it probably is)
// then include the base model as well
if strings.Contains(name, ".Model") && !strings.Contains(name, "ModelBase") {
filterMap = GetFilterMap(model.ModelBase{}, globalTablePrefix)
}
if t.Kind() != reflect.Struct { if t.Kind() != reflect.Struct {
logger.Error("GetFilterMapError", eris.Errorf("%v type can't have attributes inspected", t.Kind())) logger.Error("GetFilterMapError", eris.Errorf("%v type can't have attributes inspected", t.Kind()))
return nil return nil
@@ -44,35 +56,41 @@ func GetFilterMap(m interface{}) map[string]model.FilterMapValue {
// Get the field tag value // Get the field tag value
filterTag := field.Tag.Get("filter") filterTag := field.Tag.Get("filter")
dbTag := field.Tag.Get("gorm") dbTag := field.Tag.Get("gorm")
// Filter -> Schema mapping
if filterTag != "" && filterTag != "-" {
f := model.FilterMapValue{ f := model.FilterMapValue{
Model: name, Model: name,
} }
// Filter -> Schema mapping
if filterTag != "" && filterTag != "-" {
f.Schema = getFilterTagSchema(filterTag) f.Schema = getFilterTagSchema(filterTag)
parts := strings.Split(filterTag, ",") parts := strings.Split(filterTag, ",")
// Filter -> DB Field mapping // Filter -> DB Field mapping
if dbTag != "" && dbTag != "-" { if dbTag != "" && dbTag != "-" {
// db can have many parts, we need to pull out the "column:value" part
f.Field = field.Name
r := regexp.MustCompile(`(?:^|;)column:([^;|$]+)(?:$|;)`)
if matches := r.FindStringSubmatch(dbTag); len(matches) > 1 {
f.Field = matches[1]
}
// Filter tag can be a 2 part thing: name,type // Filter tag can be a 2 part thing: name,type
// ie: account_id,integer // ie: account_id,integer
// So we need to split and use the first part // So we need to split and use the first part
tablePrefix := globalTablePrefix
if len(parts) > 1 { if len(parts) > 1 {
f.Type = parts[1] f.Type = parts[1]
if len(parts) > 2 {
tablePrefix = fmt.Sprintf(`"%s".`, parts[2])
} }
} }
// db can have many parts, we need to pull out the "column:value" part
f.Field = database.QuoteTableName(field.Name)
r := regexp.MustCompile(`(?:^|;)column:([^;|$]+)(?:$|;)`)
if matches := r.FindStringSubmatch(dbTag); len(matches) > 1 {
f.Field = fmt.Sprintf("%s%s", tablePrefix, database.QuoteTableName(matches[1]))
}
}
filterMap[parts[0]] = f filterMap[parts[0]] = f
} }
} }
setCache(name, filterMap)
return filterMap return filterMap
} }
@@ -111,7 +129,7 @@ func getFilterTagSchema(filterTag string) string {
// GetFilterSchema creates a jsonschema for validating filters, based on the model // GetFilterSchema creates a jsonschema for validating filters, based on the model
// object given and by reading the struct "filter" tags. // object given and by reading the struct "filter" tags.
func GetFilterSchema(m interface{}) string { func GetFilterSchema(m interface{}) string {
filterMap := GetFilterMap(m) filterMap := GetFilterMap(m, "")
schemas := make([]string, 0) schemas := make([]string, 0)
for _, f := range filterMap { for _, f := range filterMap {

View File

@@ -14,6 +14,7 @@ if [[ -n "$INCOMPLETE_COMMENTS" ]]; then
# RESULT=1 # RESULT=1
fi fi
echo -e "${YELLOW}golangci-lint ...${RESET}"
if ! golangci-lint run -E goimports ./...; then if ! golangci-lint run -E goimports ./...; then
exit 1 exit 1
fi fi

View File

@@ -4,8 +4,8 @@
# This file assumes that these scripts have been run first: # This file assumes that these scripts have been run first:
# - ./scripts/ci/build-frontend # - ./scripts/ci/build-frontend
FROM nginxproxymanager/testca as testca FROM nginxproxymanager/testca AS testca
FROM letsencrypt/pebble as pebbleca FROM letsencrypt/pebble AS pebbleca
FROM jc21/gotools:latest AS gobuild FROM jc21/gotools:latest AS gobuild
SHELL ["/bin/bash", "-o", "pipefail", "-c"] SHELL ["/bin/bash", "-o", "pipefail", "-c"]

View File

@@ -1,5 +1,5 @@
FROM nginxproxymanager/testca as testca FROM nginxproxymanager/testca AS testca
FROM letsencrypt/pebble as pebbleca FROM letsencrypt/pebble AS pebbleca
FROM nginxproxymanager/nginx-full:acmesh-golang FROM nginxproxymanager/nginx-full:acmesh-golang
LABEL maintainer="Jamie Curnow <jc@jc21.com>" LABEL maintainer="Jamie Curnow <jc@jc21.com>"

View File

@@ -8,7 +8,7 @@ BLUE='\E[1;34m'
GREEN='\E[1;32m' GREEN='\E[1;32m'
RESET='\E[0m' RESET='\E[0m'
S6_OVERLAY_VERSION=3.1.6.2 S6_OVERLAY_VERSION=3.2.0.0
TARGETPLATFORM=${1:-linux/amd64} TARGETPLATFORM=${1:-linux/amd64}
# Determine the correct binary file for the architecture given # Determine the correct binary file for the architecture given