diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..c30fc91 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,30 @@ +# CV Project Instructions + +## Required Reading + +1. **[PROJECT-MEMORY.md](./PROJECT-MEMORY.md)** - Development rules, critical bugs, patterns +2. **[doc/DECISIONS.md](./doc/DECISIONS.md)** - Architectural Decision Records (ADRs) + +## Quick Commands + +```bash +# Run all frontend tests (Playwright) +bun tests/run-all.mjs + +# Run Go tests with coverage +go test -cover ./internal/... + +# Start dev server +go run . +``` + +## Tech Stack + +- **Backend**: Go 1.21+ with standard library +- **Frontend**: HTMX + Hyperscript + Vanilla JS +- **Testing**: Playwright (frontend), Go test (backend) + +## Documentation Index + +- [doc/00-GO-DOCUMENTATION-INDEX.md](./doc/00-GO-DOCUMENTATION-INDEX.md) - Go system docs +- [doc/01-ARCHITECTURE.md](./doc/01-ARCHITECTURE.md) - System architecture diff --git a/internal/email/email_test.go b/internal/email/email_test.go new file mode 100644 index 0000000..e029ce3 --- /dev/null +++ b/internal/email/email_test.go @@ -0,0 +1,456 @@ +package email + +import ( + "strings" + "testing" + "time" +) + +func TestNewService(t *testing.T) { + config := &Config{ + SMTPHost: "smtp.example.com", + SMTPPort: "587", + SMTPUser: "user@example.com", + SMTPPassword: "password", + FromEmail: "from@example.com", + ToEmail: "to@example.com", + } + + service := NewService(config) + + if service == nil { + t.Fatal("NewService should return a non-nil service") + } + + if service.config != config { + t.Error("NewService should store the config") + } +} + +func TestContactFormData_Validate(t *testing.T) { + tests := []struct { + name string + data ContactFormData + wantError bool + errorMsg string + }{ + { + name: "Valid - all fields", + data: ContactFormData{ + Email: "test@example.com", + Name: "Test User", + Company: "Test Company", + Subject: "Test Subject", + Message: "This is a test message with enough characters.", + }, + wantError: false, + }, + { + name: "Valid - minimal fields", + data: ContactFormData{ + Email: "test@example.com", + Message: "This is a test message with enough characters.", + }, + wantError: false, + }, + { + name: "Invalid - missing email", + data: ContactFormData{ + Message: "This is a test message with enough characters.", + }, + wantError: true, + errorMsg: "email is required", + }, + { + name: "Invalid - missing message", + data: ContactFormData{ + Email: "test@example.com", + }, + wantError: true, + errorMsg: "message is required", + }, + { + name: "Invalid - bad email format (no @)", + data: ContactFormData{ + Email: "testexample.com", + Message: "This is a test message with enough characters.", + }, + wantError: true, + errorMsg: "invalid email format", + }, + { + name: "Invalid - bad email format (no .)", + data: ContactFormData{ + Email: "test@examplecom", + Message: "This is a test message with enough characters.", + }, + wantError: true, + errorMsg: "invalid email format", + }, + { + name: "Invalid - email with newline", + data: ContactFormData{ + Email: "test@example.com\r\nBcc: hacker@evil.com", + Message: "This is a test message with enough characters.", + }, + wantError: true, + errorMsg: "invalid email: contains prohibited characters", + }, + { + name: "Invalid - subject with newline", + data: ContactFormData{ + Email: "test@example.com", + Subject: "Test\r\nBcc: hacker@evil.com", + Message: "This is a test message with enough characters.", + }, + wantError: true, + errorMsg: "invalid subject: contains prohibited characters", + }, + { + name: "Invalid - email too long", + data: ContactFormData{ + Email: strings.Repeat("a", 250) + "@example.com", + Message: "This is a test message with enough characters.", + }, + wantError: true, + errorMsg: "email too long", + }, + { + name: "Invalid - name too long", + data: ContactFormData{ + Email: "test@example.com", + Name: strings.Repeat("a", 101), + Message: "This is a test message with enough characters.", + }, + wantError: true, + errorMsg: "name too long", + }, + { + name: "Invalid - company too long", + data: ContactFormData{ + Email: "test@example.com", + Company: strings.Repeat("a", 101), + Message: "This is a test message with enough characters.", + }, + wantError: true, + errorMsg: "company too long", + }, + { + name: "Invalid - subject too long", + data: ContactFormData{ + Email: "test@example.com", + Subject: strings.Repeat("a", 201), + Message: "This is a test message with enough characters.", + }, + wantError: true, + errorMsg: "subject too long", + }, + { + name: "Invalid - message too long", + data: ContactFormData{ + Email: "test@example.com", + Message: strings.Repeat("a", 5001), + }, + wantError: true, + errorMsg: "message too long", + }, + { + name: "Invalid - message too short", + data: ContactFormData{ + Email: "test@example.com", + Message: "Short", + }, + wantError: true, + errorMsg: "message too short", + }, + { + name: "Valid - trims whitespace", + data: ContactFormData{ + Email: " test@example.com ", + Name: " Test User ", + Message: " This is a test message with enough characters. ", + }, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.data.Validate() + if (err != nil) != tt.wantError { + t.Errorf("Validate() error = %v, wantError %v", err, tt.wantError) + } + if err != nil && tt.errorMsg != "" && !strings.Contains(err.Error(), tt.errorMsg) { + t.Errorf("Validate() error = %v, want error containing %q", err, tt.errorMsg) + } + }) + } +} + +func TestContainsNewlines(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + {"No newlines", "normal text", false}, + {"Carriage return", "text\rmore", true}, + {"Newline", "text\nmore", true}, + {"CRLF", "text\r\nmore", true}, + {"Empty", "", false}, + {"Spaces only", " ", false}, + {"Tab (allowed)", "text\ttab", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := containsNewlines(tt.input) + if result != tt.expected { + t.Errorf("containsNewlines(%q) = %v, want %v", tt.input, result, tt.expected) + } + }) + } +} + +func TestFormatMultipartMessage(t *testing.T) { + service := NewService(&Config{ + SMTPHost: "smtp.example.com", + SMTPPort: "587", + }) + + message := service.formatMultipartMessage( + "from@example.com", + "to@example.com", + "reply@example.com", + "Test Subject", + "HTML Body", + "Plain text body", + ) + + // Check required headers + if !strings.Contains(message, "From: from@example.com") { + t.Error("Message should contain From header") + } + if !strings.Contains(message, "To: to@example.com") { + t.Error("Message should contain To header") + } + if !strings.Contains(message, "Reply-To: reply@example.com") { + t.Error("Message should contain Reply-To header") + } + if !strings.Contains(message, "Subject: Test Subject") { + t.Error("Message should contain Subject header") + } + if !strings.Contains(message, "MIME-Version: 1.0") { + t.Error("Message should contain MIME-Version header") + } + if !strings.Contains(message, "multipart/alternative") { + t.Error("Message should be multipart/alternative") + } + if !strings.Contains(message, "text/plain") { + t.Error("Message should contain text/plain part") + } + if !strings.Contains(message, "text/html") { + t.Error("Message should contain text/html part") + } + if !strings.Contains(message, "Plain text body") { + t.Error("Message should contain plain text body") + } +} + +func TestFormatMultipartMessage_NoReplyTo(t *testing.T) { + service := NewService(&Config{ + SMTPHost: "smtp.example.com", + SMTPPort: "587", + }) + + message := service.formatMultipartMessage( + "from@example.com", + "to@example.com", + "", // No reply-to + "Test Subject", + "HTML", + "Plain text", + ) + + if strings.Contains(message, "Reply-To:") { + t.Error("Message should not contain Reply-To header when empty") + } +} + +func TestBuildEmailBody(t *testing.T) { + service := NewService(&Config{ + SMTPHost: "smtp.example.com", + SMTPPort: "587", + }) + + data := &ContactFormData{ + Email: "sender@example.com", + Name: "Test User", + Company: "Test Company", + Subject: "Test Subject", + Message: "This is a test message.", + IP: "192.168.1.1", + Time: time.Now(), + } + + htmlBody, textBody, err := service.buildEmailBody(data) + + if err != nil { + t.Errorf("buildEmailBody() error = %v", err) + } + + // Check HTML body contains data + if !strings.Contains(htmlBody, "Test User") { + t.Error("HTML body should contain name") + } + if !strings.Contains(htmlBody, "sender@example.com") { + t.Error("HTML body should contain email") + } + if !strings.Contains(htmlBody, "This is a test message") { + t.Error("HTML body should contain message") + } + + // Check text body contains data + if !strings.Contains(textBody, "Test User") { + t.Error("Text body should contain name") + } + if !strings.Contains(textBody, "sender@example.com") { + t.Error("Text body should contain email") + } +} + +func TestBuildEmailBody_EmptyName(t *testing.T) { + service := NewService(&Config{ + SMTPHost: "smtp.example.com", + SMTPPort: "587", + }) + + data := &ContactFormData{ + Email: "sender@example.com", + Name: "", // Empty name + Message: "This is a test message.", + Time: time.Now(), + } + + htmlBody, textBody, err := service.buildEmailBody(data) + + if err != nil { + t.Errorf("buildEmailBody() error = %v", err) + } + + // Should show "Not provided" for empty name + if !strings.Contains(htmlBody, "Not provided") { + t.Error("HTML body should show 'Not provided' for empty name") + } + if !strings.Contains(textBody, "Not provided") { + t.Error("Text body should show 'Not provided' for empty name") + } +} + +func TestSendContactForm_ValidationError(t *testing.T) { + service := NewService(&Config{ + SMTPHost: "smtp.example.com", + SMTPPort: "587", + SMTPUser: "user", + SMTPPassword: "pass", + ToEmail: "to@example.com", + }) + + // Invalid data - missing email + data := &ContactFormData{ + Message: "Test message that is long enough.", + } + + err := service.SendContactForm(data) + + if err == nil { + t.Error("SendContactForm should return error for invalid data") + } + + if !strings.Contains(err.Error(), "validation failed") { + t.Errorf("Error should mention validation failure: %v", err) + } +} + +func TestSendMultipartEmail_MissingConfig(t *testing.T) { + tests := []struct { + name string + config *Config + wantErr string + }{ + { + name: "Missing SMTP host", + config: &Config{SMTPPort: "587", SMTPUser: "user", SMTPPassword: "pass", ToEmail: "to@example.com"}, + wantErr: "SMTP configuration incomplete", + }, + { + name: "Missing SMTP port", + config: &Config{SMTPHost: "smtp.example.com", SMTPUser: "user", SMTPPassword: "pass", ToEmail: "to@example.com"}, + wantErr: "SMTP configuration incomplete", + }, + { + name: "Missing SMTP user", + config: &Config{SMTPHost: "smtp.example.com", SMTPPort: "587", SMTPPassword: "pass", ToEmail: "to@example.com"}, + wantErr: "SMTP credentials missing", + }, + { + name: "Missing SMTP password", + config: &Config{SMTPHost: "smtp.example.com", SMTPPort: "587", SMTPUser: "user", ToEmail: "to@example.com"}, + wantErr: "SMTP credentials missing", + }, + { + name: "Missing recipient email", + config: &Config{SMTPHost: "smtp.example.com", SMTPPort: "587", SMTPUser: "user", SMTPPassword: "pass"}, + wantErr: "recipient email not configured", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + service := NewService(tt.config) + err := service.sendMultipartEmail("Subject", "", "text", "reply@example.com") + + if err == nil { + t.Error("sendMultipartEmail should return error for incomplete config") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("Error = %v, want error containing %q", err, tt.wantErr) + } + }) + } +} + +func TestCVThemeCSS(t *testing.T) { + css := CVThemeCSS() + + if css == "" { + t.Error("CVThemeCSS should return non-empty CSS") + } + + // Check for some expected CSS properties + if !strings.Contains(css, "font-family") { + t.Error("CSS should contain font-family") + } + if !strings.Contains(css, "color") { + t.Error("CSS should contain color definitions") + } +} + +func TestContactEmailHTMLTemplate(t *testing.T) { + template := ContactEmailHTMLTemplate() + + if template == "" { + t.Error("ContactEmailHTMLTemplate should return non-empty template") + } + + // Check for template variables + if !strings.Contains(template, "{{.Name}}") { + t.Error("Template should contain {{.Name}}") + } + if !strings.Contains(template, "{{.Email}}") { + t.Error("Template should contain {{.Email}}") + } + if !strings.Contains(template, "{{.Message}}") { + t.Error("Template should contain {{.Message}}") + } +} diff --git a/internal/middleware/contact_rate_limit_test.go b/internal/middleware/contact_rate_limit_test.go new file mode 100644 index 0000000..23a00a9 --- /dev/null +++ b/internal/middleware/contact_rate_limit_test.go @@ -0,0 +1,242 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + c "github.com/juanatsap/cv-site/internal/constants" +) + +func TestNewContactRateLimiter(t *testing.T) { + rl := NewContactRateLimiter() + + if rl == nil { + t.Fatal("NewContactRateLimiter should return non-nil") + } + + if rl.clients == nil { + t.Error("clients map should be initialized") + } +} + +func TestContactRateLimiter_allow(t *testing.T) { + rl := &ContactRateLimiter{ + clients: make(map[string]*contactRateLimitEntry), + } + + ip := "192.168.1.1" + + // First request should be allowed + if !rl.allow(ip) { + t.Error("First request should be allowed") + } + + // Subsequent requests up to limit should be allowed + limit := c.RateLimitContactRequests + for i := 1; i < limit; i++ { + if !rl.allow(ip) { + t.Errorf("Request %d should be allowed (limit: %d)", i+1, limit) + } + } + + // Request exceeding limit should be blocked + if rl.allow(ip) { + t.Error("Request exceeding limit should be blocked") + } + + // Different IP should be allowed + if !rl.allow("192.168.1.2") { + t.Error("Different IP should be allowed") + } +} + +func TestContactRateLimiter_Middleware_Allowed(t *testing.T) { + rl := &ContactRateLimiter{ + clients: make(map[string]*contactRateLimitEntry), + } + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) + }) + + protected := rl.Middleware(handler) + + req := httptest.NewRequest(http.MethodPost, "/api/contact", nil) + req.RemoteAddr = "192.168.1.1:12345" + rec := httptest.NewRecorder() + + protected.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("Status = %d, want %d", rec.Code, http.StatusOK) + } +} + +func TestContactRateLimiter_Middleware_Blocked(t *testing.T) { + rl := &ContactRateLimiter{ + clients: make(map[string]*contactRateLimitEntry), + } + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + protected := rl.Middleware(handler) + + // Exhaust the rate limit + limit := c.RateLimitContactRequests + for i := 0; i < limit; i++ { + req := httptest.NewRequest(http.MethodPost, "/api/contact", nil) + req.RemoteAddr = "192.168.1.1:12345" + rec := httptest.NewRecorder() + protected.ServeHTTP(rec, req) + } + + // Next request should be blocked + req := httptest.NewRequest(http.MethodPost, "/api/contact", nil) + req.RemoteAddr = "192.168.1.1:12345" + rec := httptest.NewRecorder() + + protected.ServeHTTP(rec, req) + + if rec.Code != http.StatusTooManyRequests { + t.Errorf("Status = %d, want %d", rec.Code, http.StatusTooManyRequests) + } + + // Should have Retry-After header + if rec.Header().Get(c.HeaderRetryAfter) == "" { + t.Error("Response should have Retry-After header") + } +} + +func TestContactRateLimiter_Middleware_HTMX(t *testing.T) { + rl := &ContactRateLimiter{ + clients: make(map[string]*contactRateLimitEntry), + } + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + protected := rl.Middleware(handler) + + // Exhaust the rate limit + limit := c.RateLimitContactRequests + for i := 0; i < limit; i++ { + req := httptest.NewRequest(http.MethodPost, "/api/contact", nil) + req.RemoteAddr = "192.168.1.1:12345" + rec := httptest.NewRecorder() + protected.ServeHTTP(rec, req) + } + + // HTMX request should get HTML response + req := httptest.NewRequest(http.MethodPost, "/api/contact", nil) + req.RemoteAddr = "192.168.1.1:12345" + req.Header.Set(c.HeaderHXRequest, "true") + rec := httptest.NewRecorder() + + protected.ServeHTTP(rec, req) + + if rec.Code != http.StatusTooManyRequests { + t.Errorf("Status = %d, want %d", rec.Code, http.StatusTooManyRequests) + } + + body := rec.Body.String() + if !strings.Contains(body, "Too Many Requests") { + t.Error("HTMX response should contain HTML error message") + } + if !strings.Contains(body, "alert") { + t.Error("HTMX response should contain alert class") + } +} + +func TestContactRateLimiter_Middleware_XForwardedFor(t *testing.T) { + rl := &ContactRateLimiter{ + clients: make(map[string]*contactRateLimitEntry), + } + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + protected := rl.Middleware(handler) + + // Request with X-Forwarded-For + req := httptest.NewRequest(http.MethodPost, "/api/contact", nil) + req.Header.Set(c.HeaderXForwardedFor, "10.0.0.1, 192.168.1.1") + rec := httptest.NewRecorder() + + protected.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("Status = %d, want %d", rec.Code, http.StatusOK) + } + + // Check that first IP was used + rl.mu.RLock() + _, exists := rl.clients["10.0.0.1"] + rl.mu.RUnlock() + + if !exists { + t.Error("Should use first IP from X-Forwarded-For") + } +} + +func TestContactRateLimiter_Middleware_XRealIP(t *testing.T) { + rl := &ContactRateLimiter{ + clients: make(map[string]*contactRateLimitEntry), + } + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + protected := rl.Middleware(handler) + + // Request with X-Real-IP + req := httptest.NewRequest(http.MethodPost, "/api/contact", nil) + req.Header.Set(c.HeaderXRealIP, "10.0.0.2") + rec := httptest.NewRecorder() + + protected.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("Status = %d, want %d", rec.Code, http.StatusOK) + } + + // Check that X-Real-IP was used + rl.mu.RLock() + _, exists := rl.clients["10.0.0.2"] + rl.mu.RUnlock() + + if !exists { + t.Error("Should use X-Real-IP") + } +} + +func TestContactRateLimiter_GetStats(t *testing.T) { + rl := &ContactRateLimiter{ + clients: make(map[string]*contactRateLimitEntry), + } + + // Add some entries + rl.allow("192.168.1.1") + rl.allow("192.168.1.2") + + stats := rl.GetStats() + + if stats["total_clients"] != 2 { + t.Errorf("total_clients = %v, want 2", stats["total_clients"]) + } + + if stats["limit"] != c.RateLimitContactRequests { + t.Errorf("limit = %v, want %d", stats["limit"], c.RateLimitContactRequests) + } + + if stats["window"] == "" { + t.Error("window should not be empty") + } +} diff --git a/internal/middleware/csrf_test.go b/internal/middleware/csrf_test.go new file mode 100644 index 0000000..5c24530 --- /dev/null +++ b/internal/middleware/csrf_test.go @@ -0,0 +1,385 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + c "github.com/juanatsap/cv-site/internal/constants" +) + +func TestNewCSRFProtection(t *testing.T) { + csrf := NewCSRFProtection() + + if csrf == nil { + t.Fatal("NewCSRFProtection should return non-nil") + } + + if csrf.tokens == nil { + t.Error("tokens map should be initialized") + } +} + +func TestCSRFProtection_generateToken(t *testing.T) { + csrf := NewCSRFProtection() + + token, err := csrf.generateToken() + if err != nil { + t.Errorf("generateToken() error = %v", err) + } + + if token == "" { + t.Error("generateToken() should return non-empty token") + } + + // Token should be stored + csrf.mu.RLock() + entry, exists := csrf.tokens[token] + csrf.mu.RUnlock() + + if !exists { + t.Error("Generated token should be stored") + } + + if entry.token != token { + t.Errorf("Stored token = %q, want %q", entry.token, token) + } + + // Token should have expiration in the future + if !entry.expiresAt.After(time.Now()) { + t.Error("Token expiration should be in the future") + } +} + +func TestCSRFProtection_generateToken_Unique(t *testing.T) { + csrf := NewCSRFProtection() + + tokens := make(map[string]bool) + for i := 0; i < 100; i++ { + token, err := csrf.generateToken() + if err != nil { + t.Fatalf("generateToken() error = %v", err) + } + if tokens[token] { + t.Error("Tokens should be unique") + } + tokens[token] = true + } +} + +func TestCSRFProtection_GetToken_NewToken(t *testing.T) { + csrf := NewCSRFProtection() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + token, err := csrf.GetToken(rec, req) + if err != nil { + t.Errorf("GetToken() error = %v", err) + } + + if token == "" { + t.Error("GetToken() should return non-empty token") + } + + // Check cookie was set + cookies := rec.Result().Cookies() + var found bool + for _, cookie := range cookies { + if cookie.Name == c.CSRFCookieName { + found = true + if cookie.Value != token { + t.Errorf("Cookie value = %q, want %q", cookie.Value, token) + } + if !cookie.HttpOnly { + t.Error("Cookie should be HttpOnly") + } + } + } + if !found { + t.Error("CSRF cookie should be set") + } +} + +func TestCSRFProtection_GetToken_ExistingToken(t *testing.T) { + csrf := NewCSRFProtection() + + // First request to get token + req1 := httptest.NewRequest(http.MethodGet, "/", nil) + rec1 := httptest.NewRecorder() + token1, _ := csrf.GetToken(rec1, req1) + + // Second request with existing token cookie + req2 := httptest.NewRequest(http.MethodGet, "/", nil) + req2.AddCookie(&http.Cookie{ + Name: c.CSRFCookieName, + Value: token1, + }) + rec2 := httptest.NewRecorder() + + token2, err := csrf.GetToken(rec2, req2) + if err != nil { + t.Errorf("GetToken() error = %v", err) + } + + // Should return same token + if token2 != token1 { + t.Errorf("GetToken() = %q, want %q (same token)", token2, token1) + } +} + +func TestCSRFProtection_validateToken(t *testing.T) { + csrf := NewCSRFProtection() + + // Generate a token first + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + token, _ := csrf.GetToken(rec, req) + + t.Run("Valid token in form", func(t *testing.T) { + form := url.Values{} + form.Set(c.CSRFFormField, token) + + postReq := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode())) + postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + postReq.AddCookie(&http.Cookie{ + Name: c.CSRFCookieName, + Value: token, + }) + + if !csrf.validateToken(postReq) { + t.Error("validateToken should return true for valid token") + } + }) + + t.Run("Valid token in header", func(t *testing.T) { + postReq := httptest.NewRequest(http.MethodPost, "/", nil) + postReq.Header.Set(c.HeaderXCSRFToken, token) + postReq.AddCookie(&http.Cookie{ + Name: c.CSRFCookieName, + Value: token, + }) + + if !csrf.validateToken(postReq) { + t.Error("validateToken should return true for valid token in header") + } + }) + + t.Run("Missing form token", func(t *testing.T) { + postReq := httptest.NewRequest(http.MethodPost, "/", nil) + postReq.AddCookie(&http.Cookie{ + Name: c.CSRFCookieName, + Value: token, + }) + + if csrf.validateToken(postReq) { + t.Error("validateToken should return false for missing form token") + } + }) + + t.Run("Missing cookie", func(t *testing.T) { + form := url.Values{} + form.Set(c.CSRFFormField, token) + + postReq := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode())) + postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + if csrf.validateToken(postReq) { + t.Error("validateToken should return false for missing cookie") + } + }) + + t.Run("Token mismatch", func(t *testing.T) { + form := url.Values{} + form.Set(c.CSRFFormField, "wrong-token") + + postReq := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode())) + postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + postReq.AddCookie(&http.Cookie{ + Name: c.CSRFCookieName, + Value: token, + }) + + if csrf.validateToken(postReq) { + t.Error("validateToken should return false for mismatched tokens") + } + }) + + t.Run("Token not in store", func(t *testing.T) { + unknownToken := "unknown-token" + form := url.Values{} + form.Set(c.CSRFFormField, unknownToken) + + postReq := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode())) + postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + postReq.AddCookie(&http.Cookie{ + Name: c.CSRFCookieName, + Value: unknownToken, + }) + + if csrf.validateToken(postReq) { + t.Error("validateToken should return false for token not in store") + } + }) +} + +func TestCSRFProtection_Middleware(t *testing.T) { + csrf := NewCSRFProtection() + + // Generate a valid token + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + token, _ := csrf.GetToken(rec, req) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) + }) + + protected := csrf.Middleware(handler) + + t.Run("GET request passes through", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + protected.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("Status = %d, want %d", rec.Code, http.StatusOK) + } + }) + + t.Run("POST with valid token passes", func(t *testing.T) { + form := url.Values{} + form.Set(c.CSRFFormField, token) + + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.AddCookie(&http.Cookie{ + Name: c.CSRFCookieName, + Value: token, + }) + rec := httptest.NewRecorder() + + protected.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("Status = %d, want %d", rec.Code, http.StatusOK) + } + }) + + t.Run("POST without token fails", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/", nil) + rec := httptest.NewRecorder() + + protected.ServeHTTP(rec, req) + + if rec.Code != http.StatusForbidden { + t.Errorf("Status = %d, want %d", rec.Code, http.StatusForbidden) + } + }) + + t.Run("PUT without token fails", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPut, "/", nil) + rec := httptest.NewRecorder() + + protected.ServeHTTP(rec, req) + + if rec.Code != http.StatusForbidden { + t.Errorf("Status = %d, want %d", rec.Code, http.StatusForbidden) + } + }) + + t.Run("DELETE without token fails", func(t *testing.T) { + req := httptest.NewRequest(http.MethodDelete, "/", nil) + rec := httptest.NewRecorder() + + protected.ServeHTTP(rec, req) + + if rec.Code != http.StatusForbidden { + t.Errorf("Status = %d, want %d", rec.Code, http.StatusForbidden) + } + }) + + t.Run("HTMX request gets HTML error", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/", nil) + req.Header.Set(c.HeaderHXRequest, "true") + rec := httptest.NewRecorder() + + protected.ServeHTTP(rec, req) + + if rec.Code != http.StatusForbidden { + t.Errorf("Status = %d, want %d", rec.Code, http.StatusForbidden) + } + + body := rec.Body.String() + if !strings.Contains(body, "Security Error") { + t.Error("HTMX response should contain HTML error message") + } + if !strings.Contains(body, "alert") { + t.Error("HTMX response should contain alert class") + } + }) +} + +func TestCSRFTokenEntry_Expiration(t *testing.T) { + csrf := &CSRFProtection{ + tokens: make(map[string]*csrfTokenEntry), + } + + // Add expired token + expiredToken := "expired-token" + csrf.tokens[expiredToken] = &csrfTokenEntry{ + token: expiredToken, + expiresAt: time.Now().Add(-1 * time.Hour), // Expired + } + + // Validation should fail for expired token + form := url.Values{} + form.Set(c.CSRFFormField, expiredToken) + + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.AddCookie(&http.Cookie{ + Name: c.CSRFCookieName, + Value: expiredToken, + }) + + if csrf.validateToken(req) { + t.Error("validateToken should return false for expired token") + } +} + +func TestGetToken_ExpiredTokenInCookie(t *testing.T) { + csrf := &CSRFProtection{ + tokens: make(map[string]*csrfTokenEntry), + } + + // Add expired token to store + expiredToken := "expired-token" + csrf.tokens[expiredToken] = &csrfTokenEntry{ + token: expiredToken, + expiresAt: time.Now().Add(-1 * time.Hour), // Expired + } + + // Request with expired token cookie + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{ + Name: c.CSRFCookieName, + Value: expiredToken, + }) + rec := httptest.NewRecorder() + + newToken, err := csrf.GetToken(rec, req) + if err != nil { + t.Errorf("GetToken() error = %v", err) + } + + // Should generate new token + if newToken == expiredToken { + t.Error("GetToken() should generate new token when existing is expired") + } +} diff --git a/internal/middleware/recovery_test.go b/internal/middleware/recovery_test.go new file mode 100644 index 0000000..057cb50 --- /dev/null +++ b/internal/middleware/recovery_test.go @@ -0,0 +1,69 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestRecovery(t *testing.T) { + t.Run("Recovers from panic", func(t *testing.T) { + // Handler that panics + panicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("test panic") + }) + + handler := Recovery(panicHandler) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + + // Should not panic + handler.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("Status = %d, want %d", w.Code, http.StatusInternalServerError) + } + }) + + t.Run("Passes through normal requests", func(t *testing.T) { + normalHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) + }) + + handler := Recovery(normalHandler) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Status = %d, want %d", w.Code, http.StatusOK) + } + + if w.Body.String() != "OK" { + t.Errorf("Body = %q, want %q", w.Body.String(), "OK") + } + }) + + t.Run("Recovers from nil panic", func(t *testing.T) { + nilPanicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var p *int = nil + _ = *p // This will cause a nil pointer dereference + }) + + handler := Recovery(nilPanicHandler) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + + // Should not panic + handler.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("Status = %d, want %d", w.Code, http.StatusInternalServerError) + } + }) +}