From b4be136426a3445966ef25bf549c0ba6f3185d5a Mon Sep 17 00:00:00 2001 From: EwenQuim Date: Sun, 22 Dec 2024 09:09:17 +0100 Subject: [PATCH] Make ValidateParams take an interface instead of netHttpContext only --- ctx.go | 11 ++++++----- internal/common_context.go | 11 +++++++++++ serve.go | 3 ++- validate_params.go | 14 +++++++++++--- 4 files changed, 30 insertions(+), 9 deletions(-) diff --git a/ctx.go b/ctx.go index 2ca4017b..6a8aefe3 100644 --- a/ctx.go +++ b/ctx.go @@ -120,11 +120,6 @@ func NewNetHTTPContext[B any](w http.ResponseWriter, r *http.Request, options re return c } -var ( - _ ContextWithBody[any] = &netHttpContext[any]{} // Check that ContextWithBody implements Ctx. - _ ContextWithBody[string] = &netHttpContext[string]{} // Check that ContextWithBody implements Ctx. -) - // ContextWithBody is the same as fuego.ContextNoBody, but // has a Body. The Body type parameter represents the expected data type // from http.Request.Body. Please do not use a pointer as a type parameter. @@ -142,6 +137,12 @@ type netHttpContext[Body any] struct { readOptions readOptions } +var ( + _ ContextWithBody[any] = &netHttpContext[any]{} // Check that ContextWithBody implements Ctx. + _ ContextWithBody[string] = &netHttpContext[string]{} // Check that ContextWithBody implements Ctx. + _ ValidableCtx = &netHttpContext[any]{} // Check that ContextWithBody implements ValidableCtx. +) + // SetStatus sets the status code of the response. // Alias to http.ResponseWriter.WriteHeader. func (c netHttpContext[B]) SetStatus(code int) { diff --git a/internal/common_context.go b/internal/common_context.go index bba70180..78315e0e 100644 --- a/internal/common_context.go +++ b/internal/common_context.go @@ -42,6 +42,11 @@ type CommonContext[B any] struct { type ParamType string // Query, Header, Cookie +// GetOpenAPIParams returns the OpenAPI parameters declared in the OpenAPI spec. +func (c CommonContext[B]) GetOpenAPIParams() map[string]OpenAPIParam { + return c.OpenAPIParams +} + func (c CommonContext[B]) Context() context.Context { return c.CommonCtx } @@ -71,6 +76,12 @@ func (c CommonContext[B]) QueryParams() url.Values { return c.UrlValues } +// HasQueryParam returns true if the query parameter with the given name exists. +func (c CommonContext[B]) HasQueryParam(name string) bool { + _, ok := c.UrlValues[name] + return ok +} + // QueryParam returns the query parameter with the given name. // If it does not exist, it returns an empty string, unless there is a default value declared in the OpenAPI spec. // diff --git a/serve.go b/serve.go index b963e6d7..516b92a4 100644 --- a/serve.go +++ b/serve.go @@ -94,7 +94,8 @@ func HTTPHandler[ReturnType, Body any](s *Server, controller func(c ContextWithB templates: templates, } - err := validateParams(*ctx) + // PARAMS VALIDATION + err := ValidateParams(ctx) if err != nil { err = s.ErrorHandler(err) s.SerializeError(w, r, err) diff --git a/validate_params.go b/validate_params.go index b545398d..7679a083 100644 --- a/validate_params.go +++ b/validate_params.go @@ -2,8 +2,16 @@ package fuego import "fmt" -func validateParams[B any](c netHttpContext[B]) error { - for k, param := range c.OpenAPIParams { +type ValidableCtx interface { + GetOpenAPIParams() map[string]OpenAPIParam + HasQueryParam(key string) bool + HasHeader(key string) bool + HasCookie(key string) bool +} + +// ValidateParams checks if all required parameters are present in the request. +func ValidateParams(c ValidableCtx) error { + for k, param := range c.GetOpenAPIParams() { if param.Default != nil { // skip: param has a default continue @@ -12,7 +20,7 @@ func validateParams[B any](c netHttpContext[B]) error { if param.Required { switch param.Type { case QueryParamType: - if !c.UrlValues.Has(k) { + if !c.HasQueryParam(k) { err := fmt.Errorf("%s is a required query param", k) return BadRequestError{ Title: "Query Param Not Found",