SHAP_DEMO / ARCHITECTURE.md
xxnithicxx's picture
Fix requirements.txt
326e833

A newer version of the Gradio SDK is available: 6.12.0

Upgrade

SHAP Gradio Demo - Architecture Overview

📐 System Architecture

┌─────────────────────────────────────────────────────────────────┐
│                        User's Browser                            │
│                     http://localhost:7860                        │
└────────────────────────────┬────────────────────────────────────┘
                             │
                             │ HTTP/WebSocket
                             │
┌────────────────────────────▼────────────────────────────────────┐
│                      Gradio Interface                            │
│  ┌──────────────┬──────────────────┬─────────────────────────┐  │
│  │   Tab 1:     │     Tab 2:       │      Tab 3:             │  │
│  │   MNIST      │   ImageNet       │      Tabular            │  │
│  │   Pixels     │   Segmentation   │      Waterfall          │  │
│  └──────┬───────┴────────┬─────────┴──────────┬──────────────┘  │
└─────────┼────────────────┼────────────────────┼─────────────────┘
          │                │                    │
          │                │                    │
┌─────────▼────────────────▼────────────────────▼─────────────────┐
│                   Explanation Functions                          │
│  ┌──────────────┬──────────────────┬─────────────────────────┐  │
│  │ explain_     │ explain_         │ explain_                │  │
│  │ mnist_digit()│ imagenet_image() │ tabular_sample()        │  │
│  └──────┬───────┴────────┬─────────┴──────────┬──────────────┘  │
└─────────┼────────────────┼────────────────────┼─────────────────┘
          │                │                    │
          │                │                    │
┌─────────▼────────────────▼────────────────────▼─────────────────┐
│                    SHAP Explainers                               │
│  ┌──────────────┬──────────────────┬─────────────────────────┐  │
│  │ Deep         │ Partition        │ Tree                    │  │
│  │ Explainer    │ Explainer        │ Explainer               │  │
│  └──────┬───────┴────────┬─────────┴──────────┬──────────────┘  │
└─────────┼────────────────┼────────────────────┼─────────────────┘
          │                │                    │
          │                │                    │
┌─────────▼────────────────▼────────────────────▼─────────────────┐
│                      ML Models                                   │
│  ┌──────────────┬──────────────────┬─────────────────────────┐  │
│  │ PyTorch CNN  │ ResNet50         │ Random Forest           │  │
│  │ (MNIST)      │ (ImageNet)       │ (Adult Income)          │  │
│  └──────┬───────┴────────┬─────────┴──────────┬──────────────┘  │
└─────────┼────────────────┼────────────────────┼─────────────────┘
          │                │                    │
          │                │                    │
┌─────────▼────────────────▼────────────────────▼─────────────────┐
│                        Datasets                                  │
│  ┌──────────────┬──────────────────┬─────────────────────────┐  │
│  │ MNIST        │ User Uploaded    │ Adult Income            │  │
│  │ (28x28)      │ Images (224x224) │ (Tabular)               │  │
│  └──────────────┴──────────────────┴─────────────────────────┘  │
└──────────────────────────────────────────────────────────────────┘

🔄 Data Flow

Tab 1: MNIST Pixel Explanations

User selects digit index (0-9)
         │
         ▼
explain_mnist_digit()
         │
         ├─► Load MNIST test data
         │
         ├─► Get selected image
         │
         ├─► Run through PyTorch CNN
         │
         ├─► Create DeepExplainer with background
         │
         ├─► Calculate SHAP values
         │
         ├─► Generate image_plot visualization
         │
         └─► Return image + prediction text
                  │
                  ▼
         Display in Gradio UI

Tab 2: ImageNet Segmentation

User uploads image
         │
         ▼
explain_imagenet_image()
         │
         ├─► Resize to 224x224
         │
         ├─► Convert to RGB if needed
         │
         ├─► Preprocess for ResNet50
         │
         ├─► Create Partition Explainer with masker
         │
         ├─► Calculate SHAP values (max_evals=100)
         │
         ├─► Generate image_plot for top 4 classes
         │
         └─► Return image + status text
                  │
                  ▼
         Display in Gradio UI

Tab 3: Tabular Waterfall

User selects sample index (0-99)
         │
         ▼
explain_tabular_sample()
         │
         ├─► Load Adult Income test data
         │
         ├─► Get selected sample
         │
         ├─► Run through Random Forest
         │
         ├─► Create TreeExplainer
         │
         ├─► Calculate SHAP values
         │
         ├─► Generate waterfall plot
         │
         └─► Return image + prediction text
                  │
                  ▼
         Display in Gradio UI

🧩 Component Interactions

Model Initialization (Lazy Loading)

# First call to any tab triggers initialization

initialize_mnist_model()
    ├─► Download MNIST if needed
    ├─► Load test dataset
    ├─► Create MNISTNet model
    └─► Prepare background samples

initialize_resnet_model()
    ├─► Load ResNet50 with ImageNet weights
    ├─► Load class names (if available)
    ├─► Create image masker
    └─► Create Partition Explainer

initialize_tabular_model()
    ├─► Load Adult Income dataset
    ├─► Split train/test
    ├─► Train Random Forest
    └─► Create TreeExplainer

🎨 UI Component Hierarchy

