Creating Custom Redux Middleware

Why Create Custom Middleware?

While Redux provides built-in middleware like Thunk, creating custom middleware allows you to add specific functionality tailored to your application's needs.

🛠️ The Workshop Assembly Line Analogy

Think of custom middleware as specialized workstations in an assembly line:

  • Raw Material: Actions entering the system
  • Workstations: Custom middleware that process actions
  • Quality Control: Validation and error checking
  • Customization: Each workstation adds specific value
  • Final Product: Processed actions reaching reducers

Just as you can add specialized workstations to handle unique manufacturing needs, you can create custom middleware to handle specific application requirements.

Anatomy of Custom Middleware

graph TD A[Action Dispatched] --> B[Middleware Function] B --> C{Process Action} C --> D[Modify Action] C --> E[Add Side Effects] C --> F[Stop Action] D --> G[next(action)] E --> G F --> H[Return Early] G --> I[Next Middleware] I --> J[Eventually Reducer] style B fill:#f96 style C fill:#9cf style G fill:#9f9

Middleware Structure


// Basic middleware template
const customMiddleware = store => next => action => {
  // Pre-processing phase
  console.log('Before:', action);
  
  // Modify action or perform side effects
  if (action.type === 'SPECIAL_ACTION') {
    // Do something special
  }
  
  // Pass action to next middleware
  const result = next(action);
  
  // Post-processing phase
  console.log('After:', store.getState());
  
  // Return result
  return result;
};

// Expanded version for clarity
function customMiddleware(store) {
  return function wrapDispatch(next) {
    return function handleAction(action) {
      // Your middleware logic here
      return next(action);
    }
  }
}

// What you have access to:
// - store.getState(): Get current state
// - store.dispatch(): Dispatch new actions
// - next: Call next middleware in chain
// - action: The current action being processed
            

Common Custom Middleware Patterns

1. Action Transformation Middleware


// Automatically add timestamps to actions
const timestampMiddleware = store => next => action => {
  // Add timestamp to all actions
  const timestampedAction = {
    ...action,
    meta: {
      ...action.meta,
      timestamp: Date.now()
    }
  };
  
  return next(timestampedAction);
};

// Transform action types based on conditions
const actionTransformer = store => next => action => {
  // Transform actions for A/B testing
  if (action.type === 'SHOW_FEATURE' && store.getState().experiments.newUI) {
    return next({
      ...action,
      type: 'SHOW_FEATURE_VARIANT_B'
    });
  }
  
  return next(action);
};

// Normalize action payloads
const payloadNormalizer = store => next => action => {
  // Ensure consistent payload structure
  if (action.type.startsWith('FETCH_') && action.type.endsWith('_SUCCESS')) {
    const normalizedAction = {
      ...action,
      payload: {
        data: Array.isArray(action.payload) ? action.payload : [action.payload],
        timestamp: Date.now(),
        normalized: true
      }
    };
    return next(normalizedAction);
  }
  
  return next(action);
};
            

2. Validation Middleware


// Action validation middleware
const actionValidator = store => next => action => {
  // Validate action structure
  if (!action || typeof action !== 'object') {
    throw new Error('Actions must be plain objects');
  }
  
  if (typeof action.type === 'undefined') {
    throw new Error('Actions must have a type property');
  }
  
  // Validate specific action types
  if (action.type === 'ADD_USER') {
    const { payload } = action;
    if (!payload || !payload.name || !payload.email) {
      console.error('Invalid ADD_USER action:', action);
      return next({
        type: 'ADD_USER_ERROR',
        payload: 'Name and email are required',
        error: true
      });
    }
    
    // Validate email format
    const emailRegex = /^[^\s@]+@[^\s@]+\.[^\s@]+$/;
    if (!emailRegex.test(payload.email)) {
      return next({
        type: 'ADD_USER_ERROR',
        payload: 'Invalid email format',
        error: true
      });
    }
  }
  
  return next(action);
};

// Schema validation middleware
import Joi from 'joi';

