package middleware import ( "net/http" "net/http/httptest" "testing" ) func TestRecovery(t *testing.T) { t.Run("Recovers from panic", func(t *testing.T) { // Handler that panics panicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { panic("test panic") }) handler := Recovery(panicHandler) req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() // Should not panic handler.ServeHTTP(w, req) if w.Code != http.StatusInternalServerError { t.Errorf("Status = %d, want %d", w.Code, http.StatusInternalServerError) } }) t.Run("Passes through normal requests", func(t *testing.T) { normalHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("OK")) }) handler := Recovery(normalHandler) req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("Status = %d, want %d", w.Code, http.StatusOK) } if w.Body.String() != "OK" { t.Errorf("Body = %q, want %q", w.Body.String(), "OK") } }) t.Run("Recovers from nil panic", func(t *testing.T) { nilPanicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var p *int = nil _ = *p // This will cause a nil pointer dereference }) handler := Recovery(nilPanicHandler) req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() // Should not panic handler.ServeHTTP(w, req) if w.Code != http.StatusInternalServerError { t.Errorf("Status = %d, want %d", w.Code, http.StatusInternalServerError) } }) }