524 lines
14 KiB
Go
524 lines
14 KiB
Go
|
|
package middleware
|
||
|
|
|
||
|
|
import (
|
||
|
|
"net/http"
|
||
|
|
"net/http/httptest"
|
||
|
|
"os"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
c "github.com/juanatsap/cv-site/internal/constants"
|
||
|
|
)
|
||
|
|
|
||
|
|
func TestSecurityHeaders(t *testing.T) {
|
||
|
|
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
}))
|
||
|
|
|
||
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
|
||
|
|
handler.ServeHTTP(w, req)
|
||
|
|
|
||
|
|
// Check required security headers
|
||
|
|
tests := []struct {
|
||
|
|
header string
|
||
|
|
expected string
|
||
|
|
}{
|
||
|
|
{c.HeaderXFrameOptions, c.FrameOptionsSameOrigin},
|
||
|
|
{c.HeaderXContentTypeOpts, c.NoSniff},
|
||
|
|
{c.HeaderXXSSProtection, c.XSSProtection},
|
||
|
|
{c.HeaderReferrerPolicy, c.ReferrerPolicy},
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, tt := range tests {
|
||
|
|
t.Run(tt.header, func(t *testing.T) {
|
||
|
|
value := w.Header().Get(tt.header)
|
||
|
|
if value != tt.expected {
|
||
|
|
t.Errorf("Header %s = %q, want %q", tt.header, value, tt.expected)
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
// Check CSP header exists
|
||
|
|
if w.Header().Get(c.HeaderCSP) == "" {
|
||
|
|
t.Error("CSP header should be set")
|
||
|
|
}
|
||
|
|
|
||
|
|
// Check Permissions-Policy exists
|
||
|
|
if w.Header().Get(c.HeaderPermissionsPolicy) == "" {
|
||
|
|
t.Error("Permissions-Policy header should be set")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestSecurityHeaders_HSTS(t *testing.T) {
|
||
|
|
// Test in production mode
|
||
|
|
os.Setenv(c.EnvVarGOEnv, c.EnvProduction)
|
||
|
|
defer os.Unsetenv(c.EnvVarGOEnv)
|
||
|
|
|
||
|
|
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
}))
|
||
|
|
|
||
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
|
||
|
|
handler.ServeHTTP(w, req)
|
||
|
|
|
||
|
|
// HSTS should be set in production
|
||
|
|
if w.Header().Get(c.HeaderHSTS) == "" {
|
||
|
|
t.Error("HSTS header should be set in production")
|
||
|
|
}
|
||
|
|
|
||
|
|
// Test in development mode
|
||
|
|
os.Setenv(c.EnvVarGOEnv, c.EnvDevelopment)
|
||
|
|
|
||
|
|
w = httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(w, req)
|
||
|
|
|
||
|
|
// HSTS should NOT be set in development
|
||
|
|
if w.Header().Get(c.HeaderHSTS) != "" {
|
||
|
|
t.Error("HSTS header should not be set in development")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestBrowserOnly(t *testing.T) {
|
||
|
|
handler := BrowserOnly(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
}))
|
||
|
|
|
||
|
|
tests := []struct {
|
||
|
|
name string
|
||
|
|
userAgent string
|
||
|
|
referer string
|
||
|
|
origin string
|
||
|
|
htmxHeader string
|
||
|
|
xhrHeader string
|
||
|
|
browserReq string
|
||
|
|
expectStatus int
|
||
|
|
}{
|
||
|
|
{
|
||
|
|
name: "Valid browser request with HTMX",
|
||
|
|
userAgent: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)",
|
||
|
|
referer: "https://example.com",
|
||
|
|
htmxHeader: "true",
|
||
|
|
expectStatus: http.StatusOK,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "Valid browser request with XHR",
|
||
|
|
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64)",
|
||
|
|
origin: "https://example.com",
|
||
|
|
xhrHeader: c.HeaderValueXMLHTTPRequest,
|
||
|
|
expectStatus: http.StatusOK,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "Valid browser request with custom header",
|
||
|
|
userAgent: "Mozilla/5.0 (Linux; Android 10)",
|
||
|
|
referer: "https://example.com",
|
||
|
|
browserReq: "true",
|
||
|
|
expectStatus: http.StatusOK,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "Blocked - curl user agent",
|
||
|
|
userAgent: "curl/7.68.0",
|
||
|
|
referer: "https://example.com",
|
||
|
|
htmxHeader: "true",
|
||
|
|
expectStatus: http.StatusForbidden,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "Blocked - wget user agent",
|
||
|
|
userAgent: "Wget/1.20.3",
|
||
|
|
referer: "https://example.com",
|
||
|
|
htmxHeader: "true",
|
||
|
|
expectStatus: http.StatusForbidden,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "Blocked - empty user agent",
|
||
|
|
userAgent: "",
|
||
|
|
referer: "https://example.com",
|
||
|
|
htmxHeader: "true",
|
||
|
|
expectStatus: http.StatusForbidden,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "Blocked - no referer/origin",
|
||
|
|
userAgent: "Mozilla/5.0",
|
||
|
|
referer: "",
|
||
|
|
origin: "",
|
||
|
|
htmxHeader: "true",
|
||
|
|
expectStatus: http.StatusForbidden,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "Blocked - no browser headers",
|
||
|
|
userAgent: "Mozilla/5.0",
|
||
|
|
referer: "https://example.com",
|
||
|
|
expectStatus: http.StatusForbidden,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "Blocked - Postman",
|
||
|
|
userAgent: "PostmanRuntime/7.26.8",
|
||
|
|
referer: "https://example.com",
|
||
|
|
htmxHeader: "true",
|
||
|
|
expectStatus: http.StatusForbidden,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "Blocked - Python requests",
|
||
|
|
userAgent: "python-requests/2.25.1",
|
||
|
|
referer: "https://example.com",
|
||
|
|
htmxHeader: "true",
|
||
|
|
expectStatus: http.StatusForbidden,
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, tt := range tests {
|
||
|
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
|
req := httptest.NewRequest(http.MethodPost, "/api/contact", nil)
|
||
|
|
if tt.userAgent != "" {
|
||
|
|
req.Header.Set(c.HeaderUserAgent, tt.userAgent)
|
||
|
|
}
|
||
|
|
if tt.referer != "" {
|
||
|
|
req.Header.Set(c.HeaderReferer, tt.referer)
|
||
|
|
}
|
||
|
|
if tt.origin != "" {
|
||
|
|
req.Header.Set(c.HeaderOrigin, tt.origin)
|
||
|
|
}
|
||
|
|
if tt.htmxHeader != "" {
|
||
|
|
req.Header.Set(c.HeaderHXRequest, tt.htmxHeader)
|
||
|
|
}
|
||
|
|
if tt.xhrHeader != "" {
|
||
|
|
req.Header.Set(c.HeaderXRequestedWith, tt.xhrHeader)
|
||
|
|
}
|
||
|
|
if tt.browserReq != "" {
|
||
|
|
req.Header.Set(c.HeaderXBrowserReq, tt.browserReq)
|
||
|
|
}
|
||
|
|
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(w, req)
|
||
|
|
|
||
|
|
if w.Code != tt.expectStatus {
|
||
|
|
t.Errorf("Status = %d, want %d", w.Code, tt.expectStatus)
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestIsBotUserAgent(t *testing.T) {
|
||
|
|
tests := []struct {
|
||
|
|
name string
|
||
|
|
ua string
|
||
|
|
expected bool
|
||
|
|
}{
|
||
|
|
{"Browser - Chrome", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", false},
|
||
|
|
{"Browser - Firefox", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:93.0) Gecko/20100101 Firefox/93.0", false},
|
||
|
|
{"Browser - Safari", "Mozilla/5.0 (iPhone; CPU iPhone OS 15_0 like Mac OS X) AppleWebKit/605.1.15", false},
|
||
|
|
{"Bot - curl", "curl/7.68.0", true},
|
||
|
|
{"Bot - wget", "Wget/1.20.3 (linux-gnu)", true},
|
||
|
|
{"Bot - Postman", "PostmanRuntime/7.26.8", true},
|
||
|
|
{"Bot - Python requests", "python-requests/2.25.1", true},
|
||
|
|
{"Bot - Go HTTP client", "Go-http-client/1.1", true},
|
||
|
|
{"Bot - Insomnia", "insomnia/2021.5.3", true},
|
||
|
|
{"Bot - HTTPie", "HTTPie/2.4.0", true},
|
||
|
|
{"Bot - Scrapy", "Scrapy/2.5.0", true},
|
||
|
|
{"Bot - Generic bot", "Googlebot/2.1", true},
|
||
|
|
{"Bot - Generic crawler", "AhrefsBot/7.0", true},
|
||
|
|
{"Bot - Spider", "screaming frog spider/1.0", true},
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, tt := range tests {
|
||
|
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
|
result := isBotUserAgent(tt.ua)
|
||
|
|
if result != tt.expected {
|
||
|
|
t.Errorf("isBotUserAgent(%q) = %v, want %v", tt.ua, result, tt.expected)
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestGetRequestIP(t *testing.T) {
|
||
|
|
tests := []struct {
|
||
|
|
name string
|
||
|
|
xForwardedFor string
|
||
|
|
xRealIP string
|
||
|
|
remoteAddr string
|
||
|
|
expected string
|
||
|
|
}{
|
||
|
|
{
|
||
|
|
name: "X-Forwarded-For single IP",
|
||
|
|
xForwardedFor: "192.168.1.1",
|
||
|
|
expected: "192.168.1.1",
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "X-Forwarded-For multiple IPs",
|
||
|
|
xForwardedFor: "203.0.113.1, 70.41.3.18, 150.172.238.178",
|
||
|
|
expected: "203.0.113.1",
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "X-Real-IP",
|
||
|
|
xRealIP: "10.0.0.5",
|
||
|
|
expected: "10.0.0.5",
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "RemoteAddr with port",
|
||
|
|
remoteAddr: "192.168.1.100:54321",
|
||
|
|
expected: "192.168.1.100",
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "RemoteAddr without port",
|
||
|
|
remoteAddr: "192.168.1.100",
|
||
|
|
expected: "192.168.1.100",
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, tt := range tests {
|
||
|
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
|
|
if tt.xForwardedFor != "" {
|
||
|
|
req.Header.Set(c.HeaderXForwardedFor, tt.xForwardedFor)
|
||
|
|
}
|
||
|
|
if tt.xRealIP != "" {
|
||
|
|
req.Header.Set(c.HeaderXRealIP, tt.xRealIP)
|
||
|
|
}
|
||
|
|
if tt.remoteAddr != "" {
|
||
|
|
req.RemoteAddr = tt.remoteAddr
|
||
|
|
}
|
||
|
|
|
||
|
|
result := getRequestIP(req)
|
||
|
|
if result != tt.expected {
|
||
|
|
t.Errorf("getRequestIP() = %q, want %q", result, tt.expected)
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestOriginChecker(t *testing.T) {
|
||
|
|
handler := OriginChecker(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
}))
|
||
|
|
|
||
|
|
tests := []struct {
|
||
|
|
name string
|
||
|
|
origin string
|
||
|
|
referer string
|
||
|
|
expectStatus int
|
||
|
|
}{
|
||
|
|
{
|
||
|
|
name: "Allowed - localhost",
|
||
|
|
origin: "http://localhost:3000",
|
||
|
|
expectStatus: http.StatusOK,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "Allowed - 127.0.0.1",
|
||
|
|
origin: "http://127.0.0.1:8080",
|
||
|
|
expectStatus: http.StatusOK,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "Allowed - configured domain",
|
||
|
|
origin: "https://juan.andres.morenorub.io",
|
||
|
|
expectStatus: http.StatusOK,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "Blocked - external origin",
|
||
|
|
origin: "https://malicious-site.com",
|
||
|
|
expectStatus: http.StatusForbidden,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "Blocked - external referer",
|
||
|
|
referer: "https://external-site.org/page",
|
||
|
|
expectStatus: http.StatusForbidden,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "Allowed - no origin/referer (direct)",
|
||
|
|
expectStatus: http.StatusOK,
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, tt := range tests {
|
||
|
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
|
|
if tt.origin != "" {
|
||
|
|
req.Header.Set(c.HeaderOrigin, tt.origin)
|
||
|
|
}
|
||
|
|
if tt.referer != "" {
|
||
|
|
req.Header.Set(c.HeaderReferer, tt.referer)
|
||
|
|
}
|
||
|
|
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(w, req)
|
||
|
|
|
||
|
|
if w.Code != tt.expectStatus {
|
||
|
|
t.Errorf("Status = %d, want %d", w.Code, tt.expectStatus)
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestIsAllowedOrigin(t *testing.T) {
|
||
|
|
allowedOrigins := []string{"localhost", "127.0.0.1", "example.com"}
|
||
|
|
|
||
|
|
tests := []struct {
|
||
|
|
name string
|
||
|
|
originURL string
|
||
|
|
expected bool
|
||
|
|
}{
|
||
|
|
{"Simple localhost", "localhost", true},
|
||
|
|
{"HTTP localhost", "http://localhost", true},
|
||
|
|
{"HTTPS localhost with port", "https://localhost:3000", true},
|
||
|
|
{"localhost with path", "http://localhost/path/to/page", true},
|
||
|
|
{"127.0.0.1", "http://127.0.0.1:8080", true},
|
||
|
|
{"example.com", "https://example.com/api", true},
|
||
|
|
{"External site", "https://external.com", false},
|
||
|
|
{"Similar domain", "https://example.com.evil.com", false},
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, tt := range tests {
|
||
|
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
|
result := isAllowedOrigin(tt.originURL, allowedOrigins)
|
||
|
|
if result != tt.expected {
|
||
|
|
t.Errorf("isAllowedOrigin(%q) = %v, want %v", tt.originURL, result, tt.expected)
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRateLimiter(t *testing.T) {
|
||
|
|
// Create a rate limiter: 3 requests per 100ms
|
||
|
|
rl := NewRateLimiter(3, 100*time.Millisecond)
|
||
|
|
|
||
|
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
}))
|
||
|
|
|
||
|
|
// First 3 requests should succeed
|
||
|
|
for i := 0; i < 3; i++ {
|
||
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
|
|
req.RemoteAddr = "192.168.1.1:1234"
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(w, req)
|
||
|
|
|
||
|
|
if w.Code != http.StatusOK {
|
||
|
|
t.Errorf("Request %d: Status = %d, want %d", i+1, w.Code, http.StatusOK)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// 4th request should be rate limited
|
||
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
|
|
req.RemoteAddr = "192.168.1.1:1234"
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(w, req)
|
||
|
|
|
||
|
|
if w.Code != http.StatusTooManyRequests {
|
||
|
|
t.Errorf("4th request: Status = %d, want %d", w.Code, http.StatusTooManyRequests)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Check Retry-After header
|
||
|
|
if w.Header().Get(c.HeaderRetryAfter) == "" {
|
||
|
|
t.Error("Retry-After header should be set")
|
||
|
|
}
|
||
|
|
|
||
|
|
// Different IP should succeed
|
||
|
|
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||
|
|
req.RemoteAddr = "192.168.1.2:1234"
|
||
|
|
w = httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(w, req)
|
||
|
|
|
||
|
|
if w.Code != http.StatusOK {
|
||
|
|
t.Errorf("Different IP: Status = %d, want %d", w.Code, http.StatusOK)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Wait for window to expire
|
||
|
|
time.Sleep(150 * time.Millisecond)
|
||
|
|
|
||
|
|
// Original IP should be able to make requests again
|
||
|
|
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||
|
|
req.RemoteAddr = "192.168.1.1:1234"
|
||
|
|
w = httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(w, req)
|
||
|
|
|
||
|
|
if w.Code != http.StatusOK {
|
||
|
|
t.Errorf("After window expiry: Status = %d, want %d", w.Code, http.StatusOK)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRateLimiter_XForwardedFor(t *testing.T) {
|
||
|
|
rl := NewRateLimiter(2, time.Minute)
|
||
|
|
|
||
|
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
}))
|
||
|
|
|
||
|
|
// Make 2 requests from same IP via X-Forwarded-For
|
||
|
|
for i := 0; i < 2; i++ {
|
||
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
|
|
req.Header.Set(c.HeaderXForwardedFor, "10.0.0.1")
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(w, req)
|
||
|
|
|
||
|
|
if w.Code != http.StatusOK {
|
||
|
|
t.Errorf("Request %d: Status = %d, want %d", i+1, w.Code, http.StatusOK)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// 3rd request should be rate limited
|
||
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
|
|
req.Header.Set(c.HeaderXForwardedFor, "10.0.0.1")
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(w, req)
|
||
|
|
|
||
|
|
if w.Code != http.StatusTooManyRequests {
|
||
|
|
t.Errorf("3rd request: Status = %d, want %d", w.Code, http.StatusTooManyRequests)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestCacheControl(t *testing.T) {
|
||
|
|
handler := CacheControl(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
}))
|
||
|
|
|
||
|
|
// Development mode
|
||
|
|
os.Unsetenv(c.EnvVarGOEnv)
|
||
|
|
req := httptest.NewRequest(http.MethodGet, "/static/file.css", nil)
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(w, req)
|
||
|
|
|
||
|
|
if w.Header().Get(c.HeaderCacheControl) != c.CachePublic1Hour {
|
||
|
|
t.Errorf("Dev cache = %q, want %q", w.Header().Get(c.HeaderCacheControl), c.CachePublic1Hour)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Production mode
|
||
|
|
os.Setenv(c.EnvVarGOEnv, c.EnvProduction)
|
||
|
|
defer os.Unsetenv(c.EnvVarGOEnv)
|
||
|
|
|
||
|
|
w = httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(w, req)
|
||
|
|
|
||
|
|
if w.Header().Get(c.HeaderCacheControl) != c.CachePublic1Day {
|
||
|
|
t.Errorf("Prod cache = %q, want %q", w.Header().Get(c.HeaderCacheControl), c.CachePublic1Day)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestDynamicCacheControl(t *testing.T) {
|
||
|
|
handler := DynamicCacheControl(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
}))
|
||
|
|
|
||
|
|
// Development mode - no cache
|
||
|
|
os.Unsetenv(c.EnvVarGOEnv)
|
||
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||
|
|
w := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(w, req)
|
||
|
|
|
||
|
|
if w.Header().Get(c.HeaderCacheControl) != c.CacheNoStore {
|
||
|
|
t.Errorf("Dev dynamic cache = %q, want %q", w.Header().Get(c.HeaderCacheControl), c.CacheNoStore)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Production mode - short cache
|
||
|
|
os.Setenv(c.EnvVarGOEnv, c.EnvProduction)
|
||
|
|
defer os.Unsetenv(c.EnvVarGOEnv)
|
||
|
|
|
||
|
|
w = httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(w, req)
|
||
|
|
|
||
|
|
if w.Header().Get(c.HeaderCacheControl) != c.CachePublic5Min {
|
||
|
|
t.Errorf("Prod dynamic cache = %q, want %q", w.Header().Get(c.HeaderCacheControl), c.CachePublic5Min)
|
||
|
|
}
|
||
|
|
}
|