Topic 8
Middleware
Middleware sits between the client and your route handlers — it intercepts every request before it reaches your routes, and every response before it leaves your app. Think of it as a security checkpoint, logger, or transformer that wraps your entire application.
Real-world analogy: Middleware is like the reception desk at an office. Every visitor (request) passes through reception first — they can be checked, logged, redirected, or turned away — before reaching the actual office (route). On the way out, reception can also stamp the visitor's pass (add response headers).
8.1
Middleware Fundamentals
Request Flow
▾
Every HTTP request passes through a middleware stack — a chain of middleware functions — before reaching your route. The stack executes in order on the way in, and in reverse order on the way out.
Incoming Request
│
▼
┌──────────────────────────┐
│ Middleware 1 (CORS) │ ← runs first on request
│ ┌────────────────────┐ │
│ │ Middleware 2 (Log) │ │ ← runs second
│ │ ┌──────────────┐ │ │
│ │ │ Your Route │ │ │ ← business logic
│ │ └──────────────┘ │ │
│ │ ↑ response here │ │
│ └────────────────────┘ │ ← Middleware 2 wraps response
└──────────────────────────┘ ← Middleware 1 wraps response
│
▼
Outgoing Response
Key: each middleware calls `await call_next(request)` to pass to the next layer
The middleware lifecycle — 3 phases:
from fastapi import FastAPI, Request from starlette.middleware.base import BaseHTTPMiddleware app = FastAPI() class LifecycleMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): # ── PHASE 1: Before the route runs ────────────────── print(f"→ Incoming: {request.method} {request.url.path}") # You can: inspect headers, validate tokens, log, block requests # ── PHASE 2: Call the actual route ────────────────── response = await call_next(request) # Everything ABOVE this line runs BEFORE the route # Everything BELOW this line runs AFTER the route # ── PHASE 3: After the route runs ─────────────────── print(f"← Outgoing: status={response.status_code}") # You can: add response headers, log timing, compress return response # must return the response! app.add_middleware(LifecycleMiddleware) @app.get("/hello") def hello(): return {"message": "Hello!"} # Console output when hitting GET /hello: # → Incoming: GET /hello # ← Outgoing: status=200
Response Flow
▾
After
await call_next(request) returns, you hold the response object. You can inspect or mutate it — add headers, change the status code, or even replace the body entirely — before returning it to the client.
Example — adding custom headers to every response:
from fastapi import FastAPI, Request from starlette.middleware.base import BaseHTTPMiddleware import time, uuid app = FastAPI() class ResponseEnricherMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): start = time.perf_counter() request_id = str(uuid.uuid4())[:8] # Pass request_id downstream (other middleware/routes can read it) request.state.request_id = request_id response = await call_next(request) # Mutate the response — add headers elapsed_ms = (time.perf_counter() - start) * 1000 response.headers["X-Request-ID"] = request_id response.headers["X-Process-Time"] = f"{elapsed_ms:.2f}ms" response.headers["X-Powered-By"] = "FastAPI" return response app.add_middleware(ResponseEnricherMiddleware) @app.get("/ping") def ping(request: Request): return { "pong": True, "request_id": request.state.request_id # set by middleware! } # Response headers will contain: # X-Request-ID: a1b2c3d4 # X-Process-Time: 1.23ms # X-Powered-By: FastAPI
request.state is a scratchpad you can use to pass data from middleware to your route handlers — like a request-scoped context object.
Middleware order matters — stacking example:
# Middleware is applied in REVERSE order of add_middleware calls! # Last added = outermost layer (runs first) app.add_middleware(MiddlewareA) # runs SECOND (inner) app.add_middleware(MiddlewareB) # runs FIRST (outer) # Request flow: MiddlewareB → MiddlewareA → Route # Response flow: Route → MiddlewareA → MiddlewareB
8.2
Built-in Middleware
CORS — Cross-Origin Resource Sharing
▾
CORS controls which domains are allowed to make requests to your API from a browser. Without CORS middleware, your React/Vue frontend on
localhost:3000 would be blocked from calling your API on localhost:8000.
Why does CORS exist? Browsers enforce the Same-Origin Policy — JS can only call APIs on the same domain unless the API explicitly allows cross-origin requests via CORS headers.
Browser (localhost:3000) ──► OPTIONS /api/users ──► FastAPI (localhost:8000)
◄── Access-Control-Allow-Origin: * ◄──
──► GET /api/users ──────────────►
◄── 200 OK ◄─────────────────────
Browser first sends a "preflight" OPTIONS request.
CORS middleware replies with permission headers.
Then browser proceeds with the real request.
Example 1 — Development setup (allow all origins):
from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware app = FastAPI() # ⚠️ Development only — never use allow_origins=["*"] in production! app.add_middleware( CORSMiddleware, allow_origins=["*"], # any domain can access allow_credentials=False, # must be False with allow_origins=["*"] allow_methods=["*"], # GET, POST, PUT, DELETE, etc. allow_headers=["*"], # any request header ) @app.get("/public-data") def public_data(): return {"data": "accessible from any origin"}
Example 2 — Production setup (specific origins):
from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware app = FastAPI() ALLOWED_ORIGINS = [ "https://myapp.com", "https://www.myapp.com", "https://admin.myapp.com", "http://localhost:3000", # local dev frontend ] app.add_middleware( CORSMiddleware, allow_origins=ALLOWED_ORIGINS, allow_credentials=True, # allow cookies/auth headers allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH"], allow_headers=["Authorization", "Content-Type", "X-API-Key"], expose_headers=["X-Request-ID", "X-Process-Time"], max_age=3600, # preflight cache: 1 hour (reduces OPTIONS requests) ) @app.get("/protected-data") def protected_data(): return {"data": "only accessible from allowed origins"}
| Parameter | What it controls | Common value |
|---|---|---|
allow_origins | Which domains can access the API | ["https://myapp.com"] |
allow_credentials | Allow cookies & auth headers | True (with specific origins) |
allow_methods | Which HTTP methods are allowed | ["GET","POST","PUT","DELETE"] |
allow_headers | Which request headers are allowed | ["Authorization","Content-Type"] |
expose_headers | Which response headers JS can read | ["X-Request-ID"] |
max_age | Preflight cache duration (seconds) | 3600 |
GZip — Response Compression
▾
GZip middleware automatically compresses responses when the client supports it (sends
Accept-Encoding: gzip). This can reduce response sizes by 60–80% for JSON payloads — a huge win for large API responses.
from fastapi import FastAPI from fastapi.middleware.gzip import GZipMiddleware app = FastAPI() app.add_middleware( GZipMiddleware, minimum_size=1000 # only compress responses > 1000 bytes # small responses aren't worth compressing ) @app.get("/large-data") def large_data(): # FastAPI will compress this automatically if client supports it return {"items": [f"item-{i}" for i in range(1000)]} # Without GZip: ~12,000 bytes # With GZip: ~1,800 bytes ← 85% smaller!
GZip middleware only activates if the client sends
Accept-Encoding: gzip. Modern browsers and httpx/requests send this by default.Trusted Host — Host Header Validation
▾
TrustedHostMiddleware rejects requests with a
Host header that doesn't match your allowed domains. This protects against Host header injection attacks — where attackers send forged Host headers to trick your app.
from fastapi import FastAPI from starlette.middleware.trustedhost import TrustedHostMiddleware app = FastAPI() app.add_middleware( TrustedHostMiddleware, allowed_hosts=[ "myapp.com", "*.myapp.com", # wildcard subdomains "localhost", "127.0.0.1", ] ) @app.get("/data") def data(): return {"ok": True} # Request with Host: myapp.com → 200 OK # Request with Host: evil-site.com → 400 Bad Request (rejected!) # Request with Host: api.myapp.com → 200 OK (wildcard match)
HTTPS Redirect
▾
HTTPSRedirectMiddleware automatically redirects all HTTP requests to HTTPS. Any request coming in on plain HTTP gets a
307 Temporary Redirect to the same URL but with https://.
from fastapi import FastAPI from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware app = FastAPI() # Only add in production — will break local development! import os if os.getenv("ENV") == "production": app.add_middleware(HTTPSRedirectMiddleware) @app.get("/secure") def secure(): return {"message": "Secure!"} # http://myapp.com/secure → 307 redirect to https://myapp.com/secure
Don't use in development! Your local server likely doesn't have TLS configured, so this will cause redirect loops. Guard it behind an environment check.
8.3
Custom Middleware
Logging Middleware
▾
A logging middleware records every request and response — method, path, status code, timing. This gives you full visibility into traffic without cluttering your route handlers.
Example 1 — Basic request/response logger:
from fastapi import FastAPI, Request from starlette.middleware.base import BaseHTTPMiddleware import time, logging # Set up structured logger logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s" ) logger = logging.getLogger("api") app = FastAPI() class RequestLoggingMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): start = time.perf_counter() # Log request (before route runs) logger.info( f"REQUEST {request.method} {request.url.path}" f" | client={request.client.host}" f" | ua={request.headers.get('user-agent','unknown')[:40]}" ) try: response = await call_next(request) elapsed = (time.perf_counter() - start) * 1000 # Log response (after route runs) logger.info( f"RESPONSE {request.method} {request.url.path}" f" | status={response.status_code}" f" | time={elapsed:.1f}ms" ) return response except Exception as e: elapsed = (time.perf_counter() - start) * 1000 logger.error( f"ERROR {request.method} {request.url.path}" f" | error={e!r} | time={elapsed:.1f}ms" ) raise # re-raise so FastAPI handles it app.add_middleware(RequestLoggingMiddleware) @app.get("/items") def get_items(): return ["item1", "item2"] # Console output: # 2024-01-15 | INFO | REQUEST GET /items | client=127.0.0.1 | ua=curl/7.68 # 2024-01-15 | INFO | RESPONSE GET /items | status=200 | time=2.1ms
Example 2 — Structured JSON logging (production-grade):
import json, time, uuid, logging from fastapi import FastAPI, Request from starlette.middleware.base import BaseHTTPMiddleware app = FastAPI() class StructuredLogMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): request_id = str(uuid.uuid4()) start = time.perf_counter() response = await call_next(request) elapsed_ms = (time.perf_counter() - start) * 1000 # Structured log as JSON — parseable by tools like Datadog, ELK log_entry = { "request_id": request_id, "method": request.method, "path": request.url.path, "query": str(request.query_params), "status": response.status_code, "duration_ms": round(elapsed_ms, 2), "client_ip": request.client.host, } print(json.dumps(log_entry)) # or: logger.info(json.dumps(log_entry)) response.headers["X-Request-ID"] = request_id return response app.add_middleware(StructuredLogMiddleware)
Metrics Middleware
▾
Metrics middleware tracks aggregate numbers about your API — request counts, latency percentiles, error rates. These power dashboards and alerting. Here's how to build one from scratch (usable with any metrics backend).
from fastapi import FastAPI, Request from starlette.middleware.base import BaseHTTPMiddleware from collections import defaultdict import time app = FastAPI() # ---- In-memory metrics store ---- # In production: use prometheus_client, statsd, or OpenTelemetry metrics = { "request_count": defaultdict(int), # {route: count} "error_count": defaultdict(int), # {route: error_count} "total_latency": defaultdict(float), # {route: total_ms} } class MetricsMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): route = f"{request.method} {request.url.path}" start = time.perf_counter() response = await call_next(request) elapsed_ms = (time.perf_counter() - start) * 1000 # Record metrics metrics["request_count"][route] += 1 metrics["total_latency"][route] += elapsed_ms if response.status_code >= 400: metrics["error_count"][route] += 1 return response app.add_middleware(MetricsMiddleware) # ---- Metrics endpoint ---- @app.get("/metrics") def get_metrics(): result = {} for route, count in metrics["request_count"].items(): avg_latency = metrics["total_latency"][route] / count errors = metrics["error_count"][route] result[route] = { "requests": count, "avg_latency_ms": round(avg_latency, 2), "errors": errors, "error_rate": f"{(errors/count)*100:.1f}%" } return result @app.get("/users") def get_users(): return ["Alice", "Bob"] # After some requests, GET /metrics returns: # { # "GET /users": {"requests": 42, "avg_latency_ms": 1.8, "errors": 0, "error_rate": "0.0%"}, # "GET /metrics": {"requests": 5, "avg_latency_ms": 0.4, "errors": 0, "error_rate": "0.0%"} # }
In production, use
prometheus_client and expose a /metrics endpoint that Prometheus scrapes. The middleware structure is identical — just replace the dict with Prometheus counters/histograms.
Correlation IDs
▾
A Correlation ID is a unique ID assigned to each request that flows through every layer of your system — middleware, routes, services, database calls, logs. When something goes wrong, you can grep all logs by the same correlation ID to trace the full request journey.
Client ──► Middleware assigns ID: abc-123
──► Route logs with ID: abc-123
──► Service logs with ID: abc-123
──► DB query logged with ID: abc-123
Client ◄── Response header: X-Correlation-ID: abc-123
Now you can grep logs for "abc-123" and see the FULL story of one request
Example — Full correlation ID system:
from fastapi import FastAPI, Request, Depends from starlette.middleware.base import BaseHTTPMiddleware from contextvars import ContextVar import uuid, logging app = FastAPI() # ContextVar: async-safe per-request storage (like threading.local but for async) correlation_id_var: ContextVar[str] = ContextVar("correlation_id", default="none") class CorrelationIDMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): # Accept ID from client (useful for distributed tracing) # or generate a new one if not provided cid = request.headers.get("X-Correlation-ID") or str(uuid.uuid4()) # Store in context var — accessible from anywhere in this async task correlation_id_var.set(cid) request.state.correlation_id = cid response = await call_next(request) # Return the ID in response so client can reference it response.headers["X-Correlation-ID"] = cid return response app.add_middleware(CorrelationIDMiddleware) # ---- Helper to get current request's correlation ID ---- def get_correlation_id() -> str: return correlation_id_var.get() # ---- Use in logging ---- class CorrelatedLogger: def __init__(self, name: str): self._logger = logging.getLogger(name) def info(self, msg: str): cid = get_correlation_id() self._logger.info(f"[{cid}] {msg}") log = CorrelatedLogger("myapp") # ---- Route — correlation ID is available everywhere ---- @app.get("/process") def process(request: Request): cid = request.state.correlation_id log.info("Processing request in route") # [abc-123] Processing... log.info("Calling external service") # [abc-123] Calling... return {"processed": True, "correlation_id": cid} # Client sends: X-Correlation-ID: my-trace-123 # All logs tagged: [my-trace-123] Processing request in route # Response header: X-Correlation-ID: my-trace-123
ContextVar is crucial here. Regular global variables would mix up values between concurrent requests. ContextVar is async-safe — each coroutine gets its own isolated value.
Pure ASGI Middleware (Advanced)
▾
BaseHTTPMiddleware is convenient but has a small overhead — it buffers responses. For maximum performance, you can write pure ASGI middleware that works at the raw ASGI protocol level with zero buffering.
Example — pure ASGI middleware:
from fastapi import FastAPI from starlette.types import ASGIApp, Receive, Scope, Send import time class TimingMiddleware: def __init__(self, app: ASGIApp): self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send): if scope["type"] != "http": # Pass WebSocket/lifespan events through unchanged await self.app(scope, receive, send) return start = time.perf_counter() # Wrap the send callable to intercept response start async def send_with_timing(message): if message["type"] == "http.response.start": elapsed = (time.perf_counter() - start) * 1000 # Inject header into response headers = dict(message.get("headers", [])) headers[b"x-process-time"] = f"{elapsed:.2f}ms".encode() message["headers"] = list(headers.items()) await send(message) await self.app(scope, receive, send_with_timing) app = FastAPI() app.add_middleware(TimingMiddleware) @app.get("/fast") async def fast(): return {"fast": True} # Response includes header: x-process-time: 0.42ms
BaseHTTPMiddleware | Pure ASGI | |
|---|---|---|
| Complexity | Simple ✅ | More complex |
| Performance | Slight overhead (buffers body) | Zero overhead ✅ |
| Streaming support | Limited | Full ✅ |
| Best for | Most use cases | High-throughput APIs, streaming |
Combining Multiple Middleware
▾
Real applications stack multiple middleware together. Here's a production-ready pattern combining CORS, logging, metrics, and correlation IDs — in the correct order.
from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from starlette.middleware.trustedhost import TrustedHostMiddleware # assuming you defined these above: # from middleware import CorrelationIDMiddleware, RequestLoggingMiddleware, MetricsMiddleware app = FastAPI() # ── Add middleware in REVERSE execution order ────────────────── # Last added = OUTERMOST (runs first on request, last on response) # 1. Correlation ID — outermost, so ID is available to all other middleware app.add_middleware(CorrelationIDMiddleware) # 2. Logging — wraps everything, so it captures real total time app.add_middleware(RequestLoggingMiddleware) # 3. Metrics — also wraps business logic app.add_middleware(MetricsMiddleware) # 4. GZip — compress before sending (near the edge) app.add_middleware(GZipMiddleware, minimum_size=500) # 5. CORS — must be outermost edge middleware to handle preflight app.add_middleware( CORSMiddleware, allow_origins=["https://myapp.com"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 6. Trusted Host — security, outermost possible app.add_middleware(TrustedHostMiddleware, allowed_hosts=["myapp.com", "localhost"]) # Execution order (request): TrustedHost → CORS → GZip → Metrics → Logging → CorrelationID → Route # Execution order (response): Route → CorrelationID → Logging → Metrics → GZip → CORS → TrustedHost
Order tip: CORS and TrustedHost should be near the outermost layer — they can reject requests before any expensive work is done. GZip should wrap the route response, not the entire stack.
📋 Topic 8 Summary
| Middleware | Purpose | Source |
|---|---|---|
| BaseHTTPMiddleware | Base class for custom middleware | starlette |
| CORSMiddleware | Allow cross-origin browser requests | Built-in |
| GZipMiddleware | Compress large responses automatically | Built-in |
| TrustedHostMiddleware | Reject forged Host headers | Built-in |
| HTTPSRedirectMiddleware | Force HTTPS in production | Built-in |
| Logging Middleware | Record every request/response | Custom |
| Metrics Middleware | Count requests, measure latency | Custom |
| Correlation ID | Trace one request across all logs | Custom |
| Pure ASGI | Max performance, no buffering | Custom |
✅ Approve this topic to continue to Topic 9: Exception Handling