Middleware
8 minute read
Middleware functions execute before route handlers. They perform cross-cutting concerns like authentication, logging, and rate limiting.
Basic Usage
Middleware is a function that wraps your handlers:
func Logger() router.HandlerFunc {
return func(c *router.Context) {
start := time.Now()
path := c.Request.URL.Path
c.Next() // Continue to next handler
duration := time.Since(start)
fmt.Printf("[%s] %s - %v\n", c.Request.Method, path, duration)
}
}
func main() {
r := router.MustNew()
// Apply middleware globally
r.Use(Logger())
r.GET("/", handler)
http.ListenAndServe(":8080", r)
}
Key concepts:
c.Next()- Continues to the next middleware or handler.- Call
c.Next()to proceed. Don’t call it to stop the chain. - Middleware runs in registration order.
Middleware Scope
Global Middleware
Applied to all routes:
r := router.MustNew()
// These apply to ALL routes
r.Use(Logger())
r.Use(Recovery())
r.Use(CORS())
r.GET("/", handler)
r.GET("/users", usersHandler)
Group Middleware
Applied only to routes in a group:
r := router.MustNew()
r.Use(Logger()) // Global
// Public routes - no auth
public := r.Group("/api/public")
public.GET("/status", statusHandler)
// Private routes - auth required
private := r.Group("/api/private")
private.Use(AuthRequired()) // Group-level
private.GET("/profile", profileHandler)
Route-Specific Middleware
Applied to individual routes:
r := router.MustNew()
r.Use(Logger()) // Global
// Auth only for this route
r.GET("/admin", AdminAuth(), adminHandler)
// Multiple middleware for one route
r.POST("/upload", RateLimit(), ValidateFile(), uploadHandler)
Built-in Middleware
The router includes production-ready middleware in sub-packages. See the Middleware Reference for complete options.
Security
Security Headers
import "rivaas.dev/router/middleware/security"
r.Use(security.New(
security.WithHSTS(true),
security.WithFrameDeny(true),
security.WithContentTypeNosniff(true),
))
CORS
import "rivaas.dev/router/middleware/cors"
r.Use(cors.New(
cors.WithAllowedOrigins("https://example.com"),
cors.WithAllowedMethods("GET", "POST", "PUT", "DELETE"),
cors.WithAllowedHeaders("Content-Type", "Authorization"),
cors.WithAllowCredentials(true),
))
Basic Auth
import "rivaas.dev/router/middleware/basicauth"
admin := r.Group("/admin")
admin.Use(basicauth.New(
basicauth.WithCredentials("admin", "secret"),
))
Observability
Access Log
import (
"log/slog"
"rivaas.dev/router/middleware/accesslog"
)
logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
r.Use(accesslog.New(
accesslog.WithLogger(logger),
accesslog.WithExcludePaths("/health", "/metrics"),
accesslog.WithSlowThreshold(500 * time.Millisecond),
))
Request ID
import "rivaas.dev/router/middleware/requestid"
// UUID v7 by default (36 chars, time-ordered, RFC 9562)
r.Use(requestid.New())
// Use ULID for shorter IDs (26 chars)
r.Use(requestid.New(requestid.WithULID()))
// Custom header name
r.Use(requestid.New(requestid.WithHeader("X-Correlation-ID")))
// Later in handlers:
func handler(c *router.Context) {
id := requestid.Get(c)
fmt.Println("Request ID:", id)
}
Reliability
Recovery
import "rivaas.dev/router/middleware/recovery"
r.Use(recovery.New(
recovery.WithPrintStack(true),
recovery.WithLogger(logger),
))
Timeout
import "rivaas.dev/router/middleware/timeout"
r.Use(timeout.New(
timeout.WithDuration(30 * time.Second),
timeout.WithMessage("Request timeout"),
))
Rate Limit
import "rivaas.dev/router/middleware/ratelimit"
r.Use(ratelimit.New(
ratelimit.WithRequestsPerSecond(1000),
ratelimit.WithBurst(100),
ratelimit.WithKeyFunc(func(c *router.Context) string {
return c.ClientIP() // Rate limit by IP
}),
))
Body Limit
import "rivaas.dev/router/middleware/bodylimit"
r.Use(bodylimit.New(
bodylimit.WithLimit(10 * 1024 * 1024), // 10MB
))
Performance
Compression
import "rivaas.dev/router/middleware/compression"
r.Use(compression.New(
compression.WithLevel(compression.DefaultCompression),
compression.WithMinSize(1024), // Don't compress <1KB
))
Middleware Ordering
The order in which middleware is applied matters. Recommended order:
r := router.MustNew()
// 1. Request ID - Generate early for logging/tracing
r.Use(requestid.New())
// 2. AccessLog - Log all requests including failed ones
r.Use(accesslog.New())
// 3. Recovery - Catch panics from all other middleware
r.Use(recovery.New())
// 4. Security/CORS - Set security headers early
r.Use(security.New())
r.Use(cors.New())
// 5. Body Limit - Reject large requests before processing
r.Use(bodylimit.New())
// 6. Rate Limit - Reject excessive requests before processing
r.Use(ratelimit.New())
// 7. Timeout - Set time limits for downstream processing
r.Use(timeout.New())
// 8. Authentication - Verify identity after rate limiting
r.Use(auth.New())
// 9. Compression - Compress responses (last)
r.Use(compression.New())
// 10. Your application routes
r.GET("/", handler)
Why this order?
- RequestID first - Generates a unique ID that other middleware can use
- Logger early - Captures all activity including errors
- Recovery early - Catches panics to prevent crashes
- Security/CORS - Applies security policies before business logic
- BodyLimit - Prevents reading excessive request bodies (DoS protection)
- RateLimit - Blocks excessive requests before expensive operations
- Timeout - Sets deadlines for request processing
- Auth - Authenticates after rate limiting but before business logic
- Compression - Compresses response bodies (should be last)
Writing Custom Middleware
Basic Middleware Pattern
func MyMiddleware() router.HandlerFunc {
return func(c *router.Context) {
// Before request processing
fmt.Println("Before handler")
c.Next() // Execute next middleware/handler
// After request processing
fmt.Println("After handler")
}
}
Middleware with Configuration
func RateLimit(requestsPerSecond int) router.HandlerFunc {
// Setup (runs once when middleware is created)
limiter := rate.NewLimiter(rate.Limit(requestsPerSecond), requestsPerSecond)
return func(c *router.Context) {
// Per-request logic
if !limiter.Allow() {
c.JSON(429, map[string]string{
"error": "Too many requests",
})
return // Don't call c.Next() - stop the chain
}
c.Next()
}
}
// Usage
r.Use(RateLimit(100)) // 100 requests per second
Middleware with Dependencies
func Auth(db *Database) router.HandlerFunc {
return func(c *router.Context) {
token := c.Request.Header.Get("Authorization")
user, err := db.ValidateToken(token)
if err != nil {
c.JSON(401, map[string]string{
"error": "Unauthorized",
})
return
}
// Store user in request context for handlers
ctx := context.WithValue(c.Request.Context(), "user", user)
c.Request = c.Request.WithContext(ctx)
c.Next()
}
}
// Usage
db := NewDatabase()
r.Use(Auth(db))
Conditional Middleware
func ConditionalAuth() router.HandlerFunc {
return func(c *router.Context) {
// Skip auth for public endpoints
if c.Request.URL.Path == "/public" {
c.Next()
return
}
// Require auth for other endpoints
token := c.Request.Header.Get("Authorization")
if token == "" {
c.JSON(401, map[string]string{
"error": "Unauthorized",
})
return
}
c.Next()
}
}
Middleware Patterns
Pattern: Error Handling Middleware
func ErrorHandler() router.HandlerFunc {
return func(c *router.Context) {
defer func() {
if err := recover(); err != nil {
log.Printf("Panic: %v", err)
c.JSON(500, map[string]string{
"error": "Internal server error",
})
}
}()
c.Next()
}
}
Pattern: Logging Middleware
func Logger() router.HandlerFunc {
return func(c *router.Context) {
start := time.Now()
path := c.Request.URL.Path
method := c.Request.Method
c.Next()
duration := time.Since(start)
status := c.Writer.Status()
log.Printf("[%s] %s %s - %d (%v)",
method,
path,
c.ClientIP(),
status,
duration,
)
}
}
Pattern: Authentication Middleware
func JWTAuth(secret string) router.HandlerFunc {
return func(c *router.Context) {
authHeader := c.Request.Header.Get("Authorization")
if authHeader == "" {
c.JSON(401, map[string]string{
"error": "Missing authorization header",
})
return
}
// Extract token (Bearer <token>)
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
c.JSON(401, map[string]string{
"error": "Invalid authorization header format",
})
return
}
token := parts[1]
claims, err := validateJWT(token, secret)
if err != nil {
c.JSON(401, map[string]string{
"error": "Invalid token",
})
return
}
// Store claims in request context
ctx := c.Request.Context()
ctx = context.WithValue(ctx, "user_id", claims.UserID)
ctx = context.WithValue(ctx, "user_email", claims.Email)
c.Request = c.Request.WithContext(ctx)
c.Next()
}
}
Pattern: Request ID Middleware
The built-in requestid middleware handles this pattern with UUID v7 or ULID:
import "rivaas.dev/router/middleware/requestid"
// UUID v7 (default) - time-ordered, 36 chars
r.Use(requestid.New())
// ULID - shorter, 26 chars
r.Use(requestid.New(requestid.WithULID()))
// Access in handlers
func handler(c *router.Context) {
id := requestid.Get(c) // Get from context
// Or from header: c.Response.Header().Get("X-Request-ID")
}
If you need a custom implementation:
func RequestID() router.HandlerFunc {
return func(c *router.Context) {
// Check for existing request ID
requestID := c.Request.Header.Get("X-Request-ID")
if requestID == "" {
// Generate new UUID v7
requestID = uuid.Must(uuid.NewV7()).String()
}
// Store in request context and response header
ctx := context.WithValue(c.Request.Context(), "request_id", requestID)
c.Request = c.Request.WithContext(ctx)
c.Header("X-Request-ID", requestID)
c.Next()
}
}
Best Practices
1. Always Call c.Next()
Unless you want to stop the middleware chain:
// ✅ GOOD: Calls c.Next() to continue
func Logger() router.HandlerFunc {
return func(c *router.Context) {
start := time.Now()
c.Next() // Continue to handler
duration := time.Since(start)
log.Printf("Duration: %v", duration)
}
}
// ✅ GOOD: Doesn't call c.Next() to stop chain
func Auth() router.HandlerFunc {
return func(c *router.Context) {
if !isAuthorized(c) {
c.JSON(401, map[string]string{"error": "Unauthorized"})
return // Don't call c.Next()
}
c.Next()
}
}
2. Keep Middleware Focused
Each middleware should do one thing:
// ✅ GOOD: Single responsibility
func Logger() router.HandlerFunc { ... }
func Auth() router.HandlerFunc { ... }
func RateLimit() router.HandlerFunc { ... }
// ❌ BAD: Does too much
func SuperMiddleware() router.HandlerFunc {
return func(c *router.Context) {
// Logging
// Auth
// Rate limiting
// ...
c.Next()
}
}
3. Use Functional Options for Configuration
type Config struct {
Limit int
Burst int
}
type Option func(*Config)
func WithLimit(limit int) Option {
return func(c *Config) {
c.Limit = limit
}
}
func WithBurst(burst int) Option {
return func(c *Config) {
c.Burst = burst
}
}
func RateLimit(opts ...Option) router.HandlerFunc {
config := &Config{
Limit: 100,
Burst: 10,
}
for _, opt := range opts {
opt(config)
}
limiter := rate.NewLimiter(rate.Limit(config.Limit), config.Burst)
return func(c *router.Context) {
if !limiter.Allow() {
c.JSON(429, map[string]string{"error": "Too many requests"})
return
}
c.Next()
}
}
// Usage
r.Use(RateLimit(
WithLimit(1000),
WithBurst(100),
))
4. Handle Errors Gracefully
func Middleware() router.HandlerFunc {
return func(c *router.Context) {
if err := doSomething(c); err != nil {
// Log error
log.Printf("Middleware error: %v", err)
// Return error response
c.JSON(500, map[string]string{
"error": "Internal server error",
})
return // Don't call c.Next()
}
c.Next()
}
}
Complete Example
package main
import (
"fmt"
"log"
"log/slog"
"net/http"
"os"
"time"
"rivaas.dev/router"
"rivaas.dev/router/middleware/accesslog"
"rivaas.dev/router/middleware/cors"
"rivaas.dev/router/middleware/recovery"
"rivaas.dev/router/middleware/requestid"
"rivaas.dev/router/middleware/security"
)
func main() {
logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
r := router.MustNew()
// Global middleware (applies to all routes)
r.Use(requestid.New())
r.Use(accesslog.New(accesslog.WithLogger(logger)))
r.Use(recovery.New())
r.Use(security.New())
r.Use(cors.New(
cors.WithAllowedOrigins("*"),
cors.WithAllowedMethods("GET", "POST", "PUT", "DELETE"),
))
// Public routes
r.GET("/health", healthHandler)
r.GET("/public", publicHandler)
// API routes with auth
api := r.Group("/api")
api.Use(JWTAuth("your-secret-key"))
{
api.GET("/profile", profileHandler)
api.POST("/posts", createPostHandler)
// Admin routes with additional middleware
admin := api.Group("/admin")
admin.Use(RequireAdmin())
{
admin.GET("/users", listUsersHandler)
admin.DELETE("/users/:id", deleteUserHandler)
}
}
log.Fatal(http.ListenAndServe(":8080", r))
}
// Custom middleware
func JWTAuth(secret string) router.HandlerFunc {
return func(c *router.Context) {
token := c.Request.Header.Get("Authorization")
if token == "" {
c.JSON(401, map[string]string{"error": "Unauthorized"})
return
}
// Validate token...
c.Next()
}
}
func RequireAdmin() router.HandlerFunc {
return func(c *router.Context) {
// Check if user is admin...
c.Next()
}
}
// Handlers
func healthHandler(c *router.Context) {
c.JSON(200, map[string]string{"status": "OK"})
}
func publicHandler(c *router.Context) {
c.JSON(200, map[string]string{"message": "Public endpoint"})
}
func profileHandler(c *router.Context) {
c.JSON(200, map[string]string{"user": "john@example.com"})
}
func createPostHandler(c *router.Context) {
c.JSON(201, map[string]string{"message": "Post created"})
}
func listUsersHandler(c *router.Context) {
c.JSON(200, []string{"user1", "user2"})
}
func deleteUserHandler(c *router.Context) {
c.Status(204)
}
Next Steps
- Context API: Learn about the Context and its lifecycle
- Middleware Reference: See all built-in middleware options
- Examples: Browse working examples with middleware
Feedback
Was this page helpful?
Glad to hear it! Please tell us how we can improve.
Sorry to hear that. Please tell us how we can improve.