# SHAP Gradio Demo - Architecture Overview ## 📐 System Architecture ``` ┌─────────────────────────────────────────────────────────────────┐ │ User's Browser │ │ http://localhost:7860 │ └────────────────────────────┬────────────────────────────────────┘ │ │ HTTP/WebSocket │ ┌────────────────────────────▼────────────────────────────────────┐ │ Gradio Interface │ │ ┌──────────────┬──────────────────┬─────────────────────────┐ │ │ │ Tab 1: │ Tab 2: │ Tab 3: │ │ │ │ MNIST │ ImageNet │ Tabular │ │ │ │ Pixels │ Segmentation │ Waterfall │ │ │ └──────┬───────┴────────┬─────────┴──────────┬──────────────┘ │ └─────────┼────────────────┼────────────────────┼─────────────────┘ │ │ │ │ │ │ ┌─────────▼────────────────▼────────────────────▼─────────────────┐ │ Explanation Functions │ │ ┌──────────────┬──────────────────┬─────────────────────────┐ │ │ │ explain_ │ explain_ │ explain_ │ │ │ │ mnist_digit()│ imagenet_image() │ tabular_sample() │ │ │ └──────┬───────┴────────┬─────────┴──────────┬──────────────┘ │ └─────────┼────────────────┼────────────────────┼─────────────────┘ │ │ │ │ │ │ ┌─────────▼────────────────▼────────────────────▼─────────────────┐ │ SHAP Explainers │ │ ┌──────────────┬──────────────────┬─────────────────────────┐ │ │ │ Deep │ Partition │ Tree │ │ │ │ Explainer │ Explainer │ Explainer │ │ │ └──────┬───────┴────────┬─────────┴──────────┬──────────────┘ │ └─────────┼────────────────┼────────────────────┼─────────────────┘ │ │ │ │ │ │ ┌─────────▼────────────────▼────────────────────▼─────────────────┐ │ ML Models │ │ ┌──────────────┬──────────────────┬─────────────────────────┐ │ │ │ PyTorch CNN │ ResNet50 │ Random Forest │ │ │ │ (MNIST) │ (ImageNet) │ (Adult Income) │ │ │ └──────┬───────┴────────┬─────────┴──────────┬──────────────┘ │ └─────────┼────────────────┼────────────────────┼─────────────────┘ │ │ │ │ │ │ ┌─────────▼────────────────▼────────────────────▼─────────────────┐ │ Datasets │ │ ┌──────────────┬──────────────────┬─────────────────────────┐ │ │ │ MNIST │ User Uploaded │ Adult Income │ │ │ │ (28x28) │ Images (224x224) │ (Tabular) │ │ │ └──────────────┴──────────────────┴─────────────────────────┘ │ └──────────────────────────────────────────────────────────────────┘ ``` --- ## 🔄 Data Flow ### Tab 1: MNIST Pixel Explanations ``` User selects digit index (0-9) │ ▼ explain_mnist_digit() │ ├─► Load MNIST test data │ ├─► Get selected image │ ├─► Run through PyTorch CNN │ ├─► Create DeepExplainer with background │ ├─► Calculate SHAP values │ ├─► Generate image_plot visualization │ └─► Return image + prediction text │ ▼ Display in Gradio UI ``` ### Tab 2: ImageNet Segmentation ``` User uploads image │ ▼ explain_imagenet_image() │ ├─► Resize to 224x224 │ ├─► Convert to RGB if needed │ ├─► Preprocess for ResNet50 │ ├─► Create Partition Explainer with masker │ ├─► Calculate SHAP values (max_evals=100) │ ├─► Generate image_plot for top 4 classes │ └─► Return image + status text │ ▼ Display in Gradio UI ``` ### Tab 3: Tabular Waterfall ``` User selects sample index (0-99) │ ▼ explain_tabular_sample() │ ├─► Load Adult Income test data │ ├─► Get selected sample │ ├─► Run through Random Forest │ ├─► Create TreeExplainer │ ├─► Calculate SHAP values │ ├─► Generate waterfall plot │ └─► Return image + prediction text │ ▼ Display in Gradio UI ``` --- ## 🧩 Component Interactions ### Model Initialization (Lazy Loading) ```python # First call to any tab triggers initialization initialize_mnist_model() ├─► Download MNIST if needed ├─► Load test dataset ├─► Create MNISTNet model └─► Prepare background samples initialize_resnet_model() ├─► Load ResNet50 with ImageNet weights ├─► Load class names (if available) ├─► Create image masker └─► Create Partition Explainer initialize_tabular_model() ├─► Load Adult Income dataset ├─► Split train/test ├─► Train Random Forest └─► Create TreeExplainer ``` --- ## 🎨 UI Component Hierarchy ``` gr.Blocks (Main Container) │ ├─► gr.Markdown (Title) │ ├─► gr.Tabs │ │ │ ├─► gr.Tab("Pixel-level") │ │ ├─► gr.Markdown (Description) │ │ ├─► gr.Row │ │ │ ├─► gr.Column (Input) │ │ │ │ ├─► gr.Slider (digit_index) │ │ │ │ └─► gr.Button (Generate) │ │ │ └─► gr.Column (Output) │ │ │ ├─► gr.Image (SHAP plot) │ │ │ └─► gr.Textbox (Prediction) │ │ │ ├─► gr.Tab("Image Segmentation") │ │ ├─► gr.Markdown (Description) │ │ ├─► gr.Row │ │ │ ├─► gr.Column (Input) │ │ │ │ ├─► gr.Image (Upload) │ │ │ │ └─► gr.Button (Generate) │ │ │ └─► gr.Column (Output) │ │ │ ├─► gr.Image (SHAP plot) │ │ │ └─► gr.Textbox (Status) │ │ │ └─► gr.Tab("Tabular Data") │ ├─► gr.Markdown (Description) │ └─► gr.Row │ ├─► gr.Column (Input) │ │ ├─► gr.Slider (sample_index) │ │ └─► gr.Button (Generate) │ └─► gr.Column (Output) │ ├─► gr.Image (Waterfall) │ └─► gr.Textbox (Prediction) │ └─► gr.Markdown (Footer) ``` --- ## 🔐 Error Handling Flow ``` User Action │ ▼ Try Block │ ├─► Success Path │ └─► Return (image, success_message) │ └─► Exception Caught └─► Return (None, error_message) │ ▼ Display in Gradio UI ``` --- ## 💾 Memory Management ### Lazy Loading Strategy ``` Application Start │ ├─► Global variables = None │ └─► Models not loaded yet │ ▼ User clicks Tab 1 │ ├─► Check if mnist_model is None │ └─► Yes: initialize_mnist_model() │ └─► Use cached model for subsequent calls ``` ### Memory Footprint ``` ┌─────────────────────────────────────┐ │ Component │ Memory Usage │ ├────────────────────┼────────────────┤ │ MNIST Model │ ~50 MB │ │ MNIST Data │ ~100 MB │ │ ResNet50 │ ~100 MB │ │ Random Forest │ ~50 MB │ │ SHAP Explainers │ ~50 MB │ │ Gradio Interface │ ~50 MB │ ├────────────────────┼────────────────┤ │ Total (all loaded) │ ~400 MB │ └─────────────────────────────────────┘ ``` --- ## 🚀 Performance Characteristics ### Execution Times ``` ┌──────────────────────────────────────────┐ │ Operation │ Time │ ├────────────────────────┼────────────────┤ │ MNIST Explanation │ 1-2 seconds │ │ ImageNet Explanation │ 30-60 seconds │ │ Tabular Explanation │ 1 second │ │ Model Initialization │ 5-10 seconds │ │ First MNIST Download │ 1-2 minutes │ └──────────────────────────────────────────┘ ``` ### Bottlenecks 1. **ImageNet Masking**: Most time-consuming (inherent to method) 2. **First Run**: Dataset downloads take time 3. **Model Loading**: Initial load takes a few seconds --- ## 🔧 Configuration Points ### Adjustable Parameters ```python # MNIST BACKGROUND_SIZE = 100 # Background samples for DeepExplainer TEST_SAMPLES = 10 # Number of test images # ImageNet MAX_EVALS = 100 # SHAP evaluation budget BATCH_SIZE = 50 # Batch size for masking TOP_CLASSES = 4 # Number of classes to explain # Tabular N_ESTIMATORS = 100 # Random Forest trees TEST_SIZE = 0.2 # Train/test split EXPLAIN_SAMPLES = 100 # Samples to explain ``` --- ## 📊 Visualization Pipeline ``` SHAP Values (numpy arrays) │ ▼ Matplotlib Figure │ ├─► shap.image_plot() or └─► shap.plots.waterfall() │ ▼ BytesIO Buffer │ ▼ PIL Image │ ▼ Gradio gr.Image │ ▼ User's Browser ``` --- ## 🌐 Network Architecture ``` ┌─────────────────────────────────────────┐ │ Client (Browser) │ │ - HTML/CSS/JavaScript │ │ - Gradio Frontend │ └──────────────┬──────────────────────────┘ │ │ HTTP/WebSocket │ Port 7860 │ ┌──────────────▼──────────────────────────┐ │ Server (Python) │ │ - Gradio Backend │ │ - Flask/FastAPI │ │ - SHAP Computations │ └─────────────────────────────────────────┘ ``` --- ## 🎯 Design Principles 1. **Modularity**: Each SHAP method is independent 2. **Lazy Loading**: Models load only when needed 3. **Error Resilience**: Try-except blocks everywhere 4. **User Feedback**: Clear status messages 5. **No Icons**: Text-only interface (per requirements) 6. **Simplicity**: Minimal dependencies 7. **Cross-Platform**: Works on Windows/Linux/Mac --- ## 📝 Code Organization ``` gradio_shap_demo.py (300 lines) │ ├─► Imports & Setup (30 lines) ├─► Model Definitions (20 lines) ├─► Global Variables (10 lines) ├─► Initialization Functions (60 lines) ├─► Explanation Functions (120 lines) └─► Gradio Interface (60 lines) ``` --- This architecture ensures: - ✅ Clean separation of concerns - ✅ Efficient memory usage - ✅ Fast response times (except ImageNet) - ✅ Easy maintenance and extension - ✅ User-friendly experience