6ed6c7780b
Links to PROJECT-MEMORY.md and DECISIONS.md for development rules and architectural decisions, plus quick commands and doc index.
386 lines
9.6 KiB
Go
386 lines
9.6 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
c "github.com/juanatsap/cv-site/internal/constants"
|
|
)
|
|
|
|
func TestNewCSRFProtection(t *testing.T) {
|
|
csrf := NewCSRFProtection()
|
|
|
|
if csrf == nil {
|
|
t.Fatal("NewCSRFProtection should return non-nil")
|
|
}
|
|
|
|
if csrf.tokens == nil {
|
|
t.Error("tokens map should be initialized")
|
|
}
|
|
}
|
|
|
|
func TestCSRFProtection_generateToken(t *testing.T) {
|
|
csrf := NewCSRFProtection()
|
|
|
|
token, err := csrf.generateToken()
|
|
if err != nil {
|
|
t.Errorf("generateToken() error = %v", err)
|
|
}
|
|
|
|
if token == "" {
|
|
t.Error("generateToken() should return non-empty token")
|
|
}
|
|
|
|
// Token should be stored
|
|
csrf.mu.RLock()
|
|
entry, exists := csrf.tokens[token]
|
|
csrf.mu.RUnlock()
|
|
|
|
if !exists {
|
|
t.Error("Generated token should be stored")
|
|
}
|
|
|
|
if entry.token != token {
|
|
t.Errorf("Stored token = %q, want %q", entry.token, token)
|
|
}
|
|
|
|
// Token should have expiration in the future
|
|
if !entry.expiresAt.After(time.Now()) {
|
|
t.Error("Token expiration should be in the future")
|
|
}
|
|
}
|
|
|
|
func TestCSRFProtection_generateToken_Unique(t *testing.T) {
|
|
csrf := NewCSRFProtection()
|
|
|
|
tokens := make(map[string]bool)
|
|
for i := 0; i < 100; i++ {
|
|
token, err := csrf.generateToken()
|
|
if err != nil {
|
|
t.Fatalf("generateToken() error = %v", err)
|
|
}
|
|
if tokens[token] {
|
|
t.Error("Tokens should be unique")
|
|
}
|
|
tokens[token] = true
|
|
}
|
|
}
|
|
|
|
func TestCSRFProtection_GetToken_NewToken(t *testing.T) {
|
|
csrf := NewCSRFProtection()
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
token, err := csrf.GetToken(rec, req)
|
|
if err != nil {
|
|
t.Errorf("GetToken() error = %v", err)
|
|
}
|
|
|
|
if token == "" {
|
|
t.Error("GetToken() should return non-empty token")
|
|
}
|
|
|
|
// Check cookie was set
|
|
cookies := rec.Result().Cookies()
|
|
var found bool
|
|
for _, cookie := range cookies {
|
|
if cookie.Name == c.CSRFCookieName {
|
|
found = true
|
|
if cookie.Value != token {
|
|
t.Errorf("Cookie value = %q, want %q", cookie.Value, token)
|
|
}
|
|
if !cookie.HttpOnly {
|
|
t.Error("Cookie should be HttpOnly")
|
|
}
|
|
}
|
|
}
|
|
if !found {
|
|
t.Error("CSRF cookie should be set")
|
|
}
|
|
}
|
|
|
|
func TestCSRFProtection_GetToken_ExistingToken(t *testing.T) {
|
|
csrf := NewCSRFProtection()
|
|
|
|
// First request to get token
|
|
req1 := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
rec1 := httptest.NewRecorder()
|
|
token1, _ := csrf.GetToken(rec1, req1)
|
|
|
|
// Second request with existing token cookie
|
|
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req2.AddCookie(&http.Cookie{
|
|
Name: c.CSRFCookieName,
|
|
Value: token1,
|
|
})
|
|
rec2 := httptest.NewRecorder()
|
|
|
|
token2, err := csrf.GetToken(rec2, req2)
|
|
if err != nil {
|
|
t.Errorf("GetToken() error = %v", err)
|
|
}
|
|
|
|
// Should return same token
|
|
if token2 != token1 {
|
|
t.Errorf("GetToken() = %q, want %q (same token)", token2, token1)
|
|
}
|
|
}
|
|
|
|
func TestCSRFProtection_validateToken(t *testing.T) {
|
|
csrf := NewCSRFProtection()
|
|
|
|
// Generate a token first
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
rec := httptest.NewRecorder()
|
|
token, _ := csrf.GetToken(rec, req)
|
|
|
|
t.Run("Valid token in form", func(t *testing.T) {
|
|
form := url.Values{}
|
|
form.Set(c.CSRFFormField, token)
|
|
|
|
postReq := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode()))
|
|
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
postReq.AddCookie(&http.Cookie{
|
|
Name: c.CSRFCookieName,
|
|
Value: token,
|
|
})
|
|
|
|
if !csrf.validateToken(postReq) {
|
|
t.Error("validateToken should return true for valid token")
|
|
}
|
|
})
|
|
|
|
t.Run("Valid token in header", func(t *testing.T) {
|
|
postReq := httptest.NewRequest(http.MethodPost, "/", nil)
|
|
postReq.Header.Set(c.HeaderXCSRFToken, token)
|
|
postReq.AddCookie(&http.Cookie{
|
|
Name: c.CSRFCookieName,
|
|
Value: token,
|
|
})
|
|
|
|
if !csrf.validateToken(postReq) {
|
|
t.Error("validateToken should return true for valid token in header")
|
|
}
|
|
})
|
|
|
|
t.Run("Missing form token", func(t *testing.T) {
|
|
postReq := httptest.NewRequest(http.MethodPost, "/", nil)
|
|
postReq.AddCookie(&http.Cookie{
|
|
Name: c.CSRFCookieName,
|
|
Value: token,
|
|
})
|
|
|
|
if csrf.validateToken(postReq) {
|
|
t.Error("validateToken should return false for missing form token")
|
|
}
|
|
})
|
|
|
|
t.Run("Missing cookie", func(t *testing.T) {
|
|
form := url.Values{}
|
|
form.Set(c.CSRFFormField, token)
|
|
|
|
postReq := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode()))
|
|
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
|
|
if csrf.validateToken(postReq) {
|
|
t.Error("validateToken should return false for missing cookie")
|
|
}
|
|
})
|
|
|
|
t.Run("Token mismatch", func(t *testing.T) {
|
|
form := url.Values{}
|
|
form.Set(c.CSRFFormField, "wrong-token")
|
|
|
|
postReq := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode()))
|
|
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
postReq.AddCookie(&http.Cookie{
|
|
Name: c.CSRFCookieName,
|
|
Value: token,
|
|
})
|
|
|
|
if csrf.validateToken(postReq) {
|
|
t.Error("validateToken should return false for mismatched tokens")
|
|
}
|
|
})
|
|
|
|
t.Run("Token not in store", func(t *testing.T) {
|
|
unknownToken := "unknown-token"
|
|
form := url.Values{}
|
|
form.Set(c.CSRFFormField, unknownToken)
|
|
|
|
postReq := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode()))
|
|
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
postReq.AddCookie(&http.Cookie{
|
|
Name: c.CSRFCookieName,
|
|
Value: unknownToken,
|
|
})
|
|
|
|
if csrf.validateToken(postReq) {
|
|
t.Error("validateToken should return false for token not in store")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestCSRFProtection_Middleware(t *testing.T) {
|
|
csrf := NewCSRFProtection()
|
|
|
|
// Generate a valid token
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
rec := httptest.NewRecorder()
|
|
token, _ := csrf.GetToken(rec, req)
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte("OK"))
|
|
})
|
|
|
|
protected := csrf.Middleware(handler)
|
|
|
|
t.Run("GET request passes through", func(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
protected.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Errorf("Status = %d, want %d", rec.Code, http.StatusOK)
|
|
}
|
|
})
|
|
|
|
t.Run("POST with valid token passes", func(t *testing.T) {
|
|
form := url.Values{}
|
|
form.Set(c.CSRFFormField, token)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode()))
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req.AddCookie(&http.Cookie{
|
|
Name: c.CSRFCookieName,
|
|
Value: token,
|
|
})
|
|
rec := httptest.NewRecorder()
|
|
|
|
protected.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Errorf("Status = %d, want %d", rec.Code, http.StatusOK)
|
|
}
|
|
})
|
|
|
|
t.Run("POST without token fails", func(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
protected.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusForbidden {
|
|
t.Errorf("Status = %d, want %d", rec.Code, http.StatusForbidden)
|
|
}
|
|
})
|
|
|
|
t.Run("PUT without token fails", func(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodPut, "/", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
protected.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusForbidden {
|
|
t.Errorf("Status = %d, want %d", rec.Code, http.StatusForbidden)
|
|
}
|
|
})
|
|
|
|
t.Run("DELETE without token fails", func(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodDelete, "/", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
protected.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusForbidden {
|
|
t.Errorf("Status = %d, want %d", rec.Code, http.StatusForbidden)
|
|
}
|
|
})
|
|
|
|
t.Run("HTMX request gets HTML error", func(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
|
req.Header.Set(c.HeaderHXRequest, "true")
|
|
rec := httptest.NewRecorder()
|
|
|
|
protected.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusForbidden {
|
|
t.Errorf("Status = %d, want %d", rec.Code, http.StatusForbidden)
|
|
}
|
|
|
|
body := rec.Body.String()
|
|
if !strings.Contains(body, "Security Error") {
|
|
t.Error("HTMX response should contain HTML error message")
|
|
}
|
|
if !strings.Contains(body, "alert") {
|
|
t.Error("HTMX response should contain alert class")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestCSRFTokenEntry_Expiration(t *testing.T) {
|
|
csrf := &CSRFProtection{
|
|
tokens: make(map[string]*csrfTokenEntry),
|
|
}
|
|
|
|
// Add expired token
|
|
expiredToken := "expired-token"
|
|
csrf.tokens[expiredToken] = &csrfTokenEntry{
|
|
token: expiredToken,
|
|
expiresAt: time.Now().Add(-1 * time.Hour), // Expired
|
|
}
|
|
|
|
// Validation should fail for expired token
|
|
form := url.Values{}
|
|
form.Set(c.CSRFFormField, expiredToken)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode()))
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req.AddCookie(&http.Cookie{
|
|
Name: c.CSRFCookieName,
|
|
Value: expiredToken,
|
|
})
|
|
|
|
if csrf.validateToken(req) {
|
|
t.Error("validateToken should return false for expired token")
|
|
}
|
|
}
|
|
|
|
func TestGetToken_ExpiredTokenInCookie(t *testing.T) {
|
|
csrf := &CSRFProtection{
|
|
tokens: make(map[string]*csrfTokenEntry),
|
|
}
|
|
|
|
// Add expired token to store
|
|
expiredToken := "expired-token"
|
|
csrf.tokens[expiredToken] = &csrfTokenEntry{
|
|
token: expiredToken,
|
|
expiresAt: time.Now().Add(-1 * time.Hour), // Expired
|
|
}
|
|
|
|
// Request with expired token cookie
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.AddCookie(&http.Cookie{
|
|
Name: c.CSRFCookieName,
|
|
Value: expiredToken,
|
|
})
|
|
rec := httptest.NewRecorder()
|
|
|
|
newToken, err := csrf.GetToken(rec, req)
|
|
if err != nil {
|
|
t.Errorf("GetToken() error = %v", err)
|
|
}
|
|
|
|
// Should generate new token
|
|
if newToken == expiredToken {
|
|
t.Error("GetToken() should generate new token when existing is expired")
|
|
}
|
|
}
|