Better checking for api sort param to prevent sql injection

And moved filters out and cached object reflection
This commit is contained in:
Jamie Curnow
2023-07-24 11:49:08 +10:00
parent 9b32329f41
commit a0e17f9678
12 changed files with 312 additions and 223 deletions

View File

@ -7,6 +7,8 @@ var (
UserIDCtxKey = &contextKey{"UserID"}
// FiltersCtxKey is the name of the Filters value on the context
FiltersCtxKey = &contextKey{"Filters"}
// SortCtxKey is the name of the Sort value on the context
SortCtxKey = &contextKey{"Sort"}
// PrettyPrintCtxKey is the name of the pretty print context
PrettyPrintCtxKey = &contextKey{"Pretty"}
// ExpansionCtxKey is the name of the expansion context

View File

@ -3,9 +3,9 @@ package handler
import (
"net/http"
"strconv"
"strings"
"npm/internal/api/context"
"npm/internal/api/middleware"
"npm/internal/model"
"github.com/go-chi/chi/v5"
@ -23,50 +23,11 @@ func getPageInfoFromRequest(r *http.Request) (model.PageInfo, error) {
return pageInfo, err
}
pageInfo.Sort = getSortParameter(r)
pageInfo.Sort = middleware.GetSortFromContext(r)
return pageInfo, nil
}
func getSortParameter(r *http.Request) []model.Sort {
var sortFields []model.Sort
queryValues := r.URL.Query()
sortString := queryValues.Get("sort")
if sortString == "" {
return sortFields
}
// Split sort fields up in to slice
sorts := strings.Split(sortString, ",")
for _, sortItem := range sorts {
if strings.Contains(sortItem, ".") {
theseItems := strings.Split(sortItem, ".")
switch strings.ToLower(theseItems[1]) {
case "desc":
fallthrough
case "descending":
theseItems[1] = "DESC"
default:
theseItems[1] = "ASC"
}
sortFields = append(sortFields, model.Sort{
Field: theseItems[0],
Direction: theseItems[1],
})
} else {
sortFields = append(sortFields, model.Sort{
Field: sortItem,
Direction: "ASC",
})
}
}
return sortFields
}
func getQueryVarInt(r *http.Request, varName string, required bool, defaultValue int) (int, error) {
queryValues := r.URL.Query()
varValue := queryValues.Get(varName)

View File

@ -9,7 +9,7 @@ import (
)
// Expansion will determine whether the request should have objects expanded
// with ?expand=1 or ?expand=true
// with ?expand=item,item
func Expansion(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
expandStr := r.URL.Query().Get("expand")

View File

@ -1,118 +0,0 @@
package middleware
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
c "npm/internal/api/context"
h "npm/internal/api/http"
"npm/internal/entity"
"npm/internal/model"
"npm/internal/util"
"github.com/qri-io/jsonschema"
)
// Filters will accept a pre-defined schemaData to validate against the GET query params
// passed in to this endpoint. This will ensure that the filters are not injecting SQL.
// After we have determined what the Filters are to be, they are saved on the Context
// to be used later in other endpoints.
func Filters(obj interface{}) func(http.Handler) http.Handler {
schemaData := entity.GetFilterSchema(obj, true)
reservedFilterKeys := []string{
"limit",
"offset",
"sort",
"order",
"expand",
"t", // This is used as a timestamp paramater in some clients and can be ignored
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var filters []model.Filter
for key, val := range r.URL.Query() {
key = strings.ToLower(key)
// Split out the modifier from the field name and set a default modifier
var keyParts []string
keyParts = strings.Split(key, ":")
if len(keyParts) == 1 {
// Default modifier
keyParts = append(keyParts, "equals")
}
// Only use this filter if it's not a reserved get param
if !util.SliceContainsItem(reservedFilterKeys, keyParts[0]) {
for _, valItem := range val {
// Check that the val isn't empty
if len(strings.TrimSpace(valItem)) > 0 {
valSlice := []string{valItem}
if keyParts[1] == "in" || keyParts[1] == "notin" {
valSlice = strings.Split(valItem, ",")
}
filters = append(filters, model.Filter{
Field: keyParts[0],
Modifier: keyParts[1],
Value: valSlice,
})
}
}
}
}
// Only validate schema if there are filters to validate
if len(filters) > 0 {
ctx := r.Context()
// Marshal the Filters in to a JSON string so that the Schema Validation works against it
filterData, marshalErr := json.MarshalIndent(filters, "", " ")
if marshalErr != nil {
h.ResultErrorJSON(w, r, http.StatusInternalServerError, fmt.Sprintf("Schema Fatal: %v", marshalErr), nil)
return
}
// Create root schema
rs := &jsonschema.Schema{}
if err := json.Unmarshal([]byte(schemaData), rs); err != nil {
h.ResultErrorJSON(w, r, http.StatusInternalServerError, fmt.Sprintf("Schema Fatal: %v", err), nil)
return
}
// Validate it
errors, jsonError := rs.ValidateBytes(ctx, filterData)
if jsonError != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, jsonError.Error(), nil)
return
}
if len(errors) > 0 {
h.ResultErrorJSON(w, r, http.StatusBadRequest, "Invalid Filters", errors)
return
}
// todo: populate filters object with the gorm database name
ctx = context.WithValue(ctx, c.FiltersCtxKey, filters)
next.ServeHTTP(w, r.WithContext(ctx))
} else {
next.ServeHTTP(w, r)
}
})
}
}
// GetFiltersFromContext returns the Filters
func GetFiltersFromContext(r *http.Request) []model.Filter {
filters, ok := r.Context().Value(c.FiltersCtxKey).([]model.Filter)
if !ok {
// the assertion failed
return nil
}
return filters
}

