diff --git a/httperror.go b/httperror.go index 6e14da3d9..db408b143 100644 --- a/httperror.go +++ b/httperror.go @@ -9,21 +9,6 @@ import ( "net/http" ) -// HTTPStatusCoder is interface that errors can implement to produce status code for HTTP response -type HTTPStatusCoder interface { - StatusCode() int -} - -// StatusCode returns status code from error if it implements HTTPStatusCoder interface. -// If error does not implement the interface it returns 0. -func StatusCode(err error) int { - var sc HTTPStatusCoder - if errors.As(err, &sc) { - return sc.StatusCode() - } - return 0 -} - // Following errors can produce HTTP status code by implementing HTTPStatusCoder interface var ( ErrBadRequest = &httpError{http.StatusBadRequest} // 400 @@ -50,6 +35,66 @@ var ( ErrInvalidListenerNetwork = errors.New("invalid listener network") ) +// HTTPStatusCoder is interface that errors can implement to produce status code for HTTP response +type HTTPStatusCoder interface { + StatusCode() int +} + +// StatusCode returns status code from error if it implements HTTPStatusCoder interface. +// If error does not implement the interface it returns 0. +func StatusCode(err error) int { + var sc HTTPStatusCoder + if errors.As(err, &sc) { + return sc.StatusCode() + } + return 0 +} + +// ResolveResponseStatus returns the Response and HTTP status code that should be (or has been) sent for rw, +// given an optional error. +// +// This function is useful for middleware and handlers that need to figure out the HTTP status +// code to return based on the error that occurred or what was set in the response. +// +// Precedence rules: +// 1. If the response has already been committed, the committed status wins (err is ignored). +// 2. Otherwise, start with 200 OK (net/http default if WriteHeader is never called). +// 3. If the response has a non-zero suggested status, use it. +// 4. If err != nil, it overrides the suggested status: +// - StatusCode(err) if non-zero +// - otherwise 500 Internal Server Error. +func ResolveResponseStatus(rw http.ResponseWriter, err error) (resp *Response, status int) { + resp, _ = UnwrapResponse(rw) + + // once committed (sent to the client), the wire status is fixed; err cannot change it. + if resp != nil && resp.Committed { + if resp.Status == 0 { + // unlikely path, but fall back to net/http implicit default if handler never calls WriteHeader + return resp, http.StatusOK + } + return resp, resp.Status + } + + // net/http implicit default if handler never calls WriteHeader. + status = http.StatusOK + + // suggested status written from middleware/handlers, if present. + if resp != nil && resp.Status != 0 { + status = resp.Status + } + + // error overrides suggested status (matches typical Echo error-handler semantics). + if err != nil { + if s := StatusCode(err); s != 0 { + status = s + } else { + status = http.StatusInternalServerError + } + } + + return resp, status +} + // NewHTTPError creates new instance of HTTPError func NewHTTPError(code int, message string) *HTTPError { return &HTTPError{ diff --git a/httperror_test.go b/httperror_test.go index 0a91bbc9c..778a186ce 100644 --- a/httperror_test.go +++ b/httperror_test.go @@ -107,3 +107,80 @@ func TestStatusCode(t *testing.T) { }) } } + +func TestResolveResponseStatus(t *testing.T) { + someErr := errors.New("some error") + + var testCases = []struct { + name string + whenResp http.ResponseWriter + whenErr error + expectStatus int + expectResp bool + }{ + { + name: "nil resp, nil err -> 200", + whenResp: nil, + whenErr: nil, + expectStatus: http.StatusOK, + expectResp: false, + }, + { + name: "resp suggested status used when no error", + whenResp: &Response{Status: http.StatusCreated}, + whenErr: nil, + expectStatus: http.StatusCreated, + expectResp: true, + }, + { + name: "error overrides suggested status with StatusCode(err)", + whenResp: &Response{Status: http.StatusAccepted}, + whenErr: ErrBadRequest, + expectStatus: http.StatusBadRequest, + expectResp: true, + }, + { + name: "error overrides suggested status with 500 when StatusCode(err)==0", + whenResp: &Response{Status: http.StatusAccepted}, + whenErr: ErrInternalServerError, + expectStatus: http.StatusInternalServerError, + expectResp: true, + }, + { + name: "nil resp, error -> 500 when StatusCode(err)==0", + whenResp: nil, + whenErr: someErr, + expectStatus: http.StatusInternalServerError, + expectResp: false, + }, + { + name: "committed response wins over error", + whenResp: &Response{Committed: true, Status: http.StatusNoContent}, + whenErr: someErr, + expectStatus: http.StatusNoContent, + expectResp: true, + }, + { + name: "committed response with status 0 falls back to 200 (defensive)", + whenResp: &Response{Committed: true, Status: 0}, + whenErr: someErr, + expectStatus: http.StatusOK, + expectResp: true, + }, + { + name: "resp with status 0 and no error -> 200", + whenResp: &Response{Status: 0}, + whenErr: nil, + expectStatus: http.StatusOK, + expectResp: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + resp, status := ResolveResponseStatus(tc.whenResp, tc.whenErr) + + assert.Equal(t, tc.expectResp, resp != nil) + assert.Equal(t, tc.expectStatus, status) + }) + } +} diff --git a/middleware/request_logger.go b/middleware/request_logger.go index 76903c62a..9e46bf5d6 100644 --- a/middleware/request_logger.go +++ b/middleware/request_logger.go @@ -160,8 +160,8 @@ type RequestLoggerConfig struct { LogReferer bool // LogUserAgent instructs logger to extract request user agent values. LogUserAgent bool - // LogStatus instructs logger to extract response status code. If handler chain returns an echo.HTTPError, - // the status code is extracted from the echo.HTTPError returned + // LogStatus instructs logger to extract response status code. If handler chain returns an error, + // the status code is extracted from the error satisfying echo.StatusCoder interface. LogStatus bool // LogContentLength instructs logger to extract content length header value. Note: this value could be different from // actual request body size as it could be spoofed etc. @@ -211,7 +211,7 @@ type RequestLoggerValues struct { Referer string // UserAgent is request user agent values. UserAgent string - // Status is response status code. Then handler returns an echo.HTTPError then code from there. + // Status is a response status code. When the handler returns an error satisfying echo.StatusCoder interface, then code from it. Status int // Error is error returned from executed handler chain. Error error @@ -272,7 +272,6 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { } req := c.Request() - res := c.Response() start := now() if config.BeforeNextFunc != nil { @@ -284,6 +283,7 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // checked with `c.Response().Committed` field. c.Echo().HTTPErrorHandler(c, err) } + res := c.Response() v := RequestLoggerValues{ StartTime: start, @@ -330,26 +330,16 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { v.UserAgent = req.UserAgent() } - var resp *echo.Response if config.LogStatus || config.LogResponseSize { - if r, err := echo.UnwrapResponse(res); err != nil { - c.Logger().Error("can not determine response status and/or size. ResponseWriter in context does not implement unwrapper interface") - } else { - resp = r - } - } + resp, status := echo.ResolveResponseStatus(res, err) - if config.LogStatus { - v.Status = -1 - if resp != nil { - v.Status = resp.Status + if config.LogStatus { + v.Status = status } - if err != nil && !config.HandleError { - // this block should not be executed in case of HandleError=true as the global error handler will decide - // the status code. In that case status code could be different from what err contains. - var hsc echo.HTTPStatusCoder - if errors.As(err, &hsc) { - v.Status = hsc.StatusCode() + if config.LogResponseSize { + v.ResponseSize = -1 + if resp != nil { + v.ResponseSize = resp.Size } } } @@ -359,12 +349,6 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.LogContentLength { v.ContentLength = req.Header.Get(echo.HeaderContentLength) } - if config.LogResponseSize { - v.ResponseSize = -1 - if resp != nil { - v.ResponseSize = resp.Size - } - } if logHeaders { v.Headers = map[string][]string{} for _, header := range headers { diff --git a/router_test.go b/router_test.go index 1dd306a36..7bddb4a15 100644 --- a/router_test.go +++ b/router_test.go @@ -3079,10 +3079,6 @@ func TestDefaultRouter_AddDuplicateRouteNotAllowed(t *testing.T) { assert.Equal(t, "OLD", body) } -func TestName(t *testing.T) { - -} - // See issue #1531, #1258 - there are cases when path parameter need to be unescaped func TestDefaultRouter_UnescapePathParamValues(t *testing.T) { var testCases = []struct {