const schemas = {
  ADD_PRODUCT: Joi.object({
    type: Joi.string().required(),
    payload: Joi.object({
      name: Joi.string().required().min(3),
      price: Joi.number().required().positive(),
      description: Joi.string().required(),
      category: Joi.string().required()
    }).required()
  })
};

const schemaValidator = store => next => action => {
  const schema = schemas[action.type];
  
  if (schema) {
    const { error } = schema.validate(action);
    
    if (error) {
      console.error('Validation error:', error.details);
      return next({
        type: `${action.type}_VALIDATION_ERROR`,
        payload: error.details.map(d => d.message),
        error: true
      });
    }
  }
  
  return next(action);
};
            

3. Feature Flag Middleware


// Feature flag middleware
const featureFlagMiddleware = store => next => action => {
  const state = store.getState();
  const flags = state.featureFlags || {};
  
  // Check if action requires a feature flag
  if (action.meta && action.meta.featureFlag) {
    const flagName = action.meta.featureFlag;
    
    if (!flags[flagName]) {
      console.warn(`Feature ${flagName} is disabled`);
      return next({
        type: 'FEATURE_DISABLED',
        payload: {
          feature: flagName,
          originalAction: action.type
        }
      });
    }
  }
  
  // Handle feature-specific action modifications
  if (action.type === 'RENDER_DASHBOARD' && flags.newDashboard) {
    return next({
      ...action,
      type: 'RENDER_NEW_DASHBOARD'
    });
  }
  
  return next(action);
};

// Usage with feature flags
dispatch({
  type: 'USE_ADVANCED_SEARCH',
  payload: { query: 'test' },
  meta: { featureFlag: 'advancedSearch' }
});
            

4. Persistence Middleware


// Local storage persistence middleware
const persistenceMiddleware = store => next => action => {
  const result = next(action);
  
  // Define which parts of state to persist
  const persistConfig = {
    auth: ['token', 'user'],
    preferences: ['theme', 'language'],
    cart: ['items']
  };
  
  // Save specific state slices after certain actions
  if (action.type.startsWith('AUTH_') || 
      action.type.startsWith('PREFERENCES_') || 
      action.type.startsWith('CART_')) {
    
    const state = store.getState();
    const persistData = {};
    
    Object.entries(persistConfig).forEach(([key, fields]) => {
      if (state[key]) {
        persistData[key] = fields.reduce((acc, field) => {
          if (state[key][field] !== undefined) {
            acc[field] = state[key][field];
          }
          return acc;
        }, {});
      }
    });
    
    try {
      localStorage.setItem('app_state', JSON.stringify(persistData));
    } catch (error) {
      console.error('Failed to persist state:', error);
    }
  }
  
  return result;
};

// Load persisted state on app start
const loadPersistedState = () => {
  try {
    const persistedData = localStorage.getItem('app_state');
    return persistedData ? JSON.parse(persistedData) : undefined;
  } catch (error) {
    console.error('Failed to load persisted state:', error);
    return undefined;
  }
};

// Use with store creation
const persistedState = loadPersistedState();
const store = createStore(
  rootReducer,
  persistedState,
  applyMiddleware(persistenceMiddleware)
);
            

Advanced Middleware Patterns

1. Rate Limiting Middleware


// Rate limiting middleware
const rateLimiter = (config = {}) => {
  const limits = new Map();
  const defaultLimit = config.defaultLimit || 10;
  const windowMs = config.windowMs || 60000; // 1 minute
  
  return store => next => action => {
    // Check if action should be rate limited
    const limitConfig = config.limits?.[action.type];
    if (!limitConfig && !config.limitAll) {
      return next(action);
    }
    
    const limit = limitConfig?.limit || defaultLimit;
    const window = limitConfig?.windowMs || windowMs;
    const key = `${action.type}_${Math.floor(Date.now() / window)}`;
    
    // Get current count
    const current = limits.get(key) || 0;
    
    if (current >= limit) {
      console.warn(`Rate limit exceeded for ${action.type}`);
      return next({
        type: 'RATE_LIMIT_EXCEEDED',
        payload: {
          actionType: action.type,
          limit,
          window
        },
        error: true
      });
    }
    
    // Increment counter
    limits.set(key, current + 1);
    
    // Clean up old entries
    if (limits.size > 1000) {
      const cutoff = Date.now() - window;
      for (const [k, _] of limits) {
        const timestamp = parseInt(k.split('_')[1]);
        if (timestamp < cutoff) {
          limits.delete(k);
        }
      }
    }
    
    return next(action);
  };
};

