Why Create Custom Middleware?
While Redux provides built-in middleware like Thunk, creating custom middleware allows you to add specific functionality tailored to your application's needs.
🛠️ The Workshop Assembly Line Analogy
Think of custom middleware as specialized workstations in an assembly line:
- Raw Material: Actions entering the system
- Workstations: Custom middleware that process actions
- Quality Control: Validation and error checking
- Customization: Each workstation adds specific value
- Final Product: Processed actions reaching reducers
Just as you can add specialized workstations to handle unique manufacturing needs, you can create custom middleware to handle specific application requirements.
Anatomy of Custom Middleware
graph TD
A[Action Dispatched] --> B[Middleware Function]
B --> C{Process Action}
C --> D[Modify Action]
C --> E[Add Side Effects]
C --> F[Stop Action]
D --> G[next(action)]
E --> G
F --> H[Return Early]
G --> I[Next Middleware]
I --> J[Eventually Reducer]
style B fill:#f96
style C fill:#9cf
style G fill:#9f9
Middleware Structure
// Basic middleware template
const customMiddleware = store => next => action => {
// Pre-processing phase
console.log('Before:', action);
// Modify action or perform side effects
if (action.type === 'SPECIAL_ACTION') {
// Do something special
}
// Pass action to next middleware
const result = next(action);
// Post-processing phase
console.log('After:', store.getState());
// Return result
return result;
};
// Expanded version for clarity
function customMiddleware(store) {
return function wrapDispatch(next) {
return function handleAction(action) {
// Your middleware logic here
return next(action);
}
}
}
// What you have access to:
// - store.getState(): Get current state
// - store.dispatch(): Dispatch new actions
// - next: Call next middleware in chain
// - action: The current action being processed
Common Custom Middleware Patterns
1. Action Transformation Middleware
// Automatically add timestamps to actions
const timestampMiddleware = store => next => action => {
// Add timestamp to all actions
const timestampedAction = {
...action,
meta: {
...action.meta,
timestamp: Date.now()
}
};
return next(timestampedAction);
};
// Transform action types based on conditions
const actionTransformer = store => next => action => {
// Transform actions for A/B testing
if (action.type === 'SHOW_FEATURE' && store.getState().experiments.newUI) {
return next({
...action,
type: 'SHOW_FEATURE_VARIANT_B'
});
}
return next(action);
};
// Normalize action payloads
const payloadNormalizer = store => next => action => {
// Ensure consistent payload structure
if (action.type.startsWith('FETCH_') && action.type.endsWith('_SUCCESS')) {
const normalizedAction = {
...action,
payload: {
data: Array.isArray(action.payload) ? action.payload : [action.payload],
timestamp: Date.now(),
normalized: true
}
};
return next(normalizedAction);
}
return next(action);
};
2. Validation Middleware
// Action validation middleware
const actionValidator = store => next => action => {
// Validate action structure
if (!action || typeof action !== 'object') {
throw new Error('Actions must be plain objects');
}
if (typeof action.type === 'undefined') {
throw new Error('Actions must have a type property');
}
// Validate specific action types
if (action.type === 'ADD_USER') {
const { payload } = action;
if (!payload || !payload.name || !payload.email) {
console.error('Invalid ADD_USER action:', action);
return next({
type: 'ADD_USER_ERROR',
payload: 'Name and email are required',
error: true
});
}
// Validate email format
const emailRegex = /^[^\s@]+@[^\s@]+\.[^\s@]+$/;
if (!emailRegex.test(payload.email)) {
return next({
type: 'ADD_USER_ERROR',
payload: 'Invalid email format',
error: true
});
}
}
return next(action);
};
// Schema validation middleware
import Joi from 'joi';
const schemas = {
ADD_PRODUCT: Joi.object({
type: Joi.string().required(),
payload: Joi.object({
name: Joi.string().required().min(3),
price: Joi.number().required().positive(),
description: Joi.string().required(),
category: Joi.string().required()
}).required()
})
};
const schemaValidator = store => next => action => {
const schema = schemas[action.type];
if (schema) {
const { error } = schema.validate(action);
if (error) {
console.error('Validation error:', error.details);
return next({
type: `${action.type}_VALIDATION_ERROR`,
payload: error.details.map(d => d.message),
error: true
});
}
}
return next(action);
};
3. Feature Flag Middleware
// Feature flag middleware
const featureFlagMiddleware = store => next => action => {
const state = store.getState();
const flags = state.featureFlags || {};
// Check if action requires a feature flag
if (action.meta && action.meta.featureFlag) {
const flagName = action.meta.featureFlag;
if (!flags[flagName]) {
console.warn(`Feature ${flagName} is disabled`);
return next({
type: 'FEATURE_DISABLED',
payload: {
feature: flagName,
originalAction: action.type
}
});
}
}
// Handle feature-specific action modifications
if (action.type === 'RENDER_DASHBOARD' && flags.newDashboard) {
return next({
...action,
type: 'RENDER_NEW_DASHBOARD'
});
}
return next(action);
};
// Usage with feature flags
dispatch({
type: 'USE_ADVANCED_SEARCH',
payload: { query: 'test' },
meta: { featureFlag: 'advancedSearch' }
});
4. Persistence Middleware
// Local storage persistence middleware
const persistenceMiddleware = store => next => action => {
const result = next(action);
// Define which parts of state to persist
const persistConfig = {
auth: ['token', 'user'],
preferences: ['theme', 'language'],
cart: ['items']
};
// Save specific state slices after certain actions
if (action.type.startsWith('AUTH_') ||
action.type.startsWith('PREFERENCES_') ||
action.type.startsWith('CART_')) {
const state = store.getState();
const persistData = {};
Object.entries(persistConfig).forEach(([key, fields]) => {
if (state[key]) {
persistData[key] = fields.reduce((acc, field) => {
if (state[key][field] !== undefined) {
acc[field] = state[key][field];
}
return acc;
}, {});
}
});
try {
localStorage.setItem('app_state', JSON.stringify(persistData));
} catch (error) {
console.error('Failed to persist state:', error);
}
}
return result;
};
// Load persisted state on app start
const loadPersistedState = () => {
try {
const persistedData = localStorage.getItem('app_state');
return persistedData ? JSON.parse(persistedData) : undefined;
} catch (error) {
console.error('Failed to load persisted state:', error);
return undefined;
}
};
// Use with store creation
const persistedState = loadPersistedState();
const store = createStore(
rootReducer,
persistedState,
applyMiddleware(persistenceMiddleware)
);
Advanced Middleware Patterns
1. Rate Limiting Middleware
// Rate limiting middleware
const rateLimiter = (config = {}) => {
const limits = new Map();
const defaultLimit = config.defaultLimit || 10;
const windowMs = config.windowMs || 60000; // 1 minute
return store => next => action => {
// Check if action should be rate limited
const limitConfig = config.limits?.[action.type];
if (!limitConfig && !config.limitAll) {
return next(action);
}
const limit = limitConfig?.limit || defaultLimit;
const window = limitConfig?.windowMs || windowMs;
const key = `${action.type}_${Math.floor(Date.now() / window)}`;
// Get current count
const current = limits.get(key) || 0;
if (current >= limit) {
console.warn(`Rate limit exceeded for ${action.type}`);
return next({
type: 'RATE_LIMIT_EXCEEDED',
payload: {
actionType: action.type,
limit,
window
},
error: true
});
}
// Increment counter
limits.set(key, current + 1);
// Clean up old entries
if (limits.size > 1000) {
const cutoff = Date.now() - window;
for (const [k, _] of limits) {
const timestamp = parseInt(k.split('_')[1]);
if (timestamp < cutoff) {
limits.delete(k);
}
}
}
return next(action);
};
};
// Usage
const store = createStore(
rootReducer,
applyMiddleware(
rateLimiter({
limits: {
'API_REQUEST': { limit: 100, windowMs: 60000 },
'SEND_MESSAGE': { limit: 5, windowMs: 10000 }
}
})
)
);
2. Undo/Redo Middleware
// Undo/redo middleware
const undoable = (config = {}) => {
const history = {
past: [],
present: null,
future: []
};
const { limit = 50, filter = () => true } = config;
return store => next => action => {
// Handle undo/redo actions
if (action.type === 'UNDO') {
if (history.past.length === 0) return;
const previous = history.past[history.past.length - 1];
const newPast = history.past.slice(0, history.past.length - 1);
history.past = newPast;
history.future = [history.present, ...history.future];
history.present = previous;
return next({
type: 'SET_STATE',
payload: previous
});
}
if (action.type === 'REDO') {
if (history.future.length === 0) return;
const nextState = history.future[0];
const newFuture = history.future.slice(1);
history.past = [...history.past, history.present];
history.present = nextState;
history.future = newFuture;
return next({
type: 'SET_STATE',
payload: nextState
});
}
// For normal actions, save current state
if (filter(action) && action.type !== 'SET_STATE') {
const currentState = store.getState();
if (history.present !== null) {
history.past = [...history.past, history.present].slice(-limit);
}
history.present = currentState;
history.future = [];
}
return next(action);
};
};
// Usage
const store = createStore(
rootReducer,
applyMiddleware(
undoable({
limit: 30,
filter: action => !action.type.startsWith('@@redux/')
})
)
);
3. Action Batching Middleware
// Action batching middleware
const batchActions = store => next => {
let batchedActions = [];
let isBatching = false;
let flushTimeout = null;
const flushBatch = () => {
if (batchedActions.length === 0) return;
const actions = [...batchedActions];
batchedActions = [];
next({
type: 'BATCH_ACTIONS',
payload: actions
});
};
return action => {
// Start batching
if (action.type === 'BATCH_START') {
isBatching = true;
return;
}
// End batching and flush
if (action.type === 'BATCH_END') {
isBatching = false;
flushBatch();
return;
}
// Handle batched actions
if (isBatching) {
batchedActions.push(action);
// Auto-flush after delay
if (flushTimeout) clearTimeout(flushTimeout);
flushTimeout = setTimeout(() => {
isBatching = false;
flushBatch();
}, 50);
return;
}
// Handle batch actions
if (action.type === 'BATCH_ACTIONS') {
const results = action.payload.map(a => next(a));
return results[results.length - 1];
}
return next(action);
};
};
// Usage
dispatch({ type: 'BATCH_START' });
dispatch({ type: 'UPDATE_ITEM', payload: { id: 1, value: 'a' } });
dispatch({ type: 'UPDATE_ITEM', payload: { id: 2, value: 'b' } });
dispatch({ type: 'UPDATE_ITEM', payload: { id: 3, value: 'c' } });
dispatch({ type: 'BATCH_END' });
Real-World Example: Analytics Middleware
// Comprehensive analytics middleware
const analyticsMiddleware = (analytics, config = {}) => {
const {
ignoreActions = ['@@redux/', 'persist/'],
sampleRate = 1.0,
getUserId = (state) => state.auth?.user?.id,
getSessionId = () => sessionStorage.getItem('sessionId'),
enrichData = (data, state) => data
} = config;
return store => next => action => {
// Filter out ignored actions
if (ignoreActions.some(prefix => action.type.startsWith(prefix))) {
return next(action);
}
// Apply sampling
if (Math.random() > sampleRate) {
return next(action);
}
const startTime = performance.now();
const prevState = store.getState();
// Process action
const result = next(action);
const endTime = performance.now();
const nextState = store.getState();
// Prepare analytics data
const analyticsData = {
action: {
type: action.type,
payload: action.payload
},
timestamp: new Date().toISOString(),
duration: endTime - startTime,
userId: getUserId(nextState),
sessionId: getSessionId(),
metadata: {
url: window.location.href,
referrer: document.referrer,
userAgent: navigator.userAgent
}
};
// Track state changes
if (action.type.includes('SUCCESS') || action.type.includes('FAILURE')) {
analyticsData.stateChange = {
before: getRelevantState(prevState, action.type),
after: getRelevantState(nextState, action.type)
};
}
// Enrich data with custom logic
const enrichedData = enrichData(analyticsData, nextState);
// Send to analytics service
try {
analytics.track('redux_action', enrichedData);
// Track specific events
if (action.type === 'CHECKOUT_SUCCESS') {
analytics.track('purchase', {
orderId: action.payload.orderId,
total: action.payload.total,
items: action.payload.items
});
}
if (action.type === 'ERROR_OCCURRED') {
analytics.track('error', {
error: action.payload,
context: analyticsData
});
}
} catch (error) {
console.error('Analytics error:', error);
}
return result;
};
};
// Helper function to get relevant state
function getRelevantState(state, actionType) {
if (actionType.includes('USER')) return state.user;
if (actionType.includes('CART')) return state.cart;
if (actionType.includes('PRODUCT')) return state.products;
return null;
}
// Usage with analytics service
import Analytics from 'analytics-service';
const analytics = new Analytics({
apiKey: process.env.ANALYTICS_KEY,
environment: process.env.NODE_ENV
});
const store = createStore(
rootReducer,
applyMiddleware(
analyticsMiddleware(analytics, {
sampleRate: process.env.NODE_ENV === 'production' ? 0.1 : 1.0,
ignoreActions: ['@@redux/', 'persist/', 'MOUSE_MOVE'],
enrichData: (data, state) => ({
...data,
experiment: state.experiments?.currentExperiment,
subscription: state.user?.subscription?.type
})
})
)
);
Testing Custom Middleware
// Testing custom middleware
import { createStore, applyMiddleware } from 'redux';
describe('timestampMiddleware', () => {
let store;
let mockNext;
beforeEach(() => {
mockNext = jest.fn(action => action);
store = {
getState: jest.fn(() => ({})),
dispatch: jest.fn()
};
});
it('should add timestamp to actions', () => {
const middleware = timestampMiddleware(store)(mockNext);
const action = { type: 'TEST_ACTION' };
const result = middleware(action);
expect(mockNext).toHaveBeenCalledWith(
expect.objectContaining({
type: 'TEST_ACTION',
meta: expect.objectContaining({
timestamp: expect.any(Number)
})
})
);
});
it('should preserve existing meta', () => {
const middleware = timestampMiddleware(store)(mockNext);
const action = {
type: 'TEST_ACTION',
meta: { existingProp: 'value' }
};
middleware(action);
expect(mockNext).toHaveBeenCalledWith(
expect.objectContaining({
meta: expect.objectContaining({
existingProp: 'value',
timestamp: expect.any(Number)
})
})
);
});
});
// Testing async middleware
describe('rateLimiter middleware', () => {
jest.useFakeTimers();
it('should allow actions within rate limit', () => {
const middleware = rateLimiter({
limits: { 'TEST_ACTION': { limit: 2, windowMs: 1000 } }
});
const store = { getState: jest.fn(), dispatch: jest.fn() };
const next = jest.fn();
const action = { type: 'TEST_ACTION' };
const handler = middleware(store)(next);
handler(action);
handler(action);
expect(next).toHaveBeenCalledTimes(2);
expect(next).toHaveBeenCalledWith(action);
});
it('should block actions exceeding rate limit', () => {
const middleware = rateLimiter({
limits: { 'TEST_ACTION': { limit: 2, windowMs: 1000 } }
});
const store = { getState: jest.fn(), dispatch: jest.fn() };
const next = jest.fn();
const action = { type: 'TEST_ACTION' };
const handler = middleware(store)(next);
handler(action);
handler(action);
handler(action); // This should be blocked
expect(next).toHaveBeenCalledTimes(3);
expect(next).toHaveBeenLastCalledWith(
expect.objectContaining({
type: 'RATE_LIMIT_EXCEEDED',
error: true
})
);
});
});
Best Practices for Custom Middleware
Do's
- Keep middleware focused on a single responsibility
- Always call next(action) unless intentionally blocking
- Handle errors gracefully
- Make middleware configurable
- Document the middleware's purpose and usage
- Test thoroughly with different action types
Don'ts
- Don't make middleware dependent on action order
- Don't mutate actions or state directly
- Don't perform heavy computations synchronously
- Don't create circular dependencies
- Don't dispatch actions in an infinite loop
Performance Considerations
// ❌ Bad: Heavy synchronous computation
const badMiddleware = store => next => action => {
// This blocks the thread
const result = heavyComputation(store.getState());
return next({ ...action, computed: result });
};
// ✅ Good: Defer heavy computation
const goodMiddleware = store => next => action => {
if (action.type === 'COMPUTE_HEAVY') {
// Defer to next tick
Promise.resolve().then(() => {
const result = heavyComputation(store.getState());
store.dispatch({
type: 'COMPUTATION_COMPLETE',
payload: result
});
});
}
return next(action);
};
// ✅ Better: Use web workers for heavy computation
const workerMiddleware = store => next => action => {
if (action.type === 'COMPUTE_HEAVY') {
const worker = new Worker('computation-worker.js');
worker.postMessage({
type: 'COMPUTE',
data: store.getState()
});
worker.onmessage = (event) => {
store.dispatch({
type: 'COMPUTATION_COMPLETE',
payload: event.data
});
worker.terminate();
};
}
return next(action);
};
Practice Exercise
Task: Create a Caching Middleware
Build a middleware that caches API responses and serves them from cache when appropriate:
- Cache successful API responses
- Implement cache expiration
- Support cache invalidation
- Handle cache size limits
- Provide cache statistics
// TODO: Implement caching middleware
const createCacheMiddleware = (options = {}) => {
const {
maxAge = 5 * 60 * 1000, // 5 minutes
maxEntries = 100,
storage = 'memory' // 'memory' | 'localStorage' | 'sessionStorage'
} = options;
// TODO: Initialize cache storage
return store => next => action => {
// TODO: Check if action is cacheable API request
// TODO: Check cache for existing response
// TODO: Return cached response if valid
// TODO: For API responses, update cache
// TODO: Implement cache cleanup
// TODO: Handle cache statistics actions
};
};
// Example usage:
const cacheMiddleware = createCacheMiddleware({
maxAge: 10 * 60 * 1000, // 10 minutes
maxEntries: 200,
storage: 'localStorage'
});
// Actions to implement:
// - API_REQUEST (check cache)
// - API_SUCCESS (update cache)
// - INVALIDATE_CACHE (clear specific cache)
// - CLEAR_CACHE (clear all cache)
// - GET_CACHE_STATS (return cache statistics)