package middleware import ( "net/http" "os" "strings" "sync" "time" ) // SecurityHeaders adds production-grade security headers to responses func SecurityHeaders(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Prevent clickjacking w.Header().Set("X-Frame-Options", "SAMEORIGIN") // Prevent MIME type sniffing w.Header().Set("X-Content-Type-Options", "nosniff") // XSS Protection (legacy but still useful for older browsers) w.Header().Set("X-XSS-Protection", "1; mode=block") // Referrer policy - strict privacy w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") // Permissions Policy - disable unnecessary features w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=(), payment=(), usb=(), "+ "magnetometer=(), gyroscope=(), accelerometer=()") // Content Security Policy (comprehensive) csp := "default-src 'self'; " + "script-src 'self' 'unsafe-inline' https://unpkg.com https://cdn.jsdelivr.net https://matomo.morenorub.io; " + "style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; " + "font-src 'self' https://fonts.gstatic.com; " + "img-src 'self' data: https:; " + "connect-src 'self' https://api.iconify.design https://matomo.morenorub.io; " + "frame-ancestors 'self'; " + "base-uri 'self'; " + "form-action 'self'" w.Header().Set("Content-Security-Policy", csp) // HSTS - only in production with HTTPS if os.Getenv("GO_ENV") == "production" { // 1 year max-age, include subdomains w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload") } next.ServeHTTP(w, r) }) } // OriginChecker restricts API access to requests from allowed origins only // Prevents external sites from hotlinking/accessing resource-intensive endpoints func OriginChecker(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Get allowed domains from environment (comma-separated) // Example: ALLOWED_ORIGINS="yourdomain.com,www.yourdomain.com" allowedOriginsEnv := os.Getenv("ALLOWED_ORIGINS") // If empty, add "juan.andres.morenorub.io", as it is the domain of the CV if allowedOriginsEnv == "" { allowedOriginsEnv = "juan.andres.morenorub.io" } // Default to localhost for development allowedOrigins := []string{"localhost", "127.0.0.1"} if allowedOriginsEnv != "" { customOrigins := strings.Split(allowedOriginsEnv, ",") for _, origin := range customOrigins { allowedOrigins = append(allowedOrigins, strings.TrimSpace(origin)) } } // Check Origin header (for CORS requests) origin := r.Header.Get("Origin") if origin != "" { if !isAllowedOrigin(origin, allowedOrigins) { http.Error(w, "Forbidden: External access not allowed", http.StatusForbidden) return } } // Check Referer header (for direct requests) referer := r.Header.Get("Referer") if referer != "" { if !isAllowedOrigin(referer, allowedOrigins) { http.Error(w, "Forbidden: External access not allowed", http.StatusForbidden) return } } // Allow if no Origin/Referer (direct browser access) // This allows your own site visitors to access the endpoint if origin == "" && referer == "" { // For production, you might want to be stricter here // For now, allow it (users can bookmark /export/pdf directly) if os.Getenv("GO_ENV") == "production" && r.URL.Path == "/export/pdf" { // In production, require at least a referer for PDF endpoint http.Error(w, "Forbidden: Direct access not allowed", http.StatusForbidden) return } } next.ServeHTTP(w, r) }) } // isAllowedOrigin checks if the origin/referer matches allowed domains func isAllowedOrigin(originURL string, allowedOrigins []string) bool { originURL = strings.TrimSpace(originURL) originURL = strings.TrimPrefix(originURL, "http://") originURL = strings.TrimPrefix(originURL, "https://") // Extract domain from URL (remove path) parts := strings.Split(originURL, "/") domain := parts[0] // Remove port if present domain = strings.Split(domain, ":")[0] for _, allowed := range allowedOrigins { if strings.EqualFold(domain, allowed) { return true } } return false } // rateLimitEntry tracks rate limiting per IP type rateLimitEntry struct { count int resetTime time.Time } // RateLimiter provides simple in-memory rate limiting type RateLimiter struct { mu sync.RWMutex clients map[string]*rateLimitEntry limit int // requests allowed window time.Duration // time window } // NewRateLimiter creates a new rate limiter func NewRateLimiter(limit int, window time.Duration) *RateLimiter { rl := &RateLimiter{ clients: make(map[string]*rateLimitEntry), limit: limit, window: window, } // Cleanup expired entries every minute go rl.cleanup() return rl } // Middleware returns rate limiting middleware func (rl *RateLimiter) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Get client IP (handle X-Forwarded-For for proxies) ip := r.Header.Get("X-Forwarded-For") if ip == "" { ip = r.Header.Get("X-Real-IP") } if ip == "" { ip = strings.Split(r.RemoteAddr, ":")[0] } if !rl.allow(ip) { w.Header().Set("Retry-After", "60") http.Error(w, "Rate limit exceeded. Please try again later.", http.StatusTooManyRequests) return } next.ServeHTTP(w, r) }) } // allow checks if the request is allowed based on rate limit func (rl *RateLimiter) allow(ip string) bool { rl.mu.Lock() defer rl.mu.Unlock() now := time.Now() entry, exists := rl.clients[ip] if !exists || now.After(entry.resetTime) { // New client or window expired rl.clients[ip] = &rateLimitEntry{ count: 1, resetTime: now.Add(rl.window), } return true } if entry.count >= rl.limit { return false } entry.count++ return true } // cleanup removes expired entries periodically func (rl *RateLimiter) cleanup() { ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() for range ticker.C { rl.mu.Lock() now := time.Now() for ip, entry := range rl.clients { if now.After(entry.resetTime) { delete(rl.clients, ip) } } rl.mu.Unlock() } } // CacheControl adds cache headers to static files // 1 hour in development, 1 day in production func CacheControl(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { maxAge := "3600" // 1 hour if os.Getenv("GO_ENV") == "production" { maxAge = "86400" // 1 day } w.Header().Set("Cache-Control", "public, max-age="+maxAge) next.ServeHTTP(w, r) }) }