---
name: axum-middleware
description: Apply Tower and Axum middleware correctly — ServiceBuilder ordering, from_fn signatures, route_layer vs layer, tower-http layers, and graceful shutdown. Auto-triggers when adding middleware, layers, or shutdown logic to an Axum server.
---

# Axum Middleware

## When to Use

Invoke before:
- Adding authentication, logging, compression, CORS, or timeout middleware.
- Composing multiple layers onto a `Router`.
- Writing a custom `from_fn` or `from_fn_with_state` middleware.
- Implementing graceful shutdown for an Axum server.
- Writing middleware tests with `tower::ServiceExt::oneshot`.

## Core Patterns

### Pattern 1 — ServiceBuilder for Multiple Layers

For two or more layers, prefer `tower::ServiceBuilder`. With `ServiceBuilder`, layers are applied
in **top-to-bottom request order** (intuitive). Multiple `Router::layer()` calls stack
**bottom-to-top** for requests (counterintuitive).

```rust
use tower::ServiceBuilder;
use tower_http::{
    trace::TraceLayer,
    compression::CompressionLayer,
    timeout::TimeoutLayer,
};
use std::time::Duration;

let middleware_stack = ServiceBuilder::new()
    .layer(TraceLayer::new_for_http())        // 1st — request tracing
    .layer(TimeoutLayer::new(Duration::from_secs(30))) // 2nd — timeout
    .layer(CompressionLayer::new());          // 3rd — compress response

let app = Router::new()
    .route("/users", get(list_users))
    .layer(middleware_stack)
    .with_state(state);
```

Equivalent (but counterintuitive) multiple `Router::layer()` calls:

```rust
// BAD: hard to reason about — layers apply bottom-to-top on requests
let app = Router::new()
    .route("/users", get(list_users))
    .layer(CompressionLayer::new())           // applied 1st to requests (outermost)
    .layer(TimeoutLayer::new(Duration::from_secs(30)))
    .layer(TraceLayer::new_for_http());       // applied 3rd (innermost — closest to handler)
```

### Pattern 2 — from_fn and from_fn_with_state

`axum::middleware::from_fn` for stateless middleware. `from_fn_with_state` when the middleware
needs access to `AppState`. **`Request` must be the final argument before `Next`.**

```rust
use axum::{
    middleware::{self, Next},
    extract::{Request, State},
    response::Response,
};

// Stateless middleware
async fn logging_middleware(request: Request, next: Next) -> Response {
    let method = request.method().clone();
    let uri    = request.uri().clone();
    let response = next.run(request).await;
    tracing::info!("{} {} → {}", method, uri, response.status());
    response
}

// Stateful middleware — State MUST come before Request
async fn auth_middleware(
    State(state): State<AppState>,
    request: Request,     // Request is final-before-Next
    next: Next,
) -> Result<Response, AppError> {
    let token = extract_bearer(&request).ok_or(AppError::Unauthorized)?;
    state.jwt.verify(token).map_err(|_| AppError::Unauthorized)?;
    Ok(next.run(request).await)
}

let app = Router::new()
    .route("/users", get(list_users))
    .layer(middleware::from_fn(logging_middleware))
    .route_layer(middleware::from_fn_with_state(state.clone(), auth_middleware));
```

### Pattern 3 — route_layer vs layer

| | `Router::route_layer` | `Router::layer` |
|---|---|---|
| Applies to | Matched routes only | All requests + fallback |
| Unmatched path response | 404 (correct for auth) | Middleware fires → 401/403 |
| Use for | Auth, per-route limits | Tracing, compression, timeout |

```rust
// GOOD: auth as route_layer — unmatched returns 404, not 401
let app = Router::new()
    .route("/admin/{path}", get(admin_handler))
    .route_layer(middleware::from_fn_with_state(state.clone(), auth_middleware))
    .layer(TraceLayer::new_for_http());  // trace applies to everything including fallback

// BAD: auth as layer — leaks 401 on paths that don't exist
let app = Router::new()
    .route("/admin/{path}", get(admin_handler))
    .layer(middleware::from_fn_with_state(state.clone(), auth_middleware));
    // ← GET /nonexistent returns 401, revealing auth is required
```

### Pattern 4 — tower-http Layers (0.6)

