package middleware import ( "net/http" "net/http/httptest" "testing" ) func TestResponseWriter_WriteHeader(t *testing.T) { rec := httptest.NewRecorder() rw := &responseWriter{ ResponseWriter: rec, status: http.StatusOK, } // First call should set status rw.WriteHeader(http.StatusNotFound) if rw.status != http.StatusNotFound { t.Errorf("status = %d, want %d", rw.status, http.StatusNotFound) } if !rw.wroteHeader { t.Error("wroteHeader should be true after WriteHeader") } // Second call should be ignored rw.WriteHeader(http.StatusInternalServerError) if rw.status != http.StatusNotFound { t.Errorf("status = %d, want %d (should not change)", rw.status, http.StatusNotFound) } } func TestResponseWriter_Write(t *testing.T) { rec := httptest.NewRecorder() rw := &responseWriter{ ResponseWriter: rec, status: http.StatusOK, } // Write should set default status if not set n, err := rw.Write([]byte("Hello")) if err != nil { t.Errorf("Write() error = %v", err) } if n != 5 { t.Errorf("Write() n = %d, want 5", n) } if rw.written != 5 { t.Errorf("written = %d, want 5", rw.written) } if !rw.wroteHeader { t.Error("wroteHeader should be true after Write") } if rw.status != http.StatusOK { t.Errorf("status = %d, want %d", rw.status, http.StatusOK) } // Write more n, err = rw.Write([]byte(" World")) if err != nil { t.Errorf("Write() error = %v", err) } if n != 6 { t.Errorf("Write() n = %d, want 6", n) } if rw.written != 11 { t.Errorf("written = %d, want 11", rw.written) } } func TestResponseWriter_WriteWithExplicitStatus(t *testing.T) { rec := httptest.NewRecorder() rw := &responseWriter{ ResponseWriter: rec, status: http.StatusOK, } // Set status first rw.WriteHeader(http.StatusCreated) // Write should not change status _, _ = rw.Write([]byte("Created")) if rw.status != http.StatusCreated { t.Errorf("status = %d, want %d", rw.status, http.StatusCreated) } } func TestLogger(t *testing.T) { t.Run("Logs successful request", func(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("OK")) }) logged := Logger(handler) req := httptest.NewRequest(http.MethodGet, "/test", nil) rec := httptest.NewRecorder() logged.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("Code = %d, want %d", rec.Code, http.StatusOK) } if rec.Body.String() != "OK" { t.Errorf("Body = %q, want %q", rec.Body.String(), "OK") } }) t.Run("Logs error response", func(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, "Not Found", http.StatusNotFound) }) logged := Logger(handler) req := httptest.NewRequest(http.MethodGet, "/notfound", nil) rec := httptest.NewRecorder() logged.ServeHTTP(rec, req) if rec.Code != http.StatusNotFound { t.Errorf("Code = %d, want %d", rec.Code, http.StatusNotFound) } }) t.Run("Handles POST request", func(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusCreated) }) logged := Logger(handler) req := httptest.NewRequest(http.MethodPost, "/create", nil) rec := httptest.NewRecorder() logged.ServeHTTP(rec, req) if rec.Code != http.StatusCreated { t.Errorf("Code = %d, want %d", rec.Code, http.StatusCreated) } }) t.Run("Handles request with no explicit status", func(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Just write body without explicit status _, _ = w.Write([]byte("Implicit OK")) }) logged := Logger(handler) req := httptest.NewRequest(http.MethodGet, "/implicit", nil) rec := httptest.NewRecorder() logged.ServeHTTP(rec, req) // Default status should be 200 if rec.Code != http.StatusOK { t.Errorf("Code = %d, want %d", rec.Code, http.StatusOK) } }) }