package middleware import ( "net/http" "net/http/httptest" "strings" "testing" c "github.com/juanatsap/cv-site/internal/constants" ) func TestNewContactRateLimiter(t *testing.T) { rl := NewContactRateLimiter() if rl == nil { t.Fatal("NewContactRateLimiter should return non-nil") } if rl.clients == nil { t.Error("clients map should be initialized") } } func TestContactRateLimiter_allow(t *testing.T) { rl := &ContactRateLimiter{ clients: make(map[string]*contactRateLimitEntry), } ip := "192.168.1.1" // First request should be allowed if !rl.allow(ip) { t.Error("First request should be allowed") } // Subsequent requests up to limit should be allowed limit := c.RateLimitContactRequests for i := 1; i < limit; i++ { if !rl.allow(ip) { t.Errorf("Request %d should be allowed (limit: %d)", i+1, limit) } } // Request exceeding limit should be blocked if rl.allow(ip) { t.Error("Request exceeding limit should be blocked") } // Different IP should be allowed if !rl.allow("192.168.1.2") { t.Error("Different IP should be allowed") } } func TestContactRateLimiter_Middleware_Allowed(t *testing.T) { rl := &ContactRateLimiter{ clients: make(map[string]*contactRateLimitEntry), } handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("OK")) }) protected := rl.Middleware(handler) req := httptest.NewRequest(http.MethodPost, "/api/contact", nil) req.RemoteAddr = "192.168.1.1:12345" rec := httptest.NewRecorder() protected.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("Status = %d, want %d", rec.Code, http.StatusOK) } } func TestContactRateLimiter_Middleware_Blocked(t *testing.T) { rl := &ContactRateLimiter{ clients: make(map[string]*contactRateLimitEntry), } handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) protected := rl.Middleware(handler) // Exhaust the rate limit limit := c.RateLimitContactRequests for i := 0; i < limit; i++ { req := httptest.NewRequest(http.MethodPost, "/api/contact", nil) req.RemoteAddr = "192.168.1.1:12345" rec := httptest.NewRecorder() protected.ServeHTTP(rec, req) } // Next request should be blocked req := httptest.NewRequest(http.MethodPost, "/api/contact", nil) req.RemoteAddr = "192.168.1.1:12345" rec := httptest.NewRecorder() protected.ServeHTTP(rec, req) if rec.Code != http.StatusTooManyRequests { t.Errorf("Status = %d, want %d", rec.Code, http.StatusTooManyRequests) } // Should have Retry-After header if rec.Header().Get(c.HeaderRetryAfter) == "" { t.Error("Response should have Retry-After header") } } func TestContactRateLimiter_Middleware_HTMX(t *testing.T) { rl := &ContactRateLimiter{ clients: make(map[string]*contactRateLimitEntry), } handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) protected := rl.Middleware(handler) // Exhaust the rate limit limit := c.RateLimitContactRequests for i := 0; i < limit; i++ { req := httptest.NewRequest(http.MethodPost, "/api/contact", nil) req.RemoteAddr = "192.168.1.1:12345" rec := httptest.NewRecorder() protected.ServeHTTP(rec, req) } // HTMX request should get HTML response req := httptest.NewRequest(http.MethodPost, "/api/contact", nil) req.RemoteAddr = "192.168.1.1:12345" req.Header.Set(c.HeaderHXRequest, "true") rec := httptest.NewRecorder() protected.ServeHTTP(rec, req) if rec.Code != http.StatusTooManyRequests { t.Errorf("Status = %d, want %d", rec.Code, http.StatusTooManyRequests) } body := rec.Body.String() if !strings.Contains(body, "Too Many Requests") { t.Error("HTMX response should contain HTML error message") } if !strings.Contains(body, "alert") { t.Error("HTMX response should contain alert class") } } func TestContactRateLimiter_Middleware_XForwardedFor(t *testing.T) { rl := &ContactRateLimiter{ clients: make(map[string]*contactRateLimitEntry), } handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) protected := rl.Middleware(handler) // Request with X-Forwarded-For req := httptest.NewRequest(http.MethodPost, "/api/contact", nil) req.Header.Set(c.HeaderXForwardedFor, "10.0.0.1, 192.168.1.1") rec := httptest.NewRecorder() protected.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("Status = %d, want %d", rec.Code, http.StatusOK) } // Check that first IP was used rl.mu.RLock() _, exists := rl.clients["10.0.0.1"] rl.mu.RUnlock() if !exists { t.Error("Should use first IP from X-Forwarded-For") } } func TestContactRateLimiter_Middleware_XRealIP(t *testing.T) { rl := &ContactRateLimiter{ clients: make(map[string]*contactRateLimitEntry), } handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) protected := rl.Middleware(handler) // Request with X-Real-IP req := httptest.NewRequest(http.MethodPost, "/api/contact", nil) req.Header.Set(c.HeaderXRealIP, "10.0.0.2") rec := httptest.NewRecorder() protected.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("Status = %d, want %d", rec.Code, http.StatusOK) } // Check that X-Real-IP was used rl.mu.RLock() _, exists := rl.clients["10.0.0.2"] rl.mu.RUnlock() if !exists { t.Error("Should use X-Real-IP") } } func TestContactRateLimiter_GetStats(t *testing.T) { rl := &ContactRateLimiter{ clients: make(map[string]*contactRateLimitEntry), } // Add some entries rl.allow("192.168.1.1") rl.allow("192.168.1.2") stats := rl.GetStats() if stats["total_clients"] != 2 { t.Errorf("total_clients = %v, want 2", stats["total_clients"]) } if stats["limit"] != c.RateLimitContactRequests { t.Errorf("limit = %v, want %d", stats["limit"], c.RateLimitContactRequests) } if stats["window"] == "" { t.Error("window should not be empty") } }