```rust
use tower_http::{
    trace::TraceLayer,
    compression::CompressionLayer,
    cors::CorsLayer,
    timeout::TimeoutLayer,
};
use http::Method;

// Tracing
let trace = TraceLayer::new_for_http();

// Compression (gzip, br, deflate — negotiated via Accept-Encoding)
let compression = CompressionLayer::new();

// CORS — permissive() for dev only; never ship to prod
let cors_dev = CorsLayer::permissive();
let cors_prod = CorsLayer::new()
    .allow_origin("https://app.example.com".parse::<HeaderValue>().unwrap())
    .allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE])
    .allow_headers([AUTHORIZATION, CONTENT_TYPE]);

// Timeout — applies per-request; handlers exceeding limit get 408
let timeout = TimeoutLayer::new(Duration::from_secs(30));

let app = Router::new()
    .route("/", get(root))
    .layer(
        ServiceBuilder::new()
            .layer(trace)
            .layer(timeout)
            .layer(cors_prod)
            .layer(compression),
    );
```

### Pattern 5 — Graceful Shutdown

```rust
use tokio::net::TcpListener;
use tokio::signal;
#[cfg(unix)]
use tokio::signal::unix::SignalKind;

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    let listener = TcpListener::bind("0.0.0.0:3000").await?;
    axum::serve(listener, app)
        .with_graceful_shutdown(shutdown_signal())
        .await?;
    Ok(())
}

async fn shutdown_signal() {
    let ctrl_c = async {
        signal::ctrl_c().await.expect("failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(SignalKind::terminate())
            .expect("failed to install SIGTERM handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c    => {},
        _ = terminate => {},
    }
}
```

### Pattern 6 — Testing Middleware with oneshot

```rust
#[cfg(test)]
mod tests {
    use axum::{body::Body, http::Request, http::StatusCode};
    use tower::ServiceExt; // for .oneshot()

    #[tokio::test]
    async fn test_health_returns_200() {
        let app = build_app(test_state());
        let response = app
            .oneshot(Request::get("/health").body(Body::empty()).unwrap())
            .await
            .unwrap();
        assert_eq!(response.status(), StatusCode::OK);
    }

    #[tokio::test]
    async fn test_auth_required() {
        let app = build_app(test_state());
        let response = app
            .oneshot(Request::get("/admin/dashboard").body(Body::empty()).unwrap())
            .await
            .unwrap();
        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
    }
}
```

## Anti-Patterns to Block

- Multiple `Router::layer()` calls for complex stacks → use `ServiceBuilder` for clear
  top-to-bottom request ordering.
- Wrong `from_fn` signature — `Next` before `Request`, or `State` after `Request` → `Request`
  must be final-before-`Next`; `State` must precede `Request`.
- `CorsLayer::permissive()` in production builds → use explicit `allow_origin` / `allow_methods`.
- Auth middleware applied via `Router::layer()` → use `route_layer()` so unmatched paths
  return 404 not 401.
- Spawning Axum server without graceful shutdown → always add `.with_graceful_shutdown(...)`.
- Timeout middleware missing → add `TimeoutLayer` for all public endpoints; omission allows
  slow-loris-style resource exhaustion.
- Modifying `Request` body in middleware and calling `next.run(request)` twice → body is a
  stream; consumed once. Buffer with `axum::body::to_bytes` if inspection needed.

## Verification Hooks

```bash
# Detect State before Request wrong order in from_fn signatures
grep -rn -A5 'async fn.*middleware' src/ | grep -B2 'State(' | grep 'Request.*State\|Next.*State'

# Detect permissive CORS in non-dev code
grep -rn 'CorsLayer::permissive' src/ | grep -v '#\[cfg(test)\]\|test\|dev'

# Detect auth/authz middleware applied via .layer() instead of .route_layer()
grep -rn '\.layer(.*auth\|\.layer(.*jwt\|\.layer(.*require' src/

# Detect missing graceful shutdown
grep -rn 'axum::serve\|Server::bind' src/ | grep -v 'graceful_shutdown\|#\[cfg(test)\]'

# Detect multiple chained .layer() calls (ServiceBuilder preferred)
grep -c '\.layer(' src/**/*.rs 2>/dev/null | awk -F: '$2 > 2 {print $1, "has", $2, "layer() calls — consider ServiceBuilder"}'
```

## References

- https://docs.rs/axum/latest/axum/middleware/index.html
- https://docs.rs/tower-http/latest/tower_http/
- https://docs.rs/tower/latest/tower/builder/struct.ServiceBuilder.html
- https://github.com/tokio-rs/axum/blob/main/CHANGELOG.md (0.8 breaking changes)
