6ed6c7780b
Links to PROJECT-MEMORY.md and DECISIONS.md for development rules and architectural decisions, plus quick commands and doc index.
243 lines
5.9 KiB
Go
243 lines
5.9 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
c "github.com/juanatsap/cv-site/internal/constants"
|
|
)
|
|
|
|
func TestNewContactRateLimiter(t *testing.T) {
|
|
rl := NewContactRateLimiter()
|
|
|
|
if rl == nil {
|
|
t.Fatal("NewContactRateLimiter should return non-nil")
|
|
}
|
|
|
|
if rl.clients == nil {
|
|
t.Error("clients map should be initialized")
|
|
}
|
|
}
|
|
|
|
func TestContactRateLimiter_allow(t *testing.T) {
|
|
rl := &ContactRateLimiter{
|
|
clients: make(map[string]*contactRateLimitEntry),
|
|
}
|
|
|
|
ip := "192.168.1.1"
|
|
|
|
// First request should be allowed
|
|
if !rl.allow(ip) {
|
|
t.Error("First request should be allowed")
|
|
}
|
|
|
|
// Subsequent requests up to limit should be allowed
|
|
limit := c.RateLimitContactRequests
|
|
for i := 1; i < limit; i++ {
|
|
if !rl.allow(ip) {
|
|
t.Errorf("Request %d should be allowed (limit: %d)", i+1, limit)
|
|
}
|
|
}
|
|
|
|
// Request exceeding limit should be blocked
|
|
if rl.allow(ip) {
|
|
t.Error("Request exceeding limit should be blocked")
|
|
}
|
|
|
|
// Different IP should be allowed
|
|
if !rl.allow("192.168.1.2") {
|
|
t.Error("Different IP should be allowed")
|
|
}
|
|
}
|
|
|
|
func TestContactRateLimiter_Middleware_Allowed(t *testing.T) {
|
|
rl := &ContactRateLimiter{
|
|
clients: make(map[string]*contactRateLimitEntry),
|
|
}
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte("OK"))
|
|
})
|
|
|
|
protected := rl.Middleware(handler)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/api/contact", nil)
|
|
req.RemoteAddr = "192.168.1.1:12345"
|
|
rec := httptest.NewRecorder()
|
|
|
|
protected.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Errorf("Status = %d, want %d", rec.Code, http.StatusOK)
|
|
}
|
|
}
|
|
|
|
func TestContactRateLimiter_Middleware_Blocked(t *testing.T) {
|
|
rl := &ContactRateLimiter{
|
|
clients: make(map[string]*contactRateLimitEntry),
|
|
}
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
protected := rl.Middleware(handler)
|
|
|
|
// Exhaust the rate limit
|
|
limit := c.RateLimitContactRequests
|
|
for i := 0; i < limit; i++ {
|
|
req := httptest.NewRequest(http.MethodPost, "/api/contact", nil)
|
|
req.RemoteAddr = "192.168.1.1:12345"
|
|
rec := httptest.NewRecorder()
|
|
protected.ServeHTTP(rec, req)
|
|
}
|
|
|
|
// Next request should be blocked
|
|
req := httptest.NewRequest(http.MethodPost, "/api/contact", nil)
|
|
req.RemoteAddr = "192.168.1.1:12345"
|
|
rec := httptest.NewRecorder()
|
|
|
|
protected.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusTooManyRequests {
|
|
t.Errorf("Status = %d, want %d", rec.Code, http.StatusTooManyRequests)
|
|
}
|
|
|
|
// Should have Retry-After header
|
|
if rec.Header().Get(c.HeaderRetryAfter) == "" {
|
|
t.Error("Response should have Retry-After header")
|
|
}
|
|
}
|
|
|
|
func TestContactRateLimiter_Middleware_HTMX(t *testing.T) {
|
|
rl := &ContactRateLimiter{
|
|
clients: make(map[string]*contactRateLimitEntry),
|
|
}
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
protected := rl.Middleware(handler)
|
|
|
|
// Exhaust the rate limit
|
|
limit := c.RateLimitContactRequests
|
|
for i := 0; i < limit; i++ {
|
|
req := httptest.NewRequest(http.MethodPost, "/api/contact", nil)
|
|
req.RemoteAddr = "192.168.1.1:12345"
|
|
rec := httptest.NewRecorder()
|
|
protected.ServeHTTP(rec, req)
|
|
}
|
|
|
|
// HTMX request should get HTML response
|
|
req := httptest.NewRequest(http.MethodPost, "/api/contact", nil)
|
|
req.RemoteAddr = "192.168.1.1:12345"
|
|
req.Header.Set(c.HeaderHXRequest, "true")
|
|
rec := httptest.NewRecorder()
|
|
|
|
protected.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusTooManyRequests {
|
|
t.Errorf("Status = %d, want %d", rec.Code, http.StatusTooManyRequests)
|
|
}
|
|
|
|
body := rec.Body.String()
|
|
if !strings.Contains(body, "Too Many Requests") {
|
|
t.Error("HTMX response should contain HTML error message")
|
|
}
|
|
if !strings.Contains(body, "alert") {
|
|
t.Error("HTMX response should contain alert class")
|
|
}
|
|
}
|
|
|
|
func TestContactRateLimiter_Middleware_XForwardedFor(t *testing.T) {
|
|
rl := &ContactRateLimiter{
|
|
clients: make(map[string]*contactRateLimitEntry),
|
|
}
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
protected := rl.Middleware(handler)
|
|
|
|
// Request with X-Forwarded-For
|
|
req := httptest.NewRequest(http.MethodPost, "/api/contact", nil)
|
|
req.Header.Set(c.HeaderXForwardedFor, "10.0.0.1, 192.168.1.1")
|
|
rec := httptest.NewRecorder()
|
|
|
|
protected.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Errorf("Status = %d, want %d", rec.Code, http.StatusOK)
|
|
}
|
|
|
|
// Check that first IP was used
|
|
rl.mu.RLock()
|
|
_, exists := rl.clients["10.0.0.1"]
|
|
rl.mu.RUnlock()
|
|
|
|
if !exists {
|
|
t.Error("Should use first IP from X-Forwarded-For")
|
|
}
|
|
}
|
|
|
|
func TestContactRateLimiter_Middleware_XRealIP(t *testing.T) {
|
|
rl := &ContactRateLimiter{
|
|
clients: make(map[string]*contactRateLimitEntry),
|
|
}
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
protected := rl.Middleware(handler)
|
|
|
|
// Request with X-Real-IP
|
|
req := httptest.NewRequest(http.MethodPost, "/api/contact", nil)
|
|
req.Header.Set(c.HeaderXRealIP, "10.0.0.2")
|
|
rec := httptest.NewRecorder()
|
|
|
|
protected.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Errorf("Status = %d, want %d", rec.Code, http.StatusOK)
|
|
}
|
|
|
|
// Check that X-Real-IP was used
|
|
rl.mu.RLock()
|
|
_, exists := rl.clients["10.0.0.2"]
|
|
rl.mu.RUnlock()
|
|
|
|
if !exists {
|
|
t.Error("Should use X-Real-IP")
|
|
}
|
|
}
|
|
|
|
func TestContactRateLimiter_GetStats(t *testing.T) {
|
|
rl := &ContactRateLimiter{
|
|
clients: make(map[string]*contactRateLimitEntry),
|
|
}
|
|
|
|
// Add some entries
|
|
rl.allow("192.168.1.1")
|
|
rl.allow("192.168.1.2")
|
|
|
|
stats := rl.GetStats()
|
|
|
|
if stats["total_clients"] != 2 {
|
|
t.Errorf("total_clients = %v, want 2", stats["total_clients"])
|
|
}
|
|
|
|
if stats["limit"] != c.RateLimitContactRequests {
|
|
t.Errorf("limit = %v, want %d", stats["limit"], c.RateLimitContactRequests)
|
|
}
|
|
|
|
if stats["window"] == "" {
|
|
t.Error("window should not be empty")
|
|
}
|
|
}
|