// Usage
const store = createStore(
  rootReducer,
  applyMiddleware(
    rateLimiter({
      limits: {
        'API_REQUEST': { limit: 100, windowMs: 60000 },
        'SEND_MESSAGE': { limit: 5, windowMs: 10000 }
      }
    })
  )
);
            

2. Undo/Redo Middleware


// Undo/redo middleware
const undoable = (config = {}) => {
  const history = {
    past: [],
    present: null,
    future: []
  };
  
  const { limit = 50, filter = () => true } = config;
  
  return store => next => action => {
    // Handle undo/redo actions
    if (action.type === 'UNDO') {
      if (history.past.length === 0) return;
      
      const previous = history.past[history.past.length - 1];
      const newPast = history.past.slice(0, history.past.length - 1);
      
      history.past = newPast;
      history.future = [history.present, ...history.future];
      history.present = previous;
      
      return next({
        type: 'SET_STATE',
        payload: previous
      });
    }
    
    if (action.type === 'REDO') {
      if (history.future.length === 0) return;
      
      const nextState = history.future[0];
      const newFuture = history.future.slice(1);
      
      history.past = [...history.past, history.present];
      history.present = nextState;
      history.future = newFuture;
      
      return next({
        type: 'SET_STATE',
        payload: nextState
      });
    }
    
    // For normal actions, save current state
    if (filter(action) && action.type !== 'SET_STATE') {
      const currentState = store.getState();
      
      if (history.present !== null) {
        history.past = [...history.past, history.present].slice(-limit);
      }
      
      history.present = currentState;
      history.future = [];
    }
    
    return next(action);
  };
};

// Usage
const store = createStore(
  rootReducer,
  applyMiddleware(
    undoable({
      limit: 30,
      filter: action => !action.type.startsWith('@@redux/')
    })
  )
);
            

3. Action Batching Middleware


// Action batching middleware
const batchActions = store => next => {
  let batchedActions = [];
  let isBatching = false;
  let flushTimeout = null;
  
  const flushBatch = () => {
    if (batchedActions.length === 0) return;
    
    const actions = [...batchedActions];
    batchedActions = [];
    
    next({
      type: 'BATCH_ACTIONS',
      payload: actions
    });
  };
  
  return action => {
    // Start batching
    if (action.type === 'BATCH_START') {
      isBatching = true;
      return;
    }
    
    // End batching and flush
    if (action.type === 'BATCH_END') {
      isBatching = false;
      flushBatch();
      return;
    }
    
    // Handle batched actions
    if (isBatching) {
      batchedActions.push(action);
      
      // Auto-flush after delay
      if (flushTimeout) clearTimeout(flushTimeout);
      flushTimeout = setTimeout(() => {
        isBatching = false;
        flushBatch();
      }, 50);
      
      return;
    }
    
    // Handle batch actions
    if (action.type === 'BATCH_ACTIONS') {
      const results = action.payload.map(a => next(a));
      return results[results.length - 1];
    }
    
    return next(action);
  };
};

// Usage
dispatch({ type: 'BATCH_START' });
dispatch({ type: 'UPDATE_ITEM', payload: { id: 1, value: 'a' } });
dispatch({ type: 'UPDATE_ITEM', payload: { id: 2, value: 'b' } });
dispatch({ type: 'UPDATE_ITEM', payload: { id: 3, value: 'c' } });
dispatch({ type: 'BATCH_END' });
            

Real-World Example: Analytics Middleware


