docs: add CLAUDE.md pointing to key project documentation
Links to PROJECT-MEMORY.md and DECISIONS.md for development rules and architectural decisions, plus quick commands and doc index.
This commit is contained in:
@@ -0,0 +1,30 @@
|
|||||||
|
# CV Project Instructions
|
||||||
|
|
||||||
|
## Required Reading
|
||||||
|
|
||||||
|
1. **[PROJECT-MEMORY.md](./PROJECT-MEMORY.md)** - Development rules, critical bugs, patterns
|
||||||
|
2. **[doc/DECISIONS.md](./doc/DECISIONS.md)** - Architectural Decision Records (ADRs)
|
||||||
|
|
||||||
|
## Quick Commands
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run all frontend tests (Playwright)
|
||||||
|
bun tests/run-all.mjs
|
||||||
|
|
||||||
|
# Run Go tests with coverage
|
||||||
|
go test -cover ./internal/...
|
||||||
|
|
||||||
|
# Start dev server
|
||||||
|
go run .
|
||||||
|
```
|
||||||
|
|
||||||
|
## Tech Stack
|
||||||
|
|
||||||
|
- **Backend**: Go 1.21+ with standard library
|
||||||
|
- **Frontend**: HTMX + Hyperscript + Vanilla JS
|
||||||
|
- **Testing**: Playwright (frontend), Go test (backend)
|
||||||
|
|
||||||
|
## Documentation Index
|
||||||
|
|
||||||
|
- [doc/00-GO-DOCUMENTATION-INDEX.md](./doc/00-GO-DOCUMENTATION-INDEX.md) - Go system docs
|
||||||
|
- [doc/01-ARCHITECTURE.md](./doc/01-ARCHITECTURE.md) - System architecture
|
||||||
@@ -0,0 +1,456 @@
|
|||||||
|
package email
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewService(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
SMTPHost: "smtp.example.com",
|
||||||
|
SMTPPort: "587",
|
||||||
|
SMTPUser: "user@example.com",
|
||||||
|
SMTPPassword: "password",
|
||||||
|
FromEmail: "from@example.com",
|
||||||
|
ToEmail: "to@example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
service := NewService(config)
|
||||||
|
|
||||||
|
if service == nil {
|
||||||
|
t.Fatal("NewService should return a non-nil service")
|
||||||
|
}
|
||||||
|
|
||||||
|
if service.config != config {
|
||||||
|
t.Error("NewService should store the config")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContactFormData_Validate(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data ContactFormData
|
||||||
|
wantError bool
|
||||||
|
errorMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Valid - all fields",
|
||||||
|
data: ContactFormData{
|
||||||
|
Email: "test@example.com",
|
||||||
|
Name: "Test User",
|
||||||
|
Company: "Test Company",
|
||||||
|
Subject: "Test Subject",
|
||||||
|
Message: "This is a test message with enough characters.",
|
||||||
|
},
|
||||||
|
wantError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid - minimal fields",
|
||||||
|
data: ContactFormData{
|
||||||
|
Email: "test@example.com",
|
||||||
|
Message: "This is a test message with enough characters.",
|
||||||
|
},
|
||||||
|
wantError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid - missing email",
|
||||||
|
data: ContactFormData{
|
||||||
|
Message: "This is a test message with enough characters.",
|
||||||
|
},
|
||||||
|
wantError: true,
|
||||||
|
errorMsg: "email is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid - missing message",
|
||||||
|
data: ContactFormData{
|
||||||
|
Email: "test@example.com",
|
||||||
|
},
|
||||||
|
wantError: true,
|
||||||
|
errorMsg: "message is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid - bad email format (no @)",
|
||||||
|
data: ContactFormData{
|
||||||
|
Email: "testexample.com",
|
||||||
|
Message: "This is a test message with enough characters.",
|
||||||
|
},
|
||||||
|
wantError: true,
|
||||||
|
errorMsg: "invalid email format",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid - bad email format (no .)",
|
||||||
|
data: ContactFormData{
|
||||||
|
Email: "test@examplecom",
|
||||||
|
Message: "This is a test message with enough characters.",
|
||||||
|
},
|
||||||
|
wantError: true,
|
||||||
|
errorMsg: "invalid email format",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid - email with newline",
|
||||||
|
data: ContactFormData{
|
||||||
|
Email: "test@example.com\r\nBcc: hacker@evil.com",
|
||||||
|
Message: "This is a test message with enough characters.",
|
||||||
|
},
|
||||||
|
wantError: true,
|
||||||
|
errorMsg: "invalid email: contains prohibited characters",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid - subject with newline",
|
||||||
|
data: ContactFormData{
|
||||||
|
Email: "test@example.com",
|
||||||
|
Subject: "Test\r\nBcc: hacker@evil.com",
|
||||||
|
Message: "This is a test message with enough characters.",
|
||||||
|
},
|
||||||
|
wantError: true,
|
||||||
|
errorMsg: "invalid subject: contains prohibited characters",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid - email too long",
|
||||||
|
data: ContactFormData{
|
||||||
|
Email: strings.Repeat("a", 250) + "@example.com",
|
||||||
|
Message: "This is a test message with enough characters.",
|
||||||
|
},
|
||||||
|
wantError: true,
|
||||||
|
errorMsg: "email too long",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid - name too long",
|
||||||
|
data: ContactFormData{
|
||||||
|
Email: "test@example.com",
|
||||||
|
Name: strings.Repeat("a", 101),
|
||||||
|
Message: "This is a test message with enough characters.",
|
||||||
|
},
|
||||||
|
wantError: true,
|
||||||
|
errorMsg: "name too long",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid - company too long",
|
||||||
|
data: ContactFormData{
|
||||||
|
Email: "test@example.com",
|
||||||
|
Company: strings.Repeat("a", 101),
|
||||||
|
Message: "This is a test message with enough characters.",
|
||||||
|
},
|
||||||
|
wantError: true,
|
||||||
|
errorMsg: "company too long",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid - subject too long",
|
||||||
|
data: ContactFormData{
|
||||||
|
Email: "test@example.com",
|
||||||
|
Subject: strings.Repeat("a", 201),
|
||||||
|
Message: "This is a test message with enough characters.",
|
||||||
|
},
|
||||||
|
wantError: true,
|
||||||
|
errorMsg: "subject too long",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid - message too long",
|
||||||
|
data: ContactFormData{
|
||||||
|
Email: "test@example.com",
|
||||||
|
Message: strings.Repeat("a", 5001),
|
||||||
|
},
|
||||||
|
wantError: true,
|
||||||
|
errorMsg: "message too long",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid - message too short",
|
||||||
|
data: ContactFormData{
|
||||||
|
Email: "test@example.com",
|
||||||
|
Message: "Short",
|
||||||
|
},
|
||||||
|
wantError: true,
|
||||||
|
errorMsg: "message too short",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid - trims whitespace",
|
||||||
|
data: ContactFormData{
|
||||||
|
Email: " test@example.com ",
|
||||||
|
Name: " Test User ",
|
||||||
|
Message: " This is a test message with enough characters. ",
|
||||||
|
},
|
||||||
|
wantError: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := tt.data.Validate()
|
||||||
|
if (err != nil) != tt.wantError {
|
||||||
|
t.Errorf("Validate() error = %v, wantError %v", err, tt.wantError)
|
||||||
|
}
|
||||||
|
if err != nil && tt.errorMsg != "" && !strings.Contains(err.Error(), tt.errorMsg) {
|
||||||
|
t.Errorf("Validate() error = %v, want error containing %q", err, tt.errorMsg)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContainsNewlines(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"No newlines", "normal text", false},
|
||||||
|
{"Carriage return", "text\rmore", true},
|
||||||
|
{"Newline", "text\nmore", true},
|
||||||
|
{"CRLF", "text\r\nmore", true},
|
||||||
|
{"Empty", "", false},
|
||||||
|
{"Spaces only", " ", false},
|
||||||
|
{"Tab (allowed)", "text\ttab", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := containsNewlines(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("containsNewlines(%q) = %v, want %v", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatMultipartMessage(t *testing.T) {
|
||||||
|
service := NewService(&Config{
|
||||||
|
SMTPHost: "smtp.example.com",
|
||||||
|
SMTPPort: "587",
|
||||||
|
})
|
||||||
|
|
||||||
|
message := service.formatMultipartMessage(
|
||||||
|
"from@example.com",
|
||||||
|
"to@example.com",
|
||||||
|
"reply@example.com",
|
||||||
|
"Test Subject",
|
||||||
|
"<html><body>HTML Body</body></html>",
|
||||||
|
"Plain text body",
|
||||||
|
)
|
||||||
|
|
||||||
|
// Check required headers
|
||||||
|
if !strings.Contains(message, "From: from@example.com") {
|
||||||
|
t.Error("Message should contain From header")
|
||||||
|
}
|
||||||
|
if !strings.Contains(message, "To: to@example.com") {
|
||||||
|
t.Error("Message should contain To header")
|
||||||
|
}
|
||||||
|
if !strings.Contains(message, "Reply-To: reply@example.com") {
|
||||||
|
t.Error("Message should contain Reply-To header")
|
||||||
|
}
|
||||||
|
if !strings.Contains(message, "Subject: Test Subject") {
|
||||||
|
t.Error("Message should contain Subject header")
|
||||||
|
}
|
||||||
|
if !strings.Contains(message, "MIME-Version: 1.0") {
|
||||||
|
t.Error("Message should contain MIME-Version header")
|
||||||
|
}
|
||||||
|
if !strings.Contains(message, "multipart/alternative") {
|
||||||
|
t.Error("Message should be multipart/alternative")
|
||||||
|
}
|
||||||
|
if !strings.Contains(message, "text/plain") {
|
||||||
|
t.Error("Message should contain text/plain part")
|
||||||
|
}
|
||||||
|
if !strings.Contains(message, "text/html") {
|
||||||
|
t.Error("Message should contain text/html part")
|
||||||
|
}
|
||||||
|
if !strings.Contains(message, "Plain text body") {
|
||||||
|
t.Error("Message should contain plain text body")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatMultipartMessage_NoReplyTo(t *testing.T) {
|
||||||
|
service := NewService(&Config{
|
||||||
|
SMTPHost: "smtp.example.com",
|
||||||
|
SMTPPort: "587",
|
||||||
|
})
|
||||||
|
|
||||||
|
message := service.formatMultipartMessage(
|
||||||
|
"from@example.com",
|
||||||
|
"to@example.com",
|
||||||
|
"", // No reply-to
|
||||||
|
"Test Subject",
|
||||||
|
"<html>HTML</html>",
|
||||||
|
"Plain text",
|
||||||
|
)
|
||||||
|
|
||||||
|
if strings.Contains(message, "Reply-To:") {
|
||||||
|
t.Error("Message should not contain Reply-To header when empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildEmailBody(t *testing.T) {
|
||||||
|
service := NewService(&Config{
|
||||||
|
SMTPHost: "smtp.example.com",
|
||||||
|
SMTPPort: "587",
|
||||||
|
})
|
||||||
|
|
||||||
|
data := &ContactFormData{
|
||||||
|
Email: "sender@example.com",
|
||||||
|
Name: "Test User",
|
||||||
|
Company: "Test Company",
|
||||||
|
Subject: "Test Subject",
|
||||||
|
Message: "This is a test message.",
|
||||||
|
IP: "192.168.1.1",
|
||||||
|
Time: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
htmlBody, textBody, err := service.buildEmailBody(data)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("buildEmailBody() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check HTML body contains data
|
||||||
|
if !strings.Contains(htmlBody, "Test User") {
|
||||||
|
t.Error("HTML body should contain name")
|
||||||
|
}
|
||||||
|
if !strings.Contains(htmlBody, "sender@example.com") {
|
||||||
|
t.Error("HTML body should contain email")
|
||||||
|
}
|
||||||
|
if !strings.Contains(htmlBody, "This is a test message") {
|
||||||
|
t.Error("HTML body should contain message")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check text body contains data
|
||||||
|
if !strings.Contains(textBody, "Test User") {
|
||||||
|
t.Error("Text body should contain name")
|
||||||
|
}
|
||||||
|
if !strings.Contains(textBody, "sender@example.com") {
|
||||||
|
t.Error("Text body should contain email")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildEmailBody_EmptyName(t *testing.T) {
|
||||||
|
service := NewService(&Config{
|
||||||
|
SMTPHost: "smtp.example.com",
|
||||||
|
SMTPPort: "587",
|
||||||
|
})
|
||||||
|
|
||||||
|
data := &ContactFormData{
|
||||||
|
Email: "sender@example.com",
|
||||||
|
Name: "", // Empty name
|
||||||
|
Message: "This is a test message.",
|
||||||
|
Time: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
htmlBody, textBody, err := service.buildEmailBody(data)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("buildEmailBody() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should show "Not provided" for empty name
|
||||||
|
if !strings.Contains(htmlBody, "Not provided") {
|
||||||
|
t.Error("HTML body should show 'Not provided' for empty name")
|
||||||
|
}
|
||||||
|
if !strings.Contains(textBody, "Not provided") {
|
||||||
|
t.Error("Text body should show 'Not provided' for empty name")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendContactForm_ValidationError(t *testing.T) {
|
||||||
|
service := NewService(&Config{
|
||||||
|
SMTPHost: "smtp.example.com",
|
||||||
|
SMTPPort: "587",
|
||||||
|
SMTPUser: "user",
|
||||||
|
SMTPPassword: "pass",
|
||||||
|
ToEmail: "to@example.com",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Invalid data - missing email
|
||||||
|
data := &ContactFormData{
|
||||||
|
Message: "Test message that is long enough.",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := service.SendContactForm(data)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Error("SendContactForm should return error for invalid data")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), "validation failed") {
|
||||||
|
t.Errorf("Error should mention validation failure: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendMultipartEmail_MissingConfig(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config *Config
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Missing SMTP host",
|
||||||
|
config: &Config{SMTPPort: "587", SMTPUser: "user", SMTPPassword: "pass", ToEmail: "to@example.com"},
|
||||||
|
wantErr: "SMTP configuration incomplete",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Missing SMTP port",
|
||||||
|
config: &Config{SMTPHost: "smtp.example.com", SMTPUser: "user", SMTPPassword: "pass", ToEmail: "to@example.com"},
|
||||||
|
wantErr: "SMTP configuration incomplete",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Missing SMTP user",
|
||||||
|
config: &Config{SMTPHost: "smtp.example.com", SMTPPort: "587", SMTPPassword: "pass", ToEmail: "to@example.com"},
|
||||||
|
wantErr: "SMTP credentials missing",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Missing SMTP password",
|
||||||
|
config: &Config{SMTPHost: "smtp.example.com", SMTPPort: "587", SMTPUser: "user", ToEmail: "to@example.com"},
|
||||||
|
wantErr: "SMTP credentials missing",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Missing recipient email",
|
||||||
|
config: &Config{SMTPHost: "smtp.example.com", SMTPPort: "587", SMTPUser: "user", SMTPPassword: "pass"},
|
||||||
|
wantErr: "recipient email not configured",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
service := NewService(tt.config)
|
||||||
|
err := service.sendMultipartEmail("Subject", "<html>", "text", "reply@example.com")
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Error("sendMultipartEmail should return error for incomplete config")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), tt.wantErr) {
|
||||||
|
t.Errorf("Error = %v, want error containing %q", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCVThemeCSS(t *testing.T) {
|
||||||
|
css := CVThemeCSS()
|
||||||
|
|
||||||
|
if css == "" {
|
||||||
|
t.Error("CVThemeCSS should return non-empty CSS")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for some expected CSS properties
|
||||||
|
if !strings.Contains(css, "font-family") {
|
||||||
|
t.Error("CSS should contain font-family")
|
||||||
|
}
|
||||||
|
if !strings.Contains(css, "color") {
|
||||||
|
t.Error("CSS should contain color definitions")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContactEmailHTMLTemplate(t *testing.T) {
|
||||||
|
template := ContactEmailHTMLTemplate()
|
||||||
|
|
||||||
|
if template == "" {
|
||||||
|
t.Error("ContactEmailHTMLTemplate should return non-empty template")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for template variables
|
||||||
|
if !strings.Contains(template, "{{.Name}}") {
|
||||||
|
t.Error("Template should contain {{.Name}}")
|
||||||
|
}
|
||||||
|
if !strings.Contains(template, "{{.Email}}") {
|
||||||
|
t.Error("Template should contain {{.Email}}")
|
||||||
|
}
|
||||||
|
if !strings.Contains(template, "{{.Message}}") {
|
||||||
|
t.Error("Template should contain {{.Message}}")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,242 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
c "github.com/juanatsap/cv-site/internal/constants"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewContactRateLimiter(t *testing.T) {
|
||||||
|
rl := NewContactRateLimiter()
|
||||||
|
|
||||||
|
if rl == nil {
|
||||||
|
t.Fatal("NewContactRateLimiter should return non-nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if rl.clients == nil {
|
||||||
|
t.Error("clients map should be initialized")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContactRateLimiter_allow(t *testing.T) {
|
||||||
|
rl := &ContactRateLimiter{
|
||||||
|
clients: make(map[string]*contactRateLimitEntry),
|
||||||
|
}
|
||||||
|
|
||||||
|
ip := "192.168.1.1"
|
||||||
|
|
||||||
|
// First request should be allowed
|
||||||
|
if !rl.allow(ip) {
|
||||||
|
t.Error("First request should be allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subsequent requests up to limit should be allowed
|
||||||
|
limit := c.RateLimitContactRequests
|
||||||
|
for i := 1; i < limit; i++ {
|
||||||
|
if !rl.allow(ip) {
|
||||||
|
t.Errorf("Request %d should be allowed (limit: %d)", i+1, limit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request exceeding limit should be blocked
|
||||||
|
if rl.allow(ip) {
|
||||||
|
t.Error("Request exceeding limit should be blocked")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Different IP should be allowed
|
||||||
|
if !rl.allow("192.168.1.2") {
|
||||||
|
t.Error("Different IP should be allowed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContactRateLimiter_Middleware_Allowed(t *testing.T) {
|
||||||
|
rl := &ContactRateLimiter{
|
||||||
|
clients: make(map[string]*contactRateLimitEntry),
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("OK"))
|
||||||
|
})
|
||||||
|
|
||||||
|
protected := rl.Middleware(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/contact", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.1:12345"
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
protected.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Errorf("Status = %d, want %d", rec.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContactRateLimiter_Middleware_Blocked(t *testing.T) {
|
||||||
|
rl := &ContactRateLimiter{
|
||||||
|
clients: make(map[string]*contactRateLimitEntry),
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
protected := rl.Middleware(handler)
|
||||||
|
|
||||||
|
// Exhaust the rate limit
|
||||||
|
limit := c.RateLimitContactRequests
|
||||||
|
for i := 0; i < limit; i++ {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/contact", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.1:12345"
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
protected.ServeHTTP(rec, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next request should be blocked
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/contact", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.1:12345"
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
protected.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusTooManyRequests {
|
||||||
|
t.Errorf("Status = %d, want %d", rec.Code, http.StatusTooManyRequests)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have Retry-After header
|
||||||
|
if rec.Header().Get(c.HeaderRetryAfter) == "" {
|
||||||
|
t.Error("Response should have Retry-After header")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContactRateLimiter_Middleware_HTMX(t *testing.T) {
|
||||||
|
rl := &ContactRateLimiter{
|
||||||
|
clients: make(map[string]*contactRateLimitEntry),
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
protected := rl.Middleware(handler)
|
||||||
|
|
||||||
|
// Exhaust the rate limit
|
||||||
|
limit := c.RateLimitContactRequests
|
||||||
|
for i := 0; i < limit; i++ {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/contact", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.1:12345"
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
protected.ServeHTTP(rec, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTMX request should get HTML response
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/contact", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.1:12345"
|
||||||
|
req.Header.Set(c.HeaderHXRequest, "true")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
protected.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusTooManyRequests {
|
||||||
|
t.Errorf("Status = %d, want %d", rec.Code, http.StatusTooManyRequests)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := rec.Body.String()
|
||||||
|
if !strings.Contains(body, "Too Many Requests") {
|
||||||
|
t.Error("HTMX response should contain HTML error message")
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, "alert") {
|
||||||
|
t.Error("HTMX response should contain alert class")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContactRateLimiter_Middleware_XForwardedFor(t *testing.T) {
|
||||||
|
rl := &ContactRateLimiter{
|
||||||
|
clients: make(map[string]*contactRateLimitEntry),
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
protected := rl.Middleware(handler)
|
||||||
|
|
||||||
|
// Request with X-Forwarded-For
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/contact", nil)
|
||||||
|
req.Header.Set(c.HeaderXForwardedFor, "10.0.0.1, 192.168.1.1")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
protected.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Errorf("Status = %d, want %d", rec.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that first IP was used
|
||||||
|
rl.mu.RLock()
|
||||||
|
_, exists := rl.clients["10.0.0.1"]
|
||||||
|
rl.mu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
t.Error("Should use first IP from X-Forwarded-For")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContactRateLimiter_Middleware_XRealIP(t *testing.T) {
|
||||||
|
rl := &ContactRateLimiter{
|
||||||
|
clients: make(map[string]*contactRateLimitEntry),
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
protected := rl.Middleware(handler)
|
||||||
|
|
||||||
|
// Request with X-Real-IP
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/contact", nil)
|
||||||
|
req.Header.Set(c.HeaderXRealIP, "10.0.0.2")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
protected.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Errorf("Status = %d, want %d", rec.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that X-Real-IP was used
|
||||||
|
rl.mu.RLock()
|
||||||
|
_, exists := rl.clients["10.0.0.2"]
|
||||||
|
rl.mu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
t.Error("Should use X-Real-IP")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContactRateLimiter_GetStats(t *testing.T) {
|
||||||
|
rl := &ContactRateLimiter{
|
||||||
|
clients: make(map[string]*contactRateLimitEntry),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add some entries
|
||||||
|
rl.allow("192.168.1.1")
|
||||||
|
rl.allow("192.168.1.2")
|
||||||
|
|
||||||
|
stats := rl.GetStats()
|
||||||
|
|
||||||
|
if stats["total_clients"] != 2 {
|
||||||
|
t.Errorf("total_clients = %v, want 2", stats["total_clients"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if stats["limit"] != c.RateLimitContactRequests {
|
||||||
|
t.Errorf("limit = %v, want %d", stats["limit"], c.RateLimitContactRequests)
|
||||||
|
}
|
||||||
|
|
||||||
|
if stats["window"] == "" {
|
||||||
|
t.Error("window should not be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,385 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
c "github.com/juanatsap/cv-site/internal/constants"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewCSRFProtection(t *testing.T) {
|
||||||
|
csrf := NewCSRFProtection()
|
||||||
|
|
||||||
|
if csrf == nil {
|
||||||
|
t.Fatal("NewCSRFProtection should return non-nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if csrf.tokens == nil {
|
||||||
|
t.Error("tokens map should be initialized")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCSRFProtection_generateToken(t *testing.T) {
|
||||||
|
csrf := NewCSRFProtection()
|
||||||
|
|
||||||
|
token, err := csrf.generateToken()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("generateToken() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if token == "" {
|
||||||
|
t.Error("generateToken() should return non-empty token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token should be stored
|
||||||
|
csrf.mu.RLock()
|
||||||
|
entry, exists := csrf.tokens[token]
|
||||||
|
csrf.mu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
t.Error("Generated token should be stored")
|
||||||
|
}
|
||||||
|
|
||||||
|
if entry.token != token {
|
||||||
|
t.Errorf("Stored token = %q, want %q", entry.token, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token should have expiration in the future
|
||||||
|
if !entry.expiresAt.After(time.Now()) {
|
||||||
|
t.Error("Token expiration should be in the future")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCSRFProtection_generateToken_Unique(t *testing.T) {
|
||||||
|
csrf := NewCSRFProtection()
|
||||||
|
|
||||||
|
tokens := make(map[string]bool)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
token, err := csrf.generateToken()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generateToken() error = %v", err)
|
||||||
|
}
|
||||||
|
if tokens[token] {
|
||||||
|
t.Error("Tokens should be unique")
|
||||||
|
}
|
||||||
|
tokens[token] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCSRFProtection_GetToken_NewToken(t *testing.T) {
|
||||||
|
csrf := NewCSRFProtection()
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
token, err := csrf.GetToken(rec, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("GetToken() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if token == "" {
|
||||||
|
t.Error("GetToken() should return non-empty token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check cookie was set
|
||||||
|
cookies := rec.Result().Cookies()
|
||||||
|
var found bool
|
||||||
|
for _, cookie := range cookies {
|
||||||
|
if cookie.Name == c.CSRFCookieName {
|
||||||
|
found = true
|
||||||
|
if cookie.Value != token {
|
||||||
|
t.Errorf("Cookie value = %q, want %q", cookie.Value, token)
|
||||||
|
}
|
||||||
|
if !cookie.HttpOnly {
|
||||||
|
t.Error("Cookie should be HttpOnly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Error("CSRF cookie should be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCSRFProtection_GetToken_ExistingToken(t *testing.T) {
|
||||||
|
csrf := NewCSRFProtection()
|
||||||
|
|
||||||
|
// First request to get token
|
||||||
|
req1 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec1 := httptest.NewRecorder()
|
||||||
|
token1, _ := csrf.GetToken(rec1, req1)
|
||||||
|
|
||||||
|
// Second request with existing token cookie
|
||||||
|
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req2.AddCookie(&http.Cookie{
|
||||||
|
Name: c.CSRFCookieName,
|
||||||
|
Value: token1,
|
||||||
|
})
|
||||||
|
rec2 := httptest.NewRecorder()
|
||||||
|
|
||||||
|
token2, err := csrf.GetToken(rec2, req2)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("GetToken() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should return same token
|
||||||
|
if token2 != token1 {
|
||||||
|
t.Errorf("GetToken() = %q, want %q (same token)", token2, token1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCSRFProtection_validateToken(t *testing.T) {
|
||||||
|
csrf := NewCSRFProtection()
|
||||||
|
|
||||||
|
// Generate a token first
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
token, _ := csrf.GetToken(rec, req)
|
||||||
|
|
||||||
|
t.Run("Valid token in form", func(t *testing.T) {
|
||||||
|
form := url.Values{}
|
||||||
|
form.Set(c.CSRFFormField, token)
|
||||||
|
|
||||||
|
postReq := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode()))
|
||||||
|
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
postReq.AddCookie(&http.Cookie{
|
||||||
|
Name: c.CSRFCookieName,
|
||||||
|
Value: token,
|
||||||
|
})
|
||||||
|
|
||||||
|
if !csrf.validateToken(postReq) {
|
||||||
|
t.Error("validateToken should return true for valid token")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Valid token in header", func(t *testing.T) {
|
||||||
|
postReq := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
postReq.Header.Set(c.HeaderXCSRFToken, token)
|
||||||
|
postReq.AddCookie(&http.Cookie{
|
||||||
|
Name: c.CSRFCookieName,
|
||||||
|
Value: token,
|
||||||
|
})
|
||||||
|
|
||||||
|
if !csrf.validateToken(postReq) {
|
||||||
|
t.Error("validateToken should return true for valid token in header")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Missing form token", func(t *testing.T) {
|
||||||
|
postReq := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
postReq.AddCookie(&http.Cookie{
|
||||||
|
Name: c.CSRFCookieName,
|
||||||
|
Value: token,
|
||||||
|
})
|
||||||
|
|
||||||
|
if csrf.validateToken(postReq) {
|
||||||
|
t.Error("validateToken should return false for missing form token")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Missing cookie", func(t *testing.T) {
|
||||||
|
form := url.Values{}
|
||||||
|
form.Set(c.CSRFFormField, token)
|
||||||
|
|
||||||
|
postReq := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode()))
|
||||||
|
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
if csrf.validateToken(postReq) {
|
||||||
|
t.Error("validateToken should return false for missing cookie")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Token mismatch", func(t *testing.T) {
|
||||||
|
form := url.Values{}
|
||||||
|
form.Set(c.CSRFFormField, "wrong-token")
|
||||||
|
|
||||||
|
postReq := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode()))
|
||||||
|
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
postReq.AddCookie(&http.Cookie{
|
||||||
|
Name: c.CSRFCookieName,
|
||||||
|
Value: token,
|
||||||
|
})
|
||||||
|
|
||||||
|
if csrf.validateToken(postReq) {
|
||||||
|
t.Error("validateToken should return false for mismatched tokens")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Token not in store", func(t *testing.T) {
|
||||||
|
unknownToken := "unknown-token"
|
||||||
|
form := url.Values{}
|
||||||
|
form.Set(c.CSRFFormField, unknownToken)
|
||||||
|
|
||||||
|
postReq := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode()))
|
||||||
|
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
postReq.AddCookie(&http.Cookie{
|
||||||
|
Name: c.CSRFCookieName,
|
||||||
|
Value: unknownToken,
|
||||||
|
})
|
||||||
|
|
||||||
|
if csrf.validateToken(postReq) {
|
||||||
|
t.Error("validateToken should return false for token not in store")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCSRFProtection_Middleware(t *testing.T) {
|
||||||
|
csrf := NewCSRFProtection()
|
||||||
|
|
||||||
|
// Generate a valid token
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
token, _ := csrf.GetToken(rec, req)
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("OK"))
|
||||||
|
})
|
||||||
|
|
||||||
|
protected := csrf.Middleware(handler)
|
||||||
|
|
||||||
|
t.Run("GET request passes through", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
protected.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Errorf("Status = %d, want %d", rec.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("POST with valid token passes", func(t *testing.T) {
|
||||||
|
form := url.Values{}
|
||||||
|
form.Set(c.CSRFFormField, token)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode()))
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.AddCookie(&http.Cookie{
|
||||||
|
Name: c.CSRFCookieName,
|
||||||
|
Value: token,
|
||||||
|
})
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
protected.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Errorf("Status = %d, want %d", rec.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("POST without token fails", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
protected.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusForbidden {
|
||||||
|
t.Errorf("Status = %d, want %d", rec.Code, http.StatusForbidden)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("PUT without token fails", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodPut, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
protected.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusForbidden {
|
||||||
|
t.Errorf("Status = %d, want %d", rec.Code, http.StatusForbidden)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DELETE without token fails", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodDelete, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
protected.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusForbidden {
|
||||||
|
t.Errorf("Status = %d, want %d", rec.Code, http.StatusForbidden)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("HTMX request gets HTML error", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
req.Header.Set(c.HeaderHXRequest, "true")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
protected.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusForbidden {
|
||||||
|
t.Errorf("Status = %d, want %d", rec.Code, http.StatusForbidden)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := rec.Body.String()
|
||||||
|
if !strings.Contains(body, "Security Error") {
|
||||||
|
t.Error("HTMX response should contain HTML error message")
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, "alert") {
|
||||||
|
t.Error("HTMX response should contain alert class")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCSRFTokenEntry_Expiration(t *testing.T) {
|
||||||
|
csrf := &CSRFProtection{
|
||||||
|
tokens: make(map[string]*csrfTokenEntry),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add expired token
|
||||||
|
expiredToken := "expired-token"
|
||||||
|
csrf.tokens[expiredToken] = &csrfTokenEntry{
|
||||||
|
token: expiredToken,
|
||||||
|
expiresAt: time.Now().Add(-1 * time.Hour), // Expired
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validation should fail for expired token
|
||||||
|
form := url.Values{}
|
||||||
|
form.Set(c.CSRFFormField, expiredToken)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode()))
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.AddCookie(&http.Cookie{
|
||||||
|
Name: c.CSRFCookieName,
|
||||||
|
Value: expiredToken,
|
||||||
|
})
|
||||||
|
|
||||||
|
if csrf.validateToken(req) {
|
||||||
|
t.Error("validateToken should return false for expired token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetToken_ExpiredTokenInCookie(t *testing.T) {
|
||||||
|
csrf := &CSRFProtection{
|
||||||
|
tokens: make(map[string]*csrfTokenEntry),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add expired token to store
|
||||||
|
expiredToken := "expired-token"
|
||||||
|
csrf.tokens[expiredToken] = &csrfTokenEntry{
|
||||||
|
token: expiredToken,
|
||||||
|
expiresAt: time.Now().Add(-1 * time.Hour), // Expired
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request with expired token cookie
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.AddCookie(&http.Cookie{
|
||||||
|
Name: c.CSRFCookieName,
|
||||||
|
Value: expiredToken,
|
||||||
|
})
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
newToken, err := csrf.GetToken(rec, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("GetToken() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should generate new token
|
||||||
|
if newToken == expiredToken {
|
||||||
|
t.Error("GetToken() should generate new token when existing is expired")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRecovery(t *testing.T) {
|
||||||
|
t.Run("Recovers from panic", func(t *testing.T) {
|
||||||
|
// Handler that panics
|
||||||
|
panicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
panic("test panic")
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := Recovery(panicHandler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Should not panic
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusInternalServerError {
|
||||||
|
t.Errorf("Status = %d, want %d", w.Code, http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Passes through normal requests", func(t *testing.T) {
|
||||||
|
normalHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("OK"))
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := Recovery(normalHandler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("Status = %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.Body.String() != "OK" {
|
||||||
|
t.Errorf("Body = %q, want %q", w.Body.String(), "OK")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Recovers from nil panic", func(t *testing.T) {
|
||||||
|
nilPanicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var p *int = nil
|
||||||
|
_ = *p // This will cause a nil pointer dereference
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := Recovery(nilPanicHandler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Should not panic
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusInternalServerError {
|
||||||
|
t.Errorf("Status = %d, want %d", w.Code, http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user