AbstractPhil commited on
Commit
5601548
·
verified ·
1 Parent(s): 43f961a

Update prototype_v8_refined_soft_hand_loss.py

Browse files
prototype_v8_refined_soft_hand_loss.py CHANGED
@@ -168,12 +168,12 @@ class SVAE(nn.Module):
168
 
169
  # -- Training --
170
 
171
- def train(epochs=100, lr=1e-3, cv_weight=0.3, boost=0.5,
172
  sigma=0.15, device='cuda'):
173
  device = torch.device(device if torch.cuda.is_available() else 'cpu')
174
- train_loader, test_loader = get_cifar10(batch_size=256)
175
 
176
- V, D = 40, 24
177
  target_cv = 0.1250
178
  model = SVAE(matrix_v=V, D=D).to(device)
179
  opt = torch.optim.Adam(model.parameters(), lr=lr)
 
168
 
169
  # -- Training --
170
 
171
+ def train(epochs=400, lr=3e-4, cv_weight=0.3, boost=0.5,
172
  sigma=0.15, device='cuda'):
173
  device = torch.device(device if torch.cuda.is_available() else 'cpu')
174
+ train_loader, test_loader = get_cifar10(batch_size=512)
175
 
176
+ V, D = 256, 24
177
  target_cv = 0.1250
178
  model = SVAE(matrix_v=V, D=D).to(device)
179
  opt = torch.optim.Adam(model.parameters(), lr=lr)