test: add comprehensive Go test suite with ~75% coverage

New test files:
- config/config_test.go (100% coverage)
- constants/constants_test.go (100% coverage)
- httputil/response_test.go (100% coverage)
- validation/rules_test.go (91.9% coverage)
- middleware/logger_test.go, security_test.go, security_logger_test.go
- handlers/errors_test.go

Updated documentation:
- doc/27-GO-TESTING.md: Complete testing guide
- doc/00-GO-DOCUMENTATION-INDEX.md: Added testing section
- doc/01-ARCHITECTURE.md: Updated package structure
- doc/DECISIONS.md: Added ADR-004 caching decision
- PROJECT-MEMORY.md: Added Go testing section
This commit is contained in:
juanatsap
2025-12-06 17:51:20 +00:00
parent 6ed6c7780b
commit 69012bb1ae
16 changed files with 3900 additions and 865 deletions
+161
View File
@@ -0,0 +1,161 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestResponseWriter_WriteHeader(t *testing.T) {
rec := httptest.NewRecorder()
rw := &responseWriter{
ResponseWriter: rec,
status: http.StatusOK,
}
// First call should set status
rw.WriteHeader(http.StatusNotFound)
if rw.status != http.StatusNotFound {
t.Errorf("status = %d, want %d", rw.status, http.StatusNotFound)
}
if !rw.wroteHeader {
t.Error("wroteHeader should be true after WriteHeader")
}
// Second call should be ignored
rw.WriteHeader(http.StatusInternalServerError)
if rw.status != http.StatusNotFound {
t.Errorf("status = %d, want %d (should not change)", rw.status, http.StatusNotFound)
}
}
func TestResponseWriter_Write(t *testing.T) {
rec := httptest.NewRecorder()
rw := &responseWriter{
ResponseWriter: rec,
status: http.StatusOK,
}
// Write should set default status if not set
n, err := rw.Write([]byte("Hello"))
if err != nil {
t.Errorf("Write() error = %v", err)
}
if n != 5 {
t.Errorf("Write() n = %d, want 5", n)
}
if rw.written != 5 {
t.Errorf("written = %d, want 5", rw.written)
}
if !rw.wroteHeader {
t.Error("wroteHeader should be true after Write")
}
if rw.status != http.StatusOK {
t.Errorf("status = %d, want %d", rw.status, http.StatusOK)
}
// Write more
n, err = rw.Write([]byte(" World"))
if err != nil {
t.Errorf("Write() error = %v", err)
}
if n != 6 {
t.Errorf("Write() n = %d, want 6", n)
}
if rw.written != 11 {
t.Errorf("written = %d, want 11", rw.written)
}
}
func TestResponseWriter_WriteWithExplicitStatus(t *testing.T) {
rec := httptest.NewRecorder()
rw := &responseWriter{
ResponseWriter: rec,
status: http.StatusOK,
}
// Set status first
rw.WriteHeader(http.StatusCreated)
// Write should not change status
_, _ = rw.Write([]byte("Created"))
if rw.status != http.StatusCreated {
t.Errorf("status = %d, want %d", rw.status, http.StatusCreated)
}
}
func TestLogger(t *testing.T) {
t.Run("Logs successful request", func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
})
logged := Logger(handler)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rec := httptest.NewRecorder()
logged.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("Code = %d, want %d", rec.Code, http.StatusOK)
}
if rec.Body.String() != "OK" {
t.Errorf("Body = %q, want %q", rec.Body.String(), "OK")
}
})
t.Run("Logs error response", func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Not Found", http.StatusNotFound)
})
logged := Logger(handler)
req := httptest.NewRequest(http.MethodGet, "/notfound", nil)
rec := httptest.NewRecorder()
logged.ServeHTTP(rec, req)
if rec.Code != http.StatusNotFound {
t.Errorf("Code = %d, want %d", rec.Code, http.StatusNotFound)
}
})
t.Run("Handles POST request", func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusCreated)
})
logged := Logger(handler)
req := httptest.NewRequest(http.MethodPost, "/create", nil)
rec := httptest.NewRecorder()
logged.ServeHTTP(rec, req)
if rec.Code != http.StatusCreated {
t.Errorf("Code = %d, want %d", rec.Code, http.StatusCreated)
}
})
t.Run("Handles request with no explicit status", func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Just write body without explicit status
_, _ = w.Write([]byte("Implicit OK"))
})
logged := Logger(handler)
req := httptest.NewRequest(http.MethodGet, "/implicit", nil)
rec := httptest.NewRecorder()
logged.ServeHTTP(rec, req)
// Default status should be 200
if rec.Code != http.StatusOK {
t.Errorf("Code = %d, want %d", rec.Code, http.StatusOK)
}
})
}
+318
View File
@@ -0,0 +1,318 @@
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
c "github.com/juanatsap/cv-site/internal/constants"
)
func TestLogSecurityEvent(t *testing.T) {
// Just verify it doesn't panic
req := httptest.NewRequest(http.MethodPost, "/api/contact", nil)
req.Header.Set(c.HeaderUserAgent, "TestAgent/1.0")
req.RemoteAddr = "192.168.1.1:12345"
// Should not panic
LogSecurityEvent(EventContactFormSent, req, "test details")
LogSecurityEvent(EventBlocked, req, "blocked test")
LogSecurityEvent(EventCSRFViolation, req, "csrf test")
}
func TestGetSeverity(t *testing.T) {
tests := []struct {
eventType string
expected string
}{
{EventBlocked, SeverityHigh},
{EventCSRFViolation, SeverityHigh},
{EventOriginViolation, SeverityHigh},
{EventRateLimitExceeded, SeverityMedium},
{EventValidationFailed, SeverityMedium},
{EventSuspiciousUserAgent, SeverityMedium},
{EventContactFormFailed, SeverityMedium},
{EventPDFGenerationFailed, SeverityMedium},
{EventEmailSendFailed, SeverityMedium},
{EventBotDetected, SeverityLow},
{EventContactFormSent, SeverityInfo},
{EventPDFGenerated, SeverityInfo},
{"UNKNOWN_EVENT", SeverityLow},
}
for _, tt := range tests {
t.Run(tt.eventType, func(t *testing.T) {
result := getSeverity(tt.eventType)
if result != tt.expected {
t.Errorf("getSeverity(%q) = %q, want %q", tt.eventType, result, tt.expected)
}
})
}
}
func TestSecurityLogger(t *testing.T) {
t.Run("Normal request passes through", func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
})
logged := SecurityLogger(handler)
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
logged.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("Status = %d, want %d", rec.Code, http.StatusOK)
}
})
t.Run("Logs security-relevant paths", func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
logged := SecurityLogger(handler)
// Test security-relevant path
req := httptest.NewRequest(http.MethodPost, "/api/contact", nil)
req.Header.Set(c.HeaderUserAgent, "Mozilla/5.0")
rec := httptest.NewRecorder()
logged.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("Status = %d, want %d", rec.Code, http.StatusOK)
}
})
t.Run("Logs error responses", func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Forbidden", http.StatusForbidden)
})
logged := SecurityLogger(handler)
req := httptest.NewRequest(http.MethodGet, "/secret", nil)
rec := httptest.NewRecorder()
logged.ServeHTTP(rec, req)
if rec.Code != http.StatusForbidden {
t.Errorf("Status = %d, want %d", rec.Code, http.StatusForbidden)
}
})
t.Run("Logs rate limit responses", func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
})
logged := SecurityLogger(handler)
req := httptest.NewRequest(http.MethodPost, "/api/contact", nil)
rec := httptest.NewRecorder()
logged.ServeHTTP(rec, req)
if rec.Code != http.StatusTooManyRequests {
t.Errorf("Status = %d, want %d", rec.Code, http.StatusTooManyRequests)
}
})
}
func TestIsSecurityRelevantPath(t *testing.T) {
tests := []struct {
path string
expected bool
}{
{"/api/contact", true},
{"/api/contact/send", true},
{"/export/pdf", true},
{"/export/pdf/cv", true},
{"/toggle/theme", true},
{"/toggle/length", true},
{"/switch-language", true},
{"/", false},
{"/cv", false},
{"/health", false},
{"/static/css/style.css", false},
{"/api/other", false},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
result := isSecurityRelevantPath(tt.path)
if result != tt.expected {
t.Errorf("isSecurityRelevantPath(%q) = %v, want %v", tt.path, result, tt.expected)
}
})
}
}
// Test preferences helper functions
func TestPreferencesHelperFunctions(t *testing.T) {
// Create a request with preferences in context
prefs := &Preferences{
CVLength: c.CVLengthLong,
CVIcons: c.CVIconsShow,
CVLanguage: c.LangSpanish,
CVTheme: c.CVThemeClean,
ColorTheme: c.ColorThemeDark,
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx := context.WithValue(req.Context(), PreferencesKey, prefs)
reqWithPrefs := req.WithContext(ctx)
t.Run("GetLanguage", func(t *testing.T) {
result := GetLanguage(reqWithPrefs)
if result != c.LangSpanish {
t.Errorf("GetLanguage() = %q, want %q", result, c.LangSpanish)
}
})
t.Run("GetCVLength", func(t *testing.T) {
result := GetCVLength(reqWithPrefs)
if result != c.CVLengthLong {
t.Errorf("GetCVLength() = %q, want %q", result, c.CVLengthLong)
}
})
t.Run("GetCVIcons", func(t *testing.T) {
result := GetCVIcons(reqWithPrefs)
if result != c.CVIconsShow {
t.Errorf("GetCVIcons() = %q, want %q", result, c.CVIconsShow)
}
})
t.Run("GetCVTheme", func(t *testing.T) {
result := GetCVTheme(reqWithPrefs)
if result != c.CVThemeClean {
t.Errorf("GetCVTheme() = %q, want %q", result, c.CVThemeClean)
}
})
t.Run("GetColorTheme", func(t *testing.T) {
result := GetColorTheme(reqWithPrefs)
if result != c.ColorThemeDark {
t.Errorf("GetColorTheme() = %q, want %q", result, c.ColorThemeDark)
}
})
t.Run("IsLongCV", func(t *testing.T) {
if !IsLongCV(reqWithPrefs) {
t.Error("IsLongCV() should return true")
}
})
t.Run("IsShortCV", func(t *testing.T) {
if IsShortCV(reqWithPrefs) {
t.Error("IsShortCV() should return false for long CV")
}
})
t.Run("ShowIcons", func(t *testing.T) {
if !ShowIcons(reqWithPrefs) {
t.Error("ShowIcons() should return true")
}
})
t.Run("HideIcons", func(t *testing.T) {
if HideIcons(reqWithPrefs) {
t.Error("HideIcons() should return false when icons shown")
}
})
t.Run("IsCleanTheme", func(t *testing.T) {
if !IsCleanTheme(reqWithPrefs) {
t.Error("IsCleanTheme() should return true")
}
})
t.Run("IsDefaultTheme", func(t *testing.T) {
if IsDefaultTheme(reqWithPrefs) {
t.Error("IsDefaultTheme() should return false for clean theme")
}
})
t.Run("IsDarkMode", func(t *testing.T) {
if !IsDarkMode(reqWithPrefs) {
t.Error("IsDarkMode() should return true")
}
})
t.Run("IsLightMode", func(t *testing.T) {
if IsLightMode(reqWithPrefs) {
t.Error("IsLightMode() should return false for dark mode")
}
})
}
func TestPreferencesHelperFunctions_Defaults(t *testing.T) {
// Request without preferences should return defaults
req := httptest.NewRequest(http.MethodGet, "/", nil)
t.Run("GetLanguage default", func(t *testing.T) {
result := GetLanguage(req)
if result != c.LangEnglish {
t.Errorf("GetLanguage() = %q, want %q", result, c.LangEnglish)
}
})
t.Run("GetCVLength default", func(t *testing.T) {
result := GetCVLength(req)
if result != c.CVLengthShort {
t.Errorf("GetCVLength() = %q, want %q", result, c.CVLengthShort)
}
})
t.Run("IsShortCV default", func(t *testing.T) {
if !IsShortCV(req) {
t.Error("IsShortCV() should return true by default")
}
})
t.Run("ShowIcons default", func(t *testing.T) {
if !ShowIcons(req) {
t.Error("ShowIcons() should return true by default")
}
})
t.Run("IsDefaultTheme default", func(t *testing.T) {
if !IsDefaultTheme(req) {
t.Error("IsDefaultTheme() should return true by default")
}
})
t.Run("IsLightMode default", func(t *testing.T) {
if !IsLightMode(req) {
t.Error("IsLightMode() should return true by default")
}
})
}
func TestPreferencesHelperFunctions_HideIcons(t *testing.T) {
prefs := &Preferences{
CVLength: c.CVLengthShort,
CVIcons: c.CVIconsHide,
CVLanguage: c.LangEnglish,
CVTheme: c.CVThemeDefault,
ColorTheme: c.ColorThemeLight,
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx := context.WithValue(req.Context(), PreferencesKey, prefs)
reqWithPrefs := req.WithContext(ctx)
if !HideIcons(reqWithPrefs) {
t.Error("HideIcons() should return true when icons hidden")
}
if ShowIcons(reqWithPrefs) {
t.Error("ShowIcons() should return false when icons hidden")
}
}
+523
View File
@@ -0,0 +1,523 @@
package middleware
import (
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
c "github.com/juanatsap/cv-site/internal/constants"
)
func TestSecurityHeaders(t *testing.T) {
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
// Check required security headers
tests := []struct {
header string
expected string
}{
{c.HeaderXFrameOptions, c.FrameOptionsSameOrigin},
{c.HeaderXContentTypeOpts, c.NoSniff},
{c.HeaderXXSSProtection, c.XSSProtection},
{c.HeaderReferrerPolicy, c.ReferrerPolicy},
}
for _, tt := range tests {
t.Run(tt.header, func(t *testing.T) {
value := w.Header().Get(tt.header)
if value != tt.expected {
t.Errorf("Header %s = %q, want %q", tt.header, value, tt.expected)
}
})
}
// Check CSP header exists
if w.Header().Get(c.HeaderCSP) == "" {
t.Error("CSP header should be set")
}
// Check Permissions-Policy exists
if w.Header().Get(c.HeaderPermissionsPolicy) == "" {
t.Error("Permissions-Policy header should be set")
}
}
func TestSecurityHeaders_HSTS(t *testing.T) {
// Test in production mode
os.Setenv(c.EnvVarGOEnv, c.EnvProduction)
defer os.Unsetenv(c.EnvVarGOEnv)
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
// HSTS should be set in production
if w.Header().Get(c.HeaderHSTS) == "" {
t.Error("HSTS header should be set in production")
}
// Test in development mode
os.Setenv(c.EnvVarGOEnv, c.EnvDevelopment)
w = httptest.NewRecorder()
handler.ServeHTTP(w, req)
// HSTS should NOT be set in development
if w.Header().Get(c.HeaderHSTS) != "" {
t.Error("HSTS header should not be set in development")
}
}
func TestBrowserOnly(t *testing.T) {
handler := BrowserOnly(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
tests := []struct {
name string
userAgent string
referer string
origin string
htmxHeader string
xhrHeader string
browserReq string
expectStatus int
}{
{
name: "Valid browser request with HTMX",
userAgent: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)",
referer: "https://example.com",
htmxHeader: "true",
expectStatus: http.StatusOK,
},
{
name: "Valid browser request with XHR",
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64)",
origin: "https://example.com",
xhrHeader: c.HeaderValueXMLHTTPRequest,
expectStatus: http.StatusOK,
},
{
name: "Valid browser request with custom header",
userAgent: "Mozilla/5.0 (Linux; Android 10)",
referer: "https://example.com",
browserReq: "true",
expectStatus: http.StatusOK,
},
{
name: "Blocked - curl user agent",
userAgent: "curl/7.68.0",
referer: "https://example.com",
htmxHeader: "true",
expectStatus: http.StatusForbidden,
},
{
name: "Blocked - wget user agent",
userAgent: "Wget/1.20.3",
referer: "https://example.com",
htmxHeader: "true",
expectStatus: http.StatusForbidden,
},
{
name: "Blocked - empty user agent",
userAgent: "",
referer: "https://example.com",
htmxHeader: "true",
expectStatus: http.StatusForbidden,
},
{
name: "Blocked - no referer/origin",
userAgent: "Mozilla/5.0",
referer: "",
origin: "",
htmxHeader: "true",
expectStatus: http.StatusForbidden,
},
{
name: "Blocked - no browser headers",
userAgent: "Mozilla/5.0",
referer: "https://example.com",
expectStatus: http.StatusForbidden,
},
{
name: "Blocked - Postman",
userAgent: "PostmanRuntime/7.26.8",
referer: "https://example.com",
htmxHeader: "true",
expectStatus: http.StatusForbidden,
},
{
name: "Blocked - Python requests",
userAgent: "python-requests/2.25.1",
referer: "https://example.com",
htmxHeader: "true",
expectStatus: http.StatusForbidden,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/api/contact", nil)
if tt.userAgent != "" {
req.Header.Set(c.HeaderUserAgent, tt.userAgent)
}
if tt.referer != "" {
req.Header.Set(c.HeaderReferer, tt.referer)
}
if tt.origin != "" {
req.Header.Set(c.HeaderOrigin, tt.origin)
}
if tt.htmxHeader != "" {
req.Header.Set(c.HeaderHXRequest, tt.htmxHeader)
}
if tt.xhrHeader != "" {
req.Header.Set(c.HeaderXRequestedWith, tt.xhrHeader)
}
if tt.browserReq != "" {
req.Header.Set(c.HeaderXBrowserReq, tt.browserReq)
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != tt.expectStatus {
t.Errorf("Status = %d, want %d", w.Code, tt.expectStatus)
}
})
}
}
func TestIsBotUserAgent(t *testing.T) {
tests := []struct {
name string
ua string
expected bool
}{
{"Browser - Chrome", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", false},
{"Browser - Firefox", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:93.0) Gecko/20100101 Firefox/93.0", false},
{"Browser - Safari", "Mozilla/5.0 (iPhone; CPU iPhone OS 15_0 like Mac OS X) AppleWebKit/605.1.15", false},
{"Bot - curl", "curl/7.68.0", true},
{"Bot - wget", "Wget/1.20.3 (linux-gnu)", true},
{"Bot - Postman", "PostmanRuntime/7.26.8", true},
{"Bot - Python requests", "python-requests/2.25.1", true},
{"Bot - Go HTTP client", "Go-http-client/1.1", true},
{"Bot - Insomnia", "insomnia/2021.5.3", true},
{"Bot - HTTPie", "HTTPie/2.4.0", true},
{"Bot - Scrapy", "Scrapy/2.5.0", true},
{"Bot - Generic bot", "Googlebot/2.1", true},
{"Bot - Generic crawler", "AhrefsBot/7.0", true},
{"Bot - Spider", "screaming frog spider/1.0", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isBotUserAgent(tt.ua)
if result != tt.expected {
t.Errorf("isBotUserAgent(%q) = %v, want %v", tt.ua, result, tt.expected)
}
})
}
}
func TestGetRequestIP(t *testing.T) {
tests := []struct {
name string
xForwardedFor string
xRealIP string
remoteAddr string
expected string
}{
{
name: "X-Forwarded-For single IP",
xForwardedFor: "192.168.1.1",
expected: "192.168.1.1",
},
{
name: "X-Forwarded-For multiple IPs",
xForwardedFor: "203.0.113.1, 70.41.3.18, 150.172.238.178",
expected: "203.0.113.1",
},
{
name: "X-Real-IP",
xRealIP: "10.0.0.5",
expected: "10.0.0.5",
},
{
name: "RemoteAddr with port",
remoteAddr: "192.168.1.100:54321",
expected: "192.168.1.100",
},
{
name: "RemoteAddr without port",
remoteAddr: "192.168.1.100",
expected: "192.168.1.100",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tt.xForwardedFor != "" {
req.Header.Set(c.HeaderXForwardedFor, tt.xForwardedFor)
}
if tt.xRealIP != "" {
req.Header.Set(c.HeaderXRealIP, tt.xRealIP)
}
if tt.remoteAddr != "" {
req.RemoteAddr = tt.remoteAddr
}
result := getRequestIP(req)
if result != tt.expected {
t.Errorf("getRequestIP() = %q, want %q", result, tt.expected)
}
})
}
}
func TestOriginChecker(t *testing.T) {
handler := OriginChecker(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
tests := []struct {
name string
origin string
referer string
expectStatus int
}{
{
name: "Allowed - localhost",
origin: "http://localhost:3000",
expectStatus: http.StatusOK,
},
{
name: "Allowed - 127.0.0.1",
origin: "http://127.0.0.1:8080",
expectStatus: http.StatusOK,
},
{
name: "Allowed - configured domain",
origin: "https://juan.andres.morenorub.io",
expectStatus: http.StatusOK,
},
{
name: "Blocked - external origin",
origin: "https://malicious-site.com",
expectStatus: http.StatusForbidden,
},
{
name: "Blocked - external referer",
referer: "https://external-site.org/page",
expectStatus: http.StatusForbidden,
},
{
name: "Allowed - no origin/referer (direct)",
expectStatus: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tt.origin != "" {
req.Header.Set(c.HeaderOrigin, tt.origin)
}
if tt.referer != "" {
req.Header.Set(c.HeaderReferer, tt.referer)
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != tt.expectStatus {
t.Errorf("Status = %d, want %d", w.Code, tt.expectStatus)
}
})
}
}
func TestIsAllowedOrigin(t *testing.T) {
allowedOrigins := []string{"localhost", "127.0.0.1", "example.com"}
tests := []struct {
name string
originURL string
expected bool
}{
{"Simple localhost", "localhost", true},
{"HTTP localhost", "http://localhost", true},
{"HTTPS localhost with port", "https://localhost:3000", true},
{"localhost with path", "http://localhost/path/to/page", true},
{"127.0.0.1", "http://127.0.0.1:8080", true},
{"example.com", "https://example.com/api", true},
{"External site", "https://external.com", false},
{"Similar domain", "https://example.com.evil.com", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isAllowedOrigin(tt.originURL, allowedOrigins)
if result != tt.expected {
t.Errorf("isAllowedOrigin(%q) = %v, want %v", tt.originURL, result, tt.expected)
}
})
}
}
func TestRateLimiter(t *testing.T) {
// Create a rate limiter: 3 requests per 100ms
rl := NewRateLimiter(3, 100*time.Millisecond)
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// First 3 requests should succeed
for i := 0; i < 3; i++ {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "192.168.1.1:1234"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Request %d: Status = %d, want %d", i+1, w.Code, http.StatusOK)
}
}
// 4th request should be rate limited
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "192.168.1.1:1234"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("4th request: Status = %d, want %d", w.Code, http.StatusTooManyRequests)
}
// Check Retry-After header
if w.Header().Get(c.HeaderRetryAfter) == "" {
t.Error("Retry-After header should be set")
}
// Different IP should succeed
req = httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "192.168.1.2:1234"
w = httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Different IP: Status = %d, want %d", w.Code, http.StatusOK)
}
// Wait for window to expire
time.Sleep(150 * time.Millisecond)
// Original IP should be able to make requests again
req = httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "192.168.1.1:1234"
w = httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("After window expiry: Status = %d, want %d", w.Code, http.StatusOK)
}
}
func TestRateLimiter_XForwardedFor(t *testing.T) {
rl := NewRateLimiter(2, time.Minute)
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Make 2 requests from same IP via X-Forwarded-For
for i := 0; i < 2; i++ {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(c.HeaderXForwardedFor, "10.0.0.1")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Request %d: Status = %d, want %d", i+1, w.Code, http.StatusOK)
}
}
// 3rd request should be rate limited
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(c.HeaderXForwardedFor, "10.0.0.1")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("3rd request: Status = %d, want %d", w.Code, http.StatusTooManyRequests)
}
}
func TestCacheControl(t *testing.T) {
handler := CacheControl(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Development mode
os.Unsetenv(c.EnvVarGOEnv)
req := httptest.NewRequest(http.MethodGet, "/static/file.css", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Header().Get(c.HeaderCacheControl) != c.CachePublic1Hour {
t.Errorf("Dev cache = %q, want %q", w.Header().Get(c.HeaderCacheControl), c.CachePublic1Hour)
}
// Production mode
os.Setenv(c.EnvVarGOEnv, c.EnvProduction)
defer os.Unsetenv(c.EnvVarGOEnv)
w = httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Header().Get(c.HeaderCacheControl) != c.CachePublic1Day {
t.Errorf("Prod cache = %q, want %q", w.Header().Get(c.HeaderCacheControl), c.CachePublic1Day)
}
}
func TestDynamicCacheControl(t *testing.T) {
handler := DynamicCacheControl(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Development mode - no cache
os.Unsetenv(c.EnvVarGOEnv)
req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Header().Get(c.HeaderCacheControl) != c.CacheNoStore {
t.Errorf("Dev dynamic cache = %q, want %q", w.Header().Get(c.HeaderCacheControl), c.CacheNoStore)
}
// Production mode - short cache
os.Setenv(c.EnvVarGOEnv, c.EnvProduction)
defer os.Unsetenv(c.EnvVarGOEnv)
w = httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Header().Get(c.HeaderCacheControl) != c.CachePublic5Min {
t.Errorf("Prod dynamic cache = %q, want %q", w.Header().Get(c.HeaderCacheControl), c.CachePublic5Min)
}
}