gr.Blocks (Main Container)
│
├─► gr.Markdown (Title)
│
├─► gr.Tabs
│   │
│   ├─► gr.Tab("Pixel-level")
│   │   ├─► gr.Markdown (Description)
│   │   ├─► gr.Row
│   │   │   ├─► gr.Column (Input)
│   │   │   │   ├─► gr.Slider (digit_index)
│   │   │   │   └─► gr.Button (Generate)
│   │   │   └─► gr.Column (Output)
│   │   │       ├─► gr.Image (SHAP plot)
│   │   │       └─► gr.Textbox (Prediction)
│   │
│   ├─► gr.Tab("Image Segmentation")
│   │   ├─► gr.Markdown (Description)
│   │   ├─► gr.Row
│   │   │   ├─► gr.Column (Input)
│   │   │   │   ├─► gr.Image (Upload)
│   │   │   │   └─► gr.Button (Generate)
│   │   │   └─► gr.Column (Output)
│   │   │       ├─► gr.Image (SHAP plot)
│   │   │       └─► gr.Textbox (Status)
│   │
│   └─► gr.Tab("Tabular Data")
│       ├─► gr.Markdown (Description)
│       └─► gr.Row
│           ├─► gr.Column (Input)
│           │   ├─► gr.Slider (sample_index)
│           │   └─► gr.Button (Generate)
│           └─► gr.Column (Output)
│               ├─► gr.Image (Waterfall)
│               └─► gr.Textbox (Prediction)
│
└─► gr.Markdown (Footer)

🔐 Error Handling Flow

User Action
    │
    ▼
Try Block
    │
    ├─► Success Path
    │   └─► Return (image, success_message)
    │
    └─► Exception Caught
        └─► Return (None, error_message)
            │
            ▼
    Display in Gradio UI

💾 Memory Management

Lazy Loading Strategy

Application Start
    │
    ├─► Global variables = None
    │
    └─► Models not loaded yet
        │
        ▼
User clicks Tab 1
    │
    ├─► Check if mnist_model is None
    │   └─► Yes: initialize_mnist_model()
    │
    └─► Use cached model for subsequent calls

Memory Footprint

┌─────────────────────────────────────┐
│ Component          │ Memory Usage   │
├────────────────────┼────────────────┤
│ MNIST Model        │ ~50 MB         │
│ MNIST Data         │ ~100 MB        │
│ ResNet50           │ ~100 MB        │
│ Random Forest      │ ~50 MB         │
│ SHAP Explainers    │ ~50 MB         │
│ Gradio Interface   │ ~50 MB         │
├────────────────────┼────────────────┤
│ Total (all loaded) │ ~400 MB        │
└─────────────────────────────────────┘

🚀 Performance Characteristics

Execution Times

┌──────────────────────────────────────────┐
│ Operation              │ Time           │
├────────────────────────┼────────────────┤
│ MNIST Explanation      │ 1-2 seconds    │
│ ImageNet Explanation   │ 30-60 seconds  │
│ Tabular Explanation    │ 1 second       │
│ Model Initialization   │ 5-10 seconds   │
│ First MNIST Download   │ 1-2 minutes    │
└──────────────────────────────────────────┘

Bottlenecks

  1. ImageNet Masking: Most time-consuming (inherent to method)
  2. First Run: Dataset downloads take time
  3. Model Loading: Initial load takes a few seconds

🔧 Configuration Points

Adjustable Parameters

# MNIST
BACKGROUND_SIZE = 100        # Background samples for DeepExplainer
TEST_SAMPLES = 10            # Number of test images

# ImageNet
MAX_EVALS = 100              # SHAP evaluation budget
BATCH_SIZE = 50              # Batch size for masking
TOP_CLASSES = 4              # Number of classes to explain

# Tabular
N_ESTIMATORS = 100           # Random Forest trees
TEST_SIZE = 0.2              # Train/test split
EXPLAIN_SAMPLES = 100        # Samples to explain

📊 Visualization Pipeline

SHAP Values (numpy arrays)
         │
         ▼
Matplotlib Figure
         │
         ├─► shap.image_plot() or
         └─► shap.plots.waterfall()
                  │
                  ▼
         BytesIO Buffer
                  │
                  ▼
         PIL Image
                  │
                  ▼
         Gradio gr.Image
                  │
                  ▼
         User's Browser

🌐 Network Architecture

┌─────────────────────────────────────────┐
│ Client (Browser)                        │
│ - HTML/CSS/JavaScript                   │
│ - Gradio Frontend                       │
└──────────────┬──────────────────────────┘
               │
               │ HTTP/WebSocket
               │ Port 7860
               │
┌──────────────▼──────────────────────────┐
│ Server (Python)                         │
│ - Gradio Backend                        │
│ - Flask/FastAPI                         │
│ - SHAP Computations                     │
└─────────────────────────────────────────┘

🎯 Design Principles

  1. Modularity: Each SHAP method is independent
  2. Lazy Loading: Models load only when needed
  3. Error Resilience: Try-except blocks everywhere
  4. User Feedback: Clear status messages
  5. No Icons: Text-only interface (per requirements)
  6. Simplicity: Minimal dependencies
  7. Cross-Platform: Works on Windows/Linux/Mac

📝 Code Organization

gradio_shap_demo.py (300 lines)
│
├─► Imports & Setup (30 lines)
├─► Model Definitions (20 lines)
├─► Global Variables (10 lines)
├─► Initialization Functions (60 lines)
├─► Explanation Functions (120 lines)
└─► Gradio Interface (60 lines)

This architecture ensures:

  • ✅ Clean separation of concerns
  • ✅ Efficient memory usage
  • ✅ Fast response times (except ImageNet)
  • ✅ Easy maintenance and extension
  • ✅ User-friendly experience