OmniSeperate 1.0 (XS, 2-stem version)

After reading HiDolen's Mini-BS-RoFormer, which has 8.8 million parameters, I decided to make my own tiny stem separator model! Currently, it's only ~16 MB in size and it uses a BS-RoFormer-like architecture!

Try the HF demo!

Notes

It's only a 2-stem model, that means it can only separate vocals and instrumentals from any song.

And I used the 7-second clip version of MUSDB18 (a music seperation dataset), not the full version.

And if you want to run, you must use a FFT size of 2044 (1023 bins). It doesn't use a standard FFT size like 2048.

Future Improvements

  • A 4-stem version of OmniSeperate (bass, drums, other instruments, vocals),
  • A better 2-stem version of OmniSeperate XS,
  • And a larger 2-stem version of OmniSeperate 1.0!

Evaluation Results

General Results:

Epoch Avearge SI-SNR Loss Epoch Avearge SI-SNR Loss
1 0.06 2 -0.83
3 -1.05 4 -1.56
5 -1.81 6 -2.18
7 -2.60 8 -2.84
9 -3.05 10 -3.24
11 -3.44 12 -3.60
13 -3.68 14 -3.80
15 -3.91 16 -4.00
17 -4.12 18 -4.23
19 -4.27 20 -4.43

Results on MUSDB18 (7-second clip version):

==============================
OVERALL PERFORMANCE SUMMARY
==============================
Aggrated Scores (median over frames, median over tracks)
vocals ==> SDR: 5.149 SIR: 7.358 ISR: 7.492 SAR: 7.426
accompaniment ==> SDR: 9.941 SIR: 11.944 ISR: 14.683 SAR: 14.034

How to Use

Code is by Gemini 3 and 2.5 Flash:

import torch
import numpy as np
import soundfile as sf

class MiniStereoRoFormer(nn.Module):
    def __init__(self):
        super().__init__()
        self.band_specs = [2, 2, 4, 8, 16, 32, 64, 128, 256, 256, 128, 127] # Total 1023 bins
        self.hidden_dim = 128

        self.band_extractors = nn.ModuleList([
            nn.Linear(spec * 2, self.hidden_dim) for spec in self.band_specs
        ])

        layer = nn.TransformerEncoderLayer(d_model=self.hidden_dim, nhead=8, batch_first=True, dropout=0.1)
        self.transformer = nn.TransformerEncoder(layer, num_layers=6)

        self.de_band = nn.ModuleList([
            nn.Linear(self.hidden_dim, spec * 4) for spec in self.band_specs
        ])

    def forward(self, x):
        batch, chan, freq, time = x.shape
        x_permuted = x.permute(0, 3, 2, 1) # [B, T, F, C]

        bands_processed = []
        curr_freq_idx = 0
        for i, spec_len in enumerate(self.band_specs):
            b = x_permuted[:, :, curr_freq_idx:curr_freq_idx+spec_len, :].reshape(batch, time, -1)
            bands_processed.append(self.band_extractors[i](b))
            curr_freq_idx += spec_len

        z = torch.stack(bands_processed, dim=1).mean(dim=1)
        z = self.transformer(z)

        all_mask_logits = []
        for i, spec_len in enumerate(self.band_specs):
            out = self.de_band[i](z)
            all_mask_logits.append(out.reshape(batch, time, spec_len, 4))

        combined_logits = torch.cat(all_mask_logits, dim=2)
        combined_logits = combined_logits.permute(0, 3, 2, 1) # [B, 4, F, T]
        reshaped_logits = combined_logits.view(batch, 2, 2, freq, time)

        # COMPETITIVE MASKING
        masks = torch.softmax(reshaped_logits, dim=1)
        return masks[:, 0, ...], masks[:, 1, ...]

def run_omniseparate_perfect_align(input_path, model_path, output_path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 1. Load Model
    model = MiniStereoRoFormer()
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device).eval()

    # 2. Load Audio (Use soundfile to avoid librosa's hidden offsets)
    audio, sr = sf.read(input_path)
    if len(audio.shape) == 1: 
        audio = np.stack([audio, audio], axis=1)
    
    # Original length in samples
    orig_len = audio.shape[0]
    
    # Convert to [channels, samples]
    audio_torch = torch.from_numpy(audio.T).float().to(device)
    
    # 3. Parameters
    n_fft = 2044
    hop_stft = 512
    window = torch.hann_window(n_fft).to(device)

    print(f"Processing {orig_len} samples...")
    with torch.no_grad():
        # STFT - Using center=True is fine AS LONG AS ISTFT length is set
        spec = torch.stft(
            audio_torch, 
            n_fft=n_fft, 
            hop_length=hop_stft, 
            window=window,
            center=True, 
            return_complex=True
        ).unsqueeze(0) # [1, 2, F, T]

        # Inference
        # If your GPU OOMs here, we'll need a different approach, 
        # but for a standard song, this is the only way to ensure 0ms drift.
        v_mask, _ = model(torch.abs(spec))
        v_spec = spec * v_mask

        # ISTFT
        # The 'length' parameter is the ONLY thing that guarantees alignment
        v_audio = torch.istft(
            v_spec.squeeze(0), 
            n_fft=n_fft, 
            hop_length=hop_stft, 
            window=window,
            center=True, 
            length=orig_len 
        )

    # 4. Save
    final_vocals = v_audio.cpu().numpy().T
    sf.write(output_path, final_vocals, sr)
    print(f"File saved. Start sample should now be EXACTLY at 0.000 in REAPER.")

run_omniseparate_perfect_align('y2mate.com - ODETARI KEEP UP Official Visualizer.mp3', 'Tiny_BS-RoFormer-1.0.pth', 'vocals_perfect.wav')
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Space using MihaiPopa-1/OmniSeperate-1-XS-2_stem 1