Middleware Examples
This document provides practical examples of how to use middleware in MCP Composer. Each example includes complete, runnable code that demonstrates real-world usage patterns.
Example 1: Logging and Monitoring Middleware
A comprehensive logging middleware that tracks performance metrics and provides observability.
Complete Example
python
import time
import asyncio
from dataclasses import dataclass, field
from typing import Dict, Any
from fastmcp import FastMCP
from fastmcp.server.middleware import Middleware, MiddlewareContext, CallNext
from mcp_composer.core.utils.logger import LoggerFactory
logger = LoggerFactory.get_logger()
@dataclass
class Metrics:
total_calls: int = 0
successful_calls: int = 0
failed_calls: int = 0
total_duration: float = 0.0
call_durations: list = field(default_factory=list)
class LoggingMiddleware(Middleware):
"""
Comprehensive logging and monitoring middleware.
Features:
- Request/response logging
- Performance metrics
- Error tracking
- Custom metrics collection
"""
def __init__(self,
log_arguments: bool = True,
log_responses: bool = False,
collect_metrics: bool = True):
self.log_arguments = log_arguments
self.log_responses = log_responses
self.collect_metrics = collect_metrics
self.metrics = Metrics()
self.lock = asyncio.Lock()
async def on_call_tool(self, context: MiddlewareContext, call_next: CallNext) -> Any:
tool_name = getattr(context.message, "name", "<unknown>")
arguments = getattr(context.message, "arguments", {})
start_time = time.time()
# Log request
logger.info(f"🔄 Starting tool call: {tool_name}")
if self.log_arguments:
logger.debug(f"📥 Arguments: {arguments}")
try:
# Execute the tool
result = await call_next(context)
# Calculate duration
duration = time.time() - start_time
# Log success
logger.info(f"✅ Tool call completed: {tool_name} in {duration:.3f}s")
if self.log_responses:
logger.debug(f"📤 Response: {result}")
# Update metrics
if self.collect_metrics:
await self._update_metrics(duration, success=True)
return result
except Exception as e:
# Calculate duration
duration = time.time() - start_time
# Log error
logger.error(f"❌ Tool call failed: {tool_name} after {duration:.3f}s - {e}")
# Update metrics
if self.collect_metrics:
await self._update_metrics(duration, success=False)
# Re-raise the exception
raise
async def _update_metrics(self, duration: float, success: bool):
"""Thread-safe metrics update."""
async with self.lock:
self.metrics.total_calls += 1
self.metrics.total_duration += duration
self.metrics.call_durations.append(duration)
if success:
self.metrics.successful_calls += 1
else:
self.metrics.failed_calls += 1
# Keep only last 1000 durations to prevent memory bloat
if len(self.metrics.call_durations) > 1000:
self.metrics.call_durations = self.metrics.call_durations[-1000:]
def get_metrics(self) -> Dict[str, Any]:
"""Get current metrics snapshot."""
async def _get():
async with self.lock:
avg_duration = (self.metrics.total_duration / self.metrics.total_calls
if self.metrics.total_calls > 0 else 0.0)
return {
"total_calls": self.metrics.total_calls,
"successful_calls": self.metrics.successful_calls,
"failed_calls": self.metrics.failed_calls,
"success_rate": (self.metrics.successful_calls / self.metrics.total_calls
if self.metrics.total_calls > 0 else 0.0),
"average_duration": avg_duration,
"total_duration": self.metrics.total_duration
}
# Run the async function in the current event loop
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# If we're in an async context, we need to handle this differently
# For simplicity, we'll return a basic dict
return {
"total_calls": self.metrics.total_calls,
"successful_calls": self.metrics.successful_calls,
"failed_calls": self.metrics.failed_calls,
}
else:
return loop.run_until_complete(_get())
except RuntimeError:
# No event loop, return basic metrics
return {
"total_calls": self.metrics.total_calls,
"successful_calls": self.metrics.successful_calls,
"failed_calls": self.metrics.failed_calls,
}
# Example usage
app = FastMCP("Logging Demo")
# Add the logging middleware
logging_middleware = LoggingMiddleware(
log_arguments=True,
log_responses=False,
collect_metrics=True
)
app.add_middleware(logging_middleware)
@app.tool()
async def slow_operation(data: dict) -> dict:
"""Simulate a slow operation."""
await asyncio.sleep(0.1) # Simulate work
return {"result": "success", "processed": data}
@app.tool()
async def fast_operation(data: dict) -> dict:
"""Simulate a fast operation."""
return {"result": "success", "processed": data}
@app.tool()
async def failing_operation(data: dict) -> dict:
"""Simulate a failing operation."""
raise ValueError("Simulated error")
if __name__ == "__main__":
app.run()
Running the Example
bash
# Start the server
python logging_example.py
# In another terminal, test the tools
curl -X POST http://localhost:8000/tools/slow_operation \
-H "Content-Type: application/json" \
-d '{"data": {"test": "value"}}'
# Check metrics
print(logging_middleware.get_metrics())
Example 2: Request Transformation Middleware
A middleware that transforms requests and responses, useful for API versioning, data validation, and format conversion.
Complete Example
python
import json
from typing import Any, Dict, Optional
from fastmcp import FastMCP
from fastmcp.server.middleware import Middleware, MiddlewareContext, CallNext
from mcp_composer.core.utils.logger import LoggerFactory
logger = LoggerFactory.get_logger()
class RequestTransformationMiddleware(Middleware):
"""
Transforms requests and responses for API compatibility.
Features:
- Request validation and sanitization
- Response formatting
- API versioning support
- Data type conversion
"""
def __init__(self,
validate_requests: bool = True,
format_responses: bool = True,
api_version: str = "v1",
required_fields: Optional[Dict[str, list]] = None):
self.validate_requests = validate_requests
self.format_responses = format_responses
self.api_version = api_version
self.required_fields = required_fields or {}
async def on_call_tool(self, context: MiddlewareContext, call_next: CallNext) -> Any:
tool_name = getattr(context.message, "name", "<unknown>")
arguments = getattr(context.message, "arguments", {})
# Transform request
transformed_args = await self._transform_request(tool_name, arguments)
original_args = context.message.arguments
context.message.arguments = transformed_args
try:
# Execute the tool
result = await call_next(context)
# Transform response
if self.format_responses:
result = await self._transform_response(tool_name, result)
return result
finally:
# Restore original arguments
context.message.arguments = original_args
async def _transform_request(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Transform and validate the request arguments."""
transformed = arguments.copy()
# Add API version if not present
if "api_version" not in transformed:
transformed["api_version"] = self.api_version
# Validate required fields
if self.validate_requests and tool_name in self.required_fields:
missing_fields = []
for field in self.required_fields[tool_name]:
if field not in transformed:
missing_fields.append(field)
if missing_fields:
raise ValueError(f"Missing required fields: {missing_fields}")
# Sanitize string inputs
for key, value in transformed.items():
if isinstance(value, str):
transformed[key] = value.strip()
# Convert data types based on tool
if tool_name == "process_user_data":
if "age" in transformed and isinstance(transformed["age"], str):
try:
transformed["age"] = int(transformed["age"])
except ValueError:
raise ValueError("Age must be a valid integer")
logger.debug(f"Transformed request for {tool_name}: {transformed}")
return transformed
async def _transform_response(self, tool_name: str, result: Any) -> Any:
"""Transform the response."""
if not isinstance(result, dict):
return result
transformed = result.copy()
# Add metadata
transformed["_metadata"] = {
"api_version": self.api_version,
"tool_name": tool_name,
"timestamp": "2024-01-01T00:00:00Z" # In real app, use actual timestamp
}
# Format specific responses
if tool_name == "process_user_data" and "user" in transformed:
user = transformed["user"]
if isinstance(user, dict):
# Ensure consistent user object structure
user.setdefault("id", None)
user.setdefault("name", "")
user.setdefault("email", "")
user.setdefault("created_at", None)
logger.debug(f"Transformed response for {tool_name}: {transformed}")
return transformed
# Example usage
app = FastMCP("Transformation Demo")
# Add the transformation middleware
app.add_middleware(
RequestTransformationMiddleware(
validate_requests=True,
format_responses=True,
api_version="v2",
required_fields={
"process_user_data": ["name", "email"],
"update_user": ["user_id"]
}
)
)
@app.tool()
async def process_user_data(name: str, email: str, age: Optional[str] = None) -> dict:
"""Process user data with validation."""
return {
"status": "success",
"user": {
"name": name,
"email": email,
"age": int(age) if age else None,
"processed": True
}
}
@app.tool()
async def update_user(user_id: str, data: dict) -> dict:
"""Update user information."""
return {
"status": "success",
"user_id": user_id,
"updated_fields": list(data.keys()),
"message": "User updated successfully"
}
@app.tool()
async def simple_operation(data: dict) -> dict:
"""Simple operation without special requirements."""
return {"result": "success", "data": data}
if __name__ == "__main__":
app.run()
Testing the Transformation Middleware
bash
# Test with valid data
curl -X POST http://localhost:8000/tools/process_user_data \
-H "Content-Type: application/json" \
-d '{"name": "John Doe", "email": "john@example.com", "age": "30"}'
# Test with missing required field (should fail)
curl -X POST http://localhost:8000/tools/process_user_data \
-H "Content-Type: application/json" \
-d '{"name": "John Doe"}'
# Test simple operation
curl -X POST http://localhost:8000/tools/simple_operation \
-H "Content-Type: application/json" \
-d '{"test": "value"}'
Example 3: Circuit Breaker with Health Check
A practical example showing how to implement a circuit breaker with health checks and graceful degradation.
Complete Example
python
import asyncio
import time
from dataclasses import dataclass, field
from typing import Dict, Set, Optional
from fastmcp import FastMCP
from fastmcp.server.middleware import Middleware, MiddlewareContext, CallNext
from fastmcp.exceptions import ToolError
from mcp_composer.core.utils.logger import LoggerFactory
logger = LoggerFactory.get_logger()
@dataclass
class CircuitState:
state: str = "CLOSED" # CLOSED, OPEN, HALF_OPEN
failures: list = field(default_factory=list)
opened_at: Optional[float] = None
last_success: Optional[float] = None
health_check_count: int = 0
class AdvancedCircuitBreakerMiddleware(Middleware):
"""
Advanced circuit breaker with health checks and graceful degradation.
Features:
- Configurable failure thresholds
- Health check probes
- Graceful degradation with fallback responses
- Detailed monitoring
"""
def __init__(self,
failure_threshold: int = 5,
open_timeout: float = 30.0,
health_check_interval: float = 10.0,
health_check_timeout: float = 5.0,
exempt_tools: Optional[Set[str]] = None,
fallback_responses: Optional[Dict[str, Any]] = None):
self.failure_threshold = failure_threshold
self.open_timeout = open_timeout
self.health_check_interval = health_check_interval
self.health_check_timeout = health_check_timeout
self.exempt_tools = exempt_tools or set()
self.fallback_responses = fallback_responses or {}
self._circuits: Dict[str, CircuitState] = {}
self._lock = asyncio.Lock()
self._health_check_task: Optional[asyncio.Task] = None
async def on_call_tool(self, context: MiddlewareContext, call_next: CallNext) -> Any:
tool_name = getattr(context.message, "name", "<unknown>")
# Skip circuit breaker for exempt tools
if tool_name in self.exempt_tools:
return await call_next(context)
# Get or create circuit state
circuit = await self._get_circuit(tool_name)
# Check circuit state
if circuit.state == "OPEN":
if self._should_attempt_reset(circuit):
circuit.state = "HALF_OPEN"
logger.info(f"Circuit {tool_name} moved to HALF_OPEN")
else:
# Return fallback response or raise error
fallback = self.fallback_responses.get(tool_name)
if fallback:
logger.warning(f"Circuit {tool_name} OPEN, returning fallback response")
return fallback
else:
raise ToolError(f"Circuit breaker OPEN for {tool_name}. Try again later.")
# Execute the call
start_time = time.time()
try:
result = await call_next(context)
# Record success
await self._record_success(tool_name, circuit, start_time)
return result
except Exception as e:
# Record failure
await self._record_failure(tool_name, circuit, start_time, str(e))
raise
async def _get_circuit(self, tool_name: str) -> CircuitState:
"""Get or create circuit state for a tool."""
async with self._lock:
if tool_name not in self._circuits:
self._circuits[tool_name] = CircuitState()
return self._circuits[tool_name]
def _should_attempt_reset(self, circuit: CircuitState) -> bool:
"""Check if circuit should attempt to reset."""
if circuit.opened_at is None:
return False
return time.time() - circuit.opened_at >= self.open_timeout
async def _record_success(self, tool_name: str, circuit: CircuitState, start_time: float):
"""Record a successful call."""
async with self._lock:
circuit.state = "CLOSED"
circuit.last_success = time.time()
circuit.failures.clear()
circuit.health_check_count = 0
duration = time.time() - start_time
logger.info(f"Circuit {tool_name} call succeeded in {duration:.3f}s")
async def _record_failure(self, tool_name: str, circuit: CircuitState,
start_time: float, error: str):
"""Record a failed call."""
async with self._lock:
circuit.failures.append(time.time())
# Clean old failures
cutoff = time.time() - 60.0 # 1 minute window
circuit.failures = [f for f in circuit.failures if f > cutoff]
# Check if circuit should open
if len(circuit.failures) >= self.failure_threshold and circuit.state == "CLOSED":
circuit.state = "OPEN"
circuit.opened_at = time.time()
logger.warning(f"Circuit {tool_name} opened after {len(circuit.failures)} failures")
duration = time.time() - start_time
logger.error(f"Circuit {tool_name} call failed in {duration:.3f}s: {error}")
async def start_health_checks(self):
"""Start periodic health checks."""
if self._health_check_task is None:
self._health_check_task = asyncio.create_task(self._health_check_loop())
async def stop_health_checks(self):
"""Stop health checks."""
if self._health_check_task:
self._health_check_task.cancel()
try:
await self._health_check_task
except asyncio.CancelledError:
pass
self._health_check_task = None
async def _health_check_loop(self):
"""Periodic health check loop."""
while True:
try:
await asyncio.sleep(self.health_check_interval)
await self._perform_health_checks()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Health check error: {e}")
async def _perform_health_checks(self):
"""Perform health checks for all circuits."""
for tool_name, circuit in self._circuits.items():
if circuit.state == "OPEN":
await self._check_circuit_health(tool_name, circuit)
async def _check_circuit_health(self, tool_name: str, circuit: CircuitState):
"""Check health of a specific circuit."""
try:
# This would typically call a health check endpoint
# For this example, we'll simulate a health check
await asyncio.sleep(0.1) # Simulate health check
# If health check passes, move to HALF_OPEN
circuit.state = "HALF_OPEN"
circuit.health_check_count += 1
logger.info(f"Health check passed for {tool_name}, moved to HALF_OPEN")
except Exception as e:
logger.warning(f"Health check failed for {tool_name}: {e}")
def get_circuit_status(self) -> Dict[str, Dict[str, Any]]:
"""Get status of all circuits."""
status = {}
for tool_name, circuit in self._circuits.items():
status[tool_name] = {
"state": circuit.state,
"failure_count": len(circuit.failures),
"last_success": circuit.last_success,
"opened_at": circuit.opened_at,
"health_check_count": circuit.health_check_count
}
return status
# Example usage
app = FastMCP("Circuit Breaker Demo")
# Add circuit breaker middleware
circuit_breaker = AdvancedCircuitBreakerMiddleware(
failure_threshold=3,
open_timeout=10.0,
health_check_interval=5.0,
exempt_tools={"health_check"},
fallback_responses={
"unreliable_service": {"status": "degraded", "message": "Service temporarily unavailable"},
"external_api": {"status": "fallback", "data": "cached_data"}
}
)
app.add_middleware(circuit_breaker)
@app.tool()
async def reliable_service(data: dict) -> dict:
"""A reliable service that rarely fails."""
await asyncio.sleep(0.1)
return {"status": "success", "data": data}
@app.tool()
async def unreliable_service(data: dict) -> dict:
"""An unreliable service that fails frequently."""
import random
if random.random() < 0.7: # 70% failure rate
raise Exception("Service temporarily unavailable")
await asyncio.sleep(0.1)
return {"status": "success", "data": data}
@app.tool()
async def external_api(data: dict) -> dict:
"""External API that might be slow or fail."""
import random
if random.random() < 0.5: # 50% failure rate
await asyncio.sleep(2.0) # Simulate timeout
raise Exception("External API timeout")
await asyncio.sleep(0.2)
return {"status": "success", "external_data": data}
@app.tool()
async def health_check() -> dict:
"""Health check endpoint (exempt from circuit breaker)."""
return {"status": "healthy", "timestamp": time.time()}
# Start health checks when the app starts
@app.on_startup
async def startup():
await circuit_breaker.start_health_checks()
@app.on_shutdown
async def shutdown():
await circuit_breaker.stop_health_checks()
if __name__ == "__main__":
app.run()
Testing the Circuit Breaker
bash
# Test reliable service (should work consistently)
for i in {1..10}; do
curl -X POST http://localhost:8000/tools/reliable_service \
-H "Content-Type: application/json" \
-d '{"test": "value"}'
done
# Test unreliable service (should trigger circuit breaker)
for i in {1..10}; do
curl -X POST http://localhost:8000/tools/unreliable_service \
-H "Content-Type: application/json" \
-d '{"test": "value"}'
done
# Check circuit status
print(circuit_breaker.get_circuit_status())
Summary
These examples demonstrate three key middleware patterns:
- Logging and Monitoring: Provides observability and metrics collection
- Request Transformation: Handles data validation, sanitization, and format conversion
- Circuit Breaker: Implements fault tolerance and graceful degradation
Each example includes:
- Complete, runnable code
- Comprehensive error handling
- Configuration options
- Testing instructions
- Best practices
These patterns can be combined and customized to build robust, observable, and fault-tolerant MCP applications.