""" WavCoch model for Hugging Face Transformers. This implementation is self-contained so HF-hosted WavCoch checkpoints do not depend on the local auristream package or vector_quantize_pytorch. """ import math import os from typing import List, Optional, Sequence os.environ.setdefault("USE_TORCH_XLA", "0") import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import Conv1d, ConvTranspose1d from torch.nn.utils import remove_weight_norm try: from torch.nn.utils.parametrizations import weight_norm except ImportError: # pragma: no cover - older PyTorch compatibility from torch.nn.utils import weight_norm from transformers import PreTrainedModel from transformers.modeling_outputs import BaseModelOutput try: from transformers.tokenization_utils_base import BatchEncoding except ImportError: # pragma: no cover - compatibility with older Transformers from transformers.tokenization_utils import BatchEncoding import transformers.modeling_utils as transformers_modeling_utils import transformers.utils.import_utils as transformers_import_utils transformers_import_utils.is_torch_xla_available = lambda *args, **kwargs: False transformers_modeling_utils.is_torch_xla_available = lambda *args, **kwargs: False try: from .configuration_wavcoch import WavCochConfig except ImportError: # pragma: no cover - compatibility with older repos from .configure_wavcoch import WavCochConfig class CausalConv1d(nn.Module): """1D causal convolution with left-only padding.""" def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, dilation: int = 1, bias: bool = True, groups: int = 1, pad_mode: str = "repeat", constant_value: float = 0.0, ): super().__init__() left_pad = dilation * (kernel_size - 1) if pad_mode == "repeat": self.pad = nn.ReplicationPad1d((left_pad, 0)) elif pad_mode == "constant": self.pad = nn.ConstantPad1d((left_pad, 0), constant_value) else: raise ValueError(f"Unsupported pad_mode: {pad_mode}") self.conv = nn.Conv1d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0, dilation=dilation, groups=groups, bias=bias, ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.conv(self.pad(x)) def _build_conv1d( in_channels: int, out_channels: int, kernel_size: int, *, causal: bool, dilation: int = 1, pad_mode: str = "repeat", ): if causal: return CausalConv1d( in_channels, out_channels, kernel_size=kernel_size, stride=1, dilation=dilation, pad_mode=pad_mode, ) padding = dilation * (kernel_size - 1) // 2 return nn.Conv1d( in_channels, out_channels, kernel_size=kernel_size, stride=1, dilation=dilation, padding=padding, ) class SoundStreamResidualUnit(nn.Module): def __init__( self, channels: int, *, kernel_size: int = 7, dilation: int = 1, bottleneck_factor: int = 4, causal: bool, pad_mode: str = "repeat", ): super().__init__() bottleneck_channels = max(1, channels // max(1, int(bottleneck_factor))) self.pre_act = nn.ELU() self.reduce = nn.Conv1d(channels, bottleneck_channels, kernel_size=1) self.mid_act = nn.ELU() self.dilated = _build_conv1d( bottleneck_channels, bottleneck_channels, kernel_size, causal=causal, dilation=dilation, pad_mode=pad_mode, ) self.post_act = nn.ELU() self.expand = nn.Conv1d(bottleneck_channels, channels, kernel_size=1) def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x x = self.pre_act(x) x = self.reduce(x) x = self.mid_act(x) x = self.dilated(x) x = self.post_act(x) x = self.expand(x) return x + residual class SoundStreamConvStack(nn.Module): def __init__( self, in_channels: int, hidden_channels: int, out_channels: int, *, num_layers: int, residual_kernel_size: int = 7, residual_dilations: Sequence[int] = (1, 3, 9), bottleneck_factor: int = 4, causal: bool, pad_mode: str = "repeat", ): super().__init__() dilations = [int(d) for d in residual_dilations] if not dilations: raise ValueError("SoundStream residual dilations must be non-empty") self.input_proj = nn.Conv1d(in_channels, hidden_channels, kernel_size=1) self.blocks = nn.ModuleList( [ SoundStreamResidualUnit( hidden_channels, kernel_size=residual_kernel_size, dilation=dilations[layer_idx % len(dilations)], bottleneck_factor=bottleneck_factor, causal=causal, pad_mode=pad_mode, ) for layer_idx in range(int(num_layers)) ] ) self.output_act = nn.ELU() self.output_proj = ( nn.Identity() if hidden_channels == out_channels else nn.Conv1d(hidden_channels, out_channels, kernel_size=1) ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.input_proj(x) for block in self.blocks: x = block(x) x = self.output_act(x) return self.output_proj(x) class FSQ(nn.Module): """Finite Scalar Quantization with the subset of functionality needed for inference.""" def __init__(self, levels: List[int], dim: int): super().__init__() if not levels: raise ValueError("FSQ levels must be non-empty") self.levels = [int(level) for level in levels] self.codebook_dim = len(self.levels) self.dim = int(dim) level_tensor = torch.tensor(self.levels, dtype=torch.int32) basis = torch.cumprod(torch.tensor([1] + self.levels[:-1], dtype=torch.int32), dim=0) self.register_buffer("_levels", level_tensor, persistent=False) self.register_buffer("_basis", basis, persistent=False) if self.dim != self.codebook_dim: self.project_in = nn.Linear(self.dim, self.codebook_dim) self.project_out = nn.Linear(self.codebook_dim, self.dim) else: self.project_in = nn.Identity() self.project_out = nn.Identity() def _refresh_level_buffers(self, device: Optional[torch.device] = None): level_values = [int(level) for level in self.levels] if device is None: if isinstance(self.project_in, nn.Linear): device = self.project_in.weight.device elif isinstance(self.project_out, nn.Linear): device = self.project_out.weight.device else: device = self._levels.device self._levels = torch.tensor(level_values, dtype=torch.int32, device=device) self._basis = torch.cumprod( torch.tensor([1] + level_values[:-1], dtype=torch.int32, device=device), dim=0, ) def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor: levels = self._levels.to(dtype=z.dtype, device=z.device) half_l = (levels - 1) * (1 + eps) / 2 offset = torch.where( (self._levels % 2).to(device=z.device) == 0, torch.tensor(0.5, device=z.device, dtype=z.dtype), torch.tensor(0.0, device=z.device, dtype=z.dtype), ) shift = (offset / half_l).atanh() return (z + shift).tanh() * half_l - offset def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor: half_width = (self._levels // 2).to(dtype=zhat_normalized.dtype, device=zhat_normalized.device) return (zhat_normalized * half_width) + half_width def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor: half_width = (self._levels // 2).to(dtype=zhat.dtype, device=zhat.device) return (zhat - half_width) / half_width def quantize_values(self, z: torch.Tensor) -> torch.Tensor: self._refresh_level_buffers(device=z.device) half_width = (self._levels // 2).to(dtype=z.dtype, device=z.device) return self.bound(z).round() / half_width def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor: self._refresh_level_buffers(device=zhat.device) zhat = self._scale_and_shift(zhat) basis = self._basis.to(device=zhat.device, dtype=zhat.dtype) return (zhat * basis).sum(dim=-1).to(torch.int32) def indices_to_level_indices(self, indices: torch.Tensor) -> torch.Tensor: self._refresh_level_buffers(device=indices.device) indices = indices.unsqueeze(-1) levels = self._levels.to(device=indices.device) basis = self._basis.to(device=indices.device) return (indices // basis) % levels def indices_to_codes(self, indices: torch.Tensor) -> torch.Tensor: self._refresh_level_buffers(device=indices.device) level_indices = self.indices_to_level_indices(indices) codes = self._scale_and_shift_inverse(level_indices.to(dtype=torch.float32)) return self.project_out(codes) def forward(self, z: torch.Tensor): orig_dtype = z.dtype z = self.project_in(z.to(torch.float32)) q = self.quantize_values(z) indices = self.codes_to_indices(q) out = self.project_out(q).to(orig_dtype) return out, indices.long() LRELU_SLOPE = 0.1 def get_padding(kernel_size: int, dilation: int = 1) -> int: return int((kernel_size * dilation - dilation) / 2) def init_weights(module, mean: float = 0.0, std: float = 0.01): classname = module.__class__.__name__ if classname.find("Conv") != -1 and hasattr(module, "weight"): module.weight.data.normal_(mean, std) class ResBlock1(nn.Module): __constants__ = ["lrelu_slope"] def __init__(self, channels: int, kernel_size: int = 3, dilation=(1, 3, 5)): super().__init__() self.lrelu_slope = LRELU_SLOPE ch = channels ks = kernel_size self.convs1 = nn.Sequential( weight_norm(Conv1d(ch, ch, ks, 1, get_padding(ks, dilation[0]), dilation[0])), weight_norm(Conv1d(ch, ch, ks, 1, get_padding(ks, dilation[1]), dilation[1])), weight_norm(Conv1d(ch, ch, ks, 1, get_padding(ks, dilation[2]), dilation[2])), ) self.convs2 = nn.Sequential( weight_norm(Conv1d(ch, ch, ks, 1, get_padding(ks, 1))), weight_norm(Conv1d(ch, ch, ks, 1, get_padding(ks, 1))), weight_norm(Conv1d(ch, ch, ks, 1, get_padding(ks, 1))), ) self.convs1.apply(init_weights) self.convs2.apply(init_weights) def forward(self, x: torch.Tensor) -> torch.Tensor: for conv1, conv2 in zip(self.convs1, self.convs2): xt = F.leaky_relu(x, self.lrelu_slope) xt = conv1(xt) xt = F.leaky_relu(xt, self.lrelu_slope) xt = conv2(xt) x = xt + x return x def remove_weight_norm(self): for layer in self.convs1: remove_weight_norm(layer) for layer in self.convs2: remove_weight_norm(layer) class ResBlock2(nn.Module): __constants__ = ["lrelu_slope"] def __init__(self, channels: int, kernel_size: int = 3, dilation=(1, 3)): super().__init__() self.lrelu_slope = LRELU_SLOPE ch = channels ks = kernel_size self.convs = nn.ModuleList( [ weight_norm(Conv1d(ch, ch, ks, 1, get_padding(kernel_size, dilation[0]), dilation[0])), weight_norm(Conv1d(ch, ch, ks, 1, get_padding(kernel_size, dilation[1]), dilation[1])), ] ) self.convs.apply(init_weights) def forward(self, x: torch.Tensor) -> torch.Tensor: for conv in self.convs: xt = F.leaky_relu(x, self.lrelu_slope) xt = conv(xt) x = xt + x return x def remove_weight_norm(self): for layer in self.convs: remove_weight_norm(layer) class Generator(nn.Module): __constants__ = ["lrelu_slope", "num_kernels", "num_upsamples"] def __init__( self, out_channels: int = 211, upsample_rates=None, upsample_kernel_sizes=None, upsample_initial_channel: int = 512, resblock: str = "1", resblock_kernel_sizes=None, resblock_dilation_sizes=None, ): super().__init__() upsample_rates = list(upsample_rates or [5, 4, 2, 2]) upsample_kernel_sizes = list(upsample_kernel_sizes or [10, 8, 4, 4]) resblock_kernel_sizes = list(resblock_kernel_sizes or [11, 7, 3]) resblock_dilation_sizes = [list(d) for d in (resblock_dilation_sizes or [[1, 3, 5], [1, 3, 5], [1, 3, 5]])] self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_rates) self.lrelu_slope = LRELU_SLOPE self.conv_pre = weight_norm(Conv1d(out_channels, upsample_initial_channel, 7, 1, padding=3)) resblock_cls = ResBlock1 if resblock == "1" else ResBlock2 ups = [] for i, (rate, kernel) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): ups.append( weight_norm( ConvTranspose1d( upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)), kernel, rate, padding=(kernel - rate) // 2, ) ) ) self.ups = nn.Sequential(*ups) resblocks = [] for i in range(len(self.ups)): ch = upsample_initial_channel // (2 ** (i + 1)) resblocks.append( nn.Sequential( *[ resblock_cls(ch, kernel, dilation) for kernel, dilation in zip(resblock_kernel_sizes, resblock_dilation_sizes) ] ) ) self.resblocks = nn.Sequential(*resblocks) self.conv_post = weight_norm(Conv1d(ch, 1, 17, 1, padding=0)) self.ups.apply(init_weights) self.conv_post.apply(init_weights) def load_state_dict(self, state_dict, strict: bool = True): new_state_dict = {} for key, value in state_dict.items(): new_key = key if "resblocks" in key: parts = key.split(".") if len(parts) == 5: layer = int(parts[1]) new_key = f"resblocks.{layer // 3}.{layer % 3}.{'.'.join(parts[2:])}" new_state_dict[new_key] = value current_state = self.state_dict() for key, value in list(new_state_dict.items()): if key not in current_state: continue len_diff = value.dim() - current_state[key].dim() if len_diff == -1: new_state_dict[key] = value.unsqueeze(-1) elif len_diff == 1: new_state_dict[key] = value.squeeze(-1) super().load_state_dict(new_state_dict, strict=strict) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv_pre(x.permute(0, 2, 1)) for upsample_layer, resblock_group in zip(self.ups, self.resblocks): x = F.leaky_relu(x, self.lrelu_slope) x = upsample_layer(x) xs = 0 for resblock in resblock_group: xs = xs + resblock(x) x = xs / self.num_kernels x = F.leaky_relu(x) x = self.conv_post(x) return torch.tanh(x) def remove_weight_norm(self): for layer in self.ups: remove_weight_norm(layer) for group in self.resblocks: for block in group: block.remove_weight_norm() remove_weight_norm(self.conv_pre) remove_weight_norm(self.conv_post) class WavCoch(PreTrainedModel): """Causal waveform-to-cochleagram tokenizer with optional vocoder.""" config_class = WavCochConfig main_input_name = "wav" def __init__(self, config: WavCochConfig): super().__init__(config) self.config = config self.N = int(config.window_size) self.hop_length = int(config.hop_length) self.window_padding = int(getattr(config, "window_padding", self.N - self.hop_length)) self.causal_convs = bool(getattr(config, "causal_convs", True)) self.causal_pad_mode = getattr(config, "causal_pad_mode", "repeat") self.encoder_block_type = getattr(config, "encoder_block_type", "plain") self.decoder_block_type = getattr(config, "decoder_block_type", "plain") out_bins = self.N // 2 + 1 self.conv_real_filters = nn.Conv1d(1, out_bins, kernel_size=self.N, stride=self.hop_length) self.conv_imag_filters = nn.Conv1d(1, out_bins, kernel_size=self.N, stride=self.hop_length) self._initialize_conv_filters() self.encoder = self._build_processing_stack( in_channels=out_bins, hidden_channels=config.encoder_dim, out_channels=config.encoder_dim, num_layers=config.encoder_layers, kernel_size=config.encoder_kernel_size, causal=self.causal_convs, block_type=self.encoder_block_type, residual_kernel_size=getattr(config, "encoder_residual_kernel_size", 7), residual_dilations=getattr(config, "encoder_residual_dilations", (1, 3, 9)), residual_bottleneck_factor=getattr(config, "encoder_residual_bottleneck_factor", 4), ) self.quantizer = FSQ(levels=list(config.channels), dim=config.encoder_dim) self.decoder = self._build_processing_stack( in_channels=config.decoder_dim, hidden_channels=config.decoder_dim, out_channels=config.out_channels, num_layers=config.decoder_layers, kernel_size=config.decoder_kernel_size, causal=self.causal_convs, block_type=self.decoder_block_type, residual_kernel_size=getattr(config, "decoder_residual_kernel_size", 7), residual_dilations=getattr(config, "decoder_residual_dilations", (1, 3, 9)), residual_bottleneck_factor=getattr(config, "decoder_residual_bottleneck_factor", 4), ) self.has_vocoder = bool(getattr(config, "has_vocoder", False)) if self.has_vocoder: if int(config.out_channels) != 211: raise ValueError("Bundled vocoder currently expects 211 cochleagram channels") self.vocoder = Generator( out_channels=config.out_channels, upsample_rates=config.vocoder_upsample_rates, upsample_kernel_sizes=config.vocoder_upsample_kernel_sizes, upsample_initial_channel=config.vocoder_upsample_initial_channel, resblock=config.vocoder_resblock, resblock_kernel_sizes=config.vocoder_resblock_kernel_sizes, resblock_dilation_sizes=config.vocoder_resblock_dilation_sizes, ) else: self.vocoder = None self._vocab_size = int(config.vocab_size) self.post_init() def _build_plain_conv_stack( self, in_channels: int, out_channels: int, num_layers: int, kernel_size: int, causal: bool, ) -> nn.Sequential: layers = [] for layer_idx in range(int(num_layers)): input_channels = in_channels if layer_idx == 0 else out_channels conv = _build_conv1d( input_channels, out_channels, kernel_size, causal=causal, pad_mode=self.causal_pad_mode, ) layers.extend([conv, nn.ReLU()]) return nn.Sequential(*layers) def _build_processing_stack( self, in_channels: int, hidden_channels: int, out_channels: int, num_layers: int, kernel_size: int, causal: bool, block_type: str = "plain", residual_kernel_size: int = 7, residual_dilations: Sequence[int] = (1, 3, 9), residual_bottleneck_factor: int = 4, ) -> nn.Module: if block_type == "plain": return self._build_plain_conv_stack( in_channels=in_channels, out_channels=out_channels, num_layers=num_layers, kernel_size=kernel_size, causal=causal, ) if block_type != "soundstream": raise ValueError(f"Unsupported WavCoch block_type: {block_type}") return SoundStreamConvStack( in_channels=in_channels, hidden_channels=hidden_channels, out_channels=out_channels, num_layers=num_layers, residual_kernel_size=residual_kernel_size, residual_dilations=residual_dilations, bottleneck_factor=residual_bottleneck_factor, causal=causal, pad_mode=self.causal_pad_mode, ) def _compute_twiddle_factors(self): n = torch.arange(self.N, dtype=torch.float32).unsqueeze(1) k = torch.arange(self.N, dtype=torch.float32).unsqueeze(0) angles = -2.0 * math.pi * n * k / float(self.N) return torch.cos(angles), torch.sin(angles) def _initialize_conv_filters(self): with torch.no_grad(): cos_matrix, sin_matrix = self._compute_twiddle_factors() cos_matrix = cos_matrix[: self.N // 2 + 1, :] sin_matrix = sin_matrix[: self.N // 2 + 1, :] window = torch.hann_window(self.N, periodic=True).view(1, 1, -1) real_weights = (cos_matrix.unsqueeze(1) * window).to(dtype=self.conv_real_filters.weight.dtype) imag_weights = (sin_matrix.unsqueeze(1) * window).to(dtype=self.conv_imag_filters.weight.dtype) self.conv_real_filters.weight.copy_(real_weights) self.conv_imag_filters.weight.copy_(imag_weights) for param in self.conv_real_filters.parameters(): param.requires_grad_(False) for param in self.conv_imag_filters.parameters(): param.requires_grad_(False) def _normalize_sample_rate(self, sample_rate: Optional[int], sampling_rate: Optional[int]) -> int: if sample_rate is not None and sampling_rate is not None and sample_rate != sampling_rate: raise ValueError(f"sample_rate ({sample_rate}) and sampling_rate ({sampling_rate}) conflict") resolved = int(sample_rate or sampling_rate or self.config.sample_rate) if resolved != int(self.config.sample_rate): raise ValueError( f"WavCoch expects {self.config.sample_rate} Hz audio, but received {resolved} Hz" ) return resolved def _prepare_wav_batch(self, wav) -> torch.Tensor: if isinstance(wav, list): wav = [item if isinstance(item, torch.Tensor) else torch.tensor(item) for item in wav] normalized = [] for item in wav: if item.ndim == 1: normalized.append(item) elif item.ndim == 2 and 1 in item.shape: normalized.append(item.reshape(-1)) else: raise ValueError(f"Unexpected list element shape {tuple(item.shape)}") wav = torch.nn.utils.rnn.pad_sequence(normalized, batch_first=True).unsqueeze(1) elif isinstance(wav, torch.Tensor): if wav.ndim == 1: wav = wav.unsqueeze(0).unsqueeze(0) elif wav.ndim == 2: wav = wav.unsqueeze(1) elif wav.ndim != 3: raise ValueError(f"Unexpected tensor shape {tuple(wav.shape)}, expected 1D, 2D or 3D") else: raise TypeError(f"Unsupported input type: {type(wav)}") return wav.to(dtype=torch.float32) @property def vocab_size(self) -> int: return self._vocab_size def _resolve_wav_input( self, wav: Optional[torch.Tensor], input_values: Optional[torch.Tensor], ) -> torch.Tensor: if wav is not None and input_values is not None: raise ValueError("Provide either `wav` or `input_values`, not both") resolved = wav if wav is not None else input_values if resolved is None: raise ValueError("WavCoch requires waveform input via `wav` or `input_values`") return resolved def _encode_quantized( self, wav: torch.Tensor, pad: bool = True, ): wav = self._prepare_wav_batch(wav) if pad: wav = F.pad(wav, (self.window_padding, 0), mode="constant", value=0.0) with torch.no_grad(): real_part = self.conv_real_filters(wav) imag_part = self.conv_imag_filters(wav) x = real_part + imag_part x = self.encoder(x).permute(0, 2, 1) quantized, indices = self.quantizer(x) return quantized, indices def forward( self, wav: Optional[torch.Tensor] = None, coch: Optional[torch.Tensor] = None, return_tensors: str = "pt", sample_rate: Optional[int] = None, sampling_rate: Optional[int] = None, pad: bool = True, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, input_values: Optional[torch.Tensor] = None, ): del return_tensors # unused, kept for tokenizer-like API compatibility self._normalize_sample_rate(sample_rate, sampling_rate) wav = self._resolve_wav_input(wav, input_values) quantized, indices = self._encode_quantized(wav, pad=pad) if output_hidden_states: hidden_states = (quantized,) if not return_dict: return quantized, hidden_states return BaseModelOutput(last_hidden_state=quantized, hidden_states=hidden_states) if coch is None: codes = indices.long() return BatchEncoding({"input_values": codes, "input_ids": codes}) pred_coch = self.decoder(quantized.permute(0, 2, 1)).permute(0, 2, 1) loss = F.l1_loss(pred_coch, coch) return pred_coch, loss, None @torch.no_grad() def quantize(self, wav: torch.Tensor, pad: bool = True) -> torch.Tensor: _, indices = self._encode_quantized(wav, pad=pad) return indices.long() @torch.no_grad() def decode(self, indices: torch.Tensor) -> torch.Tensor: if indices.ndim == 1: indices = indices.unsqueeze(0) emb = self.quantizer.indices_to_codes(indices.long()) return self.decoder(emb.permute(0, 2, 1)).permute(0, 2, 1) @torch.no_grad() def wav2coch(self, wav: torch.Tensor, pad: bool = True) -> torch.Tensor: quantized, _ = self._encode_quantized(wav, pad=pad) return self.decoder(quantized.permute(0, 2, 1)).permute(0, 2, 1) @torch.no_grad() def vocode(self, coch: torch.Tensor) -> torch.Tensor: if self.vocoder is None: raise ValueError("This WavCoch checkpoint does not include a bundled vocoder") if coch.ndim == 2: coch = coch.unsqueeze(0) elif coch.ndim != 3: raise ValueError(f"Unexpected cochleagram shape {tuple(coch.shape)}") if coch.shape[-1] != self.config.out_channels and coch.shape[1] == self.config.out_channels: coch = coch.transpose(1, 2) return self.vocoder(coch) @torch.no_grad() def decode_audio(self, indices: torch.Tensor) -> torch.Tensor: return self.vocode(self.decode(indices))