package middleware import ( "net/http" "net/http/httptest" "os" "testing" "time" c "github.com/juanatsap/cv-site/internal/constants" ) func TestSecurityHeaders(t *testing.T) { handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) // Check required security headers tests := []struct { header string expected string }{ {c.HeaderXFrameOptions, c.FrameOptionsSameOrigin}, {c.HeaderXContentTypeOpts, c.NoSniff}, {c.HeaderXXSSProtection, c.XSSProtection}, {c.HeaderReferrerPolicy, c.ReferrerPolicy}, } for _, tt := range tests { t.Run(tt.header, func(t *testing.T) { value := w.Header().Get(tt.header) if value != tt.expected { t.Errorf("Header %s = %q, want %q", tt.header, value, tt.expected) } }) } // Check CSP header exists if w.Header().Get(c.HeaderCSP) == "" { t.Error("CSP header should be set") } // Check Permissions-Policy exists if w.Header().Get(c.HeaderPermissionsPolicy) == "" { t.Error("Permissions-Policy header should be set") } } func TestSecurityHeaders_HSTS(t *testing.T) { // Test in production mode t.Setenv(c.EnvVarGOEnv, c.EnvProduction) handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) // HSTS should be set in production if w.Header().Get(c.HeaderHSTS) == "" { t.Error("HSTS header should be set in production") } // Test in development mode t.Setenv(c.EnvVarGOEnv, c.EnvDevelopment) w = httptest.NewRecorder() handler.ServeHTTP(w, req) // HSTS should NOT be set in development if w.Header().Get(c.HeaderHSTS) != "" { t.Error("HSTS header should not be set in development") } } func TestBrowserOnly(t *testing.T) { handler := BrowserOnly(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) tests := []struct { name string userAgent string referer string origin string htmxHeader string xhrHeader string browserReq string expectStatus int }{ { name: "Valid browser request with HTMX", userAgent: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)", referer: "https://example.com", htmxHeader: "true", expectStatus: http.StatusOK, }, { name: "Valid browser request with XHR", userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64)", origin: "https://example.com", xhrHeader: c.HeaderValueXMLHTTPRequest, expectStatus: http.StatusOK, }, { name: "Valid browser request with custom header", userAgent: "Mozilla/5.0 (Linux; Android 10)", referer: "https://example.com", browserReq: "true", expectStatus: http.StatusOK, }, { name: "Blocked - curl user agent", userAgent: "curl/7.68.0", referer: "https://example.com", htmxHeader: "true", expectStatus: http.StatusForbidden, }, { name: "Blocked - wget user agent", userAgent: "Wget/1.20.3", referer: "https://example.com", htmxHeader: "true", expectStatus: http.StatusForbidden, }, { name: "Blocked - empty user agent", userAgent: "", referer: "https://example.com", htmxHeader: "true", expectStatus: http.StatusForbidden, }, { name: "Blocked - no referer/origin", userAgent: "Mozilla/5.0", referer: "", origin: "", htmxHeader: "true", expectStatus: http.StatusForbidden, }, { name: "Blocked - no browser headers", userAgent: "Mozilla/5.0", referer: "https://example.com", expectStatus: http.StatusForbidden, }, { name: "Blocked - Postman", userAgent: "PostmanRuntime/7.26.8", referer: "https://example.com", htmxHeader: "true", expectStatus: http.StatusForbidden, }, { name: "Blocked - Python requests", userAgent: "python-requests/2.25.1", referer: "https://example.com", htmxHeader: "true", expectStatus: http.StatusForbidden, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/api/contact", nil) if tt.userAgent != "" { req.Header.Set(c.HeaderUserAgent, tt.userAgent) } if tt.referer != "" { req.Header.Set(c.HeaderReferer, tt.referer) } if tt.origin != "" { req.Header.Set(c.HeaderOrigin, tt.origin) } if tt.htmxHeader != "" { req.Header.Set(c.HeaderHXRequest, tt.htmxHeader) } if tt.xhrHeader != "" { req.Header.Set(c.HeaderXRequestedWith, tt.xhrHeader) } if tt.browserReq != "" { req.Header.Set(c.HeaderXBrowserReq, tt.browserReq) } w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != tt.expectStatus { t.Errorf("Status = %d, want %d", w.Code, tt.expectStatus) } }) } } func TestIsBotUserAgent(t *testing.T) { tests := []struct { name string ua string expected bool }{ {"Browser - Chrome", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", false}, {"Browser - Firefox", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:93.0) Gecko/20100101 Firefox/93.0", false}, {"Browser - Safari", "Mozilla/5.0 (iPhone; CPU iPhone OS 15_0 like Mac OS X) AppleWebKit/605.1.15", false}, {"Bot - curl", "curl/7.68.0", true}, {"Bot - wget", "Wget/1.20.3 (linux-gnu)", true}, {"Bot - Postman", "PostmanRuntime/7.26.8", true}, {"Bot - Python requests", "python-requests/2.25.1", true}, {"Bot - Go HTTP client", "Go-http-client/1.1", true}, {"Bot - Insomnia", "insomnia/2021.5.3", true}, {"Bot - HTTPie", "HTTPie/2.4.0", true}, {"Bot - Scrapy", "Scrapy/2.5.0", true}, {"Bot - Generic bot", "Googlebot/2.1", true}, {"Bot - Generic crawler", "AhrefsBot/7.0", true}, {"Bot - Spider", "screaming frog spider/1.0", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := isBotUserAgent(tt.ua) if result != tt.expected { t.Errorf("isBotUserAgent(%q) = %v, want %v", tt.ua, result, tt.expected) } }) } } func TestGetRequestIP(t *testing.T) { tests := []struct { name string xForwardedFor string xRealIP string remoteAddr string expected string }{ { name: "X-Forwarded-For single IP", xForwardedFor: "192.168.1.1", expected: "192.168.1.1", }, { name: "X-Forwarded-For multiple IPs", xForwardedFor: "203.0.113.1, 70.41.3.18, 150.172.238.178", expected: "203.0.113.1", }, { name: "X-Real-IP", xRealIP: "10.0.0.5", expected: "10.0.0.5", }, { name: "RemoteAddr with port", remoteAddr: "192.168.1.100:54321", expected: "192.168.1.100", }, { name: "RemoteAddr without port", remoteAddr: "192.168.1.100", expected: "192.168.1.100", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) if tt.xForwardedFor != "" { req.Header.Set(c.HeaderXForwardedFor, tt.xForwardedFor) } if tt.xRealIP != "" { req.Header.Set(c.HeaderXRealIP, tt.xRealIP) } if tt.remoteAddr != "" { req.RemoteAddr = tt.remoteAddr } result := getRequestIP(req) if result != tt.expected { t.Errorf("getRequestIP() = %q, want %q", result, tt.expected) } }) } } func TestOriginChecker(t *testing.T) { handler := OriginChecker(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) tests := []struct { name string origin string referer string expectStatus int }{ { name: "Allowed - localhost", origin: "http://localhost:3000", expectStatus: http.StatusOK, }, { name: "Allowed - 127.0.0.1", origin: "http://127.0.0.1:8080", expectStatus: http.StatusOK, }, { name: "Allowed - configured domain", origin: "https://juan.andres.morenorub.io", expectStatus: http.StatusOK, }, { name: "Blocked - external origin", origin: "https://malicious-site.com", expectStatus: http.StatusForbidden, }, { name: "Blocked - external referer", referer: "https://external-site.org/page", expectStatus: http.StatusForbidden, }, { name: "Allowed - no origin/referer (direct)", expectStatus: http.StatusOK, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) if tt.origin != "" { req.Header.Set(c.HeaderOrigin, tt.origin) } if tt.referer != "" { req.Header.Set(c.HeaderReferer, tt.referer) } w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != tt.expectStatus { t.Errorf("Status = %d, want %d", w.Code, tt.expectStatus) } }) } } func TestIsAllowedOrigin(t *testing.T) { allowedOrigins := []string{"localhost", "127.0.0.1", "example.com"} tests := []struct { name string originURL string expected bool }{ {"Simple localhost", "localhost", true}, {"HTTP localhost", "http://localhost", true}, {"HTTPS localhost with port", "https://localhost:3000", true}, {"localhost with path", "http://localhost/path/to/page", true}, {"127.0.0.1", "http://127.0.0.1:8080", true}, {"example.com", "https://example.com/api", true}, {"External site", "https://external.com", false}, {"Similar domain", "https://example.com.evil.com", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := isAllowedOrigin(tt.originURL, allowedOrigins) if result != tt.expected { t.Errorf("isAllowedOrigin(%q) = %v, want %v", tt.originURL, result, tt.expected) } }) } } func TestRateLimiter(t *testing.T) { // Create a rate limiter: 3 requests per 100ms rl := NewRateLimiter(3, 100*time.Millisecond) handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // First 3 requests should succeed for i := 0; i < 3; i++ { req := httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "192.168.1.1:1234" w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("Request %d: Status = %d, want %d", i+1, w.Code, http.StatusOK) } } // 4th request should be rate limited req := httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "192.168.1.1:1234" w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusTooManyRequests { t.Errorf("4th request: Status = %d, want %d", w.Code, http.StatusTooManyRequests) } // Check Retry-After header if w.Header().Get(c.HeaderRetryAfter) == "" { t.Error("Retry-After header should be set") } // Different IP should succeed req = httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "192.168.1.2:1234" w = httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("Different IP: Status = %d, want %d", w.Code, http.StatusOK) } // Wait for window to expire time.Sleep(150 * time.Millisecond) // Original IP should be able to make requests again req = httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "192.168.1.1:1234" w = httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("After window expiry: Status = %d, want %d", w.Code, http.StatusOK) } } func TestRateLimiter_XForwardedFor(t *testing.T) { rl := NewRateLimiter(2, time.Minute) handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // Make 2 requests from same IP via X-Forwarded-For for i := 0; i < 2; i++ { req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(c.HeaderXForwardedFor, "10.0.0.1") w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("Request %d: Status = %d, want %d", i+1, w.Code, http.StatusOK) } } // 3rd request should be rate limited req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(c.HeaderXForwardedFor, "10.0.0.1") w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusTooManyRequests { t.Errorf("3rd request: Status = %d, want %d", w.Code, http.StatusTooManyRequests) } } func TestCacheControl(t *testing.T) { handler := CacheControl(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // Development mode _ = os.Unsetenv(c.EnvVarGOEnv) req := httptest.NewRequest(http.MethodGet, "/static/file.css", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Header().Get(c.HeaderCacheControl) != c.CachePublic1Hour { t.Errorf("Dev cache = %q, want %q", w.Header().Get(c.HeaderCacheControl), c.CachePublic1Hour) } // Production mode t.Setenv(c.EnvVarGOEnv, c.EnvProduction) w = httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Header().Get(c.HeaderCacheControl) != c.CachePublic1Day { t.Errorf("Prod cache = %q, want %q", w.Header().Get(c.HeaderCacheControl), c.CachePublic1Day) } } func TestDynamicCacheControl(t *testing.T) { handler := DynamicCacheControl(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // Development mode - no cache _ = os.Unsetenv(c.EnvVarGOEnv) req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Header().Get(c.HeaderCacheControl) != c.CacheNoStore { t.Errorf("Dev dynamic cache = %q, want %q", w.Header().Get(c.HeaderCacheControl), c.CacheNoStore) } // Production mode - short cache t.Setenv(c.EnvVarGOEnv, c.EnvProduction) w = httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Header().Get(c.HeaderCacheControl) != c.CachePublic5Min { t.Errorf("Prod dynamic cache = %q, want %q", w.Header().Get(c.HeaderCacheControl), c.CachePublic5Min) } }