NetGuard-AI / src /train.py
Alireza Aminzadeh
Upload folder using huggingface_hub
199e55c verified
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import os
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report, f1_score
from model import AnomalyDetector, detect_anomaly
def generate_synthetic_data(num_samples, is_anomaly=False, input_dim=41):
"""
Generates synthetic network traffic data.
Normal traffic is centered around 0 with small variance.
Anomalous traffic has higher variance and shifted means.
"""
if not is_anomaly:
# Normal traffic: Gaussian distribution centered at 0
data = np.random.normal(loc=0.0, scale=0.5, size=(num_samples, input_dim))
else:
# Anomalous traffic: Shifted mean and higher variance
data = np.random.normal(loc=2.0, scale=1.5, size=(num_samples, input_dim))
return torch.tensor(data, dtype=torch.float32)
def train_autoencoder():
print("Starting NetGuard-AI Model Training...")
# Hyperparameters
input_dim = 41
batch_size = 64
epochs = 20
learning_rate = 1e-3
# 1. Generate Training Data (Only Normal Traffic for Autoencoder)
print("Generating synthetic normal traffic for training...")
train_data = generate_synthetic_data(10000, is_anomaly=False, input_dim=input_dim)
train_loader = DataLoader(TensorDataset(train_data, train_data), batch_size=batch_size, shuffle=True)
# 2. Initialize Model
model = AnomalyDetector(input_dim=input_dim)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 3. Training Loop
model.train()
loss_history = []
for epoch in range(epochs):
epoch_loss = 0
for batch_x, _ in train_loader:
optimizer.zero_grad()
reconstructed = model(batch_x)
loss = criterion(reconstructed, batch_x)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
avg_loss = epoch_loss / len(train_loader)
loss_history.append(avg_loss)
if (epoch + 1) % 5 == 0:
print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")
# 4. Save Model
os.makedirs('../models', exist_ok=True)
model_path = '../models/autoencoder.pth'
torch.save(model.state_dict(), model_path)
print("Model saved to", model_path)
# 5. Evaluation Phase
print("\nEvaluating Model on Test Set (Mixed Traffic)...")
model.eval()
# Generate test data (80% normal, 20% anomalous)
test_normal = generate_synthetic_data(800, is_anomaly=False)
test_anomalous = generate_synthetic_data(200, is_anomaly=True)
test_data = torch.cat([test_normal, test_anomalous])
true_labels = np.concatenate([np.zeros(800), np.ones(200)]) # 0 = Normal, 1 = Anomaly
# Predict
anomalies, scores = detect_anomaly(model, test_data, threshold=0.5) # Set a reasonable threshold
pred_labels = anomalies.numpy().astype(int)
# Metrics
print("\nClassification Report:")
print(classification_report(true_labels, pred_labels, target_names=["Normal", "Anomaly"]))
f1 = f1_score(true_labels, pred_labels)
print(f"F1 Score: {f1:.4f}")
# Save evaluation plot
plt.figure(figsize=(10, 5))
plt.plot(loss_history, label='Training Loss')
plt.title('Autoencoder Training Loss')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.legend()
plt.savefig('../models/training_loss.png')
print("Training loss plot saved to models/training_loss.png")
if __name__ == "__main__":
train_autoencoder()