Update prototype_v8_refined_soft_hand_loss.py
Browse files
prototype_v8_refined_soft_hand_loss.py
CHANGED
|
@@ -168,12 +168,12 @@ class SVAE(nn.Module):
|
|
| 168 |
|
| 169 |
# -- Training --
|
| 170 |
|
| 171 |
-
def train(epochs=
|
| 172 |
sigma=0.15, device='cuda'):
|
| 173 |
device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
| 174 |
-
train_loader, test_loader = get_cifar10(batch_size=
|
| 175 |
|
| 176 |
-
V, D =
|
| 177 |
target_cv = 0.1250
|
| 178 |
model = SVAE(matrix_v=V, D=D).to(device)
|
| 179 |
opt = torch.optim.Adam(model.parameters(), lr=lr)
|
|
|
|
| 168 |
|
| 169 |
# -- Training --
|
| 170 |
|
| 171 |
+
def train(epochs=400, lr=3e-4, cv_weight=0.3, boost=0.5,
|
| 172 |
sigma=0.15, device='cuda'):
|
| 173 |
device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
| 174 |
+
train_loader, test_loader = get_cifar10(batch_size=512)
|
| 175 |
|
| 176 |
+
V, D = 256, 24
|
| 177 |
target_cv = 0.1250
|
| 178 |
model = SVAE(matrix_v=V, D=D).to(device)
|
| 179 |
opt = torch.optim.Adam(model.parameters(), lr=lr)
|