// Comprehensive analytics middleware
const analyticsMiddleware = (analytics, config = {}) => {
  const {
    ignoreActions = ['@@redux/', 'persist/'],
    sampleRate = 1.0,
    getUserId = (state) => state.auth?.user?.id,
    getSessionId = () => sessionStorage.getItem('sessionId'),
    enrichData = (data, state) => data
  } = config;
  
  return store => next => action => {
    // Filter out ignored actions
    if (ignoreActions.some(prefix => action.type.startsWith(prefix))) {
      return next(action);
    }
    
    // Apply sampling
    if (Math.random() > sampleRate) {
      return next(action);
    }
    
    const startTime = performance.now();
    const prevState = store.getState();
    
    // Process action
    const result = next(action);
    
    const endTime = performance.now();
    const nextState = store.getState();
    
    // Prepare analytics data
    const analyticsData = {
      action: {
        type: action.type,
        payload: action.payload
      },
      timestamp: new Date().toISOString(),
      duration: endTime - startTime,
      userId: getUserId(nextState),
      sessionId: getSessionId(),
      metadata: {
        url: window.location.href,
        referrer: document.referrer,
        userAgent: navigator.userAgent
      }
    };
    
    // Track state changes
    if (action.type.includes('SUCCESS') || action.type.includes('FAILURE')) {
      analyticsData.stateChange = {
        before: getRelevantState(prevState, action.type),
        after: getRelevantState(nextState, action.type)
      };
    }
    
    // Enrich data with custom logic
    const enrichedData = enrichData(analyticsData, nextState);
    
    // Send to analytics service
    try {
      analytics.track('redux_action', enrichedData);
      
      // Track specific events
      if (action.type === 'CHECKOUT_SUCCESS') {
        analytics.track('purchase', {
          orderId: action.payload.orderId,
          total: action.payload.total,
          items: action.payload.items
        });
      }
      
      if (action.type === 'ERROR_OCCURRED') {
        analytics.track('error', {
          error: action.payload,
          context: analyticsData
        });
      }
    } catch (error) {
      console.error('Analytics error:', error);
    }
    
    return result;
  };
};

// Helper function to get relevant state
function getRelevantState(state, actionType) {
  if (actionType.includes('USER')) return state.user;
  if (actionType.includes('CART')) return state.cart;
  if (actionType.includes('PRODUCT')) return state.products;
  return null;
}

// Usage with analytics service
import Analytics from 'analytics-service';

const analytics = new Analytics({
  apiKey: process.env.ANALYTICS_KEY,
  environment: process.env.NODE_ENV
});

const store = createStore(
  rootReducer,
  applyMiddleware(
    analyticsMiddleware(analytics, {
      sampleRate: process.env.NODE_ENV === 'production' ? 0.1 : 1.0,
      ignoreActions: ['@@redux/', 'persist/', 'MOUSE_MOVE'],
      enrichData: (data, state) => ({
        ...data,
        experiment: state.experiments?.currentExperiment,
        subscription: state.user?.subscription?.type
      })
    })
  )
);
            

Testing Custom Middleware


// Testing custom middleware
import { createStore, applyMiddleware } from 'redux';

describe('timestampMiddleware', () => {
  let store;
  let mockNext;
  
  beforeEach(() => {
    mockNext = jest.fn(action => action);
    store = {
      getState: jest.fn(() => ({})),
      dispatch: jest.fn()
    };
  });
  
  it('should add timestamp to actions', () => {
    const middleware = timestampMiddleware(store)(mockNext);
    const action = { type: 'TEST_ACTION' };
    
    const result = middleware(action);
    
    expect(mockNext).toHaveBeenCalledWith(
      expect.objectContaining({
        type: 'TEST_ACTION',
        meta: expect.objectContaining({
          timestamp: expect.any(Number)
        })
      })
    );
  });
  
  it('should preserve existing meta', () => {
    const middleware = timestampMiddleware(store)(mockNext);
    const action = { 
      type: 'TEST_ACTION',
      meta: { existingProp: 'value' }
    };
    
    middleware(action);
    
    expect(mockNext).toHaveBeenCalledWith(
      expect.objectContaining({
        meta: expect.objectContaining({
          existingProp: 'value',
          timestamp: expect.any(Number)
        })
      })
    );
  });
});

