FGSMDemo / src /visualization.py
hua-nchu's picture
Initial commit: FGSM Demo with Gradio interface
c84101d
"""
視覺化工具模組
提供對抗樣本視覺化、實驗結果繪圖等功能
"""
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg') # 使用非互動式後端
import numpy as np
import torch
from pathlib import Path
def visualize_attack(original, adversarial, perturbation,
original_label, orig_pred, adv_pred,
epsilon, save_path=None, amplify=10):
"""
視覺化對抗攻擊結果 (三圖並排)
Args:
original (torch.Tensor or np.ndarray): 原始影像
adversarial (torch.Tensor or np.ndarray): 對抗樣本
perturbation (torch.Tensor or np.ndarray): 擾動
original_label (int): 真實標籤
orig_pred (int): 原始預測
adv_pred (int): 對抗樣本預測
epsilon (float): 擾動幅度
save_path (str, optional): 儲存路徑
amplify (int): 擾動放大倍數 (方便觀察)
"""
# 轉換為 numpy
if torch.is_tensor(original):
original = original.detach().cpu().numpy()
if torch.is_tensor(adversarial):
adversarial = adversarial.detach().cpu().numpy()
if torch.is_tensor(perturbation):
perturbation = perturbation.detach().cpu().numpy()
# 移除 channel 維度
if original.ndim == 3:
original = original.squeeze()
if adversarial.ndim == 3:
adversarial = adversarial.squeeze()
if perturbation.ndim == 3:
perturbation = perturbation.squeeze()
# 建立圖表
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
# 原始影像
axes[0].imshow(original, cmap='gray')
axes[0].set_title(f'原始影像\n真實標籤: {original_label}\n預測: {orig_pred}', fontsize=10)
axes[0].axis('off')
# 擾動 (放大顯示)
amplified_pert = perturbation * amplify
axes[1].imshow(amplified_pert, cmap='seismic', vmin=-epsilon*amplify, vmax=epsilon*amplify)
axes[1].set_title(f'擾動 (放大 {amplify}x)\nε = {epsilon}', fontsize=10)
axes[1].axis('off')
# 對抗樣本
axes[2].imshow(adversarial, cmap='gray')
success = "✓ 成功" if orig_pred != adv_pred else "✗ 失敗"
axes[2].set_title(f'對抗樣本\n預測: {adv_pred}\n攻擊{success}', fontsize=10)
axes[2].axis('off')
plt.tight_layout()
# 儲存或顯示
if save_path:
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"圖片已儲存至 {save_path}")
plt.close()
return fig
def visualize_multiple_attacks(examples, epsilon, save_path=None, max_examples=10):
"""
視覺化多個對抗攻擊範例
Args:
examples (list): 攻擊範例列表,每個元素包含:
- original: 原始影像
- adversarial: 對抗樣本
- perturbation: 擾動
- true_label: 真實標籤
- original_pred: 原始預測
- adv_pred: 對抗預測
epsilon (float): 擾動幅度
save_path (str, optional): 儲存路徑
max_examples (int): 最多顯示幾個範例
"""
n_examples = min(len(examples), max_examples)
fig, axes = plt.subplots(n_examples, 3, figsize=(10, 3*n_examples))
if n_examples == 1:
axes = axes.reshape(1, -1)
for i, example in enumerate(examples[:n_examples]):
# 轉換為 numpy
original = example['original'].squeeze().numpy()
adversarial = example['adversarial'].squeeze().numpy()
perturbation = example['perturbation'].squeeze().numpy()
# 原始影像
axes[i, 0].imshow(original, cmap='gray')
if i == 0:
axes[i, 0].set_title('原始影像', fontsize=12, fontweight='bold')
axes[i, 0].set_ylabel(f'範例 {i+1}\n真實: {example["true_label"]}\n預測: {example["original_pred"]}',
fontsize=10)
axes[i, 0].set_xticks([])
axes[i, 0].set_yticks([])
# 擾動
axes[i, 1].imshow(perturbation * 10, cmap='seismic', vmin=-epsilon*10, vmax=epsilon*10)
if i == 0:
axes[i, 1].set_title(f'擾動 (放大 10x)', fontsize=12, fontweight='bold')
axes[i, 1].set_xticks([])
axes[i, 1].set_yticks([])
# 對抗樣本
axes[i, 2].imshow(adversarial, cmap='gray')
if i == 0:
axes[i, 2].set_title('對抗樣本', fontsize=12, fontweight='bold')
axes[i, 2].set_ylabel(f'預測: {example["adv_pred"]}', fontsize=10)
axes[i, 2].set_xticks([])
axes[i, 2].set_yticks([])
plt.suptitle(f'FGSM 對抗攻擊範例 (ε = {epsilon})', fontsize=14, fontweight='bold', y=0.995)
plt.tight_layout()
if save_path:
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"圖片已儲存至 {save_path}")
plt.close()
return fig
def plot_accuracy_vs_epsilon(epsilons, accuracies, save_path=None):
"""
繪製準確率 vs ε 曲線圖
Args:
epsilons (list): ε 值列表
accuracies (list): 對應的準確率列表
save_path (str, optional): 儲存路徑
"""
plt.figure(figsize=(10, 6))
plt.plot(epsilons, accuracies, 'o-', linewidth=2, markersize=8, color='#2E86AB')
plt.xlabel('擾動幅度 (ε)', fontsize=12, fontweight='bold')
plt.ylabel('準確率 (%)', fontsize=12, fontweight='bold')
plt.title('FGSM 攻擊: 準確率 vs 擾動幅度', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3, linestyle='--')
# 標註數值
for eps, acc in zip(epsilons, accuracies):
plt.annotate(f'{acc:.1f}%',
xy=(eps, acc),
xytext=(0, 10),
textcoords='offset points',
ha='center',
fontsize=9)
plt.tight_layout()
if save_path:
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"圖片已儲存至 {save_path}")
plt.close()
def plot_attack_success_rate(epsilons, success_rates, save_path=None):
"""
繪製攻擊成功率 vs ε 曲線圖
Args:
epsilons (list): ε 值列表
success_rates (list): 攻擊成功率列表
save_path (str, optional): 儲存路徑
"""
plt.figure(figsize=(10, 6))
plt.plot(epsilons, success_rates, 'o-', linewidth=2, markersize=8, color='#A23B72')
plt.xlabel('擾動幅度 (ε)', fontsize=12, fontweight='bold')
plt.ylabel('攻擊成功率 (%)', fontsize=12, fontweight='bold')
plt.title('FGSM 攻擊成功率 vs 擾動幅度', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3, linestyle='--')
# 標註數值
for eps, rate in zip(epsilons, success_rates):
plt.annotate(f'{rate:.1f}%',
xy=(eps, rate),
xytext=(0, 10),
textcoords='offset points',
ha='center',
fontsize=9)
plt.tight_layout()
if save_path:
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"圖片已儲存至 {save_path}")
plt.close()
def plot_comparison(epsilons, original_accs, adversarial_accs, save_path=None):
"""
繪製原始準確率和對抗準確率的比較圖
Args:
epsilons (list): ε 值列表
original_accs (list): 原始準確率列表
adversarial_accs (list): 對抗樣本準確率列表
save_path (str, optional): 儲存路徑
"""
plt.figure(figsize=(12, 6))
plt.plot(epsilons, original_accs, 'o-', linewidth=2, markersize=8,
label='原始準確率', color='#2E86AB')
plt.plot(epsilons, adversarial_accs, 's-', linewidth=2, markersize=8,
label='對抗樣本準確率', color='#A23B72')
plt.xlabel('擾動幅度 (ε)', fontsize=12, fontweight='bold')
plt.ylabel('準確率 (%)', fontsize=12, fontweight='bold')
plt.title('FGSM 攻擊效果分析', fontsize=14, fontweight='bold')
plt.legend(fontsize=11, loc='best')
plt.grid(True, alpha=0.3, linestyle='--')
plt.tight_layout()
if save_path:
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"圖片已儲存至 {save_path}")
plt.close()
def create_gradio_output(original, adversarial, perturbation,
original_label, orig_pred, adv_pred, epsilon):
"""
為 Gradio 介面建立視覺化輸出
Args:
original (torch.Tensor): 原始影像
adversarial (torch.Tensor): 對抗樣本
perturbation (torch.Tensor): 擾動
original_label (int): 真實標籤
orig_pred (int): 原始預測
adv_pred (int): 對抗預測
epsilon (float): 擾動幅度
Returns:
tuple: (original_dict, adv_dict, visualization_image)
"""
# 預測結果字典
original_dict = {f"數字 {orig_pred}": 1.0}
adv_dict = {f"數字 {adv_pred}": 1.0}
# 建立視覺化圖片
fig = visualize_attack(
original, adversarial, perturbation,
original_label, orig_pred, adv_pred, epsilon
)
# 轉換為 numpy array
fig.canvas.draw()
vis_image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
vis_image = vis_image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close(fig)
return original_dict, adv_dict, vis_image
if __name__ == "__main__":
# 測試視覺化功能
print("測試視覺化模組...")
# 建立測試資料
original = torch.rand(1, 28, 28)
adversarial = torch.clamp(original + torch.randn(1, 28, 28) * 0.1, 0, 1)
perturbation = adversarial - original
# 測試單個攻擊視覺化
print("\n測試單個攻擊視覺化...")
visualize_attack(
original, adversarial, perturbation,
original_label=7, orig_pred=7, adv_pred=3,
epsilon=0.3,
save_path='test_single_attack.png'
)
# 測試準確率曲線
print("\n測試準確率曲線...")
epsilons = [0.0, 0.1, 0.2, 0.3]
accuracies = [98.5, 85.2, 65.3, 45.8]
plot_accuracy_vs_epsilon(epsilons, accuracies, 'test_accuracy_curve.png')
# 測試攻擊成功率曲線
print("\n測試攻擊成功率曲線...")
success_rates = [0.0, 15.3, 35.2, 54.7]
plot_attack_success_rate(epsilons, success_rates, 'test_success_rate.png')
print("\n視覺化模組測試完成!")