import streamlit as st import torch import torch.nn as nn import matplotlib.pyplot as plt from torchvision import models, transforms from torchvision.models import VGG16_Weights from PIL import Image from cbam import CBAM # CBAM attention from preprocess import crop_black_background, apply_clahe # Import preprocessing functions # Configuration model_load_path = 'model/quantized_model.pth' # Path to the saved *dynamically* quantized model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Define the VGG model with CBAM (same definition as in the quantization script) class VGGWithAttention(nn.Module): def __init__(self, num_classes): super(VGGWithAttention, self).__init__() self.vgg = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1) # Add CBAM after some convolutional layers self.cbam1 = CBAM(in_channels=64) self.cbam2 = CBAM(in_channels=128) self.cbam3 = CBAM(in_channels=256) self.cbam4 = CBAM(in_channels=512) # Modify the final layer for binary classification self.vgg.classifier[6] = nn.Linear(self.vgg.classifier[6].in_features, num_classes) def forward(self, x): attention_maps = [] x = self.vgg.features[0:5](x) x = self.cbam1(x) attention_maps.append(x) x = self.vgg.features[5:10](x) x = self.cbam2(x) attention_maps.append(x) x = self.vgg.features[10:17](x) x = self.cbam3(x) attention_maps.append(x) x = self.vgg.features[17:24](x) x = self.cbam4(x) attention_maps.append(x) x = self.vgg.features[24:](x) x = self.vgg.avgpool(x) x = torch.flatten(x, 1) x = self.vgg.classifier(x) return x, attention_maps # Instantiate the model num_classes = 2 # Binary classification model = VGGWithAttention(num_classes=num_classes).to(device) # Apply dynamic quantization (same as in the saving script) quantized_model_dynamic = torch.quantization.quantize_dynamic( model, {nn.Linear}, dtype=torch.qint8 ) # Load the state_dict of the dynamically quantized model quantized_model_dynamic.load_state_dict(torch.load(model_load_path, map_location=device)) quantized_model_dynamic.eval() # Set the model to evaluation mode # Preprocess the input image preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def visualize_attention(attention_maps): fig, axs = plt.subplots(1, len(attention_maps), figsize=(20, 5)) for i, attention_map in enumerate(attention_maps): attention_map = attention_map.mean(dim=1).squeeze().cpu().detach().numpy() axs[i].imshow(attention_map, cmap='viridis') axs[i].axis('off') st.pyplot(fig) st.title("Fundus Image Classification with Attention Visualization") uploaded_file = st.file_uploader("Choose an image...", type="png") if uploaded_file is not None: image = Image.open(uploaded_file).convert('RGB') st.image(image, caption='Uploaded Image.', use_container_width=True) st.write("") st.write("Classifying...") # Preprocess the image image = crop_black_background(image) # Apply cropping if image == "error": st.write("Error: The uploaded image is all black.") else: image = apply_clahe(image) # Apply CLAHE image = preprocess(image).unsqueeze(0).to(device) # Add batch dimension and move to device # Make prediction and get attention maps with torch.no_grad(): output, attention_maps = quantized_model_dynamic(image) _, predicted = torch.max(output, 1) prediction = predicted.item() prediction = 'abnormal' if prediction == 1 else 'normal' # Output the prediction st.write(f'Predicted class: {prediction}') # Visualize attention maps st.write("Attention Maps:") visualize_attention(attention_maps)