diff --git a/contrib/internal/httptrace/make_responsewriter.go b/contrib/internal/httptrace/make_responsewriter.go index 962b67e343..2458ef7fbf 100644 --- a/contrib/internal/httptrace/make_responsewriter.go +++ b/contrib/internal/httptrace/make_responsewriter.go @@ -66,6 +66,7 @@ func wrapResponseWriter(w http.ResponseWriter) (http.ResponseWriter, *responseWr http.ResponseWriter Status() int Block() + Blocked() bool Unwrap() http.ResponseWriter } switch { diff --git a/contrib/internal/httptrace/response_writer.go b/contrib/internal/httptrace/response_writer.go index da04f7a6ee..964a9a83a9 100644 --- a/contrib/internal/httptrace/response_writer.go +++ b/contrib/internal/httptrace/response_writer.go @@ -39,6 +39,11 @@ func (w *responseWriter) Status() int { return w.status } +// Blocked returns whether the response has been blocked. +func (w *responseWriter) Blocked() bool { + return w.blocked +} + // Block is supposed only once, after a response (one made by appsec code) as been sent. If it not the case, the function will do nothing. // All subsequent calls to Write and WriteHeader will be trigger a log warning users that the response has been blocked. func (w *responseWriter) Block() { diff --git a/contrib/internal/httptrace/trace_gen.go b/contrib/internal/httptrace/trace_gen.go index 81c7e7df8f..36a7f6f2bf 100644 --- a/contrib/internal/httptrace/trace_gen.go +++ b/contrib/internal/httptrace/trace_gen.go @@ -32,6 +32,7 @@ func wrapResponseWriter(w http.ResponseWriter) (http.ResponseWriter, *responseWr http.ResponseWriter Status() int Block() + Blocked() bool Unwrap() http.ResponseWriter } switch { diff --git a/contrib/labstack/echo.v4/appsec.go b/contrib/labstack/echo.v4/appsec.go index 9cd849cc20..14d7eee319 100644 --- a/contrib/labstack/echo.v4/appsec.go +++ b/contrib/labstack/echo.v4/appsec.go @@ -21,7 +21,10 @@ func withAppSec(next echo.HandlerFunc, span tracer.Span) echo.HandlerFunc { for _, n := range c.ParamNames() { params[n] = c.Param(n) } - var err error + var ( + err error + writer = &statusResponseWriter{Response: c.Response()} + ) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c.SetRequest(r) err = next(c) @@ -32,7 +35,7 @@ func withAppSec(next echo.HandlerFunc, span tracer.Span) echo.HandlerFunc { } }) // Wrap the echo response to allow monitoring of the response status code in httpsec.WrapHandler() - httpsec.WrapHandler(handler, span, params, nil).ServeHTTP(&statusResponseWriter{Response: c.Response()}, c.Request()) + httpsec.WrapHandler(handler, span, params, nil).ServeHTTP(, c.Request()) // If an error occurred, wrap it under an echo.HTTPError. We need to do this so that APM doesn't override // the response code tag with 500 in case it doesn't recognize the error type. if _, ok := err.(*echo.HTTPError); !ok && err != nil { diff --git a/internal/appsec/emitter/waf/actions/block.go b/internal/appsec/emitter/waf/actions/block.go index 405e9d8164..f2c86a2b49 100644 --- a/internal/appsec/emitter/waf/actions/block.go +++ b/internal/appsec/emitter/waf/actions/block.go @@ -141,13 +141,12 @@ func newHTTPBlockRequestAction(statusCode int, template blockingTemplateType) *B template = blockingTemplateTypeFromHeaders(request.Header) } - if UnwrapGetStatusCode(writer) != 0 { + if code, found := UnwrapGetStatusCode(writer); found && code != 0 { // The status code has already been set, so we can't change it, do nothing return } - blocker, found := UnwrapBlocker(writer) - if found { + if blocker, found := UnwrapBlocker(writer); found { // We found our custom response writer, so we can block futur calls to Write and WriteHeader defer blocker() } @@ -203,15 +202,15 @@ func UnwrapBlocker(writer http.ResponseWriter) (func(), bool) { // UnwrapGetStatusCode unwraps the right struct method from contrib/internal/httptrace.responseWriter // and calls it to know if a call to WriteHeader has been made and returns the status code. -func UnwrapGetStatusCode(writer http.ResponseWriter) int { +func UnwrapGetStatusCode(writer http.ResponseWriter) (int, bool) { // this is part of the contrib/internal/httptrace.responseWriter interface wrapped, ok := writer.(interface { Status() int }) if !ok { // Somehow we can't access the wrapped response writer, so we can't get the status code - return 0 + return 0, false } - return wrapped.Status() + return wrapped.Status(), true }