package middleware import ( "net/http" "net/http/httptest" "net/url" "strings" "testing" "time" c "github.com/juanatsap/cv-site/internal/constants" ) func TestNewCSRFProtection(t *testing.T) { csrf := NewCSRFProtection() if csrf == nil { t.Fatal("NewCSRFProtection should return non-nil") } if csrf.tokens == nil { t.Error("tokens map should be initialized") } } func TestCSRFProtection_generateToken(t *testing.T) { csrf := NewCSRFProtection() token, err := csrf.generateToken() if err != nil { t.Errorf("generateToken() error = %v", err) } if token == "" { t.Error("generateToken() should return non-empty token") } // Token should be stored csrf.mu.RLock() entry, exists := csrf.tokens[token] csrf.mu.RUnlock() if !exists { t.Error("Generated token should be stored") } if entry.token != token { t.Errorf("Stored token = %q, want %q", entry.token, token) } // Token should have expiration in the future if !entry.expiresAt.After(time.Now()) { t.Error("Token expiration should be in the future") } } func TestCSRFProtection_generateToken_Unique(t *testing.T) { csrf := NewCSRFProtection() tokens := make(map[string]bool) for i := 0; i < 100; i++ { token, err := csrf.generateToken() if err != nil { t.Fatalf("generateToken() error = %v", err) } if tokens[token] { t.Error("Tokens should be unique") } tokens[token] = true } } func TestCSRFProtection_GetToken_NewToken(t *testing.T) { csrf := NewCSRFProtection() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() token, err := csrf.GetToken(rec, req) if err != nil { t.Errorf("GetToken() error = %v", err) } if token == "" { t.Error("GetToken() should return non-empty token") } // Check cookie was set cookies := rec.Result().Cookies() var found bool for _, cookie := range cookies { if cookie.Name == c.CSRFCookieName { found = true if cookie.Value != token { t.Errorf("Cookie value = %q, want %q", cookie.Value, token) } if !cookie.HttpOnly { t.Error("Cookie should be HttpOnly") } } } if !found { t.Error("CSRF cookie should be set") } } func TestCSRFProtection_GetToken_ExistingToken(t *testing.T) { csrf := NewCSRFProtection() // First request to get token req1 := httptest.NewRequest(http.MethodGet, "/", nil) rec1 := httptest.NewRecorder() token1, _ := csrf.GetToken(rec1, req1) // Second request with existing token cookie req2 := httptest.NewRequest(http.MethodGet, "/", nil) req2.AddCookie(&http.Cookie{ Name: c.CSRFCookieName, Value: token1, }) rec2 := httptest.NewRecorder() token2, err := csrf.GetToken(rec2, req2) if err != nil { t.Errorf("GetToken() error = %v", err) } // Should return same token if token2 != token1 { t.Errorf("GetToken() = %q, want %q (same token)", token2, token1) } } func TestCSRFProtection_validateToken(t *testing.T) { csrf := NewCSRFProtection() // Generate a token first req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() token, _ := csrf.GetToken(rec, req) t.Run("Valid token in form", func(t *testing.T) { form := url.Values{} form.Set(c.CSRFFormField, token) postReq := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode())) postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") postReq.AddCookie(&http.Cookie{ Name: c.CSRFCookieName, Value: token, }) if !csrf.validateToken(postReq) { t.Error("validateToken should return true for valid token") } }) t.Run("Valid token in header", func(t *testing.T) { postReq := httptest.NewRequest(http.MethodPost, "/", nil) postReq.Header.Set(c.HeaderXCSRFToken, token) postReq.AddCookie(&http.Cookie{ Name: c.CSRFCookieName, Value: token, }) if !csrf.validateToken(postReq) { t.Error("validateToken should return true for valid token in header") } }) t.Run("Missing form token", func(t *testing.T) { postReq := httptest.NewRequest(http.MethodPost, "/", nil) postReq.AddCookie(&http.Cookie{ Name: c.CSRFCookieName, Value: token, }) if csrf.validateToken(postReq) { t.Error("validateToken should return false for missing form token") } }) t.Run("Missing cookie", func(t *testing.T) { form := url.Values{} form.Set(c.CSRFFormField, token) postReq := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode())) postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") if csrf.validateToken(postReq) { t.Error("validateToken should return false for missing cookie") } }) t.Run("Token mismatch", func(t *testing.T) { form := url.Values{} form.Set(c.CSRFFormField, "wrong-token") postReq := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode())) postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") postReq.AddCookie(&http.Cookie{ Name: c.CSRFCookieName, Value: token, }) if csrf.validateToken(postReq) { t.Error("validateToken should return false for mismatched tokens") } }) t.Run("Token not in store", func(t *testing.T) { unknownToken := "unknown-token" form := url.Values{} form.Set(c.CSRFFormField, unknownToken) postReq := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode())) postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") postReq.AddCookie(&http.Cookie{ Name: c.CSRFCookieName, Value: unknownToken, }) if csrf.validateToken(postReq) { t.Error("validateToken should return false for token not in store") } }) } func TestCSRFProtection_Middleware(t *testing.T) { csrf := NewCSRFProtection() // Generate a valid token req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() token, _ := csrf.GetToken(rec, req) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("OK")) }) protected := csrf.Middleware(handler) t.Run("GET request passes through", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() protected.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("Status = %d, want %d", rec.Code, http.StatusOK) } }) t.Run("POST with valid token passes", func(t *testing.T) { form := url.Values{} form.Set(c.CSRFFormField, token) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.AddCookie(&http.Cookie{ Name: c.CSRFCookieName, Value: token, }) rec := httptest.NewRecorder() protected.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("Status = %d, want %d", rec.Code, http.StatusOK) } }) t.Run("POST without token fails", func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/", nil) rec := httptest.NewRecorder() protected.ServeHTTP(rec, req) if rec.Code != http.StatusForbidden { t.Errorf("Status = %d, want %d", rec.Code, http.StatusForbidden) } }) t.Run("PUT without token fails", func(t *testing.T) { req := httptest.NewRequest(http.MethodPut, "/", nil) rec := httptest.NewRecorder() protected.ServeHTTP(rec, req) if rec.Code != http.StatusForbidden { t.Errorf("Status = %d, want %d", rec.Code, http.StatusForbidden) } }) t.Run("DELETE without token fails", func(t *testing.T) { req := httptest.NewRequest(http.MethodDelete, "/", nil) rec := httptest.NewRecorder() protected.ServeHTTP(rec, req) if rec.Code != http.StatusForbidden { t.Errorf("Status = %d, want %d", rec.Code, http.StatusForbidden) } }) t.Run("HTMX request gets HTML error", func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/", nil) req.Header.Set(c.HeaderHXRequest, "true") rec := httptest.NewRecorder() protected.ServeHTTP(rec, req) if rec.Code != http.StatusForbidden { t.Errorf("Status = %d, want %d", rec.Code, http.StatusForbidden) } body := rec.Body.String() if !strings.Contains(body, "Security Error") { t.Error("HTMX response should contain HTML error message") } if !strings.Contains(body, "alert") { t.Error("HTMX response should contain alert class") } }) } func TestCSRFTokenEntry_Expiration(t *testing.T) { csrf := &CSRFProtection{ tokens: make(map[string]*csrfTokenEntry), } // Add expired token expiredToken := "expired-token" csrf.tokens[expiredToken] = &csrfTokenEntry{ token: expiredToken, expiresAt: time.Now().Add(-1 * time.Hour), // Expired } // Validation should fail for expired token form := url.Values{} form.Set(c.CSRFFormField, expiredToken) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.AddCookie(&http.Cookie{ Name: c.CSRFCookieName, Value: expiredToken, }) if csrf.validateToken(req) { t.Error("validateToken should return false for expired token") } } func TestGetToken_ExpiredTokenInCookie(t *testing.T) { csrf := &CSRFProtection{ tokens: make(map[string]*csrfTokenEntry), } // Add expired token to store expiredToken := "expired-token" csrf.tokens[expiredToken] = &csrfTokenEntry{ token: expiredToken, expiresAt: time.Now().Add(-1 * time.Hour), // Expired } // Request with expired token cookie req := httptest.NewRequest(http.MethodGet, "/", nil) req.AddCookie(&http.Cookie{ Name: c.CSRFCookieName, Value: expiredToken, }) rec := httptest.NewRecorder() newToken, err := csrf.GetToken(rec, req) if err != nil { t.Errorf("GetToken() error = %v", err) } // Should generate new token if newToken == expiredToken { t.Error("GetToken() should generate new token when existing is expired") } }