Convert db backend to use Gorm, with basis for support

for Mysql and Postgres in addition to existing Sqlite
This commit is contained in:
Jamie Curnow
2023-05-26 11:04:43 +10:00
parent b4e5b8b6db
commit 29990110b1
93 changed files with 1215 additions and 3075 deletions

View File

@ -35,7 +35,7 @@ func GetAccessLists() func(http.ResponseWriter, *http.Request) {
func GetAccessList() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var err error
var accessListID int
var accessListID uint
if accessListID, err = getURLParamInt(r, "accessListID"); err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return
@ -81,7 +81,7 @@ func CreateAccessList() func(http.ResponseWriter, *http.Request) {
func UpdateAccessList() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var err error
var accessListID int
var accessListID uint
if accessListID, err = getURLParamInt(r, "accessListID"); err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return
@ -113,7 +113,7 @@ func UpdateAccessList() func(http.ResponseWriter, *http.Request) {
func DeleteAccessList() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var err error
var accessListID int
var accessListID uint
if accessListID, err = getURLParamInt(r, "accessListID"); err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return

View File

@ -14,9 +14,9 @@ import (
)
type setAuthModel struct {
Type string `json:"type" db:"type"`
Secret string `json:"secret,omitempty" db:"secret"`
CurrentSecret string `json:"current_secret,omitempty"`
Type string
Secret string
CurrentSecret string
}
// SetAuth sets a auth method. This can be used for "me" and `2` for example

View File

@ -38,7 +38,7 @@ func GetCertificateAuthorities() func(http.ResponseWriter, *http.Request) {
func GetCertificateAuthority() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var err error
var caID int
var caID uint
if caID, err = getURLParamInt(r, "caID"); err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return
@ -92,7 +92,7 @@ func CreateCertificateAuthority() func(http.ResponseWriter, *http.Request) {
func UpdateCertificateAuthority() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var err error
var caID int
var caID uint
if caID, err = getURLParamInt(r, "caID"); err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return
@ -132,7 +132,7 @@ func UpdateCertificateAuthority() func(http.ResponseWriter, *http.Request) {
func DeleteCertificateAuthority() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var err error
var caID int
var caID uint
if caID, err = getURLParamInt(r, "caID"); err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return

View File

@ -55,7 +55,7 @@ func CreateCertificate() func(http.ResponseWriter, *http.Request) {
var item certificate.Model
if fillObjectFromBody(w, r, "", &item) {
// Get userID from token
userID, _ := r.Context().Value(c.UserIDCtxKey).(int)
userID, _ := r.Context().Value(c.UserIDCtxKey).(uint)
item.UserID = userID
if err := item.Save(); err != nil {
@ -131,7 +131,7 @@ func DownloadCertificate() func(http.ResponseWriter, *http.Request) {
// have a certificate id in the url. it will write errors to the output.
func getCertificateFromRequest(w http.ResponseWriter, r *http.Request) *certificate.Model {
var err error
var certificateID int
var certificateID uint
if certificateID, err = getURLParamInt(r, "certificateID"); err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return nil

View File

@ -38,7 +38,7 @@ func GetDNSProviders() func(http.ResponseWriter, *http.Request) {
func GetDNSProvider() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var err error
var providerID int
var providerID uint
if providerID, err = getURLParamInt(r, "providerID"); err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return
@ -87,7 +87,7 @@ func CreateDNSProvider() func(http.ResponseWriter, *http.Request) {
func UpdateDNSProvider() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var err error
var providerID int
var providerID uint
if providerID, err = getURLParamInt(r, "providerID"); err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return
@ -122,7 +122,7 @@ func UpdateDNSProvider() func(http.ResponseWriter, *http.Request) {
func DeleteDNSProvider() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var err error
var providerID int
var providerID uint
if providerID, err = getURLParamInt(r, "providerID"); err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return

View File

@ -4,7 +4,6 @@ import (
"net/http"
"strconv"
"strings"
"time"
"npm/internal/api/context"
"npm/internal/model"
@ -19,11 +18,6 @@ func getPageInfoFromRequest(r *http.Request) (model.PageInfo, error) {
var pageInfo model.PageInfo
var err error
pageInfo.FromDate, pageInfo.ToDate, err = getDateRanges(r)
if err != nil {
return pageInfo, err
}
pageInfo.Offset, pageInfo.Limit, err = getPagination(r)
if err != nil {
return pageInfo, err
@ -34,32 +28,6 @@ func getPageInfoFromRequest(r *http.Request) (model.PageInfo, error) {
return pageInfo, nil
}
func getDateRanges(r *http.Request) (time.Time, time.Time, error) {
queryValues := r.URL.Query()
from := queryValues.Get("from")
fromDate := time.Now().AddDate(0, -1, 0) // 1 month ago by default
to := queryValues.Get("to")
toDate := time.Now()
if from != "" {
var fromErr error
fromDate, fromErr = time.Parse(time.RFC3339, from)
if fromErr != nil {
return fromDate, toDate, eris.Errorf("From date is not in correct format: %v", strings.ReplaceAll(time.RFC3339, "Z", "+"))
}
}
if to != "" {
var toErr error
toDate, toErr = time.Parse(time.RFC3339, to)
if toErr != nil {
return fromDate, toDate, eris.Errorf("To date is not in correct format: %v", strings.ReplaceAll(time.RFC3339, "Z", "+"))
}
}
return fromDate, toDate, nil
}
func getSortParameter(r *http.Request) []model.Sort {
var sortFields []model.Sort
@ -132,12 +100,11 @@ func getQueryVarBool(r *http.Request, varName string, required bool, defaultValu
}
*/
func getURLParamInt(r *http.Request, varName string) (int, error) {
func getURLParamInt(r *http.Request, varName string) (uint, error) {
var defaultValue uint = 0
required := true
defaultValue := 0
paramStr := chi.URLParam(r, varName)
var err error
var paramInt int
if paramStr == "" && required {
return 0, eris.Errorf("%v was not supplied in the request", varName)
@ -145,11 +112,13 @@ func getURLParamInt(r *http.Request, varName string) (int, error) {
return defaultValue, nil
}
if paramInt, err = strconv.Atoi(paramStr); err != nil {
// func ParseUint(s string, base int, bitSize int) (n uint64, err error)
paramUint, err := strconv.ParseUint(paramStr, 10, 32)
if err != nil {
return 0, eris.Wrapf(err, "%v is not a valid number", varName)
}
return paramInt, nil
return uint(paramUint), nil
}
func getURLParamString(r *http.Request, varName string) (string, error) {

View File

@ -40,7 +40,7 @@ func GetHosts() func(http.ResponseWriter, *http.Request) {
func GetHost() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var err error
var hostID int
var hostID uint
if hostID, err = getURLParamInt(r, "hostID"); err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return
@ -74,7 +74,7 @@ func CreateHost() func(http.ResponseWriter, *http.Request) {
}
// Get userID from token
userID, _ := r.Context().Value(c.UserIDCtxKey).(int)
userID, _ := r.Context().Value(c.UserIDCtxKey).(uint)
newHost.UserID = userID
if err = validator.ValidateHost(newHost); err != nil {
@ -103,7 +103,7 @@ func CreateHost() func(http.ResponseWriter, *http.Request) {
func UpdateHost() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var err error
var hostID int
var hostID uint
if hostID, err = getURLParamInt(r, "hostID"); err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return
@ -148,7 +148,7 @@ func UpdateHost() func(http.ResponseWriter, *http.Request) {
func DeleteHost() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var err error
var hostID int
var hostID uint
if hostID, err = getURLParamInt(r, "hostID"); err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return
@ -173,7 +173,7 @@ func DeleteHost() func(http.ResponseWriter, *http.Request) {
func GetHostNginxConfig(format string) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var err error
var hostID int
var hostID uint
if hostID, err = getURLParamInt(r, "hostID"); err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return

View File

@ -36,7 +36,7 @@ func GetNginxTemplates() func(http.ResponseWriter, *http.Request) {
func GetNginxTemplate() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var err error
var templateID int
var templateID uint
if templateID, err = getURLParamInt(r, "templateID"); err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return
@ -85,7 +85,7 @@ func CreateNginxTemplate() func(http.ResponseWriter, *http.Request) {
func UpdateNginxTemplate() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var err error
var templateID int
var templateID uint
if templateID, err = getURLParamInt(r, "templateID"); err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return
@ -122,7 +122,7 @@ func UpdateNginxTemplate() func(http.ResponseWriter, *http.Request) {
func DeleteNginxTemplate() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var err error
var templateID int
var templateID uint
if templateID, err = getURLParamInt(r, "templateID"); err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return

View File

@ -36,7 +36,7 @@ func GetStreams() func(http.ResponseWriter, *http.Request) {
func GetStream() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var err error
var hostID int
var hostID uint
if hostID, err = getURLParamInt(r, "hostID"); err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return
@ -85,7 +85,7 @@ func CreateStream() func(http.ResponseWriter, *http.Request) {
func UpdateStream() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var err error
var hostID int
var hostID uint
if hostID, err = getURLParamInt(r, "hostID"); err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return
@ -120,7 +120,7 @@ func UpdateStream() func(http.ResponseWriter, *http.Request) {
func DeleteStream() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var err error
var hostID int
var hostID uint
if hostID, err = getURLParamInt(r, "hostID"); err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return

View File

@ -93,7 +93,7 @@ func RefreshToken() func(http.ResponseWriter, *http.Request) {
// Route: POST /tokens/sse
func NewSSEToken() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
userID := r.Context().Value(c.UserIDCtxKey).(int)
userID := r.Context().Value(c.UserIDCtxKey).(uint)
// Find user
userObj, userErr := user.GetByID(userID)

View File

@ -41,7 +41,7 @@ func GetUpstreams() func(http.ResponseWriter, *http.Request) {
func GetUpstream() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var err error
var upstreamID int
var upstreamID uint
if upstreamID, err = getURLParamInt(r, "upstreamID"); err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return
@ -75,7 +75,7 @@ func CreateUpstream() func(http.ResponseWriter, *http.Request) {
}
// Get userID from token
userID, _ := r.Context().Value(c.UserIDCtxKey).(int)
userID, _ := r.Context().Value(c.UserIDCtxKey).(uint)
newUpstream.UserID = userID
if err = validator.ValidateUpstream(newUpstream); err != nil {
@ -99,7 +99,7 @@ func CreateUpstream() func(http.ResponseWriter, *http.Request) {
func UpdateUpstream() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var err error
var upstreamID int
var upstreamID uint
if upstreamID, err = getURLParamInt(r, "upstreamID"); err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return
@ -141,7 +141,7 @@ func UpdateUpstream() func(http.ResponseWriter, *http.Request) {
func DeleteUpstream() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var err error
var upstreamID int
var upstreamID uint
if upstreamID, err = getURLParamInt(r, "upstreamID"); err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return
@ -172,7 +172,7 @@ func DeleteUpstream() func(http.ResponseWriter, *http.Request) {
func GetUpstreamNginxConfig(format string) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var err error
var upstreamID int
var upstreamID uint
if upstreamID, err = getURLParamInt(r, "upstreamID"); err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return

View File

@ -121,14 +121,14 @@ func UpdateUser() func(http.ResponseWriter, *http.Request) {
// Route: DELETE /users/{userID}
func DeleteUser() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var userID int
var userID uint
var err error
if userID, err = getURLParamInt(r, "userID"); err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
return
}
myUserID, _ := r.Context().Value(c.UserIDCtxKey).(int)
myUserID, _ := r.Context().Value(c.UserIDCtxKey).(uint)
if myUserID == userID {
h.ResultErrorJSON(w, r, http.StatusBadRequest, "You cannot delete yourself!", nil)
return
@ -224,11 +224,11 @@ func DeleteUsers() func(http.ResponseWriter, *http.Request) {
}
}
func getUserIDFromRequest(r *http.Request) (int, bool, error) {
func getUserIDFromRequest(r *http.Request) (uint, bool, error) {
userIDstr := chi.URLParam(r, "userID")
selfUserID, _ := r.Context().Value(c.UserIDCtxKey).(int)
selfUserID, _ := r.Context().Value(c.UserIDCtxKey).(uint)
var userID int
var userID uint
self := false
if userIDstr == "me" {
// Get user id from Token

View File

@ -48,7 +48,7 @@ func Enforce(permission string) func(http.Handler) http.Handler {
return
}
userID := int(claims["uid"].(float64))
userID := uint(claims["uid"].(float64))
_, enabled := user.IsEnabled(userID)
if token == nil || !token.Valid || !enabled {
h.ResultErrorJSON(w, r, http.StatusUnauthorized, "Unauthorised", nil)

View File

@ -21,7 +21,7 @@ func SSEAuth(next http.Handler) http.Handler {
return
}
userID := int(claims["uid"].(float64))
userID := uint(claims["uid"].(float64))
_, enabled := user.IsEnabled(userID)
if token == nil || !token.Valid || !enabled || !claims.VerifyIssuer("sse", true) {
h.ResultErrorJSON(w, r, http.StatusUnauthorized, "Unauthorised", nil)

View File

@ -8,15 +8,6 @@ import (
"npm/internal/api/middleware"
"npm/internal/api/schema"
"npm/internal/config"
"npm/internal/entity/accesslist"
"npm/internal/entity/certificate"
"npm/internal/entity/certificateauthority"
"npm/internal/entity/dnsprovider"
"npm/internal/entity/host"
"npm/internal/entity/nginxtemplate"
"npm/internal/entity/setting"
"npm/internal/entity/stream"
"npm/internal/entity/upstream"
"npm/internal/entity/user"
"npm/internal/logger"
"npm/internal/serverevents"
@ -102,7 +93,8 @@ func applyRoutes(r chi.Router) chi.Router {
r.With(middleware.Enforce(user.CapabilityUsersManage)).Route("/", func(r chi.Router) {
// List
r.With(middleware.Enforce(user.CapabilityUsersManage), middleware.Filters(user.GetFilterSchema())).
// r.With(middleware.Enforce(user.CapabilityUsersManage), middleware.Filters(user.GetFilterSchema())).
r.With(middleware.Enforce(user.CapabilityUsersManage)).
Get("/", handler.GetUsers())
// Specific Item
@ -132,8 +124,8 @@ func applyRoutes(r chi.Router) chi.Router {
// Settings
r.With(middleware.EnforceSetup(true), middleware.Enforce(user.CapabilitySettingsManage)).Route("/settings", func(r chi.Router) {
r.With(middleware.Filters(setting.GetFilterSchema())).
Get("/", handler.GetSettings())
// r.With(middleware.Filters(setting.GetFilterSchema())).
r.Get("/", handler.GetSettings())
r.Get("/{name}", handler.GetSetting())
r.With(middleware.EnforceRequestSchema(schema.CreateSetting())).
Post("/", handler.CreateSetting())
@ -144,7 +136,8 @@ func applyRoutes(r chi.Router) chi.Router {
// Access Lists
r.With(middleware.EnforceSetup(true)).Route("/access-lists", func(r chi.Router) {
// List
r.With(middleware.Filters(accesslist.GetFilterSchema()), middleware.Enforce(user.CapabilityAccessListsView)).
// r.With(middleware.Filters(accesslist.GetFilterSchema()), middleware.Enforce(user.CapabilityAccessListsView)).
r.With(middleware.Enforce(user.CapabilityAccessListsView)).
Get("/", handler.GetAccessLists())
// Create
@ -166,7 +159,8 @@ func applyRoutes(r chi.Router) chi.Router {
// DNS Providers
r.With(middleware.EnforceSetup(true)).Route("/dns-providers", func(r chi.Router) {
// List
r.With(middleware.Enforce(user.CapabilityDNSProvidersView), middleware.Filters(dnsprovider.GetFilterSchema())).
// r.With(middleware.Enforce(user.CapabilityDNSProvidersView), middleware.Filters(dnsprovider.GetFilterSchema())).
r.With(middleware.Enforce(user.CapabilityDNSProvidersView)).
Get("/", handler.GetDNSProviders())
// Create
@ -194,7 +188,8 @@ func applyRoutes(r chi.Router) chi.Router {
// Certificate Authorities
r.With(middleware.EnforceSetup(true)).Route("/certificate-authorities", func(r chi.Router) {
// List
r.With(middleware.Enforce(user.CapabilityCertificateAuthoritiesView), middleware.Filters(certificateauthority.GetFilterSchema())).
// r.With(middleware.Enforce(user.CapabilityCertificateAuthoritiesView), middleware.Filters(certificateauthority.GetFilterSchema())).
r.With(middleware.Enforce(user.CapabilityCertificateAuthoritiesView)).
Get("/", handler.GetCertificateAuthorities())
// Create
@ -216,7 +211,8 @@ func applyRoutes(r chi.Router) chi.Router {
// Certificates
r.With(middleware.EnforceSetup(true)).Route("/certificates", func(r chi.Router) {
// List
r.With(middleware.Enforce(user.CapabilityCertificatesView), middleware.Filters(certificate.GetFilterSchema())).
// r.With(middleware.Enforce(user.CapabilityCertificatesView), middleware.Filters(certificate.GetFilterSchema())).
r.With(middleware.Enforce(user.CapabilityCertificatesView)).
Get("/", handler.GetCertificates())
// Create
@ -241,7 +237,8 @@ func applyRoutes(r chi.Router) chi.Router {
// Hosts
r.With(middleware.EnforceSetup(true)).Route("/hosts", func(r chi.Router) {
// List
r.With(middleware.Enforce(user.CapabilityHostsView), middleware.Filters(host.GetFilterSchema())).
// r.With(middleware.Enforce(user.CapabilityHostsView), middleware.Filters(host.GetFilterSchema())).
r.With(middleware.Enforce(user.CapabilityHostsView)).
Get("/", handler.GetHosts())
// Create
@ -265,7 +262,8 @@ func applyRoutes(r chi.Router) chi.Router {
// Nginx Templates
r.With(middleware.EnforceSetup(true)).Route("/nginx-templates", func(r chi.Router) {
// List
r.With(middleware.Enforce(user.CapabilityNginxTemplatesView), middleware.Filters(nginxtemplate.GetFilterSchema())).
// r.With(middleware.Enforce(user.CapabilityNginxTemplatesView), middleware.Filters(nginxtemplate.GetFilterSchema())).
r.With(middleware.Enforce(user.CapabilityNginxTemplatesView)).
Get("/", handler.GetNginxTemplates())
// Create
@ -287,7 +285,8 @@ func applyRoutes(r chi.Router) chi.Router {
// Streams
r.With(middleware.EnforceSetup(true)).Route("/streams", func(r chi.Router) {
// List
r.With(middleware.Enforce(user.CapabilityStreamsView), middleware.Filters(stream.GetFilterSchema())).
// r.With(middleware.Enforce(user.CapabilityStreamsView), middleware.Filters(stream.GetFilterSchema())).
r.With(middleware.Enforce(user.CapabilityStreamsView)).
Get("/", handler.GetStreams())
// Create
@ -309,7 +308,8 @@ func applyRoutes(r chi.Router) chi.Router {
// Upstreams
r.With(middleware.EnforceSetup(true)).Route("/upstreams", func(r chi.Router) {
// List
r.With(middleware.Enforce(user.CapabilityHostsView), middleware.Filters(upstream.GetFilterSchema())).
// r.With(middleware.Enforce(user.CapabilityHostsView), middleware.Filters(upstream.GetFilterSchema())).
r.With(middleware.Enforce(user.CapabilityHostsView)).
Get("/", handler.GetUpstreams())
// Create