diff --git a/backend/go.mod b/backend/go.mod index 65789d94..11c8de16 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -69,6 +69,7 @@ require ( github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/segmentio/asm v1.2.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect golang.org/x/mod v0.17.0 // indirect golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.25.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index d220e47c..df501867 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -163,6 +163,8 @@ github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAm github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= diff --git a/backend/internal/entity/auth/entity_test.go b/backend/internal/entity/auth/entity_test.go index 29a4f39e..8cf612a1 100644 --- a/backend/internal/entity/auth/entity_test.go +++ b/backend/internal/entity/auth/entity_test.go @@ -92,6 +92,21 @@ func (s *testsuite) TestGetByUserIDType() { assertModel(s.T(), m) } +func (s *testsuite) TestGetByIdenityType() { + // goleak is used to detect goroutine leaks + defer goleak.VerifyNone(s.T(), goleak.IgnoreAnyFunction("database/sql.(*DB).connectionOpener")) + + s.mock. + ExpectQuery(regexp.QuoteMeta(`SELECT * FROM "auth" WHERE identity = $1 AND type = $2 AND "auth"."is_deleted" = $3 ORDER BY "auth"."id" LIMIT $4`)). + WithArgs("johndoe", TypeLocal, 0, 1). + WillReturnRows(s.singleRow) + + m, err := GetByIdenityType("johndoe", TypeLocal) + require.NoError(s.T(), err) + require.NoError(s.T(), s.mock.ExpectationsWereMet()) + assertModel(s.T(), m) +} + func (s *testsuite) TestSave() { // goleak is used to detect goroutine leaks defer goleak.VerifyNone(s.T(), goleak.IgnoreAnyFunction("database/sql.(*DB).connectionOpener")) diff --git a/backend/internal/entity/auth/oauth.go b/backend/internal/entity/auth/oauth.go index 1507690d..0866a732 100644 --- a/backend/internal/entity/auth/oauth.go +++ b/backend/internal/entity/auth/oauth.go @@ -17,7 +17,10 @@ import ( ) // AuthCache is a cache item that stores the Admin API data for each admin that has been requesting endpoints -var OAuthCache *cache.Cache +var ( + OAuthCache *cache.Cache + settingGetOAuthSettings = setting.GetOAuthSettings +) // OAuthCacheInit will create a new Memory Cache func OAuthCacheInit() { @@ -34,8 +37,7 @@ type OAuthUser struct { Resource map[string]interface{} `json:"resource"` } -// GetEmail will return an email address even if it can't be known in the -// Resource +// GetResourceField will attempt to get a field from the resource func (m *OAuthUser) GetResourceField(field string) string { if m.Resource != nil { if value, ok := m.Resource[field]; ok { @@ -45,8 +47,7 @@ func (m *OAuthUser) GetResourceField(field string) string { return "" } -// GetEmail will return an email address even if it can't be known in the -// Resource +// GetID attempts to get an ID from the resource func (m *OAuthUser) GetID() string { if m.Identifier != "" { return m.Identifier @@ -110,7 +111,7 @@ func (m *OAuthUser) GetEmail() string { } func getOAuth2Config() (*oauth2.Config, *setting.OAuthSettings, error) { - oauthSettings, err := setting.GetOAuthSettings() + oauthSettings, err := settingGetOAuthSettings() if err != nil { return nil, nil, err } @@ -130,7 +131,8 @@ func getOAuth2Config() (*oauth2.Config, *setting.OAuthSettings, error) { }, &oauthSettings, nil } -// OAuthLogin ... +// OAuthLogin is hit by the client to generate a URL to redirect to +// and start the oauth process func OAuthLogin(redirectBase, ipAddress string) (string, error) { OAuthCacheInit() diff --git a/backend/internal/entity/auth/oauth_test.go b/backend/internal/entity/auth/oauth_test.go new file mode 100644 index 00000000..6b9729c0 --- /dev/null +++ b/backend/internal/entity/auth/oauth_test.go @@ -0,0 +1,430 @@ +package auth + +import ( + "context" + "testing" + + "npm/internal/entity/setting" + + cache "github.com/patrickmn/go-cache" + "github.com/rotisserie/eris" + "github.com/stretchr/testify/assert" +) + +func TestGetOAuth2Config(t *testing.T) { + tests := []struct { + name string + mockSettings setting.OAuthSettings + expectedError error + }{ + { + name: "Valid settings", + mockSettings: setting.OAuthSettings{ + ClientID: "valid-client-id", + ClientSecret: "valid-client-secret", + AuthURL: "https://auth.url", + TokenURL: "https://token.url", + Scopes: []string{"scope1", "scope2"}, + }, + expectedError: nil, + }, + { + name: "Missing ClientID", + mockSettings: setting.OAuthSettings{ + ClientSecret: "valid-client-secret", + AuthURL: "https://auth.url", + TokenURL: "https://token.url", + Scopes: []string{"scope1", "scope2"}, + }, + expectedError: eris.New("oauth-settings-incorrect"), + }, + { + name: "Missing ClientSecret", + mockSettings: setting.OAuthSettings{ + ClientID: "valid-client-id", + AuthURL: "https://auth.url", + TokenURL: "https://token.url", + Scopes: []string{"scope1", "scope2"}, + }, + expectedError: eris.New("oauth-settings-incorrect"), + }, + { + name: "Missing AuthURL", + mockSettings: setting.OAuthSettings{ + ClientID: "valid-client-id", + ClientSecret: "valid-client-secret", + TokenURL: "https://token.url", + Scopes: []string{"scope1", "scope2"}, + }, + expectedError: eris.New("oauth-settings-incorrect"), + }, + { + name: "Missing TokenURL", + mockSettings: setting.OAuthSettings{ + ClientID: "valid-client-id", + ClientSecret: "valid-client-secret", + AuthURL: "https://auth.url", + Scopes: []string{"scope1", "scope2"}, + }, + expectedError: eris.New("oauth-settings-incorrect"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Mock the GetOAuthSettings function + settingGetOAuthSettings = func() (setting.OAuthSettings, error) { + return tt.mockSettings, nil + } + + config, settings, err := getOAuth2Config() + + if tt.expectedError != nil { + assert.Error(t, err) + assert.Equal(t, tt.expectedError.Error(), err.Error()) + } else { + assert.NoError(t, err) + assert.NotNil(t, config) + assert.NotNil(t, settings) + assert.Equal(t, tt.mockSettings.ClientID, config.ClientID) + assert.Equal(t, tt.mockSettings.ClientSecret, config.ClientSecret) + assert.Equal(t, tt.mockSettings.AuthURL, config.Endpoint.AuthURL) + assert.Equal(t, tt.mockSettings.TokenURL, config.Endpoint.TokenURL) + assert.Equal(t, tt.mockSettings.Scopes, config.Scopes) + } + }) + } +} + +func TestGetEmail(t *testing.T) { + tests := []struct { + name string + oauthUser OAuthUser + expected string + }{ + { + name: "Email in resource", + oauthUser: OAuthUser{ + Resource: map[string]interface{}{ + "email": "user@example.com", + }, + }, + expected: "user@example.com", + }, + { + name: "Identifier is email", + oauthUser: OAuthUser{ + Identifier: "user@example.com", + }, + expected: "user@example.com", + }, + { + name: "Identifier is not email", + oauthUser: OAuthUser{ + Identifier: "user123", + }, + expected: "user123@oauth", + }, + { + name: "No email or identifier", + oauthUser: OAuthUser{ + Resource: map[string]interface{}{}, + }, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + email := tt.oauthUser.GetEmail() + assert.Equal(t, tt.expected, email) + }) + } +} + +func TestGetName(t *testing.T) { + tests := []struct { + name string + oauthUser OAuthUser + expected string + }{ + { + name: "Nickname in resource", + oauthUser: OAuthUser{ + Resource: map[string]interface{}{ + "nickname": "user_nick", + }, + }, + expected: "user_nick", + }, + { + name: "Given name in resource", + oauthUser: OAuthUser{ + Resource: map[string]interface{}{ + "given_name": "User Given", + }, + }, + expected: "User Given", + }, + { + name: "Name in resource", + oauthUser: OAuthUser{ + Resource: map[string]interface{}{ + "name": "User Name", + }, + }, + expected: "User Name", + }, + { + name: "Preferred username in resource", + oauthUser: OAuthUser{ + Resource: map[string]interface{}{ + "preferred_username": "preferred_user", + }, + }, + expected: "preferred_user", + }, + { + name: "Username in resource", + oauthUser: OAuthUser{ + Resource: map[string]interface{}{ + "username": "user123", + }, + }, + expected: "user123", + }, + { + name: "No name fields in resource, fallback to identifier", + oauthUser: OAuthUser{ + Identifier: "fallback_identifier", + Resource: map[string]interface{}{}, + }, + expected: "fallback_identifier", + }, + { + name: "No name fields and no identifier", + oauthUser: OAuthUser{ + Resource: map[string]interface{}{}, + }, + expected: "", + }, + { + name: "All fields", + oauthUser: OAuthUser{ + Identifier: "fallback_identifier", + Resource: map[string]interface{}{ + "nickname": "user_nick", + "given_name": "User Given", + "name": "User Name", + "preferred_username": "preferred_user", + "username": "user123", + }, + }, + expected: "user_nick", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + name := tt.oauthUser.GetName() + assert.Equal(t, tt.expected, name) + }) + } +} + +func TestGetID(t *testing.T) { + tests := []struct { + name string + oauthUser OAuthUser + expected string + }{ + { + name: "Identifier is set", + oauthUser: OAuthUser{ + Identifier: "user123", + }, + expected: "user123", + }, + { + name: "UID in resource", + oauthUser: OAuthUser{ + Resource: map[string]interface{}{ + "uid": "uid123", + }, + }, + expected: "uid123", + }, + { + name: "User ID in resource", + oauthUser: OAuthUser{ + Resource: map[string]interface{}{ + "user_id": "user_id123", + }, + }, + expected: "user_id123", + }, + { + name: "Username in resource", + oauthUser: OAuthUser{ + Resource: map[string]interface{}{ + "username": "username123", + }, + }, + expected: "username123", + }, + { + name: "Preferred username in resource", + oauthUser: OAuthUser{ + Resource: map[string]interface{}{ + "preferred_username": "preferred_user", + }, + }, + expected: "preferred_user", + }, + { + name: "Email in resource", + oauthUser: OAuthUser{ + Resource: map[string]interface{}{ + "email": "user@example.com", + }, + }, + expected: "user@example.com", + }, + { + name: "Mail in resource", + oauthUser: OAuthUser{ + Resource: map[string]interface{}{ + "mail": "mail@example.com", + }, + }, + expected: "mail@example.com", + }, + { + name: "No identifier or resource fields", + oauthUser: OAuthUser{ + Resource: map[string]interface{}{}, + }, + expected: "", + }, + { + name: "All fields", + oauthUser: OAuthUser{ + Identifier: "user123", + Resource: map[string]interface{}{ + "uid": "uid123", + "user_id": "user_id123", + "username": "username123", + "preferred_username": "preferred_user", + "mail": "mail@example.com", + "email": "email@example.com", + }, + }, + expected: "user123", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + id := tt.oauthUser.GetID() + assert.Equal(t, tt.expected, id) + }) + } +} + +func TestOAuthLogin(t *testing.T) { + tests := []struct { + name string + redirectBase string + ipAddress string + expectedError error + }{ + { + name: "Valid redirect base", + redirectBase: "https://redirect.base", + ipAddress: "127.0.0.1", + expectedError: nil, + }, + { + name: "Empty redirect base", + redirectBase: "", + ipAddress: "127.0.0.1", + expectedError: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Mock the GetOAuthSettings function + settingGetOAuthSettings = func() (setting.OAuthSettings, error) { + return setting.OAuthSettings{ + ClientID: "valid-client-id", + ClientSecret: "valid-client-secret", + AuthURL: "https://auth.url", + TokenURL: "https://token.url", + Scopes: []string{"scope1", "scope2"}, + }, nil + } + + url, err := OAuthLogin(tt.redirectBase, tt.ipAddress) + + if tt.expectedError != nil { + assert.Error(t, err) + assert.Equal(t, tt.expectedError.Error(), err.Error()) + } else { + assert.NoError(t, err) + assert.NotEmpty(t, url) + } + }) + } +} + +func TestOAuthReturn(t *testing.T) { + var errNotFound = eris.New("oauth-verifier-not-found") + tests := []struct { + name string + code string + ipAddress string + expectedError error + }{ + { + name: "Invalid code", + code: "invalid-code", + ipAddress: "127.0.0.100", + expectedError: errNotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Mock the GetOAuthSettings function + settingGetOAuthSettings = func() (setting.OAuthSettings, error) { + return setting.OAuthSettings{ + ClientID: "valid-client-id", + ClientSecret: "valid-client-secret", + AuthURL: "https://auth.url", + TokenURL: "https://token.url", + Scopes: []string{"scope1", "scope2"}, + ResourceURL: "https://resource.url", + Identifier: "id", + }, nil + } + + // Initialise the cache and set a verifier + OAuthCacheInit() + if tt.expectedError != errNotFound { + OAuthCache.Set(getCacheKey(tt.ipAddress), "valid-verifier", cache.DefaultExpiration) + } + + ctx := context.Background() + user, err := OAuthReturn(ctx, tt.code, tt.ipAddress) + + if tt.expectedError != nil { + assert.Error(t, err) + assert.Equal(t, tt.expectedError.Error(), err.Error()) + } else { + assert.NoError(t, err) + assert.NotNil(t, user) + } + }) + } +} diff --git a/backend/internal/entity/filters.go b/backend/internal/entity/filters.go index 7b94e75a..16617660 100644 --- a/backend/internal/entity/filters.go +++ b/backend/internal/entity/filters.go @@ -16,14 +16,3 @@ func GetFilterMap(m interface{}, includeBaseEntity bool) map[string]model.Filter return filterMap } - -// func mergeFilterMaps(m1 map[string]model.FilterMapValue, m2 map[string]model.FilterMapValue) map[string]model.FilterMapValue { -// merged := make(map[string]model.FilterMapValue, 0) -// for k, v := range m1 { -// merged[k] = v -// } -// for key, value := range m2 { -// merged[key] = value -// } -// return merged -// } diff --git a/backend/internal/entity/scopes_test.go b/backend/internal/entity/scopes_test.go new file mode 100644 index 00000000..b0c61813 --- /dev/null +++ b/backend/internal/entity/scopes_test.go @@ -0,0 +1,33 @@ +package entity + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseBoolValue(t *testing.T) { + tests := []struct { + input string + expected []string + }{ + {"yes", []string{"1"}}, + {"true", []string{"1"}}, + {"on", []string{"1"}}, + {"t", []string{"1"}}, + {"1", []string{"1"}}, + {"y", []string{"1"}}, + {"no", []string{"0"}}, + {"false", []string{"0"}}, + {"off", []string{"0"}}, + {"f", []string{"0"}}, + {"0", []string{"0"}}, + {"n", []string{"0"}}, + {"random", []string{"0"}}, + } + + for _, test := range tests { + result := parseBoolValue(test.input) + assert.Equal(t, test.expected, result, "Input: %s", test.input) + } +} diff --git a/backend/internal/jobqueue/main.go b/backend/internal/jobqueue/main.go index bed99953..9e52129f 100644 --- a/backend/internal/jobqueue/main.go +++ b/backend/internal/jobqueue/main.go @@ -34,6 +34,8 @@ func Shutdown() error { return eris.New("Unable to shutdown, jobqueue has not been started") } cancel() + worker = nil + cancel = nil return nil } diff --git a/backend/internal/jobqueue/main_test.go b/backend/internal/jobqueue/main_test.go new file mode 100644 index 00000000..8a781d3c --- /dev/null +++ b/backend/internal/jobqueue/main_test.go @@ -0,0 +1,66 @@ +package jobqueue + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/rotisserie/eris" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type MockJob struct { + done chan bool +} + +func (m *MockJob) Execute() { + time.Sleep(1 * time.Second) + m.done <- true +} + +func TestStart(t *testing.T) { + Start() + assert.NotNil(t, ctx, "Context should not be nil after Start") + assert.NotNil(t, cancel, "Cancel function should not be nil after Start") + assert.NotNil(t, worker, "Worker should not be nil after Start") + Shutdown() +} + +func TestShutdown(t *testing.T) { + Start() + err := Shutdown() + require.Nil(t, err, "Shutdown should not return an error when jobqueue is started") + + select { + case <-ctx.Done(): + switch ctx.Err() { + case context.DeadlineExceeded: + fmt.Println("context timeout exceeded") + case context.Canceled: + fmt.Println("context cancelled by force. whole process is complete") + default: + require.Nil(t, ctx.Err(), "Context done state has unexpected value") + } + } + + require.Nil(t, cancel, "Cancel function should be nil after Shutdown") + require.Nil(t, worker, "Worker should be nil after Shutdown") + + err = Shutdown() + require.NotNil(t, err, "Shutdown should return an error when jobqueue is not started") + require.Equal(t, eris.New("Unable to shutdown, jobqueue has not been started").Error(), err.Error()) +} + +func TestAddJobWithoutStart(t *testing.T) { + mockJob := Job{ + Name: "mockJob", + Action: func() error { + return nil + }, + } + err := AddJob(mockJob) + assert.NotNil(t, err, "AddJob should return an error when jobqueue is not started") + assert.Equal(t, eris.New("Unable to add job, jobqueue has not been started").Error(), err.Error()) +} diff --git a/backend/internal/jobqueue/models.go b/backend/internal/jobqueue/models.go index 7d1c60e4..a7ea2035 100644 --- a/backend/internal/jobqueue/models.go +++ b/backend/internal/jobqueue/models.go @@ -16,6 +16,7 @@ type Queue struct { type Job struct { Name string Action func() error // A function that should be executed when the job is running. + Done chan bool // A channel that should be closed when the job is done. } // AddJobs adds jobs to the queue and cancels channel. @@ -44,11 +45,13 @@ func (q *Queue) AddJob(job Job) { } // Run performs job execution. -func (j Job) Run() error { +func (j *Job) Run() error { err := j.Action() if err != nil { + j.Done <- true return err } + j.Done <- true return nil } diff --git a/backend/internal/validator/hosts.go b/backend/internal/validator/hosts.go index 0eb191a8..685a18b4 100644 --- a/backend/internal/validator/hosts.go +++ b/backend/internal/validator/hosts.go @@ -9,6 +9,12 @@ import ( "github.com/rotisserie/eris" ) +var ( + certificateGetByID = certificate.GetByID + upstreamGetByID = upstream.GetByID + nginxtemplateGetByID = nginxtemplate.GetByID +) + // ValidateHost will check if associated objects exist and other checks // will return a nil error if things are OK func ValidateHost(h host.Model) error { @@ -16,14 +22,14 @@ func ValidateHost(h host.Model) error { // Check certificate exists and is valid // This will not determine if the certificate is Ready to use, // as this validation only cares that the row exists. - if _, cErr := certificate.GetByID(h.CertificateID.Uint); cErr != nil { + if _, cErr := certificateGetByID(h.CertificateID.Uint); cErr != nil { return eris.Wrapf(cErr, "Certificate #%d does not exist", h.CertificateID.Uint) } } if h.UpstreamID.Uint > 0 { // Check upstream exists - if _, uErr := upstream.GetByID(h.UpstreamID.Uint); uErr != nil { + if _, uErr := upstreamGetByID(h.UpstreamID.Uint); uErr != nil { return eris.Wrapf(uErr, "Upstream #%d does not exist", h.UpstreamID.Uint) } } @@ -37,7 +43,7 @@ func ValidateHost(h host.Model) error { } // Check the nginx template exists and has the same type. - nginxTemplate, tErr := nginxtemplate.GetByID(h.NginxTemplateID) + nginxTemplate, tErr := nginxtemplateGetByID(h.NginxTemplateID) if tErr != nil { return eris.Wrapf(tErr, "Host Template #%d does not exist", h.NginxTemplateID) } diff --git a/backend/internal/validator/hosts_test.go b/backend/internal/validator/hosts_test.go new file mode 100644 index 00000000..94676292 --- /dev/null +++ b/backend/internal/validator/hosts_test.go @@ -0,0 +1,145 @@ +package validator + +import ( + "testing" + + "npm/internal/entity/certificate" + "npm/internal/entity/host" + "npm/internal/entity/nginxtemplate" + "npm/internal/entity/upstream" + "npm/internal/types" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "gorm.io/gorm" +) + +// Mocking the dependencies +type MockCertificate struct { + mock.Mock +} + +func (m *MockCertificate) GetByID(id uint) (certificate.Model, error) { + args := m.Called(id) + return args.Get(0).(certificate.Model), args.Error(1) +} + +type MockUpstream struct { + mock.Mock +} + +func (m *MockUpstream) GetByID(id uint) (upstream.Model, error) { + args := m.Called(id) + return args.Get(0).(upstream.Model), args.Error(1) +} + +type MockNginxTemplate struct { + mock.Mock +} + +func (m *MockNginxTemplate) GetByID(id uint) (nginxtemplate.Model, error) { + args := m.Called(id) + return args.Get(0).(nginxtemplate.Model), args.Error(1) +} + +func TestValidateHost(t *testing.T) { + tests := []struct { + name string + host host.Model + wantErr string + }{ + { + name: "valid host with certificate and upstream", + host: host.Model{ + CertificateID: types.NullableDBUint{Uint: 1}, + UpstreamID: types.NullableDBUint{Uint: 1}, + NginxTemplateID: 1, + Type: "some-type", + }, + wantErr: "", + }, + { + name: "certificate does not exist", + host: host.Model{ + CertificateID: types.NullableDBUint{Uint: 9}, + }, + wantErr: "Certificate #9 does not exist: record not found", + }, + { + name: "upstream does not exist", + host: host.Model{ + UpstreamID: types.NullableDBUint{Uint: 9}, + }, + wantErr: "Upstream #9 does not exist: record not found", + }, + { + name: "proxy host and port set with upstream", + host: host.Model{ + UpstreamID: types.NullableDBUint{Uint: 1}, + ProxyHost: "proxy", + ProxyPort: 8080, + }, + wantErr: "Proxy Host or Port cannot be set when using an Upstream", + }, + { + name: "proxy host and port not set without upstream", + host: host.Model{ + ProxyHost: "", + ProxyPort: 0, + }, + wantErr: "Proxy Host and Port must be specified, unless using an Upstream", + }, + { + name: "nginx template does not exist", + host: host.Model{ + ProxyHost: "proxy", + ProxyPort: 8080, + NginxTemplateID: 9, + }, + wantErr: "Host Template #9 does not exist: record not found", + }, + { + name: "nginx template type mismatch", + host: host.Model{ + ProxyHost: "proxy", + ProxyPort: 8080, + NginxTemplateID: 8, + Type: "some-type", + }, + wantErr: "Host Template #8 is not valid for this host type", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockCert := new(MockCertificate) + mockUpstream := new(MockUpstream) + mockNginxTemplate := new(MockNginxTemplate) + + certificateGetByID = mockCert.GetByID + upstreamGetByID = mockUpstream.GetByID + nginxtemplateGetByID = mockNginxTemplate.GetByID + + // id 1 is valid + mockCert.On("GetByID", uint(1)).Return(certificate.Model{}, nil) + mockUpstream.On("GetByID", uint(1)).Return(upstream.Model{}, nil) + mockNginxTemplate.On("GetByID", uint(1)).Return(nginxtemplate.Model{Type: "some-type"}, nil) + + // id 9 is errors + mockCert.On("GetByID", uint(9)).Return(certificate.Model{}, gorm.ErrRecordNotFound) + mockUpstream.On("GetByID", uint(9)).Return(upstream.Model{}, gorm.ErrRecordNotFound) + mockNginxTemplate.On("GetByID", uint(9)).Return(nginxtemplate.Model{}, gorm.ErrRecordNotFound) + + // 8 is special + mockNginxTemplate.On("GetByID", uint(8)).Return(nginxtemplate.Model{Type: "different-type"}, nil) + + err := ValidateHost(tt.host) + if tt.wantErr != "" { + require.NotNil(t, err) + require.Equal(t, tt.wantErr, err.Error()) + } else { + require.Nil(t, err) + } + }) + } +} diff --git a/backend/internal/validator/upstreams.go b/backend/internal/validator/upstreams.go index 6697555f..5d099f41 100644 --- a/backend/internal/validator/upstreams.go +++ b/backend/internal/validator/upstreams.go @@ -1,7 +1,6 @@ package validator import ( - "npm/internal/entity/nginxtemplate" "npm/internal/entity/upstream" "github.com/rotisserie/eris" @@ -26,7 +25,7 @@ func ValidateUpstream(u upstream.Model) error { } // Check the nginx template exists and has the same type. - nginxTemplate, err := nginxtemplate.GetByID(u.NginxTemplateID) + nginxTemplate, err := nginxtemplateGetByID(u.NginxTemplateID) if err != nil { return eris.Errorf("Nginx Template #%d does not exist", u.NginxTemplateID) } diff --git a/backend/internal/validator/upstreams_test.go b/backend/internal/validator/upstreams_test.go new file mode 100644 index 00000000..051828ed --- /dev/null +++ b/backend/internal/validator/upstreams_test.go @@ -0,0 +1,93 @@ +package validator + +import ( + "npm/internal/entity/nginxtemplate" + "npm/internal/entity/upstream" + "npm/internal/entity/upstreamserver" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" +) + +func TestValidateUpstream(t *testing.T) { + tests := []struct { + name string + upstreamModel upstream.Model + expectedError string + }{ + { + name: "less than 2 servers", + upstreamModel: upstream.Model{ + Servers: []upstreamserver.Model{ + {Server: "192.168.1.1"}, + }, + }, + expectedError: "Upstreams require at least 2 servers", + }, + { + name: "backup server with IP hash", + upstreamModel: upstream.Model{ + Servers: []upstreamserver.Model{ + {Server: "192.168.1.1", Backup: true}, + {Server: "192.168.1.2"}, + }, + IPHash: true, + }, + expectedError: "Backup servers cannot be used with hash balancing", + }, + { + name: "nginx template does not exist", + upstreamModel: upstream.Model{ + Servers: []upstreamserver.Model{ + {Server: "192.168.1.1"}, + {Server: "192.168.1.2"}, + }, + NginxTemplateID: 999, + }, + expectedError: "Nginx Template #999 does not exist", + }, + { + name: "nginx template type mismatch", + upstreamModel: upstream.Model{ + Servers: []upstreamserver.Model{ + {Server: "192.168.1.1"}, + {Server: "192.168.1.2"}, + }, + NginxTemplateID: 2, + }, + expectedError: "Host Template #2 is not valid for this upstream", + }, + { + name: "valid upstream", + upstreamModel: upstream.Model{ + Servers: []upstreamserver.Model{ + {Server: "192.168.1.1"}, + {Server: "192.168.1.2"}, + }, + NginxTemplateID: 1, + }, + expectedError: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockNginxTemplate := new(MockNginxTemplate) + nginxtemplateGetByID = mockNginxTemplate.GetByID + + mockNginxTemplate.On("GetByID", uint(1)).Return(nginxtemplate.Model{Type: "upstream"}, nil) + mockNginxTemplate.On("GetByID", uint(2)).Return(nginxtemplate.Model{Type: "redirect"}, nil) + mockNginxTemplate.On("GetByID", uint(999)).Return(nginxtemplate.Model{}, gorm.ErrRecordNotFound) + + err := ValidateUpstream(tt.upstreamModel) + if tt.expectedError != "" { + require.NotNil(t, err) + assert.Equal(t, tt.expectedError, err.Error()) + } else { + assert.NoError(t, err) + } + }) + } +}