Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| from torchvision.models import VGG16_Weights | |
| from PIL import Image | |
| from cbam import CBAM # CBAM attention | |
| from preprocess import crop_black_background, apply_clahe # Import preprocessing functions | |
| # Configuration | |
| img_path = r'C:\Users\pc\Desktop\messidor\code\data2\normal\1f31701dd61b.png' # Path to the input image | |
| model_load_path = r'model/quantized_model.pth' # Path to the saved *dynamically* quantized model | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Define the VGG model with CBAM (same definition as in the quantization script) | |
| class VGGWithAttention(nn.Module): | |
| def __init__(self, num_classes): | |
| super(VGGWithAttention, self).__init__() | |
| self.vgg = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1) | |
| # Add CBAM after some convolutional layers | |
| self.cbam1 = CBAM(in_channels=64) | |
| self.cbam2 = CBAM(in_channels=128) | |
| self.cbam3 = CBAM(in_channels=256) | |
| self.cbam4 = CBAM(in_channels=512) | |
| # Modify the final layer for binary classification | |
| self.vgg.classifier[6] = nn.Linear(self.vgg.classifier[6].in_features, num_classes) | |
| def forward(self, x): | |
| x = self.vgg.features[0:5](x) | |
| x = self.cbam1(x) | |
| x = self.vgg.features[5:10](x) | |
| x = self.cbam2(x) | |
| x = self.vgg.features[10:17](x) | |
| x = self.cbam3(x) | |
| x = self.vgg.features[17:24](x) | |
| x = self.cbam4(x) | |
| x = self.vgg.features[24:](x) | |
| x = self.vgg.avgpool(x) | |
| x = torch.flatten(x, 1) | |
| x = self.vgg.classifier(x) | |
| return x | |
| # Instantiate the model | |
| num_classes = 2 # Binary classification | |
| model = VGGWithAttention(num_classes=num_classes).to(device) | |
| # Apply dynamic quantization (same as in the saving script) | |
| quantized_model_dynamic = torch.quantization.quantize_dynamic( | |
| model, | |
| {nn.Linear}, | |
| dtype=torch.qint8 | |
| ) | |
| # Load the state_dict of the dynamically quantized model | |
| quantized_model_dynamic.load_state_dict(torch.load(model_load_path, map_location=device)) | |
| quantized_model_dynamic.eval() # Set the model to evaluation mode | |
| # Preprocess the input image | |
| preprocess = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # Load and preprocess the image | |
| image = Image.open(img_path).convert('RGB') | |
| image = crop_black_background(image) # Apply cropping | |
| image = apply_clahe(image) # Apply CLAHE | |
| image = preprocess(image).unsqueeze(0).to(device) # Add batch dimension and move to device | |
| # Make prediction | |
| with torch.no_grad(): | |
| output = quantized_model_dynamic(image) | |
| print(output) | |
| _, predicted = torch.max(output, 1) | |
| print(predicted) | |
| prediction = predicted.item() | |
| # Output the prediction | |
| print(f'Predicted class: {prediction}') |