// Testing async middleware
describe('rateLimiter middleware', () => {
  jest.useFakeTimers();
  
  it('should allow actions within rate limit', () => {
    const middleware = rateLimiter({ 
      limits: { 'TEST_ACTION': { limit: 2, windowMs: 1000 } }
    });
    
    const store = { getState: jest.fn(), dispatch: jest.fn() };
    const next = jest.fn();
    const action = { type: 'TEST_ACTION' };
    
    const handler = middleware(store)(next);
    
    handler(action);
    handler(action);
    
    expect(next).toHaveBeenCalledTimes(2);
    expect(next).toHaveBeenCalledWith(action);
  });
  
  it('should block actions exceeding rate limit', () => {
    const middleware = rateLimiter({ 
      limits: { 'TEST_ACTION': { limit: 2, windowMs: 1000 } }
    });
    
    const store = { getState: jest.fn(), dispatch: jest.fn() };
    const next = jest.fn();
    const action = { type: 'TEST_ACTION' };
    
    const handler = middleware(store)(next);
    
    handler(action);
    handler(action);
    handler(action); // This should be blocked
    
    expect(next).toHaveBeenCalledTimes(3);
    expect(next).toHaveBeenLastCalledWith(
      expect.objectContaining({
        type: 'RATE_LIMIT_EXCEEDED',
        error: true
      })
    );
  });
});
            

Best Practices for Custom Middleware

Do's

  • Keep middleware focused on a single responsibility
  • Always call next(action) unless intentionally blocking
  • Handle errors gracefully
  • Make middleware configurable
  • Document the middleware's purpose and usage
  • Test thoroughly with different action types

Don'ts

  • Don't make middleware dependent on action order
  • Don't mutate actions or state directly
  • Don't perform heavy computations synchronously
  • Don't create circular dependencies
  • Don't dispatch actions in an infinite loop

Performance Considerations


// ❌ Bad: Heavy synchronous computation
const badMiddleware = store => next => action => {
  // This blocks the thread
  const result = heavyComputation(store.getState());
  return next({ ...action, computed: result });
};

// ✅ Good: Defer heavy computation
const goodMiddleware = store => next => action => {
  if (action.type === 'COMPUTE_HEAVY') {
    // Defer to next tick
    Promise.resolve().then(() => {
      const result = heavyComputation(store.getState());
      store.dispatch({ 
        type: 'COMPUTATION_COMPLETE', 
        payload: result 
      });
    });
  }
  return next(action);
};

// ✅ Better: Use web workers for heavy computation
const workerMiddleware = store => next => action => {
  if (action.type === 'COMPUTE_HEAVY') {
    const worker = new Worker('computation-worker.js');
    
    worker.postMessage({
      type: 'COMPUTE',
      data: store.getState()
    });
    
    worker.onmessage = (event) => {
      store.dispatch({
        type: 'COMPUTATION_COMPLETE',
        payload: event.data
      });
      worker.terminate();
    };
  }
  return next(action);
};
                

Practice Exercise

Task: Create a Caching Middleware

Build a middleware that caches API responses and serves them from cache when appropriate:

  • Cache successful API responses
  • Implement cache expiration
  • Support cache invalidation
  • Handle cache size limits
  • Provide cache statistics

// TODO: Implement caching middleware
const createCacheMiddleware = (options = {}) => {
  const {
    maxAge = 5 * 60 * 1000, // 5 minutes
    maxEntries = 100,
    storage = 'memory' // 'memory' | 'localStorage' | 'sessionStorage'
  } = options;
  
  // TODO: Initialize cache storage
  
  return store => next => action => {
    // TODO: Check if action is cacheable API request
    // TODO: Check cache for existing response
    // TODO: Return cached response if valid
    // TODO: For API responses, update cache
    // TODO: Implement cache cleanup
    // TODO: Handle cache statistics actions
  };
};

// Example usage:
const cacheMiddleware = createCacheMiddleware({
  maxAge: 10 * 60 * 1000, // 10 minutes
  maxEntries: 200,
  storage: 'localStorage'
});

// Actions to implement:
// - API_REQUEST (check cache)
// - API_SUCCESS (update cache)
// - INVALIDATE_CACHE (clear specific cache)
// - CLEAR_CACHE (clear all cache)
// - GET_CACHE_STATS (return cache statistics)
                

Additional Resources