| |
| """Last_model.ipynb |
| |
| Automatically generated by Colab. |
| |
| Original file is located at |
| https://colab.research.google.com/drive/1AdRILP1oqdiVuRSQr2dZZy0QgU8insn_ |
| |
| 🚗 TwinCar Project: SOTA Training, Full Visuals, and Advanced Reporting |
| |
| |
| --- |
| |
| |
| --- |
| |
| 1. Environment Setup and Imports |
| Explanation: |
| We start by importing all necessary libraries and prepping our working environment for advanced data handling and visualization. |
| |
| --- |
| """ |
|
|
| |
| import os |
| import zipfile |
| import numpy as np |
| import pandas as pd |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| from PIL import Image |
| from tqdm import tqdm |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler |
| from torchvision import transforms |
|
|
| from sklearn.model_selection import train_test_split |
| from sklearn.utils.class_weight import compute_class_weight |
| from sklearn.metrics import ( |
| accuracy_score, precision_score, recall_score, f1_score, hamming_loss, |
| cohen_kappa_score, matthews_corrcoef, jaccard_score, |
| confusion_matrix, classification_report |
| ) |
|
|
| import timm |
| import scipy.io |
|
|
| """2. Data Extraction and Preparation |
| Explanation: |
| We extract and organize the Stanford Cars dataset, parse .mat files to CSV for class and label mapping, and prepare all paths. |
| |
| --- |
| |
| |
| """ |
|
|
| |
| from google.colab import drive |
| drive.mount('/content/drive') |
|
|
| zip_path = '/content/drive/MyDrive/stanford_cars.zip' |
| extract_dir = '/content/stanford_cars' |
| if not os.path.exists(extract_dir): |
| with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
| zip_ref.extractall(extract_dir) |
| print("✅ Dataset extracted at", extract_dir) |
|
|
| meta = scipy.io.loadmat(f"{extract_dir}/car_devkit/devkit/cars_meta.mat") |
| class_names = [x[0] for x in meta['class_names'][0]] |
| NUM_CLASSES = len(class_names) |
|
|
| train_annos = scipy.io.loadmat(f"{extract_dir}/car_devkit/devkit/cars_train_annos.mat")['annotations'][0] |
| train_rows = [[x[5][0], int(x[4][0]) - 1] for x in train_annos] |
| df_train = pd.DataFrame(train_rows, columns=["filename", "label"]) |
| df_train.to_csv('/content/train_labels.csv', index=False) |
|
|
| test_annos = scipy.io.loadmat(f"{extract_dir}/car_devkit/devkit/cars_test_annos.mat")['annotations'][0] |
| test_rows = [[x[4][0]] for x in test_annos] |
| df_test = pd.DataFrame(test_rows, columns=["filename"]) |
| df_test.to_csv('/content/test_labels.csv', index=False) |
|
|
| train_root = f"{extract_dir}/cars_train/cars_train" |
| test_root = f"{extract_dir}/cars_test/cars_test" |
|
|
| """3. Advanced Dataset and Augmentations |
| Explanation: |
| We build a flexible dataset class, apply advanced augmentations, and lay the foundation for Mixup/CutMix later. |
| |
| --- |
| |
| |
| """ |
|
|
| |
|
|
| class StanfordCarsFromCSV(Dataset): |
| def __init__(self, root_dir, csv_file, transform=None, has_labels=True): |
| self.root_dir = root_dir |
| self.data = pd.read_csv(csv_file) |
| self.transform = transform |
| self.has_labels = has_labels |
| def __len__(self): |
| return len(self.data) |
| def __getitem__(self, idx): |
| row = self.data.iloc[idx] |
| img_path = os.path.join(self.root_dir, row['filename']) |
| image = Image.open(img_path).convert('RGB') |
| if self.transform: |
| image = self.transform(image) |
| if self.has_labels: |
| return image, int(row['label']) |
| return image, row['filename'] |
|
|
| imagenet_mean = [0.485, 0.456, 0.406] |
| imagenet_std = [0.229, 0.224, 0.225] |
| train_transform = transforms.Compose([ |
| transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), |
| transforms.RandomHorizontalFlip(), |
| transforms.RandomRotation(15), |
| transforms.ColorJitter(0.4, 0.4, 0.4, 0.2), |
| transforms.RandomApply([transforms.GaussianBlur(3)], p=0.15), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=imagenet_mean, std=imagenet_std) |
| ]) |
| val_transform = transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=imagenet_mean, std=imagenet_std) |
| ]) |
|
|
| """4. Data Splitting, Weighted Sampling, and DataLoader |
| Explanation: |
| We split the data into train and validation sets with stratification for balanced classes, |
| use class weighting to counter imbalance, and create PyTorch DataLoaders for efficient training and evaluation. |
| |
| --- |
| |
| |
| """ |
|
|
| |
|
|
| from torch.utils.data import DataLoader, WeightedRandomSampler |
| from sklearn.model_selection import train_test_split |
| from sklearn.utils.class_weight import compute_class_weight |
|
|
| |
| BATCH_SIZE = 32 |
| VAL_RATIO = 0.1 |
| RANDOM_SEED = 42 |
|
|
| |
| df_all = pd.read_csv('/content/train_labels.csv') |
| df_train, df_val = train_test_split( |
| df_all, |
| test_size=VAL_RATIO, |
| stratify=df_all['label'], |
| random_state=RANDOM_SEED |
| ) |
| df_train.to_csv('/content/train_split.csv', index=False) |
| df_val.to_csv('/content/val_split.csv', index=False) |
|
|
| |
| train_dataset = StanfordCarsFromCSV(train_root, '/content/train_split.csv', train_transform) |
| val_dataset = StanfordCarsFromCSV(train_root, '/content/val_split.csv', val_transform) |
| test_dataset = StanfordCarsFromCSV(test_root, '/content/test_labels.csv', val_transform, has_labels=False) |
|
|
| |
| labels = [label for _, label in train_dataset] |
| class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(labels), y=labels) |
| sample_weights = [class_weights[label] for label in labels] |
| sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True) |
|
|
| |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=BATCH_SIZE, |
| sampler=sampler, |
| num_workers=2, |
| pin_memory=True, |
| drop_last=True |
| ) |
| val_loader = DataLoader( |
| val_dataset, |
| batch_size=BATCH_SIZE, |
| shuffle=False, |
| num_workers=2, |
| pin_memory=True, |
| drop_last=False |
| ) |
| test_loader = DataLoader( |
| test_dataset, |
| batch_size=BATCH_SIZE, |
| shuffle=False, |
| num_workers=2, |
| pin_memory=True, |
| drop_last=False |
| ) |
|
|
| print(f"Train samples: {len(train_dataset)} | Val samples: {len(val_dataset)} | Test samples: {len(test_dataset)}") |
| print(f"Train loader batches (per epoch): {len(train_loader)} (should be integer and even-sized)") |
|
|
| """5. Model Initialization: EfficientNetV2 + Mixup/CutMix Ready |
| Explanation: |
| We load EfficientNetV2 with ImageNet weights for best transfer learning, |
| set up optimizer, scheduler, and prepare for Mixup/CutMix advanced augmentation. |
| |
| --- |
| |
| |
| """ |
|
|
| |
|
|
| from timm.data import Mixup |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| model = timm.create_model('efficientnetv2_rw_s', pretrained=True, num_classes=NUM_CLASSES, drop_rate=0.3) |
| model = model.to(device) |
|
|
| optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5) |
| scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=25) |
| criterion = nn.CrossEntropyLoss(label_smoothing=0.0) |
|
|
| mixup_fn = Mixup( |
| mixup_alpha=0.4, cutmix_alpha=1.0, cutmix_minmax=None, |
| prob=1.0, switch_prob=0.5, mode='batch', |
| label_smoothing=0.1, num_classes=NUM_CLASSES |
| ) |
|
|
| """6. Advanced Training Loop: Full Metrics, Early Stopping, and Mixup |
| Explanation: |
| This loop supports Mixup/CutMix, logs all advanced metrics, and uses early stopping with automatic best model saving. |
| Ready for real production—and all your plots and reporting. |
| |
| --- |
| |
| |
| """ |
|
|
| |
|
|
| EPOCHS = 25 |
| patience, counter = 7, 0 |
| best_val_f1 = 0 |
|
|
| metrics_dict = { |
| 'train_loss': [], 'train_acc': [], |
| 'val_loss': [], 'val_acc': [], |
| 'val_precision_macro': [], 'val_precision_weighted': [], |
| 'val_recall_macro': [], 'val_recall_weighted': [], |
| 'val_f1_macro': [], 'val_f1_weighted': [], |
| 'val_hamming': [], 'val_cohen_kappa': [], |
| 'val_mcc': [], 'val_jaccard_macro': [], |
| 'val_top3': [], 'val_top5': [], |
| } |
|
|
| for epoch in range(EPOCHS): |
| |
| model.train() |
| total_loss, correct, total = 0, 0, 0 |
| for imgs, labels in tqdm(train_loader, desc=f"Train Epoch {epoch+1}"): |
| imgs, labels = imgs.to(device), labels.to(device) |
| optimizer.zero_grad() |
| imgs, labels = mixup_fn(imgs, labels) |
| outputs = model(imgs) |
| loss = criterion(outputs, labels) |
| loss.backward() |
| optimizer.step() |
| total_loss += loss.item() * imgs.size(0) |
| correct += (outputs.argmax(1) == labels.argmax(1)).sum().item() |
| total += labels.size(0) |
| train_loss = total_loss / total |
| train_acc = correct / total |
| metrics_dict['train_loss'].append(train_loss) |
| metrics_dict['train_acc'].append(train_acc) |
|
|
| |
| model.eval() |
| val_loss, val_correct, val_total = 0, 0, 0 |
| val_probs, val_preds, val_targets = [], [], [] |
| with torch.no_grad(): |
| for imgs, labels in tqdm(val_loader, desc=f"Val Epoch {epoch+1}"): |
| imgs, labels = imgs.to(device), labels.to(device) |
| outputs = model(imgs) |
| v_loss = criterion(outputs, labels) |
| val_loss += v_loss.item() * imgs.size(0) |
| probs = torch.softmax(outputs, dim=1) |
| preds = outputs.argmax(1) |
| val_correct += (preds == labels).sum().item() |
| val_total += labels.size(0) |
| val_probs.extend(probs.cpu().numpy()) |
| val_preds.extend(preds.cpu().numpy()) |
| val_targets.extend(labels.cpu().numpy()) |
| val_loss /= val_total |
| val_acc = val_correct / val_total |
| val_preds_np = np.array(val_preds) |
| val_targets_np = np.array(val_targets) |
| val_probs_np = np.array(val_probs) |
|
|
| |
| val_precision_macro = precision_score(val_targets_np, val_preds_np, average='macro', zero_division=0) |
| val_precision_weighted = precision_score(val_targets_np, val_preds_np, average='weighted', zero_division=0) |
| val_recall_macro = recall_score(val_targets_np, val_preds_np, average='macro', zero_division=0) |
| val_recall_weighted = recall_score(val_targets_np, val_preds_np, average='weighted', zero_division=0) |
| val_f1_macro = f1_score(val_targets_np, val_preds_np, average='macro', zero_division=0) |
| val_f1_weighted = f1_score(val_targets_np, val_preds_np, average='weighted', zero_division=0) |
| top3_acc = np.mean([ |
| label in np.argsort(prob)[-3:] for prob, label in zip(val_probs_np, val_targets_np) |
| ]) |
| top5_acc = np.mean([ |
| label in np.argsort(prob)[-5:] for prob, label in zip(val_probs_np, val_targets_np) |
| ]) |
| val_hamming = hamming_loss(val_targets_np, val_preds_np) |
| val_cohen_kappa = cohen_kappa_score(val_targets_np, val_preds_np) |
| val_mcc = matthews_corrcoef(val_targets_np, val_preds_np) |
| val_jaccard_macro = jaccard_score(val_targets_np, val_preds_np, average='macro', zero_division=0) |
|
|
| |
| metrics_dict['val_loss'].append(val_loss) |
| metrics_dict['val_acc'].append(val_acc) |
| metrics_dict['val_precision_macro'].append(val_precision_macro) |
| metrics_dict['val_precision_weighted'].append(val_precision_weighted) |
| metrics_dict['val_recall_macro'].append(val_recall_macro) |
| metrics_dict['val_recall_weighted'].append(val_recall_weighted) |
| metrics_dict['val_f1_macro'].append(val_f1_macro) |
| metrics_dict['val_f1_weighted'].append(val_f1_weighted) |
| metrics_dict['val_hamming'].append(val_hamming) |
| metrics_dict['val_cohen_kappa'].append(val_cohen_kappa) |
| metrics_dict['val_mcc'].append(val_mcc) |
| metrics_dict['val_jaccard_macro'].append(val_jaccard_macro) |
| metrics_dict['val_top3'].append(top3_acc) |
| metrics_dict['val_top5'].append(top5_acc) |
|
|
| scheduler.step() |
| print(f"Epoch {epoch+1:2d} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f} | F1(macro): {val_f1_macro:.4f} | Top3: {top3_acc:.3f} | Top5: {top5_acc:.3f}") |
|
|
| |
| if val_f1_macro > best_val_f1: |
| best_val_f1 = val_f1_macro |
| torch.save(model.state_dict(), '/content/drive/MyDrive/efficientnetv2_best_model.pth') |
| counter = 0 |
| else: |
| counter += 1 |
| if counter >= patience: |
| print("⏹️ Early stopping triggered.") |
| break |
|
|
| print("✅ Training complete. Best model saved.") |
|
|
| """7.Explanation |
| After training, all metrics (accuracy, loss, precision, recall, F1, top-k, etc.) are saved as a CSV for analysis and reporting. |
| |
| We plot core metrics (accuracy, F1, loss, precision/recall, top-3/top-5 accuracy) with: |
| |
| Large, clear fonts |
| |
| Annotations for best epoch |
| |
| Colorful, pro-style Seaborn plots |
| |
| Publication-ready grid and tight layouts |
| |
| --- |
| |
| |
| """ |
|
|
| |
|
|
| import seaborn as sns |
|
|
| |
| metrics_df = pd.DataFrame(metrics_dict) |
| metrics_df.to_csv('/content/drive/MyDrive/metrics_log.csv', index_label='epoch') |
| print("✅ metrics_log.csv saved.") |
|
|
| sns.set(style='whitegrid', font_scale=1.3) |
|
|
| |
| plt.figure(figsize=(12,7)) |
| plt.plot(metrics_df['train_acc'], label='Train Acc', lw=2) |
| plt.plot(metrics_df['val_acc'], label='Val Acc', lw=2) |
| plt.plot(metrics_df['val_f1_macro'], label='Val F1 (macro)', lw=2) |
| plt.xlabel('Epoch', fontsize=16) |
| plt.ylabel('Score', fontsize=16) |
| plt.title('Accuracy and Macro F1 per Epoch', fontsize=18) |
| plt.legend(loc='lower right') |
| plt.grid(True, alpha=0.3) |
| best_epoch = metrics_df['val_f1_macro'].idxmax() |
| plt.scatter(best_epoch, metrics_df['val_f1_macro'][best_epoch], c='red', s=90, label='Best Epoch') |
| plt.annotate(f'Best\n{metrics_df["val_f1_macro"][best_epoch]:.2f}', |
| (best_epoch, metrics_df["val_f1_macro"][best_epoch]), |
| textcoords="offset points", xytext=(-5,10), ha='right', fontsize=14, color='red') |
| plt.tight_layout() |
| plt.savefig('/content/drive/MyDrive/metrics_acc_f1_beautiful.png') |
| plt.show() |
|
|
| |
| plt.figure(figsize=(12,7)) |
| plt.plot(metrics_df['train_loss'], label='Train Loss', lw=2) |
| plt.plot(metrics_df['val_loss'], label='Val Loss', lw=2) |
| plt.xlabel('Epoch', fontsize=16) |
| plt.ylabel('Loss', fontsize=16) |
| plt.title('Train & Validation Loss per Epoch', fontsize=18) |
| plt.legend(loc='upper right') |
| plt.grid(True, alpha=0.3) |
| plt.tight_layout() |
| plt.savefig('/content/drive/MyDrive/metrics_loss_beautiful.png') |
| plt.show() |
|
|
| |
| plt.figure(figsize=(12,7)) |
| plt.plot(metrics_df['val_precision_macro'], label='Val Precision (macro)', lw=2) |
| plt.plot(metrics_df['val_recall_macro'], label='Val Recall (macro)', lw=2) |
| plt.plot(metrics_df['val_precision_weighted'], label='Val Precision (weighted)', lw=2) |
| plt.plot(metrics_df['val_recall_weighted'], label='Val Recall (weighted)', lw=2) |
| plt.xlabel('Epoch', fontsize=16) |
| plt.ylabel('Score', fontsize=16) |
| plt.title('Validation Precision & Recall per Epoch', fontsize=18) |
| plt.legend(loc='lower right') |
| plt.grid(True, alpha=0.3) |
| plt.tight_layout() |
| plt.savefig('/content/drive/MyDrive/metrics_precision_recall_beautiful.png') |
| plt.show() |
|
|
| |
| plt.figure(figsize=(12,7)) |
| plt.fill_between(metrics_df.index, metrics_df['val_top3'], alpha=0.3, label='Val Top-3 Acc') |
| plt.fill_between(metrics_df.index, metrics_df['val_top5'], alpha=0.2, label='Val Top-5 Acc', color='orange') |
| plt.plot(metrics_df['val_top3'], lw=2, color='blue') |
| plt.plot(metrics_df['val_top5'], lw=2, color='orange') |
| plt.xlabel('Epoch', fontsize=16) |
| plt.ylabel('Accuracy', fontsize=16) |
| plt.title('Top-3 and Top-5 Validation Accuracy per Epoch', fontsize=18) |
| plt.legend(loc='lower right') |
| plt.grid(True, alpha=0.3) |
| plt.tight_layout() |
| plt.savefig('/content/drive/MyDrive/metrics_topk_beautiful.png') |
| plt.show() |
|
|
| """8.Confusion Matrix & Per-Class Analysis with Advanced Visuals |
| Explanation |
| After training, it's crucial to understand not just overall metrics, but where your model succeeds and fails. |
| We: |
| |
| Save a detailed classification report (per-class precision/recall/F1). |
| |
| Draw a high-contrast confusion matrix with large ticks, tight color scaling, and readable value overlays. |
| |
| Plot Top 20 Most Confused Classes for targeted debugging. |
| |
| Show Top 20 Most Accurate Classes with horizontal barplots (values on bars, sorted). |
| |
| |
| |
| --- |
| |
| |
| """ |
|
|
| |
|
|
| from sklearn.metrics import classification_report, confusion_matrix |
| import seaborn as sns |
|
|
| |
| model.load_state_dict(torch.load('/content/drive/MyDrive/efficientnetv2_best_model.pth', map_location=device)) |
| model.eval() |
|
|
| |
| all_preds, all_labels = [], [] |
| with torch.no_grad(): |
| for imgs, labels in val_loader: |
| imgs, labels = imgs.to(device), labels.to(device) |
| outputs = model(imgs) |
| preds = outputs.argmax(1) |
| all_preds.extend(preds.cpu().numpy()) |
| all_labels.extend(labels.cpu().numpy()) |
| all_preds = np.array(all_preds) |
| all_labels = np.array(all_labels) |
|
|
| |
| report = classification_report( |
| all_labels, all_preds, target_names=class_names, output_dict=True |
| ) |
| pd.DataFrame(report).transpose().to_csv('/content/drive/MyDrive/classification_report.csv') |
| print("✅ classification_report.csv saved.") |
|
|
| |
| cm = confusion_matrix(all_labels, all_preds) |
| plt.figure(figsize=(18,18)) |
| sns.heatmap( |
| cm, |
| cmap="Blues", |
| xticklabels=class_names, |
| yticklabels=class_names, |
| square=True, |
| cbar_kws={"shrink": 0.5, "label": "Count"}, |
| linewidths=.2 |
| ) |
| plt.title('Confusion Matrix', fontsize=20) |
| plt.xlabel('Predicted label', fontsize=16) |
| plt.ylabel('True label', fontsize=16) |
| plt.xticks(fontsize=8, rotation=90) |
| plt.yticks(fontsize=8) |
| plt.tight_layout() |
| plt.savefig('/content/drive/MyDrive/confusion_matrix_beautiful.png', dpi=300) |
| plt.show() |
|
|
| |
| off_diag = cm.copy() |
| np.fill_diagonal(off_diag, 0) |
| most_confused = np.argsort(off_diag.sum(axis=1))[::-1][:20] |
| cm_top = cm[np.ix_(most_confused, most_confused)] |
| labels_top = [class_names[i] for i in most_confused] |
|
|
| plt.figure(figsize=(12,10)) |
| sns.heatmap( |
| cm_top, |
| annot=True, fmt='d', cmap="Blues", |
| xticklabels=labels_top, yticklabels=labels_top, |
| linewidths=.2, cbar=False, annot_kws={"size":14} |
| ) |
| plt.title('Most Confused Classes (Top 20)', fontsize=18) |
| plt.xlabel('Predicted label', fontsize=15) |
| plt.ylabel('True label', fontsize=15) |
| plt.xticks(fontsize=11, rotation=90) |
| plt.yticks(fontsize=11) |
| plt.tight_layout() |
| plt.savefig('/content/drive/MyDrive/confused_top20_beautiful.png', dpi=300) |
| plt.show() |
|
|
| |
| acc_per_class = cm.diagonal() / (cm.sum(axis=1) + 1e-8) |
| df_acc = pd.DataFrame({'class': class_names, 'accuracy': acc_per_class}) |
| top_acc = df_acc.sort_values('accuracy', ascending=False).head(20) |
| plt.figure(figsize=(10,8)) |
| sns.barplot( |
| data=top_acc, y='class', x='accuracy', palette='Blues_d', orient='h' |
| ) |
| plt.title('Top 20 Classes by Accuracy', fontsize=18) |
| plt.xlabel('Accuracy', fontsize=15) |
| plt.ylabel('Class', fontsize=15) |
| for i, v in enumerate(top_acc['accuracy']): |
| plt.text(v + 0.01, i, f"{v:.2f}", color='blue', va='center', fontsize=13) |
| plt.tight_layout() |
| plt.savefig('/content/drive/MyDrive/top20_accuracy_beautiful.png', dpi=300) |
| plt.show() |
|
|
| """9. Test-Time Augmentation (TTA) & Batch Prediction |
| Explanation |
| Test-Time Augmentation boosts prediction robustness by averaging predictions over multiple random transformations of each test image. |
| Batch Prediction allows you to efficiently label a folder of test images with class names—production style. |
| """ |
|
|
| |
|
|
| tta_transforms = [ |
| val_transform, |
| transforms.Compose([ |
| transforms.Resize(256), |
| transforms.RandomHorizontalFlip(p=1.0), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=imagenet_mean, std=imagenet_std) |
| ]), |
| transforms.Compose([ |
| transforms.Resize(256), |
| transforms.RandomRotation(10), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=imagenet_mean, std=imagenet_std) |
| ]), |
| transforms.Compose([ |
| transforms.Resize(256), |
| transforms.ColorJitter(0.2, 0.2, 0.2, 0.1), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=imagenet_mean, std=imagenet_std) |
| ]) |
| ] |
|
|
| def tta_predict(model, img_pil, tta_transforms, device='cuda'): |
| model.eval() |
| logits = [] |
| for tform in tta_transforms: |
| img = tform(img_pil).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| logit = model(img) |
| logits.append(logit) |
| avg_logits = torch.stack(logits).mean(0) |
| return avg_logits |
|
|
| |
| tta_val_preds, tta_val_labels = [], [] |
| for imgs, labels in tqdm(val_loader, desc="TTA Validation"): |
| batch_preds = [] |
| for i in range(imgs.size(0)): |
| img_pil = transforms.ToPILImage()(imgs[i].cpu()) |
| avg_logits = tta_predict(model, img_pil, tta_transforms, device) |
| pred = avg_logits.argmax(dim=1).cpu().item() |
| batch_preds.append(pred) |
| tta_val_preds.extend(batch_preds) |
| tta_val_labels.extend(labels.cpu().numpy()) |
|
|
| tta_val_preds = np.array(tta_val_preds) |
| tta_val_labels = np.array(tta_val_labels) |
|
|
| |
| tta_f1_macro = f1_score(tta_val_labels, tta_val_preds, average='macro', zero_division=0) |
| tta_acc = accuracy_score(tta_val_labels, tta_val_preds) |
| tta_precision = precision_score(tta_val_labels, tta_val_preds, average='macro', zero_division=0) |
| tta_recall = recall_score(tta_val_labels, tta_val_preds, average='macro', zero_division=0) |
| print(f"TTA Validation Accuracy: {tta_acc:.4f}") |
| print(f"TTA Validation F1 (macro): {tta_f1_macro:.4f}") |
| print(f"TTA Validation Precision (macro): {tta_precision:.4f}") |
| print(f"TTA Validation Recall (macro): {tta_recall:.4f}") |
|
|
| |
| cm_tta = confusion_matrix(tta_val_labels, tta_val_preds) |
| plt.figure(figsize=(18,18)) |
| sns.heatmap( |
| cm_tta, |
| cmap="Purples", |
| xticklabels=class_names, |
| yticklabels=class_names, |
| square=True, |
| cbar_kws={"shrink": 0.5, "label": "Count"}, |
| linewidths=.2 |
| ) |
| plt.title('TTA Confusion Matrix (Validation)', fontsize=20) |
| plt.xlabel('Predicted label', fontsize=16) |
| plt.ylabel('True label', fontsize=16) |
| plt.xticks(fontsize=8, rotation=90) |
| plt.yticks(fontsize=8) |
| plt.tight_layout() |
| plt.savefig('/content/drive/MyDrive/tta_confusion_matrix_beautiful.png', dpi=300) |
| plt.show() |
|
|
| """10. Extraordinary Grad-CAM++ Overlays (Grid) |
| Explanation |
| We generate Grad-CAM++ visualizations for a set of sample images. |
| Each visualization shows:The input image,The Grad-CAM++ heatmap overlay,The true and predicted class for easy comparison. |
| All visualizations are saved both individually and as a large, labeled grid. |
| |
| |
| |
| --- |
| """ |
|
|
| |
|
|
| from pytorch_grad_cam import GradCAMPlusPlus |
| from pytorch_grad_cam.utils.image import show_cam_on_image |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget |
|
|
| import os |
|
|
| os.makedirs('/content/drive/MyDrive/gradcam_outputs', exist_ok=True) |
|
|
| |
| model.eval() |
| model.to(device) |
|
|
| |
| target_layer = model.blocks[-1] if hasattr(model, "blocks") else model.layer4[-1] |
|
|
| |
| cam = GradCAMPlusPlus(model=model, target_layers=[target_layer]) |
|
|
| num_images = 12 |
| fig, axes = plt.subplots(3, 4, figsize=(18, 14)) |
| fig.suptitle('Grad-CAM++ Explanations: True vs. Predicted', fontsize=22, weight='bold') |
|
|
| for idx in range(num_images): |
| img_tensor, label = val_dataset[idx] |
| img_pil = transforms.ToPILImage()(img_tensor.cpu()) |
| input_tensor = img_tensor.unsqueeze(0).to(device) |
| with torch.no_grad(): |
| output = model(input_tensor) |
| pred = output.argmax(1).item() |
| targets = [ClassifierOutputTarget(pred)] |
| grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0] |
| image_np = img_tensor.permute(1, 2, 0).cpu().numpy() |
| image_np = (image_np * np.array(imagenet_std)) + np.array(imagenet_mean) |
| image_np = np.clip(image_np, 0, 1) |
| cam_image = show_cam_on_image(image_np, grayscale_cam, use_rgb=True) |
|
|
| |
| overlay_path = f"/content/drive/MyDrive/gradcam_outputs/val_{idx}_true_{class_names[label]}_pred_{class_names[pred]}.png" |
| plt.imsave(overlay_path, cam_image) |
|
|
| |
| ax = axes[idx // 4, idx % 4] |
| ax.imshow(cam_image) |
| ax.set_title( |
| f"True: {class_names[label][:18]}\nPred: {class_names[pred][:18]}", |
| fontsize=12, |
| color="green" if pred == label else "red", |
| weight="bold" |
| ) |
| ax.axis('off') |
|
|
| plt.tight_layout(rect=[0, 0.03, 1, 0.95]) |
| plt.savefig('/content/drive/MyDrive/gradcam_outputs/gradcam_grid.png', dpi=250) |
| plt.show() |
|
|
| """11. Gradio Interactive Demo: Model + Grad-CAM++""" |
|
|
| |
|
|
| import gradio as gr |
| from PIL import Image as PILImage |
|
|
| def predict_and_explain(img): |
| image_pil = img.convert("RGB").resize((224, 224)) |
| input_tensor = val_transform(image_pil).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| output = model(input_tensor) |
| pred_idx = output.argmax().item() |
| targets = [ClassifierOutputTarget(pred_idx)] |
| grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0] |
| image_np = np.array(image_pil).astype(np.float32) / 255.0 |
| cam_image = show_cam_on_image(image_np, grayscale_cam, use_rgb=True) |
| pred_name = class_names[pred_idx] |
| return PILImage.fromarray(cam_image), f"Prediction: {pred_name} (class index {pred_idx})" |
|
|
| demo = gr.Interface( |
| fn=predict_and_explain, |
| inputs=gr.Image(type="pil", label="Upload Car Image"), |
| outputs=[gr.Image(label="Grad-CAM++ Output"), gr.Text(label="Prediction")], |
| title="🚗 TwinCar: Car Make/Model Classifier + Explainability Demo", |
| description="Upload a car photo. See the prediction (make/model/year) and a Grad-CAM++ heatmap showing what influenced the model.", |
| allow_flagging='never' |
| ) |
| demo.launch(share=True) |