package middleware import ( "crypto/rand" "encoding/base64" "fmt" "log" "net/http" "sync" "time" c "github.com/juanatsap/cv-site/internal/constants" ) // csrfTokenEntry stores token with expiration type csrfTokenEntry struct { token string expiresAt time.Time } // CSRFProtection provides CSRF token generation and validation type CSRFProtection struct { mu sync.RWMutex tokens map[string]*csrfTokenEntry // map[token]entry } // NewCSRFProtection creates a new CSRF protection instance func NewCSRFProtection() *CSRFProtection { csrf := &CSRFProtection{ tokens: make(map[string]*csrfTokenEntry), } // Cleanup expired tokens every hour go csrf.cleanup() return csrf } // Middleware provides CSRF protection for state-changing operations // GET requests: Generate and set CSRF token // POST/PUT/DELETE: Validate CSRF token func (csrf *CSRFProtection) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Only validate on state-changing methods if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodDelete { if !csrf.validateToken(r) { log.Printf("SECURITY: CSRF validation failed from IP %s", getClientIP(r)) // Check if HTMX request isHTMX := r.Header.Get(c.HeaderHXRequest) != "" if isHTMX { w.Header().Set(c.HeaderContentType, c.ContentTypeHTML) w.WriteHeader(http.StatusForbidden) _, _ = w.Write([]byte(`

Security Error

Invalid security token. Please refresh the page and try again.

`)) } else { http.Error(w, "CSRF validation failed", http.StatusForbidden) } return } } next.ServeHTTP(w, r) }) } // generateToken creates a new CSRF token func (csrf *CSRFProtection) generateToken() (string, error) { bytes := make([]byte, c.CSRFTokenLength) if _, err := rand.Read(bytes); err != nil { return "", err } token := base64.URLEncoding.EncodeToString(bytes) // Store token with expiration csrf.mu.Lock() csrf.tokens[token] = &csrfTokenEntry{ token: token, expiresAt: time.Now().Add(c.CSRFTokenTTL), } csrf.mu.Unlock() return token, nil } // GetToken retrieves or generates a CSRF token for the request // This should be called when rendering forms func (csrf *CSRFProtection) GetToken(w http.ResponseWriter, r *http.Request) (string, error) { // Check if token exists in cookie cookie, err := r.Cookie(c.CSRFCookieName) if err == nil && cookie.Value != "" { // Validate existing token csrf.mu.RLock() entry, exists := csrf.tokens[cookie.Value] csrf.mu.RUnlock() if exists && time.Now().Before(entry.expiresAt) { // Token is valid, return it return cookie.Value, nil } } // Generate new token token, err := csrf.generateToken() if err != nil { return "", fmt.Errorf("failed to generate CSRF token: %w", err) } // Set cookie http.SetCookie(w, &http.Cookie{ Name: c.CSRFCookieName, Value: token, Path: "/", HttpOnly: true, Secure: r.TLS != nil, // Only set Secure flag if using HTTPS SameSite: http.SameSiteStrictMode, MaxAge: int(c.CSRFTokenTTL.Seconds()), }) return token, nil } // validateToken validates the CSRF token from the request func (csrf *CSRFProtection) validateToken(r *http.Request) bool { // Get token from form var formToken string // Try form value first if err := r.ParseForm(); err == nil { formToken = r.FormValue(c.CSRFFormField) } // If not in form, try header (for AJAX requests) if formToken == "" { formToken = r.Header.Get(c.HeaderXCSRFToken) } if formToken == "" { log.Printf("CSRF: No token in request") return false } // Get token from cookie cookie, err := r.Cookie(c.CSRFCookieName) if err != nil || cookie.Value == "" { log.Printf("CSRF: No token in cookie") return false } // Tokens must match if formToken != cookie.Value { log.Printf("CSRF: Token mismatch") return false } // Validate token exists and is not expired csrf.mu.RLock() entry, exists := csrf.tokens[formToken] csrf.mu.RUnlock() if !exists { log.Printf("CSRF: Token not found in store") return false } if time.Now().After(entry.expiresAt) { log.Printf("CSRF: Token expired") return false } return true } // cleanup removes expired tokens periodically func (csrf *CSRFProtection) cleanup() { ticker := time.NewTicker(c.CSRFCleanupPeriod) defer ticker.Stop() for range ticker.C { csrf.mu.Lock() now := time.Now() for token, entry := range csrf.tokens { if now.After(entry.expiresAt) { delete(csrf.tokens, token) } } csrf.mu.Unlock() } } // Note: getClientIP is defined in security_logger.go