Middleware Chain
Logger, Recovery, RequestID middleware with a Chain helper.
package main
import (
"context"
"crypto/rand"
"encoding/hex"
"log"
"log/slog"
"net/http"
"time"
)
type contextKey string
const requestIDKey contextKey = "request_id"
func Logger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
next.ServeHTTP(w, r)
slog.Info("request",
"request_id", GetRequestID(r.Context()),
"method", r.Method,
"path", r.URL.Path,
"duration", time.Since(start),
)
})
}
func Recovery(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
slog.Error("panic recovered", "error", err, "path", r.URL.Path)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}
func RequestID(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id := r.Header.Get("X-Request-ID")
if id == "" {
b := make([]byte, 8)
rand.Read(b)
id = hex.EncodeToString(b)
}
w.Header().Set("X-Request-ID", id)
ctx := context.WithValue(r.Context(), requestIDKey, id)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func GetRequestID(ctx context.Context) string {
if id, ok := ctx.Value(requestIDKey).(string); ok {
return id
}
return "unknown"
}
func Chain(handler http.Handler, middlewares ...func(http.Handler) http.Handler) http.Handler {
for i := len(middlewares) - 1; i >= 0; i-- {
handler = middlewares[i](handler)
}
return handler
}
func main() {
mux := http.NewServeMux()
mux.HandleFunc("GET /health", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("ok"))
})
mux.HandleFunc("GET /panic", func(w http.ResponseWriter, r *http.Request) {
panic("something broke")
})
handler := Chain(mux, RequestID, Recovery, Logger)
slog.Info("listening", "addr", ":8080")
log.Fatal(http.ListenAndServe(":8080", handler))
}