import torch import torch.nn as nn from torchvision import models, transforms from torchvision.models import VGG16_Weights from PIL import Image from cbam import CBAM # CBAM attention from preprocess import crop_black_background, apply_clahe # Import preprocessing functions # Configuration img_path = r'C:\Users\pc\Desktop\messidor\code\data2\normal\1f31701dd61b.png' # Path to the input image model_load_path = r'model/quantized_model.pth' # Path to the saved *dynamically* quantized model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Define the VGG model with CBAM (same definition as in the quantization script) class VGGWithAttention(nn.Module): def __init__(self, num_classes): super(VGGWithAttention, self).__init__() self.vgg = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1) # Add CBAM after some convolutional layers self.cbam1 = CBAM(in_channels=64) self.cbam2 = CBAM(in_channels=128) self.cbam3 = CBAM(in_channels=256) self.cbam4 = CBAM(in_channels=512) # Modify the final layer for binary classification self.vgg.classifier[6] = nn.Linear(self.vgg.classifier[6].in_features, num_classes) def forward(self, x): x = self.vgg.features[0:5](x) x = self.cbam1(x) x = self.vgg.features[5:10](x) x = self.cbam2(x) x = self.vgg.features[10:17](x) x = self.cbam3(x) x = self.vgg.features[17:24](x) x = self.cbam4(x) x = self.vgg.features[24:](x) x = self.vgg.avgpool(x) x = torch.flatten(x, 1) x = self.vgg.classifier(x) return x # Instantiate the model num_classes = 2 # Binary classification model = VGGWithAttention(num_classes=num_classes).to(device) # Apply dynamic quantization (same as in the saving script) quantized_model_dynamic = torch.quantization.quantize_dynamic( model, {nn.Linear}, dtype=torch.qint8 ) # Load the state_dict of the dynamically quantized model quantized_model_dynamic.load_state_dict(torch.load(model_load_path, map_location=device)) quantized_model_dynamic.eval() # Set the model to evaluation mode # Preprocess the input image preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Load and preprocess the image image = Image.open(img_path).convert('RGB') image = crop_black_background(image) # Apply cropping image = apply_clahe(image) # Apply CLAHE image = preprocess(image).unsqueeze(0).to(device) # Add batch dimension and move to device # Make prediction with torch.no_grad(): output = quantized_model_dynamic(image) print(output) _, predicted = torch.max(output, 1) print(predicted) prediction = predicted.item() # Output the prediction print(f'Predicted class: {prediction}')