From 183efed212c5d067b751834b762ecec326bb12a1 Mon Sep 17 00:00:00 2001 From: Dmitrii Prisacari Date: Wed, 13 Mar 2024 12:37:10 +0000 Subject: [PATCH 1/6] feat: create server instance using constructor, use interface-based logger --- examples_test.go | 12 +-- filter_test.go | 8 +- handlers.go | 112 ++++++++++++------------- handlers_test.go | 10 +-- internal/idp_test/azuread_util_test.go | 12 +-- internal/idp_test/okta_util_test.go | 8 +- logger.go | 11 ++- server.go | 39 ++++++--- server_options.go | 26 ++++++ server_test.go | 11 +-- 10 files changed, 144 insertions(+), 105 deletions(-) create mode 100644 server_options.go diff --git a/examples_test.go b/examples_test.go index 85dd419..36eb0d6 100644 --- a/examples_test.go +++ b/examples_test.go @@ -7,16 +7,16 @@ import ( func ExampleNewServer() { server := Server{ - Config: ServiceProviderConfig{}, - ResourceTypes: nil, + config: ServiceProviderConfig{}, + resourceTypes: nil, } logger.Fatal(http.ListenAndServe(":7643", server)) } func ExampleNewServer_basePath() { server := Server{ - Config: ServiceProviderConfig{}, - ResourceTypes: nil, + config: ServiceProviderConfig{}, + resourceTypes: nil, } // You can host the SCIM server on a custom path, make sure to strip the prefix, so only `/v2/` is left. http.Handle("/scim/", http.StripPrefix("/scim", server)) @@ -34,8 +34,8 @@ func ExampleNewServer_logger() { return http.HandlerFunc(fn) } server := Server{ - Config: ServiceProviderConfig{}, - ResourceTypes: nil, + config: ServiceProviderConfig{}, + resourceTypes: nil, } logger.Fatal(http.ListenAndServe(":7643", loggingMiddleware(server))) } diff --git a/filter_test.go b/filter_test.go index 6e758e3..9e4346d 100644 --- a/filter_test.go +++ b/filter_test.go @@ -130,8 +130,8 @@ func Test_User_Filter(t *testing.T) { } func newTestServerForFilter() scim.Server { - return scim.Server{ - ResourceTypes: []scim.ResourceType{ + return scim.NewServer( + scim.WithResourceTypes([]scim.ResourceType{ { ID: optional.NewString("User"), Name: "User", @@ -160,6 +160,6 @@ func newTestServerForFilter() scim.Server { schema: schema.CoreGroupSchema(), }, }, - }, - } + }), + ) } diff --git a/handlers.go b/handlers.go index 530ad60..9009752 100644 --- a/handlers.go +++ b/handlers.go @@ -8,10 +8,10 @@ import ( "github.com/elimity-com/scim/schema" ) -func errorHandler(w http.ResponseWriter, scimErr *errors.ScimError) { +func (s Server) errorHandler(w http.ResponseWriter, scimErr *errors.ScimError) { raw, err := json.Marshal(scimErr) if err != nil { - log.Error( + s.log.Error( "failed marshaling scim error", "scimError", scimErr, "error", err, @@ -22,7 +22,7 @@ func errorHandler(w http.ResponseWriter, scimErr *errors.ScimError) { w.WriteHeader(scimErr.Status) _, err = w.Write(raw) if err != nil { - log.Error( + s.log.Error( "failed writing response", "error", err, ) @@ -35,7 +35,7 @@ func (s Server) resourceDeleteHandler(w http.ResponseWriter, r *http.Request, id deleteErr := resourceType.Handler.Delete(r, id) if deleteErr != nil { scimErr := errors.CheckScimError(deleteErr, http.MethodDelete) - errorHandler(w, &scimErr) + s.errorHandler(w, &scimErr) return } @@ -48,14 +48,14 @@ func (s Server) resourceGetHandler(w http.ResponseWriter, r *http.Request, id st resource, getErr := resourceType.Handler.Get(r, id) if getErr != nil { scimErr := errors.CheckScimError(getErr, http.MethodGet) - errorHandler(w, &scimErr) + s.errorHandler(w, &scimErr) return } raw, err := json.Marshal(resource.response(resourceType)) if err != nil { - errorHandler(w, &errors.ScimErrorInternal) - log.Error( + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( "failed marshaling resource", "resource", resource, "error", err, @@ -69,7 +69,7 @@ func (s Server) resourceGetHandler(w http.ResponseWriter, r *http.Request, id st _, err = w.Write(raw) if err != nil { - log.Error( + s.log.Error( "failed writing response", "error", err, ) @@ -81,14 +81,14 @@ func (s Server) resourceGetHandler(w http.ResponseWriter, r *http.Request, id st func (s Server) resourcePatchHandler(w http.ResponseWriter, r *http.Request, id string, resourceType ResourceType) { patch, scimErr := resourceType.validatePatch(r) if scimErr != nil { - errorHandler(w, scimErr) + s.errorHandler(w, scimErr) return } resource, patchErr := resourceType.Handler.Patch(r, id, patch) if patchErr != nil { scimErr := errors.CheckScimError(patchErr, http.MethodPatch) - errorHandler(w, &scimErr) + s.errorHandler(w, &scimErr) return } @@ -99,8 +99,8 @@ func (s Server) resourcePatchHandler(w http.ResponseWriter, r *http.Request, id raw, err := json.Marshal(resource.response(resourceType)) if err != nil { - errorHandler(w, &errors.ScimErrorInternal) - log.Error( + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( "failed marshaling resource", "resource", resource, "error", err, @@ -116,7 +116,7 @@ func (s Server) resourcePatchHandler(w http.ResponseWriter, r *http.Request, id _, err = w.Write(raw) if err != nil { - log.Error( + s.log.Error( "failed writing response", "error", err, ) @@ -130,21 +130,21 @@ func (s Server) resourcePostHandler(w http.ResponseWriter, r *http.Request, reso attributes, scimErr := resourceType.validate(data) if scimErr != nil { - errorHandler(w, scimErr) + s.errorHandler(w, scimErr) return } resource, postErr := resourceType.Handler.Create(r, attributes) if postErr != nil { scimErr := errors.CheckScimError(postErr, http.MethodPost) - errorHandler(w, &scimErr) + s.errorHandler(w, &scimErr) return } raw, err := json.Marshal(resource.response(resourceType)) if err != nil { - errorHandler(w, &errors.ScimErrorInternal) - log.Error( + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( "failed marshaling resource", "resource", resource, "error", err, @@ -160,7 +160,7 @@ func (s Server) resourcePostHandler(w http.ResponseWriter, r *http.Request, reso _, err = w.Write(raw) if err != nil { - log.Error( + s.log.Error( "failed writing response", "error", err, ) @@ -174,21 +174,21 @@ func (s Server) resourcePutHandler(w http.ResponseWriter, r *http.Request, id st attributes, scimErr := resourceType.validate(data) if scimErr != nil { - errorHandler(w, scimErr) + s.errorHandler(w, scimErr) return } resource, putError := resourceType.Handler.Replace(r, id, attributes) if putError != nil { scimErr := errors.CheckScimError(putError, http.MethodPut) - errorHandler(w, &scimErr) + s.errorHandler(w, &scimErr) return } raw, err := json.Marshal(resource.response(resourceType)) if err != nil { - errorHandler(w, &errors.ScimErrorInternal) - log.Error( + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( "failed marshaling resource", "resource", resource, "error", err, @@ -202,7 +202,7 @@ func (s Server) resourcePutHandler(w http.ResponseWriter, r *http.Request, id st _, err = w.Write(raw) if err != nil { - log.Error( + s.log.Error( "failed writing response", "error", err, ) @@ -210,10 +210,10 @@ func (s Server) resourcePutHandler(w http.ResponseWriter, r *http.Request, id st } // resourceTypeHandler receives an HTTP GET to retrieve individual resource types which can be returned by appending the -// resource types name to the /ResourceTypes endpoint. For example: "/ResourceTypes/User". +// resource types name to the /resourceTypes endpoint. For example: "/resourceTypes/User". func (s Server) resourceTypeHandler(w http.ResponseWriter, r *http.Request, name string) { var resourceType ResourceType - for _, r := range s.ResourceTypes { + for _, r := range s.resourceTypes { if r.Name == name { resourceType = r break @@ -222,14 +222,14 @@ func (s Server) resourceTypeHandler(w http.ResponseWriter, r *http.Request, name if resourceType.Name != name { scimErr := errors.ScimErrorResourceNotFound(name) - errorHandler(w, &scimErr) + s.errorHandler(w, &scimErr) return } raw, err := json.Marshal(resourceType.getRaw()) if err != nil { - errorHandler(w, &errors.ScimErrorInternal) - log.Error( + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( "failed marshaling resource type", "resourceType", resourceType, "error", err, @@ -238,39 +238,39 @@ func (s Server) resourceTypeHandler(w http.ResponseWriter, r *http.Request, name _, err = w.Write(raw) if err != nil { - log.Error( + s.log.Error( "failed writing response", "error", err, ) } } -// resourceTypesHandler receives an HTTP GET to this endpoint, "/ResourceTypes", which is used to discover the types of +// resourceTypesHandler receives an HTTP GET to this endpoint, "/resourceTypes", which is used to discover the types of // resources available on a SCIM service provider (e.g., Users and Groups). Each resource type defines the endpoints, // the core schema URI that defines the resource, and any supported schema extensions. func (s Server) resourceTypesHandler(w http.ResponseWriter, r *http.Request) { params, paramsErr := s.parseRequestParams(r, schema.ResourceTypeSchema()) if paramsErr != nil { - errorHandler(w, paramsErr) + s.errorHandler(w, paramsErr) return } - start, end := clamp(params.StartIndex-1, params.Count, len(s.ResourceTypes)) + start, end := clamp(params.StartIndex-1, params.Count, len(s.resourceTypes)) var resources []interface{} - for _, v := range s.ResourceTypes[start:end] { + for _, v := range s.resourceTypes[start:end] { resources = append(resources, v.getRaw()) } lr := listResponse{ - TotalResults: len(s.ResourceTypes), + TotalResults: len(s.resourceTypes), ItemsPerPage: params.Count, StartIndex: params.StartIndex, Resources: resources, } raw, err := json.Marshal(lr) if err != nil { - errorHandler(w, &errors.ScimErrorInternal) - log.Error( + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( "failed marshaling list response", "listResponse", lr, "error", err, @@ -280,7 +280,7 @@ func (s Server) resourceTypesHandler(w http.ResponseWriter, r *http.Request) { _, err = w.Write(raw) if err != nil { - log.Error( + s.log.Error( "failed writing response", "error", err, ) @@ -292,14 +292,14 @@ func (s Server) resourceTypesHandler(w http.ResponseWriter, r *http.Request) { func (s Server) resourcesGetHandler(w http.ResponseWriter, r *http.Request, resourceType ResourceType) { params, paramsErr := s.parseRequestParams(r, resourceType.Schema, resourceType.getSchemaExtensions()...) if paramsErr != nil { - errorHandler(w, paramsErr) + s.errorHandler(w, paramsErr) return } page, getError := resourceType.Handler.GetAll(r, params) if getError != nil { scimErr := errors.CheckScimError(getError, http.MethodGet) - errorHandler(w, &scimErr) + s.errorHandler(w, &scimErr) return } @@ -311,8 +311,8 @@ func (s Server) resourcesGetHandler(w http.ResponseWriter, r *http.Request, reso } raw, err := json.Marshal(lr) if err != nil { - errorHandler(w, &errors.ScimErrorInternal) - log.Error( + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( "failed marshaling list response", "listResponse", lr, "error", err, @@ -322,7 +322,7 @@ func (s Server) resourcesGetHandler(w http.ResponseWriter, r *http.Request, reso _, err = w.Write(raw) if err != nil { - log.Error( + s.log.Error( "failed writing response", "error", err, ) @@ -335,14 +335,14 @@ func (s Server) schemaHandler(w http.ResponseWriter, r *http.Request, id string) getSchema := s.getSchema(id) if getSchema.ID != id { scimErr := errors.ScimErrorResourceNotFound(id) - errorHandler(w, &scimErr) + s.errorHandler(w, &scimErr) return } raw, err := json.Marshal(getSchema) if err != nil { - errorHandler(w, &errors.ScimErrorInternal) - log.Error( + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( "failed marshaling schema", "schema", getSchema, "error", err, @@ -351,7 +351,7 @@ func (s Server) schemaHandler(w http.ResponseWriter, r *http.Request, id string) _, err = w.Write(raw) if err != nil { - log.Error( + s.log.Error( "failed writing response", "error", err, ) @@ -363,7 +363,7 @@ func (s Server) schemaHandler(w http.ResponseWriter, r *http.Request, id string) func (s Server) schemasHandler(w http.ResponseWriter, r *http.Request) { params, paramsErr := s.parseRequestParams(r, schema.Definition()) if paramsErr != nil { - errorHandler(w, paramsErr) + s.errorHandler(w, paramsErr) return } @@ -373,7 +373,7 @@ func (s Server) schemasHandler(w http.ResponseWriter, r *http.Request) { ) if validator := params.FilterValidator; validator != nil { if err := validator.Validate(); err != nil { - errorHandler(w, &errors.ScimErrorInvalidFilter) + s.errorHandler(w, &errors.ScimErrorInvalidFilter) return } } @@ -395,8 +395,8 @@ func (s Server) schemasHandler(w http.ResponseWriter, r *http.Request) { } raw, err := json.Marshal(lr) if err != nil { - errorHandler(w, &errors.ScimErrorInternal) - log.Error( + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( "failed marshaling list response", "listResponse", lr, "error", err, @@ -406,7 +406,7 @@ func (s Server) schemasHandler(w http.ResponseWriter, r *http.Request) { _, err = w.Write(raw) if err != nil { - log.Error( + s.log.Error( "failed writing response", "error", err, ) @@ -416,12 +416,12 @@ func (s Server) schemasHandler(w http.ResponseWriter, r *http.Request) { // serviceProviderConfigHandler receives an HTTP GET to this endpoint will return a JSON structure that describes the // SCIM specification features available on a service provider. func (s Server) serviceProviderConfigHandler(w http.ResponseWriter, r *http.Request) { - raw, err := json.Marshal(s.Config.getRaw()) + raw, err := json.Marshal(s.config.getRaw()) if err != nil { - errorHandler(w, &errors.ScimErrorInternal) - log.Error( + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( "failed marshaling service provider config", - "serviceProviderConfig", s.Config, + "serviceProviderConfig", s.config, "error", err, ) return @@ -429,7 +429,7 @@ func (s Server) serviceProviderConfigHandler(w http.ResponseWriter, r *http.Requ _, err = w.Write(raw) if err != nil { - log.Error( + s.log.Error( "failed writing response", "error", err, ) diff --git a/handlers_test.go b/handlers_test.go index 4daca06..b65ed35 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -593,7 +593,7 @@ func TestServerResourceTypeHandlerValid(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s/ResourceTypes/%s", tt.versionPrefix, tt.resourceType), nil) + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s/resourceTypes/%s", tt.versionPrefix, tt.resourceType), nil) rr := httptest.NewRecorder() newTestServer().ServeHTTP(rr, req) @@ -614,10 +614,10 @@ func TestServerResourceTypesHandler(t *testing.T) { }{ { name: "resource types request without version", - target: "/ResourceTypes", + target: "/resourceTypes", }, { name: "resource types request with version", - target: "/v2/ResourceTypes", + target: "/v2/resourceTypes", }, } @@ -934,8 +934,8 @@ func newTestServer() Server { userSchema := getUserSchema() userSchemaExtension := getUserExtensionSchema() return Server{ - Config: ServiceProviderConfig{}, - ResourceTypes: []ResourceType{ + config: ServiceProviderConfig{}, + resourceTypes: []ResourceType{ { ID: optional.NewString("User"), Name: "User", diff --git a/internal/idp_test/azuread_util_test.go b/internal/idp_test/azuread_util_test.go index 5ca108f..fda3af5 100644 --- a/internal/idp_test/azuread_util_test.go +++ b/internal/idp_test/azuread_util_test.go @@ -17,11 +17,11 @@ var azureCreatedTime = time.Date( ) func newAzureADTestServer() scim.Server { - return scim.Server{ - Config: scim.ServiceProviderConfig{ + return scim.NewServer( + scim.WithServiceProviderConfig(scim.ServiceProviderConfig{ MaxResults: 20, - }, - ResourceTypes: []scim.ResourceType{ + }), + scim.WithResourceTypes([]scim.ResourceType{ { ID: optional.NewString("User"), Name: "User", @@ -42,8 +42,8 @@ func newAzureADTestServer() scim.Server { Schema: schema.CoreGroupSchema(), Handler: azureADGroupResourceHandler{}, }, - }, - } + }), + ) } type azureADGroupResourceHandler struct{} diff --git a/internal/idp_test/okta_util_test.go b/internal/idp_test/okta_util_test.go index 5cbf0a1..043d766 100644 --- a/internal/idp_test/okta_util_test.go +++ b/internal/idp_test/okta_util_test.go @@ -10,8 +10,8 @@ import ( ) func newOktaTestServer() scim.Server { - return scim.Server{ - ResourceTypes: []scim.ResourceType{ + return scim.NewServer( + scim.WithResourceTypes([]scim.ResourceType{ { ID: optional.NewString("User"), Name: "User", @@ -29,8 +29,8 @@ func newOktaTestServer() scim.Server { Schema: schema.CoreGroupSchema(), Handler: oktaGroupResourceHandler{}, }, - }, - } + }), + ) } type oktaGroupResourceHandler struct{} diff --git a/logger.go b/logger.go index cd619c0..258ee80 100644 --- a/logger.go +++ b/logger.go @@ -1,10 +1,9 @@ package scim -import "log/slog" +type Logger interface { + Error(args ...interface{}) +} -var log *slog.Logger = slog.Default().WithGroup("scim") +type noopLogger struct{} -// SetLogger sets the logger for the scim package. -func SetLogger(l *slog.Logger) { - log = l -} +func (noopLogger) Error(...interface{}) {} diff --git a/server.go b/server.go index d8eac99..96cc320 100644 --- a/server.go +++ b/server.go @@ -52,11 +52,24 @@ func parseIdentifier(path, endpoint string) (string, error) { return url.PathUnescape(strings.TrimPrefix(path, endpoint+"/")) } -// Server represents a SCIM server which implements the HTTP-based SCIM protocol that makes managing identities in multi- -// domain scenarios easier to support via a standardized service. +// Server represents a SCIM server which implements the HTTP-based SCIM protocol +// that makes managing identities in multi-domain scenarios easier to support via a standardized service. type Server struct { - Config ServiceProviderConfig - ResourceTypes []ResourceType + config ServiceProviderConfig + resourceTypes []ResourceType + log Logger +} + +func NewServer(opts ...ServerOption) Server { + s := &Server{ + log: &noopLogger{}, + } + + for _, opt := range opts { + opt(s) + } + + return *s } // ServeHTTP dispatches the request to the handler whose pattern most closely matches the request URL. @@ -67,7 +80,7 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch { case path == "/Me": - errorHandler(w, &errors.ScimError{ + s.errorHandler(w, &errors.ScimError{ Status: http.StatusNotImplemented, }) return @@ -77,18 +90,18 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { case strings.HasPrefix(path, "/Schemas/") && r.Method == http.MethodGet: s.schemaHandler(w, r, strings.TrimPrefix(path, "/Schemas/")) return - case path == "/ResourceTypes" && r.Method == http.MethodGet: + case path == "/resourceTypes" && r.Method == http.MethodGet: s.resourceTypesHandler(w, r) return - case strings.HasPrefix(path, "/ResourceTypes/") && r.Method == http.MethodGet: - s.resourceTypeHandler(w, r, strings.TrimPrefix(path, "/ResourceTypes/")) + case strings.HasPrefix(path, "/resourceTypes/") && r.Method == http.MethodGet: + s.resourceTypeHandler(w, r, strings.TrimPrefix(path, "/resourceTypes/")) return case path == "/ServiceProviderConfig": s.serviceProviderConfigHandler(w, r) return } - for _, resourceType := range s.ResourceTypes { + for _, resourceType := range s.resourceTypes { if path == resourceType.Endpoint { switch r.Method { case http.MethodPost: @@ -123,7 +136,7 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } - errorHandler(w, &errors.ScimError{ + s.errorHandler(w, &errors.ScimError{ Detail: "Specified endpoint does not exist.", Status: http.StatusNotFound, }) @@ -131,7 +144,7 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // getSchema extracts the schemas from the resources types defined in the server with given id. func (s Server) getSchema(id string) schema.Schema { - for _, resourceType := range s.ResourceTypes { + for _, resourceType := range s.resourceTypes { if resourceType.Schema.ID == id { return resourceType.Schema } @@ -148,7 +161,7 @@ func (s Server) getSchema(id string) schema.Schema { func (s Server) getSchemas() []schema.Schema { ids := make([]string, 0) schemas := make([]schema.Schema, 0) - for _, resourceType := range s.ResourceTypes { + for _, resourceType := range s.resourceTypes { if !contains(ids, resourceType.Schema.ID) { schemas = append(schemas, resourceType.Schema) } @@ -166,7 +179,7 @@ func (s Server) getSchemas() []schema.Schema { func (s Server) parseRequestParams(r *http.Request, refSchema schema.Schema, refExtensions ...schema.Schema) (ListRequestParams, *errors.ScimError) { invalidParams := make([]string, 0) - defaultCount := s.Config.getItemsPerPage() + defaultCount := s.config.getItemsPerPage() count, countErr := getIntQueryParam(r, "count", defaultCount) if countErr != nil { invalidParams = append(invalidParams, "count") diff --git a/server_options.go b/server_options.go new file mode 100644 index 0000000..cc22d87 --- /dev/null +++ b/server_options.go @@ -0,0 +1,26 @@ +package scim + +type ServerOption func(*Server) + +// WithServiceProviderConfig sets the service provider config for the server. +func WithServiceProviderConfig(config ServiceProviderConfig) ServerOption { + return func(s *Server) { + s.config = config + } +} + +// WithResourceTypes sets the resource types for the server. +func WithResourceTypes(resourceTypes []ResourceType) ServerOption { + return func(s *Server) { + s.resourceTypes = resourceTypes + } +} + +// WithLogger sets the logger for the server. +func WithLogger(logger Logger) ServerOption { + return func(s *Server) { + if logger != nil { + s.log = logger + } + } +} diff --git a/server_test.go b/server_test.go index 57d5c56..fd35a4d 100644 --- a/server_test.go +++ b/server_test.go @@ -2,11 +2,12 @@ package scim_test import ( "fmt" - internal "github.com/elimity-com/scim/filter" "io" "net/http" "time" + internal "github.com/elimity-com/scim/filter" + "github.com/elimity-com/scim" "github.com/elimity-com/scim/errors" "github.com/elimity-com/scim/optional" @@ -41,8 +42,8 @@ func externalID(attributes scim.ResourceAttributes) optional.String { // e.g. if a member gets added, does this entity exist? func newTestServer() scim.Server { - return scim.Server{ - ResourceTypes: []scim.ResourceType{ + return scim.NewServer( + scim.WithResourceTypes([]scim.ResourceType{ { ID: optional.NewString("User"), Name: "User", @@ -69,8 +70,8 @@ func newTestServer() scim.Server { schema: schema.CoreGroupSchema(), }, }, - }, - } + }), + ) } // testData represents a resource entity. From f074b5e6315e43cf7686579d5d105b72de9a0010 Mon Sep 17 00:00:00 2001 From: Dmitrii Prisacari Date: Wed, 13 Mar 2024 12:53:39 +0000 Subject: [PATCH 2/6] fix: update server construction code --- README.md | 7 ++++--- examples_test.go | 15 +++------------ handlers_test.go | 9 ++++----- 3 files changed, 11 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 695d4a8..a4dde13 100644 --- a/README.md +++ b/README.md @@ -119,9 +119,10 @@ resourceTypes := []ResourceType{ ### 4. Create Server ```go -server := Server{ - Config: config, - ResourceTypes: resourceTypes, +server := NewServer{ + WithServiceProviderConfig(config), + WithResourceTypes(resourceTypes), + WithLogger(logger), // optional, default is no logging } ``` diff --git a/examples_test.go b/examples_test.go index 36eb0d6..9fcc3a3 100644 --- a/examples_test.go +++ b/examples_test.go @@ -6,18 +6,12 @@ import ( ) func ExampleNewServer() { - server := Server{ - config: ServiceProviderConfig{}, - resourceTypes: nil, - } + server := NewServer() logger.Fatal(http.ListenAndServe(":7643", server)) } func ExampleNewServer_basePath() { - server := Server{ - config: ServiceProviderConfig{}, - resourceTypes: nil, - } + server := NewServer() // You can host the SCIM server on a custom path, make sure to strip the prefix, so only `/v2/` is left. http.Handle("/scim/", http.StripPrefix("/scim", server)) logger.Fatal(http.ListenAndServe(":7643", nil)) @@ -33,9 +27,6 @@ func ExampleNewServer_logger() { return http.HandlerFunc(fn) } - server := Server{ - config: ServiceProviderConfig{}, - resourceTypes: nil, - } + server := NewServer() logger.Fatal(http.ListenAndServe(":7643", loggingMiddleware(server))) } diff --git a/handlers_test.go b/handlers_test.go index b65ed35..41d2930 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -933,9 +933,8 @@ func newTestResourceHandler() ResourceHandler { func newTestServer() Server { userSchema := getUserSchema() userSchemaExtension := getUserExtensionSchema() - return Server{ - config: ServiceProviderConfig{}, - resourceTypes: []ResourceType{ + return NewServer( + WithResourceTypes([]ResourceType{ { ID: optional.NewString("User"), Name: "User", @@ -963,6 +962,6 @@ func newTestServer() Server { Schema: schema.CoreGroupSchema(), Handler: newTestResourceHandler(), }, - }, - } + }), + ) } From 483869ae0a08135bde705f8ce2d6c83d792bc393 Mon Sep 17 00:00:00 2001 From: Dmitrii Prisacari Date: Wed, 13 Mar 2024 12:55:46 +0000 Subject: [PATCH 3/6] fix: correct resource type auto renaming --- handlers.go | 2 +- handlers_test.go | 4 ++-- server.go | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/handlers.go b/handlers.go index 9009752..9c87ee0 100644 --- a/handlers.go +++ b/handlers.go @@ -245,7 +245,7 @@ func (s Server) resourceTypeHandler(w http.ResponseWriter, r *http.Request, name } } -// resourceTypesHandler receives an HTTP GET to this endpoint, "/resourceTypes", which is used to discover the types of +// resourceTypesHandler receives an HTTP GET to this endpoint, "/ResourceTypes", which is used to discover the types of // resources available on a SCIM service provider (e.g., Users and Groups). Each resource type defines the endpoints, // the core schema URI that defines the resource, and any supported schema extensions. func (s Server) resourceTypesHandler(w http.ResponseWriter, r *http.Request) { diff --git a/handlers_test.go b/handlers_test.go index 41d2930..5a84672 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -614,10 +614,10 @@ func TestServerResourceTypesHandler(t *testing.T) { }{ { name: "resource types request without version", - target: "/resourceTypes", + target: "/ResourceTypes", }, { name: "resource types request with version", - target: "/v2/resourceTypes", + target: "/v2/ResourceTypes", }, } diff --git a/server.go b/server.go index 96cc320..ace6690 100644 --- a/server.go +++ b/server.go @@ -90,7 +90,7 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { case strings.HasPrefix(path, "/Schemas/") && r.Method == http.MethodGet: s.schemaHandler(w, r, strings.TrimPrefix(path, "/Schemas/")) return - case path == "/resourceTypes" && r.Method == http.MethodGet: + case path == "/ResourceTypes" && r.Method == http.MethodGet: s.resourceTypesHandler(w, r) return case strings.HasPrefix(path, "/resourceTypes/") && r.Method == http.MethodGet: From 613c3ee8f3ccd76f2e941ba2c4fa907eadb192e7 Mon Sep 17 00:00:00 2001 From: Dmitrii Prisacari Date: Wed, 13 Mar 2024 12:57:49 +0000 Subject: [PATCH 4/6] fix: correct resource type auto renaming --- handlers.go | 2 +- handlers_test.go | 2 +- server.go | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/handlers.go b/handlers.go index 9c87ee0..7f207ee 100644 --- a/handlers.go +++ b/handlers.go @@ -210,7 +210,7 @@ func (s Server) resourcePutHandler(w http.ResponseWriter, r *http.Request, id st } // resourceTypeHandler receives an HTTP GET to retrieve individual resource types which can be returned by appending the -// resource types name to the /resourceTypes endpoint. For example: "/resourceTypes/User". +// resource types name to the /ResourceTypes endpoint. For example: "/ResourceTypes/User". func (s Server) resourceTypeHandler(w http.ResponseWriter, r *http.Request, name string) { var resourceType ResourceType for _, r := range s.resourceTypes { diff --git a/handlers_test.go b/handlers_test.go index 5a84672..4f32f90 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -593,7 +593,7 @@ func TestServerResourceTypeHandlerValid(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s/resourceTypes/%s", tt.versionPrefix, tt.resourceType), nil) + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s/ResourceTypes/%s", tt.versionPrefix, tt.resourceType), nil) rr := httptest.NewRecorder() newTestServer().ServeHTTP(rr, req) diff --git a/server.go b/server.go index ace6690..a70c7e3 100644 --- a/server.go +++ b/server.go @@ -93,8 +93,8 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { case path == "/ResourceTypes" && r.Method == http.MethodGet: s.resourceTypesHandler(w, r) return - case strings.HasPrefix(path, "/resourceTypes/") && r.Method == http.MethodGet: - s.resourceTypeHandler(w, r, strings.TrimPrefix(path, "/resourceTypes/")) + case strings.HasPrefix(path, "/ResourceTypes/") && r.Method == http.MethodGet: + s.resourceTypeHandler(w, r, strings.TrimPrefix(path, "/ResourceTypes/")) return case path == "/ServiceProviderConfig": s.serviceProviderConfigHandler(w, r) From 30bb3090094db60bc8bab60931f614e8f83d8438 Mon Sep 17 00:00:00 2001 From: Dmitrii Prisacari Date: Wed, 13 Mar 2024 15:27:30 +0000 Subject: [PATCH 5/6] docs: correct the example in the docs --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a4dde13..f9f12b6 100644 --- a/README.md +++ b/README.md @@ -119,11 +119,11 @@ resourceTypes := []ResourceType{ ### 4. Create Server ```go -server := NewServer{ +server := NewServer( WithServiceProviderConfig(config), WithResourceTypes(resourceTypes), WithLogger(logger), // optional, default is no logging -} +) ``` ## Logging From 83c4e94430676df771e861b15add33a126b0303d Mon Sep 17 00:00:00 2001 From: Dmitrii Prisacari Date: Thu, 14 Mar 2024 15:46:25 +0000 Subject: [PATCH 6/6] feat: change server constructor to accept mandatory arguments --- README.md | 35 ++----- examples_test.go | 27 +++++- filter_test.go | 67 ++++++++------ handlers_test.go | 121 +++++++++++++------------ internal/idp_test/azuread_util_test.go | 56 +++++++----- internal/idp_test/idp_test.go | 20 ++-- internal/idp_test/okta_util_test.go | 47 ++++++---- internal/idp_test/util_test.go | 7 +- logger.go | 1 + patch_add_test.go | 8 +- server.go | 36 +++++++- server_options.go | 26 ------ server_test.go | 58 +++++++----- 13 files changed, 278 insertions(+), 231 deletions(-) delete mode 100644 server_options.go diff --git a/README.md b/README.md index f9f12b6..2aa6817 100644 --- a/README.md +++ b/README.md @@ -119,35 +119,16 @@ resourceTypes := []ResourceType{ ### 4. Create Server ```go -server := NewServer( - WithServiceProviderConfig(config), - WithResourceTypes(resourceTypes), - WithLogger(logger), // optional, default is no logging -) -``` - -## Logging - -No incoming or outgoing (incl. errors) requests are logged by default. It is up to the user to implement this. This can -either be done through middleware around the server or by implementing the `ResourceHandler` interface. - -### Internal - -The SCIM server uses the standard `slog` package for logging. - -There are two moments where the server logs: - -1. When it was not able to marshal the response, it will log the error. This should not happen, since these are - predefined structures, of which most have custom `MarshalJSON` methods. In these cases an `errors.ScimErrorInternal` - error is returned. -2. When the server was not able to `Write` the response. +serverArgs := &ServerArgs{ + ServiceProviderConfig: config, + ResourceTypes: resourceTypes, +} -This logger can be customized by overwriting the default `slog.Logger`. +serverOpts := []ServerOption{ + WithLogger(logger), // optional, default is no logging +} -```go -var scimLogger slog.Logger -// initialize w/ own implementation -scim.SetLogger(scimLogger) +server := NewServer(serverArgs, serverOpts...) ``` ## String Values for Attributes diff --git a/examples_test.go b/examples_test.go index 9fcc3a3..360764d 100644 --- a/examples_test.go +++ b/examples_test.go @@ -6,12 +6,26 @@ import ( ) func ExampleNewServer() { - server := NewServer() + args := &ServerArgs{ + ServiceProviderConfig: &ServiceProviderConfig{}, + ResourceTypes: []ResourceType{}, + } + server, err := NewServer(args) + if err != nil { + logger.Fatal(err) + } logger.Fatal(http.ListenAndServe(":7643", server)) } func ExampleNewServer_basePath() { - server := NewServer() + args := &ServerArgs{ + ServiceProviderConfig: &ServiceProviderConfig{}, + ResourceTypes: []ResourceType{}, + } + server, err := NewServer(args) + if err != nil { + logger.Fatal(err) + } // You can host the SCIM server on a custom path, make sure to strip the prefix, so only `/v2/` is left. http.Handle("/scim/", http.StripPrefix("/scim", server)) logger.Fatal(http.ListenAndServe(":7643", nil)) @@ -27,6 +41,13 @@ func ExampleNewServer_logger() { return http.HandlerFunc(fn) } - server := NewServer() + args := &ServerArgs{ + ServiceProviderConfig: &ServiceProviderConfig{}, + ResourceTypes: []ResourceType{}, + } + server, err := NewServer(args) + if err != nil { + logger.Fatal(err) + } logger.Fatal(http.ListenAndServe(":7643", loggingMiddleware(server))) } diff --git a/filter_test.go b/filter_test.go index 9e4346d..3377adc 100644 --- a/filter_test.go +++ b/filter_test.go @@ -14,7 +14,7 @@ import ( ) func Test_Group_Filter(t *testing.T) { - s := newTestServerForFilter() + s := newTestServerForFilter(t) tests := []struct { name string @@ -72,7 +72,7 @@ func Test_Group_Filter(t *testing.T) { } func Test_User_Filter(t *testing.T) { - s := newTestServerForFilter() + s := newTestServerForFilter(t) tests := []struct { name string @@ -129,37 +129,46 @@ func Test_User_Filter(t *testing.T) { } } -func newTestServerForFilter() scim.Server { - return scim.NewServer( - scim.WithResourceTypes([]scim.ResourceType{ - { - ID: optional.NewString("User"), - Name: "User", - Endpoint: "/Users", - Description: optional.NewString("User Account"), - Schema: schema.CoreUserSchema(), - Handler: &testResourceHandler{ - data: map[string]testData{ - "0001": {attributes: map[string]interface{}{"userName": "testUser"}}, - "0002": {attributes: map[string]interface{}{"userName": "testUser+test"}}, +// newTestServerForFilter creates a new test server with a User and Group resource type +// or fails the test if an error occurs. +func newTestServerForFilter(t *testing.T) scim.Server { + s, err := scim.NewServer( + &scim.ServerArgs{ + ServiceProviderConfig: &scim.ServiceProviderConfig{}, + ResourceTypes: []scim.ResourceType{ + { + ID: optional.NewString("User"), + Name: "User", + Endpoint: "/Users", + Description: optional.NewString("User Account"), + Schema: schema.CoreUserSchema(), + Handler: &testResourceHandler{ + data: map[string]testData{ + "0001": {attributes: map[string]interface{}{"userName": "testUser"}}, + "0002": {attributes: map[string]interface{}{"userName": "testUser+test"}}, + }, + schema: schema.CoreUserSchema(), }, - schema: schema.CoreUserSchema(), }, - }, - { - ID: optional.NewString("Group"), - Name: "Group", - Endpoint: "/Groups", - Description: optional.NewString("Group"), - Schema: schema.CoreGroupSchema(), - Handler: &testResourceHandler{ - data: map[string]testData{ - "0001": {attributes: map[string]interface{}{"displayName": "testGroup"}}, - "0002": {attributes: map[string]interface{}{"displayName": "testGroup+test"}}, + { + ID: optional.NewString("Group"), + Name: "Group", + Endpoint: "/Groups", + Description: optional.NewString("Group"), + Schema: schema.CoreGroupSchema(), + Handler: &testResourceHandler{ + data: map[string]testData{ + "0001": {attributes: map[string]interface{}{"displayName": "testGroup"}}, + "0002": {attributes: map[string]interface{}{"displayName": "testGroup+test"}}, + }, + schema: schema.CoreGroupSchema(), }, - schema: schema.CoreGroupSchema(), }, }, - }), + }, ) + if err != nil { + t.Fatal(err) + } + return s } diff --git a/handlers_test.go b/handlers_test.go index 4f32f90..ca0fb33 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -84,7 +84,7 @@ func TestInvalidRequests(t *testing.T) { t.Run(test.name, func(t *testing.T) { req := httptest.NewRequest(test.method, test.target, nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, test.expectedStatus, rr.Code) }) @@ -94,7 +94,7 @@ func TestInvalidRequests(t *testing.T) { func TestServerMeEndpoint(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/Me", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusNotImplemented, rr.Code) } @@ -102,7 +102,7 @@ func TestServerMeEndpoint(t *testing.T) { func TestServerResourceDeleteHandler(t *testing.T) { req := httptest.NewRequest(http.MethodDelete, "/Users/0001", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusNoContent, rr.Code) } @@ -110,7 +110,7 @@ func TestServerResourceDeleteHandler(t *testing.T) { func TestServerResourceDeleteHandlerNotFound(t *testing.T) { req := httptest.NewRequest(http.MethodDelete, "/Users/9999", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusNotFound, rr.Code) @@ -156,7 +156,7 @@ func TestServerResourceGetHandler(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, tt.target, nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -185,7 +185,7 @@ func TestServerResourceGetHandler(t *testing.T) { func TestServerResourceGetHandlerNotFound(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/Users/9999", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusNotFound, rr.Code) @@ -210,7 +210,7 @@ func TestServerResourcePatchHandlerFailOnBadType(t *testing.T) { ] }`)) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) var resource map[string]interface{} assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &resource)) @@ -232,7 +232,7 @@ func TestServerResourcePatchHandlerInvalidPath(t *testing.T) { ] }`)) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusBadRequest, rr.Code) @@ -252,14 +252,14 @@ func TestServerResourcePatchHandlerInvalidRemoveOp(t *testing.T) { ] }`)) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusBadRequest, rr.Code) } func TestServerResourcePatchHandlerMapTypeSubAttribute(t *testing.T) { recorder := httptest.NewRecorder() - newTestServer().ServeHTTP(recorder, httptest.NewRequest(http.MethodPatch, "/Users/0001", strings.NewReader(`{ + newTestServer(t).ServeHTTP(recorder, httptest.NewRequest(http.MethodPatch, "/Users/0001", strings.NewReader(`{ "schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"], "Operations":[ { @@ -272,7 +272,7 @@ func TestServerResourcePatchHandlerMapTypeSubAttribute(t *testing.T) { assertEqualStatusCode(t, http.StatusOK, recorder.Code) recorder2 := httptest.NewRecorder() - newTestServer().ServeHTTP(recorder2, httptest.NewRequest(http.MethodPatch, "/Users/0001", strings.NewReader(`{ + newTestServer(t).ServeHTTP(recorder2, httptest.NewRequest(http.MethodPatch, "/Users/0001", strings.NewReader(`{ "schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"], "Operations":[ { @@ -319,7 +319,7 @@ func TestServerResourcePatchHandlerReturnsNoContent(t *testing.T) { } for _, req := range reqs { rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusNoContent, rr.Code) } @@ -358,7 +358,7 @@ func TestServerResourcePatchHandlerValid(t *testing.T) { ] }`)) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -408,7 +408,7 @@ func TestServerResourcePatchHandlerValidPathHasSubAttributes(t *testing.T) { ] }`)) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) } @@ -424,7 +424,7 @@ func TestServerResourcePatchHandlerValidRemoveOp(t *testing.T) { ] }`)) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusNoContent, rr.Code) } @@ -468,7 +468,7 @@ func TestServerResourcePostHandlerValid(t *testing.T) { t.Run(test.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, test.target, test.body) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusCreated, rr.Code) @@ -498,7 +498,7 @@ func TestServerResourcePostHandlerValid(t *testing.T) { func TestServerResourcePutHandlerNotFound(t *testing.T) { req := httptest.NewRequest(http.MethodPut, "/Users/9999", strings.NewReader(`{"userName": "other"}`)) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusNotFound, rr.Code) @@ -549,7 +549,7 @@ func TestServerResourcePutHandlerValid(t *testing.T) { t.Run(test.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPut, test.target, test.body) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -595,7 +595,7 @@ func TestServerResourceTypeHandlerValid(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s/ResourceTypes/%s", tt.versionPrefix, tt.resourceType), nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -625,7 +625,7 @@ func TestServerResourceTypesHandler(t *testing.T) { t.Run(test.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, test.target, nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -650,7 +650,7 @@ func TestServerResourceTypesHandler(t *testing.T) { func TestServerResourcesGetAllHandlerNegativeCount(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/Users?count=-1", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -663,7 +663,7 @@ func TestServerResourcesGetAllHandlerNegativeCount(t *testing.T) { func TestServerResourcesGetHandler(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/Users", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -676,7 +676,7 @@ func TestServerResourcesGetHandler(t *testing.T) { func TestServerResourcesGetHandlerMaxCount(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/Users?count=20000", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -688,7 +688,7 @@ func TestServerResourcesGetHandlerMaxCount(t *testing.T) { func TestServerResourcesGetHandlerPagination(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/Users?count=2&startIndex=2", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -726,7 +726,7 @@ func TestServerSchemaEndpointValid(t *testing.T) { "%s/Schemas/%s", test.versionPrefix, test.schema, ), nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -755,7 +755,7 @@ func TestServerSchemasEndpoint(t *testing.T) { t.Run(test.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, test.target, nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -791,7 +791,7 @@ func TestServerSchemasEndpointFilter(t *testing.T) { "/Schemas?%s", params.Encode(), ), nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -819,7 +819,7 @@ func TestServerServiceProviderConfigHandler(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, tt.target, nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) }) @@ -930,38 +930,45 @@ func newTestResourceHandler() ResourceHandler { } } -func newTestServer() Server { +func newTestServer(t *testing.T) Server { userSchema := getUserSchema() userSchemaExtension := getUserExtensionSchema() - return NewServer( - WithResourceTypes([]ResourceType{ - { - ID: optional.NewString("User"), - Name: "User", - Endpoint: "/Users", - Description: optional.NewString("User Account"), - Schema: userSchema, - Handler: newTestResourceHandler(), - }, - { - ID: optional.NewString("EnterpriseUser"), - Name: "EnterpriseUser", - Endpoint: "/EnterpriseUsers", - Description: optional.NewString("Enterprise User Account"), - Schema: userSchema, - SchemaExtensions: []SchemaExtension{ - {Schema: userSchemaExtension}, + s, err := NewServer( + &ServerArgs{ + ServiceProviderConfig: &ServiceProviderConfig{}, + ResourceTypes: []ResourceType{ + { + ID: optional.NewString("User"), + Name: "User", + Endpoint: "/Users", + Description: optional.NewString("User Account"), + Schema: userSchema, + Handler: newTestResourceHandler(), + }, + { + ID: optional.NewString("EnterpriseUser"), + Name: "EnterpriseUser", + Endpoint: "/EnterpriseUsers", + Description: optional.NewString("Enterprise User Account"), + Schema: userSchema, + SchemaExtensions: []SchemaExtension{ + {Schema: userSchemaExtension}, + }, + Handler: newTestResourceHandler(), + }, + { + ID: optional.NewString("Group"), + Name: "Group", + Endpoint: "/Groups", + Description: optional.NewString("Group"), + Schema: schema.CoreGroupSchema(), + Handler: newTestResourceHandler(), }, - Handler: newTestResourceHandler(), - }, - { - ID: optional.NewString("Group"), - Name: "Group", - Endpoint: "/Groups", - Description: optional.NewString("Group"), - Schema: schema.CoreGroupSchema(), - Handler: newTestResourceHandler(), }, - }), + }, ) + if err != nil { + t.Fatal(err) + } + return s } diff --git a/internal/idp_test/azuread_util_test.go b/internal/idp_test/azuread_util_test.go index fda3af5..e12863b 100644 --- a/internal/idp_test/azuread_util_test.go +++ b/internal/idp_test/azuread_util_test.go @@ -2,6 +2,7 @@ package idp_test import ( "net/http" + "testing" "time" "github.com/elimity-com/scim" @@ -16,34 +17,39 @@ var azureCreatedTime = time.Date( 19, 59, 26, 0, time.UTC, ) -func newAzureADTestServer() scim.Server { - return scim.NewServer( - scim.WithServiceProviderConfig(scim.ServiceProviderConfig{ - MaxResults: 20, - }), - scim.WithResourceTypes([]scim.ResourceType{ - { - ID: optional.NewString("User"), - Name: "User", - Endpoint: "/Users", - Description: optional.NewString("User Account"), - Schema: schema.CoreUserSchema(), - SchemaExtensions: []scim.SchemaExtension{ - {Schema: schema.ExtensionEnterpriseUser()}, - }, - Handler: azureADUserResourceHandler{}, +func newAzureADTestServer(t *testing.T) scim.Server { + s, err := scim.NewServer( + &scim.ServerArgs{ + ServiceProviderConfig: &scim.ServiceProviderConfig{ + MaxResults: 20, }, + ResourceTypes: []scim.ResourceType{ + { + ID: optional.NewString("User"), + Name: "User", + Endpoint: "/Users", + Description: optional.NewString("User Account"), + Schema: schema.CoreUserSchema(), + SchemaExtensions: []scim.SchemaExtension{ + {Schema: schema.ExtensionEnterpriseUser()}, + }, + Handler: azureADUserResourceHandler{}, + }, - { - ID: optional.NewString("Group"), - Name: "Group", - Endpoint: "/Groups", - Description: optional.NewString("Group"), - Schema: schema.CoreGroupSchema(), - Handler: azureADGroupResourceHandler{}, + { + ID: optional.NewString("Group"), + Name: "Group", + Endpoint: "/Groups", + Description: optional.NewString("Group"), + Schema: schema.CoreGroupSchema(), + Handler: azureADGroupResourceHandler{}, + }, }, - }), - ) + }) + if err != nil { + t.Fatal(err) + } + return s } type azureADGroupResourceHandler struct{} diff --git a/internal/idp_test/idp_test.go b/internal/idp_test/idp_test.go index 881b2ea..f093ec1 100644 --- a/internal/idp_test/idp_test.go +++ b/internal/idp_test/idp_test.go @@ -27,7 +27,7 @@ func TestIdP(t *testing.T) { var test testCase _ = unmarshal(raw, &test) t.Run(strings.TrimSuffix(f.Name(), ".json"), func(t *testing.T) { - if err := testRequest(test, idp.Name()); err != nil { + if err := testRequest(t, test, idp.Name()); err != nil { t.Error(err) } }) @@ -36,23 +36,23 @@ func TestIdP(t *testing.T) { } } -func testRequest(t testCase, idpName string) error { +func testRequest(t *testing.T, tc testCase, idpName string) error { rr := httptest.NewRecorder() - br := bytes.NewReader(t.Request) - getNewServer(idpName).ServeHTTP( + br := bytes.NewReader(tc.Request) + getNewServer(t, idpName).ServeHTTP( rr, - httptest.NewRequest(t.Method, t.Path, br), + httptest.NewRequest(tc.Method, tc.Path, br), ) - if code := rr.Code; code != t.StatusCode { - return fmt.Errorf("expected %d, got %d", t.StatusCode, code) + if code := rr.Code; code != tc.StatusCode { + return fmt.Errorf("expected %d, got %d", tc.StatusCode, code) } - if len(t.Response) != 0 { + if len(tc.Response) != 0 { var response map[string]interface{} if err := unmarshal(rr.Body.Bytes(), &response); err != nil { return err } - if !reflect.DeepEqual(t.Response, response) { - return fmt.Errorf("expected, got:\n%v\n%v", t.Response, response) + if !reflect.DeepEqual(tc.Response, response) { + return fmt.Errorf("expected, got:\n%v\n%v", tc.Response, response) } } return nil diff --git a/internal/idp_test/okta_util_test.go b/internal/idp_test/okta_util_test.go index 043d766..066edc5 100644 --- a/internal/idp_test/okta_util_test.go +++ b/internal/idp_test/okta_util_test.go @@ -2,6 +2,7 @@ package idp_test import ( "net/http" + "testing" "github.com/elimity-com/scim" "github.com/elimity-com/scim/errors" @@ -9,28 +10,36 @@ import ( "github.com/elimity-com/scim/schema" ) -func newOktaTestServer() scim.Server { - return scim.NewServer( - scim.WithResourceTypes([]scim.ResourceType{ - { - ID: optional.NewString("User"), - Name: "User", - Endpoint: "/Users", - Description: optional.NewString("User Account"), - Schema: schema.CoreUserSchema(), - Handler: oktaUserResourceHandler{}, - }, +func newOktaTestServer(t *testing.T) scim.Server { + s, err := scim.NewServer( + &scim.ServerArgs{ + ServiceProviderConfig: &scim.ServiceProviderConfig{}, + ResourceTypes: []scim.ResourceType{ + { + ID: optional.NewString("User"), + Name: "User", + Endpoint: "/Users", + Description: optional.NewString("User Account"), + Schema: schema.CoreUserSchema(), + Handler: oktaUserResourceHandler{}, + }, - { - ID: optional.NewString("Group"), - Name: "Group", - Endpoint: "/Groups", - Description: optional.NewString("Group"), - Schema: schema.CoreGroupSchema(), - Handler: oktaGroupResourceHandler{}, + { + ID: optional.NewString("Group"), + Name: "Group", + Endpoint: "/Groups", + Description: optional.NewString("Group"), + Schema: schema.CoreGroupSchema(), + Handler: oktaGroupResourceHandler{}, + }, }, - }), + }, ) + if err != nil { + t.Fatal(err) + } + + return s } type oktaGroupResourceHandler struct{} diff --git a/internal/idp_test/util_test.go b/internal/idp_test/util_test.go index c1a112b..ae9b1fa 100644 --- a/internal/idp_test/util_test.go +++ b/internal/idp_test/util_test.go @@ -3,16 +3,17 @@ package idp_test import ( "bytes" "encoding/json" + "testing" "github.com/elimity-com/scim" ) -func getNewServer(idpName string) scim.Server { +func getNewServer(t *testing.T, idpName string) scim.Server { switch idpName { case "okta": - return newOktaTestServer() + return newOktaTestServer(t) case "azuread": - return newAzureADTestServer() + return newAzureADTestServer(t) default: panic("unreachable") } diff --git a/logger.go b/logger.go index 258ee80..e8db123 100644 --- a/logger.go +++ b/logger.go @@ -1,5 +1,6 @@ package scim +// Logger defines and interface for logging errors. type Logger interface { Error(args ...interface{}) } diff --git a/patch_add_test.go b/patch_add_test.go index 009ef90..c7cca2e 100644 --- a/patch_add_test.go +++ b/patch_add_test.go @@ -18,7 +18,7 @@ func TestPatch_addAttributes(t *testing.T) { req = httptest.NewRequest(http.MethodPatch, "/Users/0001", bytes.NewReader(raw)) rr = httptest.NewRecorder() ) - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Fatal(rr.Code, rr.Body.String()) } @@ -65,7 +65,7 @@ func TestPatch_addMember(t *testing.T) { req = httptest.NewRequest(http.MethodPatch, "/Groups/0001", bytes.NewReader(raw)) rr = httptest.NewRecorder() ) - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Fatal(rr.Code, rr.Body.String()) } @@ -121,7 +121,7 @@ func TestPatch_alreadyExists(t *testing.T) { changed: false, }, } { - server := newTestServer() + server := newTestServer(t) raw, err := os.ReadFile(test.jsonFilePath) if err != nil { t.Fatal(err) @@ -158,7 +158,7 @@ func TestPatch_complex(t *testing.T) { req = httptest.NewRequest(http.MethodPatch, "/Users/0001", bytes.NewReader(raw)) rr = httptest.NewRecorder() ) - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Fatal(rr.Code, rr.Body.String()) } diff --git a/server.go b/server.go index a70c7e3..e79a612 100644 --- a/server.go +++ b/server.go @@ -60,16 +60,46 @@ type Server struct { log Logger } -func NewServer(opts ...ServerOption) Server { +type ServerArgs struct { + ServiceProviderConfig *ServiceProviderConfig + ResourceTypes []ResourceType +} + +type ServerOption func(*Server) + +// WithLogger sets the logger for the server. +func WithLogger(logger Logger) ServerOption { + return func(s *Server) { + if logger != nil { + s.log = logger + } + } +} + +func NewServer(args *ServerArgs, opts ...ServerOption) (Server, error) { + if args == nil { + return Server{}, fmt.Errorf("arguments not provided") + } + + if args.ServiceProviderConfig == nil { + return Server{}, fmt.Errorf("service provider config not provided") + } + + if args.ResourceTypes == nil { + return Server{}, fmt.Errorf("resource types not provided") + } + s := &Server{ - log: &noopLogger{}, + config: *args.ServiceProviderConfig, + resourceTypes: args.ResourceTypes, + log: &noopLogger{}, } for _, opt := range opts { opt(s) } - return *s + return *s, nil } // ServeHTTP dispatches the request to the handler whose pattern most closely matches the request URL. diff --git a/server_options.go b/server_options.go deleted file mode 100644 index cc22d87..0000000 --- a/server_options.go +++ /dev/null @@ -1,26 +0,0 @@ -package scim - -type ServerOption func(*Server) - -// WithServiceProviderConfig sets the service provider config for the server. -func WithServiceProviderConfig(config ServiceProviderConfig) ServerOption { - return func(s *Server) { - s.config = config - } -} - -// WithResourceTypes sets the resource types for the server. -func WithResourceTypes(resourceTypes []ResourceType) ServerOption { - return func(s *Server) { - s.resourceTypes = resourceTypes - } -} - -// WithLogger sets the logger for the server. -func WithLogger(logger Logger) ServerOption { - return func(s *Server) { - if logger != nil { - s.log = logger - } - } -} diff --git a/server_test.go b/server_test.go index fd35a4d..726fea5 100644 --- a/server_test.go +++ b/server_test.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "net/http" + "testing" "time" internal "github.com/elimity-com/scim/filter" @@ -41,37 +42,44 @@ func externalID(attributes scim.ResourceAttributes) optional.String { // - Whether a reference to another entity really exists. // e.g. if a member gets added, does this entity exist? -func newTestServer() scim.Server { - return scim.NewServer( - scim.WithResourceTypes([]scim.ResourceType{ - { - ID: optional.NewString("User"), - Name: "User", - Endpoint: "/Users", - Description: optional.NewString("User Account"), - Schema: schema.CoreUserSchema(), - Handler: &testResourceHandler{ - data: map[string]testData{ - "0001": {attributes: map[string]interface{}{}}, +func newTestServer(t *testing.T) scim.Server { + s, err := scim.NewServer( + &scim.ServerArgs{ + ServiceProviderConfig: &scim.ServiceProviderConfig{}, + ResourceTypes: []scim.ResourceType{ + { + ID: optional.NewString("User"), + Name: "User", + Endpoint: "/Users", + Description: optional.NewString("User Account"), + Schema: schema.CoreUserSchema(), + Handler: &testResourceHandler{ + data: map[string]testData{ + "0001": {attributes: map[string]interface{}{}}, + }, + schema: schema.CoreUserSchema(), }, - schema: schema.CoreUserSchema(), }, - }, - { - ID: optional.NewString("Group"), - Name: "Group", - Endpoint: "/Groups", - Description: optional.NewString("Group"), - Schema: schema.CoreGroupSchema(), - Handler: &testResourceHandler{ - data: map[string]testData{ - "0001": {attributes: map[string]interface{}{}}, + { + ID: optional.NewString("Group"), + Name: "Group", + Endpoint: "/Groups", + Description: optional.NewString("Group"), + Schema: schema.CoreGroupSchema(), + Handler: &testResourceHandler{ + data: map[string]testData{ + "0001": {attributes: map[string]interface{}{}}, + }, + schema: schema.CoreGroupSchema(), }, - schema: schema.CoreGroupSchema(), }, }, - }), + }, ) + if err != nil { + t.Fatal(err) + } + return s } // testData represents a resource entity.