OmniSeperate 1.0 (XS, 2-stem version)
After reading HiDolen's Mini-BS-RoFormer, which has 8.8 million parameters, I decided to make my own tiny stem separator model! Currently, it's only ~16 MB in size and it uses a BS-RoFormer-like architecture!
Try the HF demo!
Notes
It's only a 2-stem model, that means it can only separate vocals and instrumentals from any song.
And I used the 7-second clip version of MUSDB18 (a music seperation dataset), not the full version.
And if you want to run, you must use a FFT size of 2044 (1023 bins). It doesn't use a standard FFT size like 2048.
Future Improvements
- A 4-stem version of OmniSeperate (bass, drums, other instruments, vocals),
- A better 2-stem version of OmniSeperate XS,
- And a larger 2-stem version of OmniSeperate 1.0!
Evaluation Results
General Results:
| Epoch | Avearge SI-SNR Loss | Epoch | Avearge SI-SNR Loss |
|---|---|---|---|
| 1 | 0.06 | 2 | -0.83 |
| 3 | -1.05 | 4 | -1.56 |
| 5 | -1.81 | 6 | -2.18 |
| 7 | -2.60 | 8 | -2.84 |
| 9 | -3.05 | 10 | -3.24 |
| 11 | -3.44 | 12 | -3.60 |
| 13 | -3.68 | 14 | -3.80 |
| 15 | -3.91 | 16 | -4.00 |
| 17 | -4.12 | 18 | -4.23 |
| 19 | -4.27 | 20 | -4.43 |
Results on MUSDB18 (7-second clip version):
==============================
OVERALL PERFORMANCE SUMMARY
==============================
Aggrated Scores (median over frames, median over tracks)
vocals ==> SDR: 5.149 SIR: 7.358 ISR: 7.492 SAR: 7.426
accompaniment ==> SDR: 9.941 SIR: 11.944 ISR: 14.683 SAR: 14.034
How to Use
Code is by Gemini 3 and 2.5 Flash:
import torch
import numpy as np
import soundfile as sf
class MiniStereoRoFormer(nn.Module):
def __init__(self):
super().__init__()
self.band_specs = [2, 2, 4, 8, 16, 32, 64, 128, 256, 256, 128, 127] # Total 1023 bins
self.hidden_dim = 128
self.band_extractors = nn.ModuleList([
nn.Linear(spec * 2, self.hidden_dim) for spec in self.band_specs
])
layer = nn.TransformerEncoderLayer(d_model=self.hidden_dim, nhead=8, batch_first=True, dropout=0.1)
self.transformer = nn.TransformerEncoder(layer, num_layers=6)
self.de_band = nn.ModuleList([
nn.Linear(self.hidden_dim, spec * 4) for spec in self.band_specs
])
def forward(self, x):
batch, chan, freq, time = x.shape
x_permuted = x.permute(0, 3, 2, 1) # [B, T, F, C]
bands_processed = []
curr_freq_idx = 0
for i, spec_len in enumerate(self.band_specs):
b = x_permuted[:, :, curr_freq_idx:curr_freq_idx+spec_len, :].reshape(batch, time, -1)
bands_processed.append(self.band_extractors[i](b))
curr_freq_idx += spec_len
z = torch.stack(bands_processed, dim=1).mean(dim=1)
z = self.transformer(z)
all_mask_logits = []
for i, spec_len in enumerate(self.band_specs):
out = self.de_band[i](z)
all_mask_logits.append(out.reshape(batch, time, spec_len, 4))
combined_logits = torch.cat(all_mask_logits, dim=2)
combined_logits = combined_logits.permute(0, 3, 2, 1) # [B, 4, F, T]
reshaped_logits = combined_logits.view(batch, 2, 2, freq, time)
# COMPETITIVE MASKING
masks = torch.softmax(reshaped_logits, dim=1)
return masks[:, 0, ...], masks[:, 1, ...]
def run_omniseparate_perfect_align(input_path, model_path, output_path):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 1. Load Model
model = MiniStereoRoFormer()
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device).eval()
# 2. Load Audio (Use soundfile to avoid librosa's hidden offsets)
audio, sr = sf.read(input_path)
if len(audio.shape) == 1:
audio = np.stack([audio, audio], axis=1)
# Original length in samples
orig_len = audio.shape[0]
# Convert to [channels, samples]
audio_torch = torch.from_numpy(audio.T).float().to(device)
# 3. Parameters
n_fft = 2044
hop_stft = 512
window = torch.hann_window(n_fft).to(device)
print(f"Processing {orig_len} samples...")
with torch.no_grad():
# STFT - Using center=True is fine AS LONG AS ISTFT length is set
spec = torch.stft(
audio_torch,
n_fft=n_fft,
hop_length=hop_stft,
window=window,
center=True,
return_complex=True
).unsqueeze(0) # [1, 2, F, T]
# Inference
# If your GPU OOMs here, we'll need a different approach,
# but for a standard song, this is the only way to ensure 0ms drift.
v_mask, _ = model(torch.abs(spec))
v_spec = spec * v_mask
# ISTFT
# The 'length' parameter is the ONLY thing that guarantees alignment
v_audio = torch.istft(
v_spec.squeeze(0),
n_fft=n_fft,
hop_length=hop_stft,
window=window,
center=True,
length=orig_len
)
# 4. Save
final_vocals = v_audio.cpu().numpy().T
sf.write(output_path, final_vocals, sr)
print(f"File saved. Start sample should now be EXACTLY at 0.000 in REAPER.")
run_omniseparate_perfect_align('y2mate.com - ODETARI KEEP UP Official Visualizer.mp3', 'Tiny_BS-RoFormer-1.0.pth', 'vocals_perfect.wav')