86 lines
2.3 KiB
Go
86 lines
2.3 KiB
Go
|
|
package httputil
|
||
|
|
|
||
|
|
import (
|
||
|
|
"net/http"
|
||
|
|
"net/http/httptest"
|
||
|
|
"testing"
|
||
|
|
)
|
||
|
|
|
||
|
|
func TestRequireMethod(t *testing.T) {
|
||
|
|
tests := []struct {
|
||
|
|
name string
|
||
|
|
requestMethod string
|
||
|
|
requiredMethod string
|
||
|
|
wantOk bool
|
||
|
|
wantStatus int
|
||
|
|
}{
|
||
|
|
{"POST matches POST", http.MethodPost, http.MethodPost, true, 0},
|
||
|
|
{"GET matches GET", http.MethodGet, http.MethodGet, true, 0},
|
||
|
|
{"GET doesn't match POST", http.MethodGet, http.MethodPost, false, http.StatusMethodNotAllowed},
|
||
|
|
{"POST doesn't match GET", http.MethodPost, http.MethodGet, false, http.StatusMethodNotAllowed},
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, tt := range tests {
|
||
|
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
|
req := httptest.NewRequest(tt.requestMethod, "/", nil)
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
|
||
|
|
ok := RequireMethod(w, req, tt.requiredMethod)
|
||
|
|
|
||
|
|
if ok != tt.wantOk {
|
||
|
|
t.Errorf("RequireMethod() = %v, want %v", ok, tt.wantOk)
|
||
|
|
}
|
||
|
|
|
||
|
|
if !ok && w.Code != tt.wantStatus {
|
||
|
|
t.Errorf("RequireMethod() status = %d, want %d", w.Code, tt.wantStatus)
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRequirePost(t *testing.T) {
|
||
|
|
t.Run("POST request allowed", func(t *testing.T) {
|
||
|
|
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
|
||
|
|
if !RequirePost(w, req) {
|
||
|
|
t.Error("RequirePost() should return true for POST request")
|
||
|
|
}
|
||
|
|
})
|
||
|
|
|
||
|
|
t.Run("GET request rejected", func(t *testing.T) {
|
||
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
|
||
|
|
if RequirePost(w, req) {
|
||
|
|
t.Error("RequirePost() should return false for GET request")
|
||
|
|
}
|
||
|
|
if w.Code != http.StatusMethodNotAllowed {
|
||
|
|
t.Errorf("RequirePost() status = %d, want %d", w.Code, http.StatusMethodNotAllowed)
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRequireGet(t *testing.T) {
|
||
|
|
t.Run("GET request allowed", func(t *testing.T) {
|
||
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
|
||
|
|
if !RequireGet(w, req) {
|
||
|
|
t.Error("RequireGet() should return true for GET request")
|
||
|
|
}
|
||
|
|
})
|
||
|
|
|
||
|
|
t.Run("POST request rejected", func(t *testing.T) {
|
||
|
|
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
|
||
|
|
if RequireGet(w, req) {
|
||
|
|
t.Error("RequireGet() should return false for POST request")
|
||
|
|
}
|
||
|
|
if w.Code != http.StatusMethodNotAllowed {
|
||
|
|
t.Errorf("RequireGet() status = %d, want %d", w.Code, http.StatusMethodNotAllowed)
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|