package httputil import ( "net/http" "net/http/httptest" "testing" ) func TestRequireMethod(t *testing.T) { tests := []struct { name string requestMethod string requiredMethod string wantOk bool wantStatus int }{ {"POST matches POST", http.MethodPost, http.MethodPost, true, 0}, {"GET matches GET", http.MethodGet, http.MethodGet, true, 0}, {"GET doesn't match POST", http.MethodGet, http.MethodPost, false, http.StatusMethodNotAllowed}, {"POST doesn't match GET", http.MethodPost, http.MethodGet, false, http.StatusMethodNotAllowed}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest(tt.requestMethod, "/", nil) w := httptest.NewRecorder() ok := RequireMethod(w, req, tt.requiredMethod) if ok != tt.wantOk { t.Errorf("RequireMethod() = %v, want %v", ok, tt.wantOk) } if !ok && w.Code != tt.wantStatus { t.Errorf("RequireMethod() status = %d, want %d", w.Code, tt.wantStatus) } }) } } func TestRequirePost(t *testing.T) { t.Run("POST request allowed", func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/", nil) w := httptest.NewRecorder() if !RequirePost(w, req) { t.Error("RequirePost() should return true for POST request") } }) t.Run("GET request rejected", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() if RequirePost(w, req) { t.Error("RequirePost() should return false for GET request") } if w.Code != http.StatusMethodNotAllowed { t.Errorf("RequirePost() status = %d, want %d", w.Code, http.StatusMethodNotAllowed) } }) } func TestRequireGet(t *testing.T) { t.Run("GET request allowed", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() if !RequireGet(w, req) { t.Error("RequireGet() should return true for GET request") } }) t.Run("POST request rejected", func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/", nil) w := httptest.NewRecorder() if RequireGet(w, req) { t.Error("RequireGet() should return false for POST request") } if w.Code != http.StatusMethodNotAllowed { t.Errorf("RequireGet() status = %d, want %d", w.Code, http.StatusMethodNotAllowed) } }) }