| """ |
| 視覺化工具模組 |
| |
| 提供對抗樣本視覺化、實驗結果繪圖等功能 |
| """ |
|
|
| import matplotlib.pyplot as plt |
| import matplotlib |
| matplotlib.use('Agg') |
| import numpy as np |
| import torch |
| from pathlib import Path |
|
|
|
|
| def visualize_attack(original, adversarial, perturbation, |
| original_label, orig_pred, adv_pred, |
| epsilon, save_path=None, amplify=10): |
| """ |
| 視覺化對抗攻擊結果 (三圖並排) |
| |
| Args: |
| original (torch.Tensor or np.ndarray): 原始影像 |
| adversarial (torch.Tensor or np.ndarray): 對抗樣本 |
| perturbation (torch.Tensor or np.ndarray): 擾動 |
| original_label (int): 真實標籤 |
| orig_pred (int): 原始預測 |
| adv_pred (int): 對抗樣本預測 |
| epsilon (float): 擾動幅度 |
| save_path (str, optional): 儲存路徑 |
| amplify (int): 擾動放大倍數 (方便觀察) |
| """ |
| |
| if torch.is_tensor(original): |
| original = original.detach().cpu().numpy() |
| if torch.is_tensor(adversarial): |
| adversarial = adversarial.detach().cpu().numpy() |
| if torch.is_tensor(perturbation): |
| perturbation = perturbation.detach().cpu().numpy() |
|
|
| |
| if original.ndim == 3: |
| original = original.squeeze() |
| if adversarial.ndim == 3: |
| adversarial = adversarial.squeeze() |
| if perturbation.ndim == 3: |
| perturbation = perturbation.squeeze() |
|
|
| |
| fig, axes = plt.subplots(1, 3, figsize=(12, 4)) |
|
|
| |
| axes[0].imshow(original, cmap='gray') |
| axes[0].set_title(f'原始影像\n真實標籤: {original_label}\n預測: {orig_pred}', fontsize=10) |
| axes[0].axis('off') |
|
|
| |
| amplified_pert = perturbation * amplify |
| axes[1].imshow(amplified_pert, cmap='seismic', vmin=-epsilon*amplify, vmax=epsilon*amplify) |
| axes[1].set_title(f'擾動 (放大 {amplify}x)\nε = {epsilon}', fontsize=10) |
| axes[1].axis('off') |
|
|
| |
| axes[2].imshow(adversarial, cmap='gray') |
| success = "✓ 成功" if orig_pred != adv_pred else "✗ 失敗" |
| axes[2].set_title(f'對抗樣本\n預測: {adv_pred}\n攻擊{success}', fontsize=10) |
| axes[2].axis('off') |
|
|
| plt.tight_layout() |
|
|
| |
| if save_path: |
| Path(save_path).parent.mkdir(parents=True, exist_ok=True) |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| print(f"圖片已儲存至 {save_path}") |
|
|
| plt.close() |
|
|
| return fig |
|
|
|
|
| def visualize_multiple_attacks(examples, epsilon, save_path=None, max_examples=10): |
| """ |
| 視覺化多個對抗攻擊範例 |
| |
| Args: |
| examples (list): 攻擊範例列表,每個元素包含: |
| - original: 原始影像 |
| - adversarial: 對抗樣本 |
| - perturbation: 擾動 |
| - true_label: 真實標籤 |
| - original_pred: 原始預測 |
| - adv_pred: 對抗預測 |
| epsilon (float): 擾動幅度 |
| save_path (str, optional): 儲存路徑 |
| max_examples (int): 最多顯示幾個範例 |
| """ |
| n_examples = min(len(examples), max_examples) |
|
|
| fig, axes = plt.subplots(n_examples, 3, figsize=(10, 3*n_examples)) |
|
|
| if n_examples == 1: |
| axes = axes.reshape(1, -1) |
|
|
| for i, example in enumerate(examples[:n_examples]): |
| |
| original = example['original'].squeeze().numpy() |
| adversarial = example['adversarial'].squeeze().numpy() |
| perturbation = example['perturbation'].squeeze().numpy() |
|
|
| |
| axes[i, 0].imshow(original, cmap='gray') |
| if i == 0: |
| axes[i, 0].set_title('原始影像', fontsize=12, fontweight='bold') |
| axes[i, 0].set_ylabel(f'範例 {i+1}\n真實: {example["true_label"]}\n預測: {example["original_pred"]}', |
| fontsize=10) |
| axes[i, 0].set_xticks([]) |
| axes[i, 0].set_yticks([]) |
|
|
| |
| axes[i, 1].imshow(perturbation * 10, cmap='seismic', vmin=-epsilon*10, vmax=epsilon*10) |
| if i == 0: |
| axes[i, 1].set_title(f'擾動 (放大 10x)', fontsize=12, fontweight='bold') |
| axes[i, 1].set_xticks([]) |
| axes[i, 1].set_yticks([]) |
|
|
| |
| axes[i, 2].imshow(adversarial, cmap='gray') |
| if i == 0: |
| axes[i, 2].set_title('對抗樣本', fontsize=12, fontweight='bold') |
| axes[i, 2].set_ylabel(f'預測: {example["adv_pred"]}', fontsize=10) |
| axes[i, 2].set_xticks([]) |
| axes[i, 2].set_yticks([]) |
|
|
| plt.suptitle(f'FGSM 對抗攻擊範例 (ε = {epsilon})', fontsize=14, fontweight='bold', y=0.995) |
| plt.tight_layout() |
|
|
| if save_path: |
| Path(save_path).parent.mkdir(parents=True, exist_ok=True) |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| print(f"圖片已儲存至 {save_path}") |
|
|
| plt.close() |
|
|
| return fig |
|
|
|
|
| def plot_accuracy_vs_epsilon(epsilons, accuracies, save_path=None): |
| """ |
| 繪製準確率 vs ε 曲線圖 |
| |
| Args: |
| epsilons (list): ε 值列表 |
| accuracies (list): 對應的準確率列表 |
| save_path (str, optional): 儲存路徑 |
| """ |
| plt.figure(figsize=(10, 6)) |
|
|
| plt.plot(epsilons, accuracies, 'o-', linewidth=2, markersize=8, color='#2E86AB') |
| plt.xlabel('擾動幅度 (ε)', fontsize=12, fontweight='bold') |
| plt.ylabel('準確率 (%)', fontsize=12, fontweight='bold') |
| plt.title('FGSM 攻擊: 準確率 vs 擾動幅度', fontsize=14, fontweight='bold') |
| plt.grid(True, alpha=0.3, linestyle='--') |
|
|
| |
| for eps, acc in zip(epsilons, accuracies): |
| plt.annotate(f'{acc:.1f}%', |
| xy=(eps, acc), |
| xytext=(0, 10), |
| textcoords='offset points', |
| ha='center', |
| fontsize=9) |
|
|
| plt.tight_layout() |
|
|
| if save_path: |
| Path(save_path).parent.mkdir(parents=True, exist_ok=True) |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| print(f"圖片已儲存至 {save_path}") |
|
|
| plt.close() |
|
|
|
|
| def plot_attack_success_rate(epsilons, success_rates, save_path=None): |
| """ |
| 繪製攻擊成功率 vs ε 曲線圖 |
| |
| Args: |
| epsilons (list): ε 值列表 |
| success_rates (list): 攻擊成功率列表 |
| save_path (str, optional): 儲存路徑 |
| """ |
| plt.figure(figsize=(10, 6)) |
|
|
| plt.plot(epsilons, success_rates, 'o-', linewidth=2, markersize=8, color='#A23B72') |
| plt.xlabel('擾動幅度 (ε)', fontsize=12, fontweight='bold') |
| plt.ylabel('攻擊成功率 (%)', fontsize=12, fontweight='bold') |
| plt.title('FGSM 攻擊成功率 vs 擾動幅度', fontsize=14, fontweight='bold') |
| plt.grid(True, alpha=0.3, linestyle='--') |
|
|
| |
| for eps, rate in zip(epsilons, success_rates): |
| plt.annotate(f'{rate:.1f}%', |
| xy=(eps, rate), |
| xytext=(0, 10), |
| textcoords='offset points', |
| ha='center', |
| fontsize=9) |
|
|
| plt.tight_layout() |
|
|
| if save_path: |
| Path(save_path).parent.mkdir(parents=True, exist_ok=True) |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| print(f"圖片已儲存至 {save_path}") |
|
|
| plt.close() |
|
|
|
|
| def plot_comparison(epsilons, original_accs, adversarial_accs, save_path=None): |
| """ |
| 繪製原始準確率和對抗準確率的比較圖 |
| |
| Args: |
| epsilons (list): ε 值列表 |
| original_accs (list): 原始準確率列表 |
| adversarial_accs (list): 對抗樣本準確率列表 |
| save_path (str, optional): 儲存路徑 |
| """ |
| plt.figure(figsize=(12, 6)) |
|
|
| plt.plot(epsilons, original_accs, 'o-', linewidth=2, markersize=8, |
| label='原始準確率', color='#2E86AB') |
| plt.plot(epsilons, adversarial_accs, 's-', linewidth=2, markersize=8, |
| label='對抗樣本準確率', color='#A23B72') |
|
|
| plt.xlabel('擾動幅度 (ε)', fontsize=12, fontweight='bold') |
| plt.ylabel('準確率 (%)', fontsize=12, fontweight='bold') |
| plt.title('FGSM 攻擊效果分析', fontsize=14, fontweight='bold') |
| plt.legend(fontsize=11, loc='best') |
| plt.grid(True, alpha=0.3, linestyle='--') |
|
|
| plt.tight_layout() |
|
|
| if save_path: |
| Path(save_path).parent.mkdir(parents=True, exist_ok=True) |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| print(f"圖片已儲存至 {save_path}") |
|
|
| plt.close() |
|
|
|
|
| def create_gradio_output(original, adversarial, perturbation, |
| original_label, orig_pred, adv_pred, epsilon): |
| """ |
| 為 Gradio 介面建立視覺化輸出 |
| |
| Args: |
| original (torch.Tensor): 原始影像 |
| adversarial (torch.Tensor): 對抗樣本 |
| perturbation (torch.Tensor): 擾動 |
| original_label (int): 真實標籤 |
| orig_pred (int): 原始預測 |
| adv_pred (int): 對抗預測 |
| epsilon (float): 擾動幅度 |
| |
| Returns: |
| tuple: (original_dict, adv_dict, visualization_image) |
| """ |
| |
| original_dict = {f"數字 {orig_pred}": 1.0} |
| adv_dict = {f"數字 {adv_pred}": 1.0} |
|
|
| |
| fig = visualize_attack( |
| original, adversarial, perturbation, |
| original_label, orig_pred, adv_pred, epsilon |
| ) |
|
|
| |
| fig.canvas.draw() |
| vis_image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) |
| vis_image = vis_image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
|
|
| plt.close(fig) |
|
|
| return original_dict, adv_dict, vis_image |
|
|
|
|
| if __name__ == "__main__": |
| |
| print("測試視覺化模組...") |
|
|
| |
| original = torch.rand(1, 28, 28) |
| adversarial = torch.clamp(original + torch.randn(1, 28, 28) * 0.1, 0, 1) |
| perturbation = adversarial - original |
|
|
| |
| print("\n測試單個攻擊視覺化...") |
| visualize_attack( |
| original, adversarial, perturbation, |
| original_label=7, orig_pred=7, adv_pred=3, |
| epsilon=0.3, |
| save_path='test_single_attack.png' |
| ) |
|
|
| |
| print("\n測試準確率曲線...") |
| epsilons = [0.0, 0.1, 0.2, 0.3] |
| accuracies = [98.5, 85.2, 65.3, 45.8] |
| plot_accuracy_vs_epsilon(epsilons, accuracies, 'test_accuracy_curve.png') |
|
|
| |
| print("\n測試攻擊成功率曲線...") |
| success_rates = [0.0, 15.3, 35.2, 54.7] |
| plot_attack_success_rate(epsilons, success_rates, 'test_success_rate.png') |
|
|
| print("\n視覺化模組測試完成!") |
|
|