| """ |
| WavCoch model for Hugging Face Transformers. |
| |
| This implementation is self-contained so HF-hosted WavCoch checkpoints do not |
| depend on the local auristream package or vector_quantize_pytorch. |
| """ |
|
|
| import math |
| import os |
| from typing import List, Optional, Sequence |
|
|
| os.environ.setdefault("USE_TORCH_XLA", "0") |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.nn import Conv1d, ConvTranspose1d |
| from torch.nn.utils import remove_weight_norm |
| try: |
| from torch.nn.utils.parametrizations import weight_norm |
| except ImportError: |
| from torch.nn.utils import weight_norm |
|
|
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import BaseModelOutput |
| try: |
| from transformers.tokenization_utils_base import BatchEncoding |
| except ImportError: |
| from transformers.tokenization_utils import BatchEncoding |
| import transformers.modeling_utils as transformers_modeling_utils |
| import transformers.utils.import_utils as transformers_import_utils |
|
|
| transformers_import_utils.is_torch_xla_available = lambda *args, **kwargs: False |
| transformers_modeling_utils.is_torch_xla_available = lambda *args, **kwargs: False |
|
|
| try: |
| from .configuration_wavcoch import WavCochConfig |
| except ImportError: |
| from .configure_wavcoch import WavCochConfig |
|
|
|
|
| class CausalConv1d(nn.Module): |
| """1D causal convolution with left-only padding.""" |
|
|
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| kernel_size: int, |
| stride: int = 1, |
| dilation: int = 1, |
| bias: bool = True, |
| groups: int = 1, |
| pad_mode: str = "repeat", |
| constant_value: float = 0.0, |
| ): |
| super().__init__() |
| left_pad = dilation * (kernel_size - 1) |
| if pad_mode == "repeat": |
| self.pad = nn.ReplicationPad1d((left_pad, 0)) |
| elif pad_mode == "constant": |
| self.pad = nn.ConstantPad1d((left_pad, 0), constant_value) |
| else: |
| raise ValueError(f"Unsupported pad_mode: {pad_mode}") |
| self.conv = nn.Conv1d( |
| in_channels, |
| out_channels, |
| kernel_size=kernel_size, |
| stride=stride, |
| padding=0, |
| dilation=dilation, |
| groups=groups, |
| bias=bias, |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.conv(self.pad(x)) |
|
|
|
|
| def _build_conv1d( |
| in_channels: int, |
| out_channels: int, |
| kernel_size: int, |
| *, |
| causal: bool, |
| dilation: int = 1, |
| pad_mode: str = "repeat", |
| ): |
| if causal: |
| return CausalConv1d( |
| in_channels, |
| out_channels, |
| kernel_size=kernel_size, |
| stride=1, |
| dilation=dilation, |
| pad_mode=pad_mode, |
| ) |
|
|
| padding = dilation * (kernel_size - 1) // 2 |
| return nn.Conv1d( |
| in_channels, |
| out_channels, |
| kernel_size=kernel_size, |
| stride=1, |
| dilation=dilation, |
| padding=padding, |
| ) |
|
|
|
|
| class SoundStreamResidualUnit(nn.Module): |
| def __init__( |
| self, |
| channels: int, |
| *, |
| kernel_size: int = 7, |
| dilation: int = 1, |
| bottleneck_factor: int = 4, |
| causal: bool, |
| pad_mode: str = "repeat", |
| ): |
| super().__init__() |
| bottleneck_channels = max(1, channels // max(1, int(bottleneck_factor))) |
| self.pre_act = nn.ELU() |
| self.reduce = nn.Conv1d(channels, bottleneck_channels, kernel_size=1) |
| self.mid_act = nn.ELU() |
| self.dilated = _build_conv1d( |
| bottleneck_channels, |
| bottleneck_channels, |
| kernel_size, |
| causal=causal, |
| dilation=dilation, |
| pad_mode=pad_mode, |
| ) |
| self.post_act = nn.ELU() |
| self.expand = nn.Conv1d(bottleneck_channels, channels, kernel_size=1) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| residual = x |
| x = self.pre_act(x) |
| x = self.reduce(x) |
| x = self.mid_act(x) |
| x = self.dilated(x) |
| x = self.post_act(x) |
| x = self.expand(x) |
| return x + residual |
|
|
|
|
| class SoundStreamConvStack(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| hidden_channels: int, |
| out_channels: int, |
| *, |
| num_layers: int, |
| residual_kernel_size: int = 7, |
| residual_dilations: Sequence[int] = (1, 3, 9), |
| bottleneck_factor: int = 4, |
| causal: bool, |
| pad_mode: str = "repeat", |
| ): |
| super().__init__() |
| dilations = [int(d) for d in residual_dilations] |
| if not dilations: |
| raise ValueError("SoundStream residual dilations must be non-empty") |
|
|
| self.input_proj = nn.Conv1d(in_channels, hidden_channels, kernel_size=1) |
| self.blocks = nn.ModuleList( |
| [ |
| SoundStreamResidualUnit( |
| hidden_channels, |
| kernel_size=residual_kernel_size, |
| dilation=dilations[layer_idx % len(dilations)], |
| bottleneck_factor=bottleneck_factor, |
| causal=causal, |
| pad_mode=pad_mode, |
| ) |
| for layer_idx in range(int(num_layers)) |
| ] |
| ) |
| self.output_act = nn.ELU() |
| self.output_proj = ( |
| nn.Identity() |
| if hidden_channels == out_channels |
| else nn.Conv1d(hidden_channels, out_channels, kernel_size=1) |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.input_proj(x) |
| for block in self.blocks: |
| x = block(x) |
| x = self.output_act(x) |
| return self.output_proj(x) |
|
|
|
|
| class FSQ(nn.Module): |
| """Finite Scalar Quantization with the subset of functionality needed for inference.""" |
|
|
| def __init__(self, levels: List[int], dim: int): |
| super().__init__() |
| if not levels: |
| raise ValueError("FSQ levels must be non-empty") |
|
|
| self.levels = [int(level) for level in levels] |
| self.codebook_dim = len(self.levels) |
| self.dim = int(dim) |
|
|
| level_tensor = torch.tensor(self.levels, dtype=torch.int32) |
| basis = torch.cumprod(torch.tensor([1] + self.levels[:-1], dtype=torch.int32), dim=0) |
| self.register_buffer("_levels", level_tensor, persistent=False) |
| self.register_buffer("_basis", basis, persistent=False) |
|
|
| if self.dim != self.codebook_dim: |
| self.project_in = nn.Linear(self.dim, self.codebook_dim) |
| self.project_out = nn.Linear(self.codebook_dim, self.dim) |
| else: |
| self.project_in = nn.Identity() |
| self.project_out = nn.Identity() |
|
|
| def _refresh_level_buffers(self, device: Optional[torch.device] = None): |
| level_values = [int(level) for level in self.levels] |
| if device is None: |
| if isinstance(self.project_in, nn.Linear): |
| device = self.project_in.weight.device |
| elif isinstance(self.project_out, nn.Linear): |
| device = self.project_out.weight.device |
| else: |
| device = self._levels.device |
|
|
| self._levels = torch.tensor(level_values, dtype=torch.int32, device=device) |
| self._basis = torch.cumprod( |
| torch.tensor([1] + level_values[:-1], dtype=torch.int32, device=device), |
| dim=0, |
| ) |
|
|
| def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor: |
| levels = self._levels.to(dtype=z.dtype, device=z.device) |
| half_l = (levels - 1) * (1 + eps) / 2 |
| offset = torch.where( |
| (self._levels % 2).to(device=z.device) == 0, |
| torch.tensor(0.5, device=z.device, dtype=z.dtype), |
| torch.tensor(0.0, device=z.device, dtype=z.dtype), |
| ) |
| shift = (offset / half_l).atanh() |
| return (z + shift).tanh() * half_l - offset |
|
|
| def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor: |
| half_width = (self._levels // 2).to(dtype=zhat_normalized.dtype, device=zhat_normalized.device) |
| return (zhat_normalized * half_width) + half_width |
|
|
| def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor: |
| half_width = (self._levels // 2).to(dtype=zhat.dtype, device=zhat.device) |
| return (zhat - half_width) / half_width |
|
|
| def quantize_values(self, z: torch.Tensor) -> torch.Tensor: |
| self._refresh_level_buffers(device=z.device) |
| half_width = (self._levels // 2).to(dtype=z.dtype, device=z.device) |
| return self.bound(z).round() / half_width |
|
|
| def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor: |
| self._refresh_level_buffers(device=zhat.device) |
| zhat = self._scale_and_shift(zhat) |
| basis = self._basis.to(device=zhat.device, dtype=zhat.dtype) |
| return (zhat * basis).sum(dim=-1).to(torch.int32) |
|
|
| def indices_to_level_indices(self, indices: torch.Tensor) -> torch.Tensor: |
| self._refresh_level_buffers(device=indices.device) |
| indices = indices.unsqueeze(-1) |
| levels = self._levels.to(device=indices.device) |
| basis = self._basis.to(device=indices.device) |
| return (indices // basis) % levels |
|
|
| def indices_to_codes(self, indices: torch.Tensor) -> torch.Tensor: |
| self._refresh_level_buffers(device=indices.device) |
| level_indices = self.indices_to_level_indices(indices) |
| codes = self._scale_and_shift_inverse(level_indices.to(dtype=torch.float32)) |
| return self.project_out(codes) |
|
|
| def forward(self, z: torch.Tensor): |
| orig_dtype = z.dtype |
| z = self.project_in(z.to(torch.float32)) |
| q = self.quantize_values(z) |
| indices = self.codes_to_indices(q) |
| out = self.project_out(q).to(orig_dtype) |
| return out, indices.long() |
|
|
|
|
| LRELU_SLOPE = 0.1 |
|
|
|
|
| def get_padding(kernel_size: int, dilation: int = 1) -> int: |
| return int((kernel_size * dilation - dilation) / 2) |
|
|
|
|
| def init_weights(module, mean: float = 0.0, std: float = 0.01): |
| classname = module.__class__.__name__ |
| if classname.find("Conv") != -1 and hasattr(module, "weight"): |
| module.weight.data.normal_(mean, std) |
|
|
|
|
| class ResBlock1(nn.Module): |
| __constants__ = ["lrelu_slope"] |
|
|
| def __init__(self, channels: int, kernel_size: int = 3, dilation=(1, 3, 5)): |
| super().__init__() |
| self.lrelu_slope = LRELU_SLOPE |
|
|
| ch = channels |
| ks = kernel_size |
| self.convs1 = nn.Sequential( |
| weight_norm(Conv1d(ch, ch, ks, 1, get_padding(ks, dilation[0]), dilation[0])), |
| weight_norm(Conv1d(ch, ch, ks, 1, get_padding(ks, dilation[1]), dilation[1])), |
| weight_norm(Conv1d(ch, ch, ks, 1, get_padding(ks, dilation[2]), dilation[2])), |
| ) |
| self.convs2 = nn.Sequential( |
| weight_norm(Conv1d(ch, ch, ks, 1, get_padding(ks, 1))), |
| weight_norm(Conv1d(ch, ch, ks, 1, get_padding(ks, 1))), |
| weight_norm(Conv1d(ch, ch, ks, 1, get_padding(ks, 1))), |
| ) |
| self.convs1.apply(init_weights) |
| self.convs2.apply(init_weights) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| for conv1, conv2 in zip(self.convs1, self.convs2): |
| xt = F.leaky_relu(x, self.lrelu_slope) |
| xt = conv1(xt) |
| xt = F.leaky_relu(xt, self.lrelu_slope) |
| xt = conv2(xt) |
| x = xt + x |
| return x |
|
|
| def remove_weight_norm(self): |
| for layer in self.convs1: |
| remove_weight_norm(layer) |
| for layer in self.convs2: |
| remove_weight_norm(layer) |
|
|
|
|
| class ResBlock2(nn.Module): |
| __constants__ = ["lrelu_slope"] |
|
|
| def __init__(self, channels: int, kernel_size: int = 3, dilation=(1, 3)): |
| super().__init__() |
| self.lrelu_slope = LRELU_SLOPE |
|
|
| ch = channels |
| ks = kernel_size |
| self.convs = nn.ModuleList( |
| [ |
| weight_norm(Conv1d(ch, ch, ks, 1, get_padding(kernel_size, dilation[0]), dilation[0])), |
| weight_norm(Conv1d(ch, ch, ks, 1, get_padding(kernel_size, dilation[1]), dilation[1])), |
| ] |
| ) |
| self.convs.apply(init_weights) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| for conv in self.convs: |
| xt = F.leaky_relu(x, self.lrelu_slope) |
| xt = conv(xt) |
| x = xt + x |
| return x |
|
|
| def remove_weight_norm(self): |
| for layer in self.convs: |
| remove_weight_norm(layer) |
|
|
|
|
| class Generator(nn.Module): |
| __constants__ = ["lrelu_slope", "num_kernels", "num_upsamples"] |
|
|
| def __init__( |
| self, |
| out_channels: int = 211, |
| upsample_rates=None, |
| upsample_kernel_sizes=None, |
| upsample_initial_channel: int = 512, |
| resblock: str = "1", |
| resblock_kernel_sizes=None, |
| resblock_dilation_sizes=None, |
| ): |
| super().__init__() |
| upsample_rates = list(upsample_rates or [5, 4, 2, 2]) |
| upsample_kernel_sizes = list(upsample_kernel_sizes or [10, 8, 4, 4]) |
| resblock_kernel_sizes = list(resblock_kernel_sizes or [11, 7, 3]) |
| resblock_dilation_sizes = [list(d) for d in (resblock_dilation_sizes or [[1, 3, 5], [1, 3, 5], [1, 3, 5]])] |
|
|
| self.num_kernels = len(resblock_kernel_sizes) |
| self.num_upsamples = len(upsample_rates) |
| self.lrelu_slope = LRELU_SLOPE |
|
|
| self.conv_pre = weight_norm(Conv1d(out_channels, upsample_initial_channel, 7, 1, padding=3)) |
| resblock_cls = ResBlock1 if resblock == "1" else ResBlock2 |
|
|
| ups = [] |
| for i, (rate, kernel) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): |
| ups.append( |
| weight_norm( |
| ConvTranspose1d( |
| upsample_initial_channel // (2 ** i), |
| upsample_initial_channel // (2 ** (i + 1)), |
| kernel, |
| rate, |
| padding=(kernel - rate) // 2, |
| ) |
| ) |
| ) |
| self.ups = nn.Sequential(*ups) |
|
|
| resblocks = [] |
| for i in range(len(self.ups)): |
| ch = upsample_initial_channel // (2 ** (i + 1)) |
| resblocks.append( |
| nn.Sequential( |
| *[ |
| resblock_cls(ch, kernel, dilation) |
| for kernel, dilation in zip(resblock_kernel_sizes, resblock_dilation_sizes) |
| ] |
| ) |
| ) |
| self.resblocks = nn.Sequential(*resblocks) |
|
|
| self.conv_post = weight_norm(Conv1d(ch, 1, 17, 1, padding=0)) |
| self.ups.apply(init_weights) |
| self.conv_post.apply(init_weights) |
|
|
| def load_state_dict(self, state_dict, strict: bool = True): |
| new_state_dict = {} |
| for key, value in state_dict.items(): |
| new_key = key |
| if "resblocks" in key: |
| parts = key.split(".") |
| if len(parts) == 5: |
| layer = int(parts[1]) |
| new_key = f"resblocks.{layer // 3}.{layer % 3}.{'.'.join(parts[2:])}" |
| new_state_dict[new_key] = value |
|
|
| current_state = self.state_dict() |
| for key, value in list(new_state_dict.items()): |
| if key not in current_state: |
| continue |
| len_diff = value.dim() - current_state[key].dim() |
| if len_diff == -1: |
| new_state_dict[key] = value.unsqueeze(-1) |
| elif len_diff == 1: |
| new_state_dict[key] = value.squeeze(-1) |
|
|
| super().load_state_dict(new_state_dict, strict=strict) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.conv_pre(x.permute(0, 2, 1)) |
|
|
| for upsample_layer, resblock_group in zip(self.ups, self.resblocks): |
| x = F.leaky_relu(x, self.lrelu_slope) |
| x = upsample_layer(x) |
| xs = 0 |
| for resblock in resblock_group: |
| xs = xs + resblock(x) |
| x = xs / self.num_kernels |
|
|
| x = F.leaky_relu(x) |
| x = self.conv_post(x) |
| return torch.tanh(x) |
|
|
| def remove_weight_norm(self): |
| for layer in self.ups: |
| remove_weight_norm(layer) |
| for group in self.resblocks: |
| for block in group: |
| block.remove_weight_norm() |
| remove_weight_norm(self.conv_pre) |
| remove_weight_norm(self.conv_post) |
|
|
|
|
| class WavCoch(PreTrainedModel): |
| """Causal waveform-to-cochleagram tokenizer with optional vocoder.""" |
|
|
| config_class = WavCochConfig |
| main_input_name = "wav" |
|
|
| def __init__(self, config: WavCochConfig): |
| super().__init__(config) |
| self.config = config |
|
|
| self.N = int(config.window_size) |
| self.hop_length = int(config.hop_length) |
| self.window_padding = int(getattr(config, "window_padding", self.N - self.hop_length)) |
| self.causal_convs = bool(getattr(config, "causal_convs", True)) |
| self.causal_pad_mode = getattr(config, "causal_pad_mode", "repeat") |
| self.encoder_block_type = getattr(config, "encoder_block_type", "plain") |
| self.decoder_block_type = getattr(config, "decoder_block_type", "plain") |
|
|
| out_bins = self.N // 2 + 1 |
| self.conv_real_filters = nn.Conv1d(1, out_bins, kernel_size=self.N, stride=self.hop_length) |
| self.conv_imag_filters = nn.Conv1d(1, out_bins, kernel_size=self.N, stride=self.hop_length) |
| self._initialize_conv_filters() |
|
|
| self.encoder = self._build_processing_stack( |
| in_channels=out_bins, |
| hidden_channels=config.encoder_dim, |
| out_channels=config.encoder_dim, |
| num_layers=config.encoder_layers, |
| kernel_size=config.encoder_kernel_size, |
| causal=self.causal_convs, |
| block_type=self.encoder_block_type, |
| residual_kernel_size=getattr(config, "encoder_residual_kernel_size", 7), |
| residual_dilations=getattr(config, "encoder_residual_dilations", (1, 3, 9)), |
| residual_bottleneck_factor=getattr(config, "encoder_residual_bottleneck_factor", 4), |
| ) |
| self.quantizer = FSQ(levels=list(config.channels), dim=config.encoder_dim) |
| self.decoder = self._build_processing_stack( |
| in_channels=config.decoder_dim, |
| hidden_channels=config.decoder_dim, |
| out_channels=config.out_channels, |
| num_layers=config.decoder_layers, |
| kernel_size=config.decoder_kernel_size, |
| causal=self.causal_convs, |
| block_type=self.decoder_block_type, |
| residual_kernel_size=getattr(config, "decoder_residual_kernel_size", 7), |
| residual_dilations=getattr(config, "decoder_residual_dilations", (1, 3, 9)), |
| residual_bottleneck_factor=getattr(config, "decoder_residual_bottleneck_factor", 4), |
| ) |
|
|
| self.has_vocoder = bool(getattr(config, "has_vocoder", False)) |
| if self.has_vocoder: |
| if int(config.out_channels) != 211: |
| raise ValueError("Bundled vocoder currently expects 211 cochleagram channels") |
| self.vocoder = Generator( |
| out_channels=config.out_channels, |
| upsample_rates=config.vocoder_upsample_rates, |
| upsample_kernel_sizes=config.vocoder_upsample_kernel_sizes, |
| upsample_initial_channel=config.vocoder_upsample_initial_channel, |
| resblock=config.vocoder_resblock, |
| resblock_kernel_sizes=config.vocoder_resblock_kernel_sizes, |
| resblock_dilation_sizes=config.vocoder_resblock_dilation_sizes, |
| ) |
| else: |
| self.vocoder = None |
|
|
| self._vocab_size = int(config.vocab_size) |
| self.post_init() |
|
|
| def _build_plain_conv_stack( |
| self, |
| in_channels: int, |
| out_channels: int, |
| num_layers: int, |
| kernel_size: int, |
| causal: bool, |
| ) -> nn.Sequential: |
| layers = [] |
| for layer_idx in range(int(num_layers)): |
| input_channels = in_channels if layer_idx == 0 else out_channels |
| conv = _build_conv1d( |
| input_channels, |
| out_channels, |
| kernel_size, |
| causal=causal, |
| pad_mode=self.causal_pad_mode, |
| ) |
| layers.extend([conv, nn.ReLU()]) |
| return nn.Sequential(*layers) |
|
|
| def _build_processing_stack( |
| self, |
| in_channels: int, |
| hidden_channels: int, |
| out_channels: int, |
| num_layers: int, |
| kernel_size: int, |
| causal: bool, |
| block_type: str = "plain", |
| residual_kernel_size: int = 7, |
| residual_dilations: Sequence[int] = (1, 3, 9), |
| residual_bottleneck_factor: int = 4, |
| ) -> nn.Module: |
| if block_type == "plain": |
| return self._build_plain_conv_stack( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| num_layers=num_layers, |
| kernel_size=kernel_size, |
| causal=causal, |
| ) |
|
|
| if block_type != "soundstream": |
| raise ValueError(f"Unsupported WavCoch block_type: {block_type}") |
|
|
| return SoundStreamConvStack( |
| in_channels=in_channels, |
| hidden_channels=hidden_channels, |
| out_channels=out_channels, |
| num_layers=num_layers, |
| residual_kernel_size=residual_kernel_size, |
| residual_dilations=residual_dilations, |
| bottleneck_factor=residual_bottleneck_factor, |
| causal=causal, |
| pad_mode=self.causal_pad_mode, |
| ) |
|
|
| def _compute_twiddle_factors(self): |
| n = torch.arange(self.N, dtype=torch.float32).unsqueeze(1) |
| k = torch.arange(self.N, dtype=torch.float32).unsqueeze(0) |
| angles = -2.0 * math.pi * n * k / float(self.N) |
| return torch.cos(angles), torch.sin(angles) |
|
|
| def _initialize_conv_filters(self): |
| with torch.no_grad(): |
| cos_matrix, sin_matrix = self._compute_twiddle_factors() |
| cos_matrix = cos_matrix[: self.N // 2 + 1, :] |
| sin_matrix = sin_matrix[: self.N // 2 + 1, :] |
| window = torch.hann_window(self.N, periodic=True).view(1, 1, -1) |
| real_weights = (cos_matrix.unsqueeze(1) * window).to(dtype=self.conv_real_filters.weight.dtype) |
| imag_weights = (sin_matrix.unsqueeze(1) * window).to(dtype=self.conv_imag_filters.weight.dtype) |
| self.conv_real_filters.weight.copy_(real_weights) |
| self.conv_imag_filters.weight.copy_(imag_weights) |
|
|
| for param in self.conv_real_filters.parameters(): |
| param.requires_grad_(False) |
| for param in self.conv_imag_filters.parameters(): |
| param.requires_grad_(False) |
|
|
| def _normalize_sample_rate(self, sample_rate: Optional[int], sampling_rate: Optional[int]) -> int: |
| if sample_rate is not None and sampling_rate is not None and sample_rate != sampling_rate: |
| raise ValueError(f"sample_rate ({sample_rate}) and sampling_rate ({sampling_rate}) conflict") |
| resolved = int(sample_rate or sampling_rate or self.config.sample_rate) |
| if resolved != int(self.config.sample_rate): |
| raise ValueError( |
| f"WavCoch expects {self.config.sample_rate} Hz audio, but received {resolved} Hz" |
| ) |
| return resolved |
|
|
| def _prepare_wav_batch(self, wav) -> torch.Tensor: |
| if isinstance(wav, list): |
| wav = [item if isinstance(item, torch.Tensor) else torch.tensor(item) for item in wav] |
| normalized = [] |
| for item in wav: |
| if item.ndim == 1: |
| normalized.append(item) |
| elif item.ndim == 2 and 1 in item.shape: |
| normalized.append(item.reshape(-1)) |
| else: |
| raise ValueError(f"Unexpected list element shape {tuple(item.shape)}") |
| wav = torch.nn.utils.rnn.pad_sequence(normalized, batch_first=True).unsqueeze(1) |
| elif isinstance(wav, torch.Tensor): |
| if wav.ndim == 1: |
| wav = wav.unsqueeze(0).unsqueeze(0) |
| elif wav.ndim == 2: |
| wav = wav.unsqueeze(1) |
| elif wav.ndim != 3: |
| raise ValueError(f"Unexpected tensor shape {tuple(wav.shape)}, expected 1D, 2D or 3D") |
| else: |
| raise TypeError(f"Unsupported input type: {type(wav)}") |
|
|
| return wav.to(dtype=torch.float32) |
|
|
| @property |
| def vocab_size(self) -> int: |
| return self._vocab_size |
|
|
| def _resolve_wav_input( |
| self, |
| wav: Optional[torch.Tensor], |
| input_values: Optional[torch.Tensor], |
| ) -> torch.Tensor: |
| if wav is not None and input_values is not None: |
| raise ValueError("Provide either `wav` or `input_values`, not both") |
| resolved = wav if wav is not None else input_values |
| if resolved is None: |
| raise ValueError("WavCoch requires waveform input via `wav` or `input_values`") |
| return resolved |
|
|
| def _encode_quantized( |
| self, |
| wav: torch.Tensor, |
| pad: bool = True, |
| ): |
| wav = self._prepare_wav_batch(wav) |
| if pad: |
| wav = F.pad(wav, (self.window_padding, 0), mode="constant", value=0.0) |
|
|
| with torch.no_grad(): |
| real_part = self.conv_real_filters(wav) |
| imag_part = self.conv_imag_filters(wav) |
|
|
| x = real_part + imag_part |
| x = self.encoder(x).permute(0, 2, 1) |
| quantized, indices = self.quantizer(x) |
| return quantized, indices |
|
|
| def forward( |
| self, |
| wav: Optional[torch.Tensor] = None, |
| coch: Optional[torch.Tensor] = None, |
| return_tensors: str = "pt", |
| sample_rate: Optional[int] = None, |
| sampling_rate: Optional[int] = None, |
| pad: bool = True, |
| output_hidden_states: Optional[bool] = False, |
| return_dict: Optional[bool] = True, |
| input_values: Optional[torch.Tensor] = None, |
| ): |
| del return_tensors |
| self._normalize_sample_rate(sample_rate, sampling_rate) |
| wav = self._resolve_wav_input(wav, input_values) |
| quantized, indices = self._encode_quantized(wav, pad=pad) |
|
|
| if output_hidden_states: |
| hidden_states = (quantized,) |
| if not return_dict: |
| return quantized, hidden_states |
| return BaseModelOutput(last_hidden_state=quantized, hidden_states=hidden_states) |
|
|
| if coch is None: |
| codes = indices.long() |
| return BatchEncoding({"input_values": codes, "input_ids": codes}) |
|
|
| pred_coch = self.decoder(quantized.permute(0, 2, 1)).permute(0, 2, 1) |
| loss = F.l1_loss(pred_coch, coch) |
| return pred_coch, loss, None |
|
|
| @torch.no_grad() |
| def quantize(self, wav: torch.Tensor, pad: bool = True) -> torch.Tensor: |
| _, indices = self._encode_quantized(wav, pad=pad) |
| return indices.long() |
|
|
| @torch.no_grad() |
| def decode(self, indices: torch.Tensor) -> torch.Tensor: |
| if indices.ndim == 1: |
| indices = indices.unsqueeze(0) |
| emb = self.quantizer.indices_to_codes(indices.long()) |
| return self.decoder(emb.permute(0, 2, 1)).permute(0, 2, 1) |
|
|
| @torch.no_grad() |
| def wav2coch(self, wav: torch.Tensor, pad: bool = True) -> torch.Tensor: |
| quantized, _ = self._encode_quantized(wav, pad=pad) |
| return self.decoder(quantized.permute(0, 2, 1)).permute(0, 2, 1) |
|
|
| @torch.no_grad() |
| def vocode(self, coch: torch.Tensor) -> torch.Tensor: |
| if self.vocoder is None: |
| raise ValueError("This WavCoch checkpoint does not include a bundled vocoder") |
|
|
| if coch.ndim == 2: |
| coch = coch.unsqueeze(0) |
| elif coch.ndim != 3: |
| raise ValueError(f"Unexpected cochleagram shape {tuple(coch.shape)}") |
|
|
| if coch.shape[-1] != self.config.out_channels and coch.shape[1] == self.config.out_channels: |
| coch = coch.transpose(1, 2) |
|
|
| return self.vocoder(coch) |
|
|
| @torch.no_grad() |
| def decode_audio(self, indices: torch.Tensor) -> torch.Tensor: |
| return self.vocode(self.decode(indices)) |
|
|