WavCochCausalV64000100M / modeling_wavcoch.py
klemenk's picture
Upload WavCoch checkpoint
c8ca075 verified
"""
WavCoch model for Hugging Face Transformers.
This implementation is self-contained so HF-hosted WavCoch checkpoints do not
depend on the local auristream package or vector_quantize_pytorch.
"""
import math
import os
from typing import List, Optional, Sequence
os.environ.setdefault("USE_TORCH_XLA", "0")
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm
try:
from torch.nn.utils.parametrizations import weight_norm
except ImportError: # pragma: no cover - older PyTorch compatibility
from torch.nn.utils import weight_norm
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput
try:
from transformers.tokenization_utils_base import BatchEncoding
except ImportError: # pragma: no cover - compatibility with older Transformers
from transformers.tokenization_utils import BatchEncoding
import transformers.modeling_utils as transformers_modeling_utils
import transformers.utils.import_utils as transformers_import_utils
transformers_import_utils.is_torch_xla_available = lambda *args, **kwargs: False
transformers_modeling_utils.is_torch_xla_available = lambda *args, **kwargs: False
try:
from .configuration_wavcoch import WavCochConfig
except ImportError: # pragma: no cover - compatibility with older repos
from .configure_wavcoch import WavCochConfig
class CausalConv1d(nn.Module):
"""1D causal convolution with left-only padding."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
bias: bool = True,
groups: int = 1,
pad_mode: str = "repeat",
constant_value: float = 0.0,
):
super().__init__()
left_pad = dilation * (kernel_size - 1)
if pad_mode == "repeat":
self.pad = nn.ReplicationPad1d((left_pad, 0))
elif pad_mode == "constant":
self.pad = nn.ConstantPad1d((left_pad, 0), constant_value)
else:
raise ValueError(f"Unsupported pad_mode: {pad_mode}")
self.conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0,
dilation=dilation,
groups=groups,
bias=bias,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(self.pad(x))
def _build_conv1d(
in_channels: int,
out_channels: int,
kernel_size: int,
*,
causal: bool,
dilation: int = 1,
pad_mode: str = "repeat",
):
if causal:
return CausalConv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
dilation=dilation,
pad_mode=pad_mode,
)
padding = dilation * (kernel_size - 1) // 2
return nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
dilation=dilation,
padding=padding,
)
class SoundStreamResidualUnit(nn.Module):
def __init__(
self,
channels: int,
*,
kernel_size: int = 7,
dilation: int = 1,
bottleneck_factor: int = 4,
causal: bool,
pad_mode: str = "repeat",
):
super().__init__()
bottleneck_channels = max(1, channels // max(1, int(bottleneck_factor)))
self.pre_act = nn.ELU()
self.reduce = nn.Conv1d(channels, bottleneck_channels, kernel_size=1)
self.mid_act = nn.ELU()
self.dilated = _build_conv1d(
bottleneck_channels,
bottleneck_channels,
kernel_size,
causal=causal,
dilation=dilation,
pad_mode=pad_mode,
)
self.post_act = nn.ELU()
self.expand = nn.Conv1d(bottleneck_channels, channels, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.pre_act(x)
x = self.reduce(x)
x = self.mid_act(x)
x = self.dilated(x)
x = self.post_act(x)
x = self.expand(x)
return x + residual
class SoundStreamConvStack(nn.Module):
def __init__(
self,
in_channels: int,
hidden_channels: int,
out_channels: int,
*,
num_layers: int,
residual_kernel_size: int = 7,
residual_dilations: Sequence[int] = (1, 3, 9),
bottleneck_factor: int = 4,
causal: bool,
pad_mode: str = "repeat",
):
super().__init__()
dilations = [int(d) for d in residual_dilations]
if not dilations:
raise ValueError("SoundStream residual dilations must be non-empty")
self.input_proj = nn.Conv1d(in_channels, hidden_channels, kernel_size=1)
self.blocks = nn.ModuleList(
[
SoundStreamResidualUnit(
hidden_channels,
kernel_size=residual_kernel_size,
dilation=dilations[layer_idx % len(dilations)],
bottleneck_factor=bottleneck_factor,
causal=causal,
pad_mode=pad_mode,
)
for layer_idx in range(int(num_layers))
]
)
self.output_act = nn.ELU()
self.output_proj = (
nn.Identity()
if hidden_channels == out_channels
else nn.Conv1d(hidden_channels, out_channels, kernel_size=1)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.input_proj(x)
for block in self.blocks:
x = block(x)
x = self.output_act(x)
return self.output_proj(x)
class FSQ(nn.Module):
"""Finite Scalar Quantization with the subset of functionality needed for inference."""
def __init__(self, levels: List[int], dim: int):
super().__init__()
if not levels:
raise ValueError("FSQ levels must be non-empty")
self.levels = [int(level) for level in levels]
self.codebook_dim = len(self.levels)
self.dim = int(dim)
level_tensor = torch.tensor(self.levels, dtype=torch.int32)
basis = torch.cumprod(torch.tensor([1] + self.levels[:-1], dtype=torch.int32), dim=0)
self.register_buffer("_levels", level_tensor, persistent=False)
self.register_buffer("_basis", basis, persistent=False)
if self.dim != self.codebook_dim:
self.project_in = nn.Linear(self.dim, self.codebook_dim)
self.project_out = nn.Linear(self.codebook_dim, self.dim)
else:
self.project_in = nn.Identity()
self.project_out = nn.Identity()
def _refresh_level_buffers(self, device: Optional[torch.device] = None):
level_values = [int(level) for level in self.levels]
if device is None:
if isinstance(self.project_in, nn.Linear):
device = self.project_in.weight.device
elif isinstance(self.project_out, nn.Linear):
device = self.project_out.weight.device
else:
device = self._levels.device
self._levels = torch.tensor(level_values, dtype=torch.int32, device=device)
self._basis = torch.cumprod(
torch.tensor([1] + level_values[:-1], dtype=torch.int32, device=device),
dim=0,
)
def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
levels = self._levels.to(dtype=z.dtype, device=z.device)
half_l = (levels - 1) * (1 + eps) / 2
offset = torch.where(
(self._levels % 2).to(device=z.device) == 0,
torch.tensor(0.5, device=z.device, dtype=z.dtype),
torch.tensor(0.0, device=z.device, dtype=z.dtype),
)
shift = (offset / half_l).atanh()
return (z + shift).tanh() * half_l - offset
def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor:
half_width = (self._levels // 2).to(dtype=zhat_normalized.dtype, device=zhat_normalized.device)
return (zhat_normalized * half_width) + half_width
def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor:
half_width = (self._levels // 2).to(dtype=zhat.dtype, device=zhat.device)
return (zhat - half_width) / half_width
def quantize_values(self, z: torch.Tensor) -> torch.Tensor:
self._refresh_level_buffers(device=z.device)
half_width = (self._levels // 2).to(dtype=z.dtype, device=z.device)
return self.bound(z).round() / half_width
def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor:
self._refresh_level_buffers(device=zhat.device)
zhat = self._scale_and_shift(zhat)
basis = self._basis.to(device=zhat.device, dtype=zhat.dtype)
return (zhat * basis).sum(dim=-1).to(torch.int32)
def indices_to_level_indices(self, indices: torch.Tensor) -> torch.Tensor:
self._refresh_level_buffers(device=indices.device)
indices = indices.unsqueeze(-1)
levels = self._levels.to(device=indices.device)
basis = self._basis.to(device=indices.device)
return (indices // basis) % levels
def indices_to_codes(self, indices: torch.Tensor) -> torch.Tensor:
self._refresh_level_buffers(device=indices.device)
level_indices = self.indices_to_level_indices(indices)
codes = self._scale_and_shift_inverse(level_indices.to(dtype=torch.float32))
return self.project_out(codes)
def forward(self, z: torch.Tensor):
orig_dtype = z.dtype
z = self.project_in(z.to(torch.float32))
q = self.quantize_values(z)
indices = self.codes_to_indices(q)
out = self.project_out(q).to(orig_dtype)
return out, indices.long()
LRELU_SLOPE = 0.1
def get_padding(kernel_size: int, dilation: int = 1) -> int:
return int((kernel_size * dilation - dilation) / 2)
def init_weights(module, mean: float = 0.0, std: float = 0.01):
classname = module.__class__.__name__
if classname.find("Conv") != -1 and hasattr(module, "weight"):
module.weight.data.normal_(mean, std)
class ResBlock1(nn.Module):
__constants__ = ["lrelu_slope"]
def __init__(self, channels: int, kernel_size: int = 3, dilation=(1, 3, 5)):
super().__init__()
self.lrelu_slope = LRELU_SLOPE
ch = channels
ks = kernel_size
self.convs1 = nn.Sequential(
weight_norm(Conv1d(ch, ch, ks, 1, get_padding(ks, dilation[0]), dilation[0])),
weight_norm(Conv1d(ch, ch, ks, 1, get_padding(ks, dilation[1]), dilation[1])),
weight_norm(Conv1d(ch, ch, ks, 1, get_padding(ks, dilation[2]), dilation[2])),
)
self.convs2 = nn.Sequential(
weight_norm(Conv1d(ch, ch, ks, 1, get_padding(ks, 1))),
weight_norm(Conv1d(ch, ch, ks, 1, get_padding(ks, 1))),
weight_norm(Conv1d(ch, ch, ks, 1, get_padding(ks, 1))),
)
self.convs1.apply(init_weights)
self.convs2.apply(init_weights)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for conv1, conv2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, self.lrelu_slope)
xt = conv1(xt)
xt = F.leaky_relu(xt, self.lrelu_slope)
xt = conv2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for layer in self.convs1:
remove_weight_norm(layer)
for layer in self.convs2:
remove_weight_norm(layer)
class ResBlock2(nn.Module):
__constants__ = ["lrelu_slope"]
def __init__(self, channels: int, kernel_size: int = 3, dilation=(1, 3)):
super().__init__()
self.lrelu_slope = LRELU_SLOPE
ch = channels
ks = kernel_size
self.convs = nn.ModuleList(
[
weight_norm(Conv1d(ch, ch, ks, 1, get_padding(kernel_size, dilation[0]), dilation[0])),
weight_norm(Conv1d(ch, ch, ks, 1, get_padding(kernel_size, dilation[1]), dilation[1])),
]
)
self.convs.apply(init_weights)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for conv in self.convs:
xt = F.leaky_relu(x, self.lrelu_slope)
xt = conv(xt)
x = xt + x
return x
def remove_weight_norm(self):
for layer in self.convs:
remove_weight_norm(layer)
class Generator(nn.Module):
__constants__ = ["lrelu_slope", "num_kernels", "num_upsamples"]
def __init__(
self,
out_channels: int = 211,
upsample_rates=None,
upsample_kernel_sizes=None,
upsample_initial_channel: int = 512,
resblock: str = "1",
resblock_kernel_sizes=None,
resblock_dilation_sizes=None,
):
super().__init__()
upsample_rates = list(upsample_rates or [5, 4, 2, 2])
upsample_kernel_sizes = list(upsample_kernel_sizes or [10, 8, 4, 4])
resblock_kernel_sizes = list(resblock_kernel_sizes or [11, 7, 3])
resblock_dilation_sizes = [list(d) for d in (resblock_dilation_sizes or [[1, 3, 5], [1, 3, 5], [1, 3, 5]])]
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.lrelu_slope = LRELU_SLOPE
self.conv_pre = weight_norm(Conv1d(out_channels, upsample_initial_channel, 7, 1, padding=3))
resblock_cls = ResBlock1 if resblock == "1" else ResBlock2
ups = []
for i, (rate, kernel) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
ups.append(
weight_norm(
ConvTranspose1d(
upsample_initial_channel // (2 ** i),
upsample_initial_channel // (2 ** (i + 1)),
kernel,
rate,
padding=(kernel - rate) // 2,
)
)
)
self.ups = nn.Sequential(*ups)
resblocks = []
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
resblocks.append(
nn.Sequential(
*[
resblock_cls(ch, kernel, dilation)
for kernel, dilation in zip(resblock_kernel_sizes, resblock_dilation_sizes)
]
)
)
self.resblocks = nn.Sequential(*resblocks)
self.conv_post = weight_norm(Conv1d(ch, 1, 17, 1, padding=0))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
def load_state_dict(self, state_dict, strict: bool = True):
new_state_dict = {}
for key, value in state_dict.items():
new_key = key
if "resblocks" in key:
parts = key.split(".")
if len(parts) == 5:
layer = int(parts[1])
new_key = f"resblocks.{layer // 3}.{layer % 3}.{'.'.join(parts[2:])}"
new_state_dict[new_key] = value
current_state = self.state_dict()
for key, value in list(new_state_dict.items()):
if key not in current_state:
continue
len_diff = value.dim() - current_state[key].dim()
if len_diff == -1:
new_state_dict[key] = value.unsqueeze(-1)
elif len_diff == 1:
new_state_dict[key] = value.squeeze(-1)
super().load_state_dict(new_state_dict, strict=strict)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv_pre(x.permute(0, 2, 1))
for upsample_layer, resblock_group in zip(self.ups, self.resblocks):
x = F.leaky_relu(x, self.lrelu_slope)
x = upsample_layer(x)
xs = 0
for resblock in resblock_group:
xs = xs + resblock(x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
return torch.tanh(x)
def remove_weight_norm(self):
for layer in self.ups:
remove_weight_norm(layer)
for group in self.resblocks:
for block in group:
block.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
class WavCoch(PreTrainedModel):
"""Causal waveform-to-cochleagram tokenizer with optional vocoder."""
config_class = WavCochConfig
main_input_name = "wav"
def __init__(self, config: WavCochConfig):
super().__init__(config)
self.config = config
self.N = int(config.window_size)
self.hop_length = int(config.hop_length)
self.window_padding = int(getattr(config, "window_padding", self.N - self.hop_length))
self.causal_convs = bool(getattr(config, "causal_convs", True))
self.causal_pad_mode = getattr(config, "causal_pad_mode", "repeat")
self.encoder_block_type = getattr(config, "encoder_block_type", "plain")
self.decoder_block_type = getattr(config, "decoder_block_type", "plain")
out_bins = self.N // 2 + 1
self.conv_real_filters = nn.Conv1d(1, out_bins, kernel_size=self.N, stride=self.hop_length)
self.conv_imag_filters = nn.Conv1d(1, out_bins, kernel_size=self.N, stride=self.hop_length)
self._initialize_conv_filters()
self.encoder = self._build_processing_stack(
in_channels=out_bins,
hidden_channels=config.encoder_dim,
out_channels=config.encoder_dim,
num_layers=config.encoder_layers,
kernel_size=config.encoder_kernel_size,
causal=self.causal_convs,
block_type=self.encoder_block_type,
residual_kernel_size=getattr(config, "encoder_residual_kernel_size", 7),
residual_dilations=getattr(config, "encoder_residual_dilations", (1, 3, 9)),
residual_bottleneck_factor=getattr(config, "encoder_residual_bottleneck_factor", 4),
)
self.quantizer = FSQ(levels=list(config.channels), dim=config.encoder_dim)
self.decoder = self._build_processing_stack(
in_channels=config.decoder_dim,
hidden_channels=config.decoder_dim,
out_channels=config.out_channels,
num_layers=config.decoder_layers,
kernel_size=config.decoder_kernel_size,
causal=self.causal_convs,
block_type=self.decoder_block_type,
residual_kernel_size=getattr(config, "decoder_residual_kernel_size", 7),
residual_dilations=getattr(config, "decoder_residual_dilations", (1, 3, 9)),
residual_bottleneck_factor=getattr(config, "decoder_residual_bottleneck_factor", 4),
)
self.has_vocoder = bool(getattr(config, "has_vocoder", False))
if self.has_vocoder:
if int(config.out_channels) != 211:
raise ValueError("Bundled vocoder currently expects 211 cochleagram channels")
self.vocoder = Generator(
out_channels=config.out_channels,
upsample_rates=config.vocoder_upsample_rates,
upsample_kernel_sizes=config.vocoder_upsample_kernel_sizes,
upsample_initial_channel=config.vocoder_upsample_initial_channel,
resblock=config.vocoder_resblock,
resblock_kernel_sizes=config.vocoder_resblock_kernel_sizes,
resblock_dilation_sizes=config.vocoder_resblock_dilation_sizes,
)
else:
self.vocoder = None
self._vocab_size = int(config.vocab_size)
self.post_init()
def _build_plain_conv_stack(
self,
in_channels: int,
out_channels: int,
num_layers: int,
kernel_size: int,
causal: bool,
) -> nn.Sequential:
layers = []
for layer_idx in range(int(num_layers)):
input_channels = in_channels if layer_idx == 0 else out_channels
conv = _build_conv1d(
input_channels,
out_channels,
kernel_size,
causal=causal,
pad_mode=self.causal_pad_mode,
)
layers.extend([conv, nn.ReLU()])
return nn.Sequential(*layers)
def _build_processing_stack(
self,
in_channels: int,
hidden_channels: int,
out_channels: int,
num_layers: int,
kernel_size: int,
causal: bool,
block_type: str = "plain",
residual_kernel_size: int = 7,
residual_dilations: Sequence[int] = (1, 3, 9),
residual_bottleneck_factor: int = 4,
) -> nn.Module:
if block_type == "plain":
return self._build_plain_conv_stack(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
kernel_size=kernel_size,
causal=causal,
)
if block_type != "soundstream":
raise ValueError(f"Unsupported WavCoch block_type: {block_type}")
return SoundStreamConvStack(
in_channels=in_channels,
hidden_channels=hidden_channels,
out_channels=out_channels,
num_layers=num_layers,
residual_kernel_size=residual_kernel_size,
residual_dilations=residual_dilations,
bottleneck_factor=residual_bottleneck_factor,
causal=causal,
pad_mode=self.causal_pad_mode,
)
def _compute_twiddle_factors(self):
n = torch.arange(self.N, dtype=torch.float32).unsqueeze(1)
k = torch.arange(self.N, dtype=torch.float32).unsqueeze(0)
angles = -2.0 * math.pi * n * k / float(self.N)
return torch.cos(angles), torch.sin(angles)
def _initialize_conv_filters(self):
with torch.no_grad():
cos_matrix, sin_matrix = self._compute_twiddle_factors()
cos_matrix = cos_matrix[: self.N // 2 + 1, :]
sin_matrix = sin_matrix[: self.N // 2 + 1, :]
window = torch.hann_window(self.N, periodic=True).view(1, 1, -1)
real_weights = (cos_matrix.unsqueeze(1) * window).to(dtype=self.conv_real_filters.weight.dtype)
imag_weights = (sin_matrix.unsqueeze(1) * window).to(dtype=self.conv_imag_filters.weight.dtype)
self.conv_real_filters.weight.copy_(real_weights)
self.conv_imag_filters.weight.copy_(imag_weights)
for param in self.conv_real_filters.parameters():
param.requires_grad_(False)
for param in self.conv_imag_filters.parameters():
param.requires_grad_(False)
def _normalize_sample_rate(self, sample_rate: Optional[int], sampling_rate: Optional[int]) -> int:
if sample_rate is not None and sampling_rate is not None and sample_rate != sampling_rate:
raise ValueError(f"sample_rate ({sample_rate}) and sampling_rate ({sampling_rate}) conflict")
resolved = int(sample_rate or sampling_rate or self.config.sample_rate)
if resolved != int(self.config.sample_rate):
raise ValueError(
f"WavCoch expects {self.config.sample_rate} Hz audio, but received {resolved} Hz"
)
return resolved
def _prepare_wav_batch(self, wav) -> torch.Tensor:
if isinstance(wav, list):
wav = [item if isinstance(item, torch.Tensor) else torch.tensor(item) for item in wav]
normalized = []
for item in wav:
if item.ndim == 1:
normalized.append(item)
elif item.ndim == 2 and 1 in item.shape:
normalized.append(item.reshape(-1))
else:
raise ValueError(f"Unexpected list element shape {tuple(item.shape)}")
wav = torch.nn.utils.rnn.pad_sequence(normalized, batch_first=True).unsqueeze(1)
elif isinstance(wav, torch.Tensor):
if wav.ndim == 1:
wav = wav.unsqueeze(0).unsqueeze(0)
elif wav.ndim == 2:
wav = wav.unsqueeze(1)
elif wav.ndim != 3:
raise ValueError(f"Unexpected tensor shape {tuple(wav.shape)}, expected 1D, 2D or 3D")
else:
raise TypeError(f"Unsupported input type: {type(wav)}")
return wav.to(dtype=torch.float32)
@property
def vocab_size(self) -> int:
return self._vocab_size
def _resolve_wav_input(
self,
wav: Optional[torch.Tensor],
input_values: Optional[torch.Tensor],
) -> torch.Tensor:
if wav is not None and input_values is not None:
raise ValueError("Provide either `wav` or `input_values`, not both")
resolved = wav if wav is not None else input_values
if resolved is None:
raise ValueError("WavCoch requires waveform input via `wav` or `input_values`")
return resolved
def _encode_quantized(
self,
wav: torch.Tensor,
pad: bool = True,
):
wav = self._prepare_wav_batch(wav)
if pad:
wav = F.pad(wav, (self.window_padding, 0), mode="constant", value=0.0)
with torch.no_grad():
real_part = self.conv_real_filters(wav)
imag_part = self.conv_imag_filters(wav)
x = real_part + imag_part
x = self.encoder(x).permute(0, 2, 1)
quantized, indices = self.quantizer(x)
return quantized, indices
def forward(
self,
wav: Optional[torch.Tensor] = None,
coch: Optional[torch.Tensor] = None,
return_tensors: str = "pt",
sample_rate: Optional[int] = None,
sampling_rate: Optional[int] = None,
pad: bool = True,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
input_values: Optional[torch.Tensor] = None,
):
del return_tensors # unused, kept for tokenizer-like API compatibility
self._normalize_sample_rate(sample_rate, sampling_rate)
wav = self._resolve_wav_input(wav, input_values)
quantized, indices = self._encode_quantized(wav, pad=pad)
if output_hidden_states:
hidden_states = (quantized,)
if not return_dict:
return quantized, hidden_states
return BaseModelOutput(last_hidden_state=quantized, hidden_states=hidden_states)
if coch is None:
codes = indices.long()
return BatchEncoding({"input_values": codes, "input_ids": codes})
pred_coch = self.decoder(quantized.permute(0, 2, 1)).permute(0, 2, 1)
loss = F.l1_loss(pred_coch, coch)
return pred_coch, loss, None
@torch.no_grad()
def quantize(self, wav: torch.Tensor, pad: bool = True) -> torch.Tensor:
_, indices = self._encode_quantized(wav, pad=pad)
return indices.long()
@torch.no_grad()
def decode(self, indices: torch.Tensor) -> torch.Tensor:
if indices.ndim == 1:
indices = indices.unsqueeze(0)
emb = self.quantizer.indices_to_codes(indices.long())
return self.decoder(emb.permute(0, 2, 1)).permute(0, 2, 1)
@torch.no_grad()
def wav2coch(self, wav: torch.Tensor, pad: bool = True) -> torch.Tensor:
quantized, _ = self._encode_quantized(wav, pad=pad)
return self.decoder(quantized.permute(0, 2, 1)).permute(0, 2, 1)
@torch.no_grad()
def vocode(self, coch: torch.Tensor) -> torch.Tensor:
if self.vocoder is None:
raise ValueError("This WavCoch checkpoint does not include a bundled vocoder")
if coch.ndim == 2:
coch = coch.unsqueeze(0)
elif coch.ndim != 3:
raise ValueError(f"Unexpected cochleagram shape {tuple(coch.shape)}")
if coch.shape[-1] != self.config.out_channels and coch.shape[1] == self.config.out_channels:
coch = coch.transpose(1, 2)
return self.vocoder(coch)
@torch.no_grad()
def decode_audio(self, indices: torch.Tensor) -> torch.Tensor:
return self.vocode(self.decode(indices))