| |
| """ |
| Example: Using SDXL Detector from HuggingFace |
| ============================================== |
| |
| Simple example showing how to use the SDXL detector |
| to classify images as real or SDXL-generated. |
| """ |
|
|
| import torch |
| from torchvision import transforms |
| from PIL import Image |
| from huggingface_hub import hf_hub_download |
| import torch.nn as nn |
| import torchvision.models as models |
|
|
| |
| |
| |
|
|
| class SDXLDetector(nn.Module): |
| """ResNet-50 based SDXL detector""" |
| |
| def __init__(self): |
| super().__init__() |
| self.backbone = models.resnet50(pretrained=False) |
| num_features = self.backbone.fc.in_features |
| |
| self.backbone.fc = nn.Sequential( |
| nn.Dropout(p=0.3), |
| nn.Linear(num_features, 512), |
| nn.BatchNorm1d(512), |
| nn.ReLU(inplace=True), |
| nn.Dropout(p=0.15), |
| nn.Linear(512, 2) |
| ) |
| |
| def forward(self, x): |
| return self.backbone(x) |
|
|
| |
| |
| |
|
|
| def load_model(device='cpu'): |
| """Load model from HuggingFace Hub""" |
| |
| |
| model_path = hf_hub_download( |
| repo_id="ash12321/sdxl-detector-resnet50", |
| filename="best.pth" |
| ) |
| |
| |
| checkpoint = torch.load(model_path, map_location=device) |
| |
| |
| model = SDXLDetector() |
| model.load_state_dict(checkpoint['model_state_dict']) |
| model.to(device) |
| model.eval() |
| |
| print(f"✅ Model loaded from {model_path}") |
| print(f" Trained for {checkpoint['epoch'] + 1} epochs") |
| print(f" Best validation accuracy: {checkpoint['best_val_acc']:.2f}%") |
| |
| return model |
|
|
| |
| |
| |
|
|
| def get_transform(): |
| """Get image preprocessing transform""" |
| return transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225] |
| ) |
| ]) |
|
|
| |
| |
| |
|
|
| def predict_image(model, image_path, device='cpu'): |
| """ |
| Predict if an image is real or SDXL-generated |
| |
| Args: |
| model: Loaded SDXLDetector model |
| image_path: Path to image file |
| device: Device to run inference on |
| |
| Returns: |
| dict with prediction, confidence, and probabilities |
| """ |
| |
| |
| image = Image.open(image_path).convert('RGB') |
| transform = get_transform() |
| input_tensor = transform(image).unsqueeze(0).to(device) |
| |
| |
| with torch.no_grad(): |
| outputs = model(input_tensor) |
| probs = torch.softmax(outputs, dim=1) |
| prediction = torch.argmax(probs, dim=1).item() |
| confidence = probs[0][prediction].item() |
| |
| |
| labels = ['Real', 'SDXL-generated'] |
| |
| return { |
| 'prediction': labels[prediction], |
| 'confidence': confidence, |
| 'probabilities': { |
| 'real': probs[0][0].item(), |
| 'sdxl': probs[0][1].item() |
| } |
| } |
|
|
| |
| |
| |
|
|
| def main(): |
| """Example usage""" |
| |
| |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| print(f"Using device: {device}") |
| |
| |
| model = load_model(device) |
| |
| |
| image_path = "test_image.jpg" |
| |
| result = predict_image(model, image_path, device) |
| |
| print(f"\n📊 Results for {image_path}:") |
| print(f" Prediction: {result['prediction']}") |
| print(f" Confidence: {result['confidence']*100:.2f}%") |
| print(f" \nProbabilities:") |
| print(f" Real: {result['probabilities']['real']*100:.2f}%") |
| print(f" SDXL: {result['probabilities']['sdxl']*100:.2f}%") |
|
|
| if __name__ == "__main__": |
| main() |
|
|