Spaces:
Build error
Build error
| import streamlit as st | |
| import torch | |
| import torch.nn as nn | |
| import matplotlib.pyplot as plt | |
| from torchvision import models, transforms | |
| from torchvision.models import VGG16_Weights | |
| from PIL import Image | |
| from cbam import CBAM # CBAM attention | |
| from preprocess import crop_black_background, apply_clahe # Import preprocessing functions | |
| # Configuration | |
| model_load_path = 'model/quantized_model.pth' # Path to the saved *dynamically* quantized model | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Define the VGG model with CBAM (same definition as in the quantization script) | |
| class VGGWithAttention(nn.Module): | |
| def __init__(self, num_classes): | |
| super(VGGWithAttention, self).__init__() | |
| self.vgg = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1) | |
| # Add CBAM after some convolutional layers | |
| self.cbam1 = CBAM(in_channels=64) | |
| self.cbam2 = CBAM(in_channels=128) | |
| self.cbam3 = CBAM(in_channels=256) | |
| self.cbam4 = CBAM(in_channels=512) | |
| # Modify the final layer for binary classification | |
| self.vgg.classifier[6] = nn.Linear(self.vgg.classifier[6].in_features, num_classes) | |
| def forward(self, x): | |
| attention_maps = [] | |
| x = self.vgg.features[0:5](x) | |
| x = self.cbam1(x) | |
| attention_maps.append(x) | |
| x = self.vgg.features[5:10](x) | |
| x = self.cbam2(x) | |
| attention_maps.append(x) | |
| x = self.vgg.features[10:17](x) | |
| x = self.cbam3(x) | |
| attention_maps.append(x) | |
| x = self.vgg.features[17:24](x) | |
| x = self.cbam4(x) | |
| attention_maps.append(x) | |
| x = self.vgg.features[24:](x) | |
| x = self.vgg.avgpool(x) | |
| x = torch.flatten(x, 1) | |
| x = self.vgg.classifier(x) | |
| return x, attention_maps | |
| # Instantiate the model | |
| num_classes = 2 # Binary classification | |
| model = VGGWithAttention(num_classes=num_classes).to(device) | |
| # Apply dynamic quantization (same as in the saving script) | |
| quantized_model_dynamic = torch.quantization.quantize_dynamic( | |
| model, | |
| {nn.Linear}, | |
| dtype=torch.qint8 | |
| ) | |
| # Load the state_dict of the dynamically quantized model | |
| quantized_model_dynamic.load_state_dict(torch.load(model_load_path, map_location=device)) | |
| quantized_model_dynamic.eval() # Set the model to evaluation mode | |
| # Preprocess the input image | |
| preprocess = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def visualize_attention(attention_maps): | |
| fig, axs = plt.subplots(1, len(attention_maps), figsize=(20, 5)) | |
| for i, attention_map in enumerate(attention_maps): | |
| attention_map = attention_map.mean(dim=1).squeeze().cpu().detach().numpy() | |
| axs[i].imshow(attention_map, cmap='viridis') | |
| axs[i].axis('off') | |
| st.pyplot(fig) | |
| st.title("Fundus Image Classification with Attention Visualization") | |
| uploaded_file = st.file_uploader("Choose an image...", type="png") | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file).convert('RGB') | |
| st.image(image, caption='Uploaded Image.', use_container_width=True) | |
| st.write("") | |
| st.write("Classifying...") | |
| # Preprocess the image | |
| image = crop_black_background(image) # Apply cropping | |
| if image == "error": | |
| st.write("Error: The uploaded image is all black.") | |
| else: | |
| image = apply_clahe(image) # Apply CLAHE | |
| image = preprocess(image).unsqueeze(0).to(device) # Add batch dimension and move to device | |
| # Make prediction and get attention maps | |
| with torch.no_grad(): | |
| output, attention_maps = quantized_model_dynamic(image) | |
| _, predicted = torch.max(output, 1) | |
| prediction = predicted.item() | |
| prediction = 'abnormal' if prediction == 1 else 'normal' | |
| # Output the prediction | |
| st.write(f'Predicted class: {prediction}') | |
| # Visualize attention maps | |
| st.write("Attention Maps:") | |
| visualize_attention(attention_maps) |