View File

@ -0,0 +1,196 @@
package middleware
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
c "npm/internal/api/context"
h "npm/internal/api/http"
"npm/internal/entity"
"npm/internal/model"
"npm/internal/tags"
"npm/internal/util"
"github.com/qri-io/jsonschema"
)
// ListQuery will accept a pre-defined schemaData to validate against the GET query params
// passed in to this endpoint. This will ensure that the filters are not injecting SQL
// and the sort parameter is valid as well.
// After we have determined what the Filters are to be, they are saved on the Context
// to be used later in other endpoints.
func ListQuery(obj interface{}) func(http.Handler) http.Handler {
schemaData := entity.GetFilterSchema(obj, true)
filterMap := tags.GetFilterMap(obj)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx, statusCode, errMsg, errors := listQueryFilters(r, ctx, schemaData)
if statusCode > 0 {
h.ResultErrorJSON(w, r, statusCode, errMsg, errors)
return
}
ctx, statusCode, errMsg = listQuerySort(r, filterMap, ctx)
if statusCode > 0 {
h.ResultErrorJSON(w, r, statusCode, errMsg, nil)
return
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
func listQuerySort(
r *http.Request,
filterMap map[string]model.FilterMapValue,
ctx context.Context,
) (context.Context, int, string) {
var sortFields []model.Sort
sortString := r.URL.Query().Get("sort")
if sortString == "" {
return ctx, 0, ""
}
// Split sort fields up in to slice
sorts := strings.Split(sortString, ",")
for _, sortItem := range sorts {
if strings.Contains(sortItem, ".") {
theseItems := strings.Split(sortItem, ".")
switch strings.ToLower(theseItems[1]) {
case "desc":
fallthrough
case "descending":
theseItems[1] = "DESC"
default:
theseItems[1] = "ASC"
}
sortFields = append(sortFields, model.Sort{
Field: theseItems[0],
Direction: theseItems[1],
})
} else {
sortFields = append(sortFields, model.Sort{
Field: sortItem,
Direction: "ASC",
})
}
}
// check against filter schema
for _, f := range sortFields {
if _, exists := filterMap[f.Field]; !exists {
return ctx, http.StatusBadRequest, "Invalid sort field"
}
}
ctx = context.WithValue(ctx, c.SortCtxKey, sortFields)
// No problems!
return ctx, 0, ""
}
func listQueryFilters(
r *http.Request,
ctx context.Context,
schemaData string,
) (context.Context, int, string, interface{}) {
reservedFilterKeys := []string{
"limit",
"offset",
"sort",
"expand",
"t", // This is used as a timestamp paramater in some clients and can be ignored
}
var filters []model.Filter
for key, val := range r.URL.Query() {
key = strings.ToLower(key)
// Split out the modifier from the field name and set a default modifier
var keyParts []string
keyParts = strings.Split(key, ":")
if len(keyParts) == 1 {
// Default modifier
keyParts = append(keyParts, "equals")
}
// Only use this filter if it's not a reserved get param
if !util.SliceContainsItem(reservedFilterKeys, keyParts[0]) {
for _, valItem := range val {
// Check that the val isn't empty
if len(strings.TrimSpace(valItem)) > 0 {
valSlice := []string{valItem}
if keyParts[1] == "in" || keyParts[1] == "notin" {
valSlice = strings.Split(valItem, ",")
}
filters = append(filters, model.Filter{
Field: keyParts[0],
Modifier: keyParts[1],
Value: valSlice,
})
}
}
}
}
// Only validate schema if there are filters to validate
if len(filters) > 0 {
// Marshal the Filters in to a JSON string so that the Schema Validation works against it
filterData, marshalErr := json.MarshalIndent(filters, "", " ")
if marshalErr != nil {
return ctx, http.StatusInternalServerError, fmt.Sprintf("Schema Fatal: %v", marshalErr), nil
}
// Create root schema
rs := &jsonschema.Schema{}
if err := json.Unmarshal([]byte(schemaData), rs); err != nil {
return ctx, http.StatusInternalServerError, fmt.Sprintf("Schema Fatal: %v", err), nil
}
// Validate it
errors, jsonError := rs.ValidateBytes(ctx, filterData)
if jsonError != nil {
return ctx, http.StatusBadRequest, jsonError.Error(), nil
}
if len(errors) > 0 {
return ctx, http.StatusBadRequest, "Invalid Filters", errors
}
ctx = context.WithValue(ctx, c.FiltersCtxKey, filters)
}
// No problems!
return ctx, 0, "", nil
}
// GetFiltersFromContext returns the Filters
func GetFiltersFromContext(r *http.Request) []model.Filter {
filters, ok := r.Context().Value(c.FiltersCtxKey).([]model.Filter)
if !ok {
// the assertion failed
return nil
}
return filters
}
// GetSortFromContext returns the Sort
func GetSortFromContext(r *http.Request) []model.Sort {
sorts, ok := r.Context().Value(c.SortCtxKey).([]model.Sort)
if !ok {
// the assertion failed
return nil
}
return sorts
}

View File

@ -104,7 +104,7 @@ func applyRoutes(r chi.Router) chi.Router {
// List
r.With(
middleware.Enforce(user.CapabilityUsersManage),
middleware.Filters(user.Model{}),
middleware.ListQuery(user.Model{}),
).Get("/", handler.GetUsers())
// Specific Item
@ -136,7 +136,7 @@ func applyRoutes(r chi.Router) chi.Router {
r.With(middleware.EnforceSetup(true), middleware.Enforce(user.CapabilitySettingsManage)).Route("/settings", func(r chi.Router) {
// List
r.With(
middleware.Filters(setting.Model{}),
middleware.ListQuery(setting.Model{}),
).Get("/", handler.GetSettings())
r.Get("/{name}", handler.GetSetting())
@ -151,7 +151,7 @@ func applyRoutes(r chi.Router) chi.Router {
// List
r.With(
middleware.Enforce(user.CapabilityAccessListsView),
middleware.Filters(accesslist.Model{}),
middleware.ListQuery(accesslist.Model{}),
).Get("/", handler.GetAccessLists())
// Create
@ -175,7 +175,7 @@ func applyRoutes(r chi.Router) chi.Router {
// List
r.With(
middleware.Enforce(user.CapabilityDNSProvidersView),
middleware.Filters(dnsprovider.Model{}),
middleware.ListQuery(dnsprovider.Model{}),
).Get("/", handler.GetDNSProviders())
// Create
@ -205,7 +205,7 @@ func applyRoutes(r chi.Router) chi.Router {
// List
r.With(
middleware.Enforce(user.CapabilityCertificateAuthoritiesView),
middleware.Filters(certificateauthority.Model{}),
middleware.ListQuery(certificateauthority.Model{}),
).Get("/", handler.GetCertificateAuthorities())
// Create
@ -235,7 +235,7 @@ func applyRoutes(r chi.Router) chi.Router {
// List
r.With(
middleware.Enforce(user.CapabilityCertificatesView),
middleware.Filters(certificate.Model{}),
middleware.ListQuery(certificate.Model{}),
).Get("/", handler.GetCertificates())
// Create
@ -262,7 +262,7 @@ func applyRoutes(r chi.Router) chi.Router {
// List
r.With(
middleware.Enforce(user.CapabilityHostsView),
middleware.Filters(host.Model{}),
middleware.ListQuery(host.Model{}),
).Get("/", handler.GetHosts())
// Create
@ -288,7 +288,7 @@ func applyRoutes(r chi.Router) chi.Router {
// List
r.With(
middleware.Enforce(user.CapabilityNginxTemplatesView),
middleware.Filters(nginxtemplate.Model{}),
middleware.ListQuery(nginxtemplate.Model{}),
).Get("/", handler.GetNginxTemplates())
// Create
@ -312,7 +312,7 @@ func applyRoutes(r chi.Router) chi.Router {
// List
r.With(
middleware.Enforce(user.CapabilityStreamsView),
middleware.Filters(stream.Model{}),
middleware.ListQuery(stream.Model{}),
).Get("/", handler.GetStreams())
// Create
@ -336,7 +336,7 @@ func applyRoutes(r chi.Router) chi.Router {
// List
r.With(
middleware.Enforce(user.CapabilityHostsView),
middleware.Filters(upstream.Model{}),
middleware.ListQuery(upstream.Model{}),
).Get("/", handler.GetUpstreams())
// Create