Introduction
Middleware in FastAPI allows you to process requests and responses globally. This tutorial covers built-in CORS middleware and creating custom middleware for various purposes.
CORS Configuration
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["https://example.com", "https://www.example.com"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Different origins for different environments
@app.on_event("startup")
async def add_cors_middleware():
import os
if os.getenv("ENV") == "production":
app.add_middleware(
CORSMiddleware,
allow_origins=["https://app.example.com"],
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["Authorization", "Content-Type"],
)
else:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
Custom Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi import Request, Response
import time
import logging
class RequestTimingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
class LoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
logging.info(f"Request: {request.method} {request.url}")
response = await call_next(request)
logging.info(f"Response: {response.status_code}")
return response
class RateLimitMiddleware(BaseHTTPMiddleware):
def __init__(self, app, max_requests: int = 100, window: int = 60):
super().__init__(app)
self.max_requests = max_requests
self.window = window
self.requests = {}
async def dispatch(self, request: Request, call_next):
client_ip = request.client.host
current_time = time.time()
# Clean old requests
self.requests = {
ip: times
for ip, times in self.requests.items()
if current_time - times[-1] < self.window
}
# Check rate limit
if client_ip in self.requests:
if len(self.requests[client_ip]) >= self.max_requests:
return Response(
content="Rate limit exceeded",
status_code=429
)
self.requests[client_ip].append(current_time)
else:
self.requests[client_ip] = [current_time]
return await call_next(request)
app.add_middleware(RequestTimingMiddleware)
app.add_middleware(LoggingMiddleware)
Middleware with Error Handling
from fastapi import Request
from starlette.responses import JSONResponse
class ErrorHandlingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
try:
response = await call_next(request)
return response
except Exception as e:
logging.exception("Unhandled exception")
return JSONResponse(
status_code=500,
content={"error": "Internal server error"}
)
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Strict-Transport-Security"] = "max-age=31536000"
return response
Practice Problems
- Create middleware that adds authentication token validation
- Implement a middleware that compresses responses for large payloads
- Add a middleware that tracks request/response sizes
- Build middleware that implements request ID tracing
- Create middleware that handles session management