Spaces:
Sleeping
Sleeping
A newer version of the Gradio SDK is available: 6.12.0
SHAP Gradio Demo - Architecture Overview
📐 System Architecture
┌─────────────────────────────────────────────────────────────────┐
│ User's Browser │
│ http://localhost:7860 │
└────────────────────────────┬────────────────────────────────────┘
│
│ HTTP/WebSocket
│
┌────────────────────────────▼────────────────────────────────────┐
│ Gradio Interface │
│ ┌──────────────┬──────────────────┬─────────────────────────┐ │
│ │ Tab 1: │ Tab 2: │ Tab 3: │ │
│ │ MNIST │ ImageNet │ Tabular │ │
│ │ Pixels │ Segmentation │ Waterfall │ │
│ └──────┬───────┴────────┬─────────┴──────────┬──────────────┘ │
└─────────┼────────────────┼────────────────────┼─────────────────┘
│ │ │
│ │ │
┌─────────▼────────────────▼────────────────────▼─────────────────┐
│ Explanation Functions │
│ ┌──────────────┬──────────────────┬─────────────────────────┐ │
│ │ explain_ │ explain_ │ explain_ │ │
│ │ mnist_digit()│ imagenet_image() │ tabular_sample() │ │
│ └──────┬───────┴────────┬─────────┴──────────┬──────────────┘ │
└─────────┼────────────────┼────────────────────┼─────────────────┘
│ │ │
│ │ │
┌─────────▼────────────────▼────────────────────▼─────────────────┐
│ SHAP Explainers │
│ ┌──────────────┬──────────────────┬─────────────────────────┐ │
│ │ Deep │ Partition │ Tree │ │
│ │ Explainer │ Explainer │ Explainer │ │
│ └──────┬───────┴────────┬─────────┴──────────┬──────────────┘ │
└─────────┼────────────────┼────────────────────┼─────────────────┘
│ │ │
│ │ │
┌─────────▼────────────────▼────────────────────▼─────────────────┐
│ ML Models │
│ ┌──────────────┬──────────────────┬─────────────────────────┐ │
│ │ PyTorch CNN │ ResNet50 │ Random Forest │ │
│ │ (MNIST) │ (ImageNet) │ (Adult Income) │ │
│ └──────┬───────┴────────┬─────────┴──────────┬──────────────┘ │
└─────────┼────────────────┼────────────────────┼─────────────────┘
│ │ │
│ │ │
┌─────────▼────────────────▼────────────────────▼─────────────────┐
│ Datasets │
│ ┌──────────────┬──────────────────┬─────────────────────────┐ │
│ │ MNIST │ User Uploaded │ Adult Income │ │
│ │ (28x28) │ Images (224x224) │ (Tabular) │ │
│ └──────────────┴──────────────────┴─────────────────────────┘ │
└──────────────────────────────────────────────────────────────────┘
🔄 Data Flow
Tab 1: MNIST Pixel Explanations
User selects digit index (0-9)
│
▼
explain_mnist_digit()
│
├─► Load MNIST test data
│
├─► Get selected image
│
├─► Run through PyTorch CNN
│
├─► Create DeepExplainer with background
│
├─► Calculate SHAP values
│
├─► Generate image_plot visualization
│
└─► Return image + prediction text
│
▼
Display in Gradio UI
Tab 2: ImageNet Segmentation
User uploads image
│
▼
explain_imagenet_image()
│
├─► Resize to 224x224
│
├─► Convert to RGB if needed
│
├─► Preprocess for ResNet50
│
├─► Create Partition Explainer with masker
│
├─► Calculate SHAP values (max_evals=100)
│
├─► Generate image_plot for top 4 classes
│
└─► Return image + status text
│
▼
Display in Gradio UI
Tab 3: Tabular Waterfall
User selects sample index (0-99)
│
▼
explain_tabular_sample()
│
├─► Load Adult Income test data
│
├─► Get selected sample
│
├─► Run through Random Forest
│
├─► Create TreeExplainer
│
├─► Calculate SHAP values
│
├─► Generate waterfall plot
│
└─► Return image + prediction text
│
▼
Display in Gradio UI
🧩 Component Interactions
Model Initialization (Lazy Loading)
# First call to any tab triggers initialization
initialize_mnist_model()
├─► Download MNIST if needed
├─► Load test dataset
├─► Create MNISTNet model
└─► Prepare background samples
initialize_resnet_model()
├─► Load ResNet50 with ImageNet weights
├─► Load class names (if available)
├─► Create image masker
└─► Create Partition Explainer
initialize_tabular_model()
├─► Load Adult Income dataset
├─► Split train/test
├─► Train Random Forest
└─► Create TreeExplainer
🎨 UI Component Hierarchy
gr.Blocks (Main Container)
│
├─► gr.Markdown (Title)
│
├─► gr.Tabs
│ │
│ ├─► gr.Tab("Pixel-level")
│ │ ├─► gr.Markdown (Description)
│ │ ├─► gr.Row
│ │ │ ├─► gr.Column (Input)
│ │ │ │ ├─► gr.Slider (digit_index)
│ │ │ │ └─► gr.Button (Generate)
│ │ │ └─► gr.Column (Output)
│ │ │ ├─► gr.Image (SHAP plot)
│ │ │ └─► gr.Textbox (Prediction)
│ │
│ ├─► gr.Tab("Image Segmentation")
│ │ ├─► gr.Markdown (Description)
│ │ ├─► gr.Row
│ │ │ ├─► gr.Column (Input)
│ │ │ │ ├─► gr.Image (Upload)
│ │ │ │ └─► gr.Button (Generate)
│ │ │ └─► gr.Column (Output)
│ │ │ ├─► gr.Image (SHAP plot)
│ │ │ └─► gr.Textbox (Status)
│ │
│ └─► gr.Tab("Tabular Data")
│ ├─► gr.Markdown (Description)
│ └─► gr.Row
│ ├─► gr.Column (Input)
│ │ ├─► gr.Slider (sample_index)
│ │ └─► gr.Button (Generate)
│ └─► gr.Column (Output)
│ ├─► gr.Image (Waterfall)
│ └─► gr.Textbox (Prediction)
│
└─► gr.Markdown (Footer)
🔐 Error Handling Flow
User Action
│
▼
Try Block
│
├─► Success Path
│ └─► Return (image, success_message)
│
└─► Exception Caught
└─► Return (None, error_message)
│
▼
Display in Gradio UI
💾 Memory Management
Lazy Loading Strategy
Application Start
│
├─► Global variables = None
│
└─► Models not loaded yet
│
▼
User clicks Tab 1
│
├─► Check if mnist_model is None
│ └─► Yes: initialize_mnist_model()
│
└─► Use cached model for subsequent calls
Memory Footprint
┌─────────────────────────────────────┐
│ Component │ Memory Usage │
├────────────────────┼────────────────┤
│ MNIST Model │ ~50 MB │
│ MNIST Data │ ~100 MB │
│ ResNet50 │ ~100 MB │
│ Random Forest │ ~50 MB │
│ SHAP Explainers │ ~50 MB │
│ Gradio Interface │ ~50 MB │
├────────────────────┼────────────────┤
│ Total (all loaded) │ ~400 MB │
└─────────────────────────────────────┘
🚀 Performance Characteristics
Execution Times
┌──────────────────────────────────────────┐
│ Operation │ Time │
├────────────────────────┼────────────────┤
│ MNIST Explanation │ 1-2 seconds │
│ ImageNet Explanation │ 30-60 seconds │
│ Tabular Explanation │ 1 second │
│ Model Initialization │ 5-10 seconds │
│ First MNIST Download │ 1-2 minutes │
└──────────────────────────────────────────┘
Bottlenecks
- ImageNet Masking: Most time-consuming (inherent to method)
- First Run: Dataset downloads take time
- Model Loading: Initial load takes a few seconds
🔧 Configuration Points
Adjustable Parameters
# MNIST
BACKGROUND_SIZE = 100 # Background samples for DeepExplainer
TEST_SAMPLES = 10 # Number of test images
# ImageNet
MAX_EVALS = 100 # SHAP evaluation budget
BATCH_SIZE = 50 # Batch size for masking
TOP_CLASSES = 4 # Number of classes to explain
# Tabular
N_ESTIMATORS = 100 # Random Forest trees
TEST_SIZE = 0.2 # Train/test split
EXPLAIN_SAMPLES = 100 # Samples to explain
📊 Visualization Pipeline
SHAP Values (numpy arrays)
│
▼
Matplotlib Figure
│
├─► shap.image_plot() or
└─► shap.plots.waterfall()
│
▼
BytesIO Buffer
│
▼
PIL Image
│
▼
Gradio gr.Image
│
▼
User's Browser
🌐 Network Architecture
┌─────────────────────────────────────────┐
│ Client (Browser) │
│ - HTML/CSS/JavaScript │
│ - Gradio Frontend │
└──────────────┬──────────────────────────┘
│
│ HTTP/WebSocket
│ Port 7860
│
┌──────────────▼──────────────────────────┐
│ Server (Python) │
│ - Gradio Backend │
│ - Flask/FastAPI │
│ - SHAP Computations │
└─────────────────────────────────────────┘
🎯 Design Principles
- Modularity: Each SHAP method is independent
- Lazy Loading: Models load only when needed
- Error Resilience: Try-except blocks everywhere
- User Feedback: Clear status messages
- No Icons: Text-only interface (per requirements)
- Simplicity: Minimal dependencies
- Cross-Platform: Works on Windows/Linux/Mac
📝 Code Organization
gradio_shap_demo.py (300 lines)
│
├─► Imports & Setup (30 lines)
├─► Model Definitions (20 lines)
├─► Global Variables (10 lines)
├─► Initialization Functions (60 lines)
├─► Explanation Functions (120 lines)
└─► Gradio Interface (60 lines)
This architecture ensures:
- ✅ Clean separation of concerns
- ✅ Efficient memory usage
- ✅ Fast response times (except ImageNet)
- ✅ Easy maintenance and extension
- ✅ User-friendly experience