feat: add origin validation and rate limiting for PDF endpoint
- Implemented origin checker middleware to prevent external sites from hotlinking the PDF generation endpoint - Added rate limiter (3 requests per minute per IP) to protect resource-intensive PDF operations - Configured allowed origins via ALLOWED_ORIGINS environment variable with localhost defaults for development
This commit is contained in:
@@ -3,6 +3,9 @@ package middleware
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SecurityHeaders adds production-grade security headers to responses
|
||||
@@ -47,3 +50,174 @@ func SecurityHeaders(next http.Handler) http.Handler {
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// OriginChecker restricts API access to requests from allowed origins only
|
||||
// Prevents external sites from hotlinking/accessing resource-intensive endpoints
|
||||
func OriginChecker(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get allowed domains from environment (comma-separated)
|
||||
// Example: ALLOWED_ORIGINS="yourdomain.com,www.yourdomain.com"
|
||||
allowedOriginsEnv := os.Getenv("ALLOWED_ORIGINS")
|
||||
|
||||
// If empty, add "juan.andres.morenorub.io", as it is the domain of the CV
|
||||
if allowedOriginsEnv == "" {
|
||||
allowedOriginsEnv = "juan.andres.morenorub.io"
|
||||
}
|
||||
|
||||
// Default to localhost for development
|
||||
allowedOrigins := []string{"localhost", "127.0.0.1"}
|
||||
|
||||
if allowedOriginsEnv != "" {
|
||||
customOrigins := strings.Split(allowedOriginsEnv, ",")
|
||||
for _, origin := range customOrigins {
|
||||
allowedOrigins = append(allowedOrigins, strings.TrimSpace(origin))
|
||||
}
|
||||
}
|
||||
|
||||
// Check Origin header (for CORS requests)
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin != "" {
|
||||
if !isAllowedOrigin(origin, allowedOrigins) {
|
||||
http.Error(w, "Forbidden: External access not allowed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check Referer header (for direct requests)
|
||||
referer := r.Header.Get("Referer")
|
||||
if referer != "" {
|
||||
if !isAllowedOrigin(referer, allowedOrigins) {
|
||||
http.Error(w, "Forbidden: External access not allowed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Allow if no Origin/Referer (direct browser access)
|
||||
// This allows your own site visitors to access the endpoint
|
||||
if origin == "" && referer == "" {
|
||||
// For production, you might want to be stricter here
|
||||
// For now, allow it (users can bookmark /export/pdf directly)
|
||||
if os.Getenv("GO_ENV") == "production" && r.URL.Path == "/export/pdf" {
|
||||
// In production, require at least a referer for PDF endpoint
|
||||
http.Error(w, "Forbidden: Direct access not allowed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// isAllowedOrigin checks if the origin/referer matches allowed domains
|
||||
func isAllowedOrigin(originURL string, allowedOrigins []string) bool {
|
||||
originURL = strings.TrimSpace(originURL)
|
||||
originURL = strings.TrimPrefix(originURL, "http://")
|
||||
originURL = strings.TrimPrefix(originURL, "https://")
|
||||
|
||||
// Extract domain from URL (remove path)
|
||||
parts := strings.Split(originURL, "/")
|
||||
domain := parts[0]
|
||||
|
||||
// Remove port if present
|
||||
domain = strings.Split(domain, ":")[0]
|
||||
|
||||
for _, allowed := range allowedOrigins {
|
||||
if strings.EqualFold(domain, allowed) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// rateLimitEntry tracks rate limiting per IP
|
||||
type rateLimitEntry struct {
|
||||
count int
|
||||
resetTime time.Time
|
||||
}
|
||||
|
||||
// RateLimiter provides simple in-memory rate limiting
|
||||
type RateLimiter struct {
|
||||
mu sync.RWMutex
|
||||
clients map[string]*rateLimitEntry
|
||||
limit int // requests allowed
|
||||
window time.Duration // time window
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter
|
||||
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
clients: make(map[string]*rateLimitEntry),
|
||||
limit: limit,
|
||||
window: window,
|
||||
}
|
||||
|
||||
// Cleanup expired entries every minute
|
||||
go rl.cleanup()
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
// Middleware returns rate limiting middleware
|
||||
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get client IP (handle X-Forwarded-For for proxies)
|
||||
ip := r.Header.Get("X-Forwarded-For")
|
||||
if ip == "" {
|
||||
ip = r.Header.Get("X-Real-IP")
|
||||
}
|
||||
if ip == "" {
|
||||
ip = strings.Split(r.RemoteAddr, ":")[0]
|
||||
}
|
||||
|
||||
if !rl.allow(ip) {
|
||||
w.Header().Set("Retry-After", "60")
|
||||
http.Error(w, "Rate limit exceeded. Please try again later.", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// allow checks if the request is allowed based on rate limit
|
||||
func (rl *RateLimiter) allow(ip string) bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
entry, exists := rl.clients[ip]
|
||||
if !exists || now.After(entry.resetTime) {
|
||||
// New client or window expired
|
||||
rl.clients[ip] = &rateLimitEntry{
|
||||
count: 1,
|
||||
resetTime: now.Add(rl.window),
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
if entry.count >= rl.limit {
|
||||
return false
|
||||
}
|
||||
|
||||
entry.count++
|
||||
return true
|
||||
}
|
||||
|
||||
// cleanup removes expired entries periodically
|
||||
func (rl *RateLimiter) cleanup() {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
rl.mu.Lock()
|
||||
now := time.Now()
|
||||
for ip, entry := range rl.clients {
|
||||
if now.After(entry.resetTime) {
|
||||
delete(rl.clients, ip)
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user