| import json |
| import logging |
| import math |
| import os |
| import sys |
| import warnings |
| from abc import ABC, abstractmethod |
| from pathlib import Path |
| from subprocess import CalledProcessError, run |
| from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
| import hydra |
| import numpy as np |
| import omegaconf |
| import torch |
| import torch.nn.functional as F |
| import torchaudio |
| from hydra.utils import instantiate |
| from sentencepiece import SentencePieceProcessor |
| from torch import Tensor, nn |
| from torch.jit import TracerWarning |
| from transformers import PretrainedConfig, PreTrainedModel |
| from transformers.utils import cached_file |
|
|
| DIR_NAME = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.append(DIR_NAME) |
|
|
|
|
| IMPORT_FLASH = False |
| SAMPLE_RATE = 16000 |
| LONGFORM_THRESHOLD = 25 * SAMPLE_RATE |
| _PIPELINE = None |
|
|
|
|
| |
|
|
|
|
| def load_audio(audio_path: str, sample_rate: int = SAMPLE_RATE) -> Tensor: |
| """ |
| Load an audio file and resample it to the specified sample rate. |
| """ |
| cmd = [ |
| "ffmpeg", |
| "-nostdin", |
| "-threads", |
| "0", |
| "-i", |
| audio_path, |
| "-f", |
| "s16le", |
| "-ac", |
| "1", |
| "-acodec", |
| "pcm_s16le", |
| "-ar", |
| str(sample_rate), |
| "-", |
| ] |
| try: |
| audio = run(cmd, capture_output=True, check=True).stdout |
| except CalledProcessError as exc: |
| raise RuntimeError("Failed to load audio") from exc |
|
|
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore", category=UserWarning) |
| return torch.frombuffer(audio, dtype=torch.int16).float() / 32768.0 |
|
|
|
|
| class SpecScaler(nn.Module): |
| """ |
| Module that applies logarithmic scaling to spectrogram values. |
| This module clamps the input values within a certain range and then applies a natural logarithm. |
| """ |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| return torch.log(x.clamp_(1e-9, 1e9)) |
|
|
|
|
| class FeatureExtractor(nn.Module): |
| """ |
| Module for extracting Log-mel spectrogram features from raw audio signals. |
| This module uses Torchaudio's MelSpectrogram transform to extract features |
| and applies logarithmic scaling. |
| """ |
|
|
| def __init__(self, sample_rate: int, features: int, **kwargs): |
| super().__init__() |
| self.hop_length = kwargs.get("hop_length", sample_rate // 100) |
| self.win_length = kwargs.get("win_length", sample_rate // 40) |
| self.n_fft = kwargs.get("n_fft", sample_rate // 40) |
| self.center = kwargs.get("center", True) |
| self.featurizer = nn.Sequential( |
| torchaudio.transforms.MelSpectrogram( |
| sample_rate=sample_rate, |
| n_mels=features, |
| win_length=self.win_length, |
| hop_length=self.hop_length, |
| n_fft=self.n_fft, |
| center=self.center, |
| ), |
| SpecScaler(), |
| ) |
|
|
| def out_len(self, input_lengths: Tensor) -> Tensor: |
| """ |
| Calculates the output length after the feature extraction process. |
| """ |
| if self.center: |
| return ( |
| input_lengths.div(self.hop_length, rounding_mode="floor").add(1).long() |
| ) |
| else: |
| return ( |
| (input_lengths - self.win_length) |
| .div(self.hop_length, rounding_mode="floor") |
| .add(1) |
| .long() |
| ) |
|
|
| def forward(self, input_signal: Tensor, length: Tensor) -> Tuple[Tensor, Tensor]: |
| """ |
| Extract Log-mel spectrogram features from the input audio signal. |
| """ |
| return self.featurizer(input_signal), self.out_len(length) |
|
|
|
|
| |
|
|
|
|
| def onnx_converter( |
| model_name: str, |
| module: torch.nn.Module, |
| out_dir: str, |
| inputs: Optional[Tuple[Tensor, ...]] = None, |
| input_names: Optional[List[str]] = None, |
| output_names: Optional[List[str]] = None, |
| dynamic_axes: Optional[ |
| Union[Dict[str, List[int]], Dict[str, Dict[int, str]]] |
| ] = None, |
| opset_version: int = 17, |
| ): |
| if inputs is None: |
| inputs = module.input_example() |
| if input_names is None: |
| input_names = module.input_names() |
| if output_names is None: |
| output_names = module.output_names() |
|
|
| Path(out_dir).mkdir(exist_ok=True, parents=True) |
| out_path = str(Path(out_dir) / f"{model_name}.onnx") |
| saved_dtype = next(module.parameters()).dtype |
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore", category=UserWarning) |
| warnings.simplefilter("ignore", category=TracerWarning) |
| torch.onnx.export( |
| module.to(torch.float32), |
| inputs, |
| out_path, |
| input_names=input_names, |
| output_names=output_names, |
| dynamic_axes=dynamic_axes, |
| opset_version=opset_version, |
| ) |
| print(f"Succesfully ported onnx {model_name} to {out_path}.") |
| module.to(saved_dtype) |
|
|
|
|
| def format_time(seconds: float) -> str: |
| """ |
| Formats time in seconds to HH:MM:SS:mm format. |
| """ |
| hours = int(seconds // 3600) |
| minutes = int((seconds % 3600) // 60) |
| seconds = seconds % 60 |
| full_seconds = int(seconds) |
| milliseconds = int((seconds - full_seconds) * 100) |
|
|
| if hours > 0: |
| return f"{hours:02}:{minutes:02}:{full_seconds:02}:{milliseconds:02}" |
| return f"{minutes:02}:{full_seconds:02}:{milliseconds:02}" |
|
|
|
|
| def rtt_half(x: Tensor) -> Tensor: |
| x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] |
| return torch.cat([-x2, x1], dim=x1.ndim - 1) |
|
|
|
|
| def apply_rotary_pos_emb( |
| q: Tensor, k: Tensor, cos: Tensor, sin: Tensor, offset: int = 0 |
| ) -> Tuple[Tensor, Tensor]: |
| """ |
| Applies Rotary Position Embeddings to query and key tensors. |
| """ |
| cos, sin = ( |
| cos[offset : q.shape[0] + offset, ...], |
| sin[offset : q.shape[0] + offset, ...], |
| ) |
| return (q * cos) + (rtt_half(q) * sin), (k * cos) + (rtt_half(k) * sin) |
|
|
|
|
| def _normalize_device(device: Optional[Union[str, torch.device]]) -> torch.device: |
| """Normalize device parameter to torch.device.""" |
| if device is None: |
| device_str = "cuda" if torch.cuda.is_available() else "cpu" |
| return torch.device(device_str) |
| if isinstance(device, str): |
| return torch.device(device) |
| return device |
|
|
|
|
| def download_short_audio(): |
| """Download test audio file if not exists""" |
| audio_file = "example.wav" |
| if not os.path.exists(audio_file): |
| os.system( |
| 'wget -O example.wav "https://cdn.chatwm.opensmodel.sberdevices.ru/GigaAM/example.wav"' |
| ) |
| assert os.path.exists(audio_file), "Short audio file not found" |
| return audio_file |
|
|
|
|
| def download_long_audio(): |
| """Download test audio file if not exists""" |
| audio_file = "long_example.wav" |
| if not os.path.exists(audio_file): |
| os.system( |
| 'wget -O long_example.wav "https://cdn.chatwm.opensmodel.sberdevices.ru/GigaAM/long_example.wav"' |
| ) |
| assert os.path.exists(audio_file), "Long audio file not found" |
| return audio_file |
|
|
|
|
| class AudioDataset(torch.utils.data.Dataset): |
| """ |
| Helper class for creating batched inputs |
| """ |
|
|
| def __init__(self, lst: List[Union[str, np.ndarray, torch.Tensor]]): |
| assert isinstance( |
| lst[0], (str, np.ndarray, torch.Tensor) |
| ), f"Unexpected dtype: {type(lst[0])}" |
| self.lst = lst |
|
|
| def __len__(self): |
| return len(self.lst) |
|
|
| def __getitem__(self, idx): |
| item = self.lst[idx] |
| if isinstance(item, str): |
| wav_tns = load_audio(item) |
| elif isinstance(item, np.ndarray): |
| wav_tns = torch.from_numpy(item) |
| elif isinstance(item, torch.Tensor): |
| wav_tns = item |
| else: |
| raise RuntimeError(f"Unexpected sample type: {type(item)} at idx={idx}") |
| return wav_tns |
|
|
| @staticmethod |
| def collate(wavs): |
| lengths = torch.tensor([len(wav) for wav in wavs]) |
| max_len = lengths.max().item() |
| wav_tns = torch.zeros(len(wavs), max_len, dtype=wavs[0].dtype) |
| for idx, wav in enumerate(wavs): |
| wav_tns[idx, : wav.shape[-1]] = wav.squeeze() |
| return wav_tns, lengths |
|
|
|
|
|
|
| |
|
|
|
|
| def get_pipeline(device: torch.device): |
| """ |
| Retrieves a PyAnnote voice activity detection pipeline and move it to the specified device. |
| The pipeline is loaded only once and reused across subsequent calls. |
| It requires the Hugging Face API token to be set in the HF_TOKEN environment variable. |
| """ |
| global _PIPELINE |
| if _PIPELINE is not None: |
| return _PIPELINE.to(device) |
|
|
| from pyannote.audio import Model |
| from pyannote.audio.pipelines import VoiceActivityDetection |
|
|
| try: |
| hf_token = os.environ["HF_TOKEN"] |
| except KeyError as exc: |
| raise ValueError("HF_TOKEN environment variable is not set") from exc |
|
|
| model = Model.from_pretrained("pyannote/segmentation-3.0", use_auth_token=hf_token) |
| _PIPELINE = VoiceActivityDetection(segmentation=model) |
| _PIPELINE.instantiate({"min_duration_on": 0.0, "min_duration_off": 0.0}) |
|
|
| return _PIPELINE.to(device) |
|
|
|
|
| def segment_audio_file( |
| wav_file: str, |
| sr: int, |
| max_duration: float = 22.0, |
| min_duration: float = 15.0, |
| strict_limit_duration: float = 30.0, |
| new_chunk_threshold: float = 0.2, |
| device: torch.device = torch.device("cpu"), |
| ) -> Tuple[List[torch.Tensor], List[Tuple[float, float]]]: |
| """ |
| Segments an audio waveform into smaller chunks based on speech activity. |
| The segmentation is performed using a PyAnnote voice activity detection pipeline. |
| """ |
|
|
| audio = load_audio(wav_file) |
| pipeline = get_pipeline(device) |
| sad_segments = pipeline(wav_file) |
|
|
| segments: List[torch.Tensor] = [] |
| curr_duration = 0.0 |
| curr_start = 0.0 |
| curr_end = 0.0 |
| boundaries: List[Tuple[float, float]] = [] |
|
|
| def _update_segments(curr_start: float, curr_end: float, curr_duration: float): |
| if curr_duration > strict_limit_duration: |
| max_segments = int(curr_duration / strict_limit_duration) + 1 |
| segment_duration = curr_duration / max_segments |
| curr_end = curr_start + segment_duration |
| for _ in range(max_segments - 1): |
| segments.append(audio[int(curr_start * sr) : int(curr_end * sr)]) |
| boundaries.append((curr_start, curr_end)) |
| curr_start = curr_end |
| curr_end += segment_duration |
| segments.append(audio[int(curr_start * sr) : int(curr_end * sr)]) |
| boundaries.append((curr_start, curr_end)) |
|
|
| |
| |
| for segment in sad_segments.get_timeline().support(): |
| start = max(0, segment.start) |
| end = min(audio.shape[0] / sr, segment.end) |
| if curr_duration > new_chunk_threshold and ( |
| curr_duration + (end - curr_end) > max_duration |
| or curr_duration > min_duration |
| ): |
| _update_segments(curr_start, curr_end, curr_duration) |
| curr_start = start |
| curr_end = end |
| curr_duration = curr_end - curr_start |
|
|
| if curr_duration > new_chunk_threshold: |
| _update_segments(curr_start, curr_end, curr_duration) |
|
|
| return segments, boundaries |
|
|
|
|
| |
|
|
|
|
|
|
| class StridingSubsampling(nn.Module): |
| """ |
| Strided Subsampling layer used to reduce the sequence length. |
| """ |
|
|
| def __init__( |
| self, |
| subsampling: str, |
| kernel_size: int, |
| subsampling_factor: int, |
| feat_in: int, |
| feat_out: int, |
| conv_channels: int, |
| ): |
| super().__init__() |
| self.subsampling_type = subsampling |
| assert self.subsampling_type in ["conv1d", "conv2d"] |
| self._sampling_num = int(math.log(subsampling_factor, 2)) |
| self._stride = 2 |
| self._kernel_size = kernel_size |
| self._padding = (self._kernel_size - 1) // 2 |
|
|
| layers: List[nn.Module] = [] |
| in_channels = 1 if self.subsampling_type == "conv2d" else feat_in |
| subs_conv_class = ( |
| torch.nn.Conv2d if self.subsampling_type == "conv2d" else torch.nn.Conv1d |
| ) |
| for _ in range(self._sampling_num): |
| layers.append( |
| subs_conv_class( |
| in_channels=in_channels, |
| out_channels=conv_channels, |
| kernel_size=self._kernel_size, |
| stride=self._stride, |
| padding=self._padding, |
| ) |
| ) |
| layers.append(nn.ReLU()) |
| in_channels = conv_channels |
|
|
| out_length = self.calc_output_length(torch.tensor(feat_in)) |
| if self.subsampling_type == "conv2d": |
| self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out) |
| self.conv = torch.nn.Sequential(*layers) |
|
|
| def calc_output_length(self, lengths: Tensor) -> Tensor: |
| """ |
| Calculates the output length after applying the subsampling. |
| """ |
| lengths = lengths.to(torch.float) |
| add_pad = 2 * self._padding - self._kernel_size |
| for _ in range(self._sampling_num): |
| lengths = torch.div(lengths + add_pad, self._stride) + 1.0 |
| lengths = torch.floor(lengths) |
| return lengths.to(dtype=torch.int) |
|
|
| def forward(self, x: Tensor, lengths: Tensor) -> Tuple[Tensor, Tensor]: |
| if self.subsampling_type == "conv2d": |
| x = self.conv(x.unsqueeze(1)) |
| b, _, t, _ = x.size() |
| x = self.out(x.transpose(1, 2).reshape(b, t, -1)) |
| else: |
| x = self.conv(x.transpose(1, 2)).transpose(1, 2) |
| return x, self.calc_output_length(lengths) |
|
|
|
|
| class MultiHeadAttention(nn.Module, ABC): |
| """ |
| Base class of Multi-Head Attention Mechanisms. |
| """ |
|
|
| def __init__( |
| self, n_head: int, n_feat: int, flash_attn=False, torch_sdpa_attn=False |
| ): |
| super().__init__() |
| assert n_feat % n_head == 0 |
| self.d_k = n_feat // n_head |
| self.h = n_head |
| self.linear_q = nn.Linear(n_feat, n_feat) |
| self.linear_k = nn.Linear(n_feat, n_feat) |
| self.linear_v = nn.Linear(n_feat, n_feat) |
| self.linear_out = nn.Linear(n_feat, n_feat) |
| self.flash_attn = flash_attn |
| self.torch_sdpa_attn = torch_sdpa_attn |
| if self.flash_attn and not IMPORT_FLASH: |
| raise RuntimeError( |
| f"flash_attn_func was imported with err {IMPORT_FLASH_ERR}. " |
| "Please install flash_attn or use --no_flash flag. " |
| "If you have already done this, " |
| "--force-reinstall flag might be useful" |
| ) |
|
|
| def forward_qkv( |
| self, query: Tensor, key: Tensor, value: Tensor |
| ) -> Tuple[Tensor, Tensor, Tensor]: |
| """ |
| Projects the inputs into queries, keys, and values for multi-head attention. |
| """ |
| b = query.size(0) |
| q = self.linear_q(query).view(b, -1, self.h, self.d_k) |
| k = self.linear_k(key).view(b, -1, self.h, self.d_k) |
| v = self.linear_v(value).view(b, -1, self.h, self.d_k) |
| if self.flash_attn: |
| return q, k, v |
| return q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) |
|
|
| def forward_attention( |
| self, value: Tensor, scores: Tensor, mask: Optional[Tensor] |
| ) -> Tensor: |
| """ |
| Computes the scaled dot-product attention given the projected values and scores. |
| """ |
| b = value.size(0) |
| if mask is not None: |
| mask = mask.unsqueeze(1) |
| scores = scores.masked_fill(mask, -10000.0) |
| attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) |
| else: |
| attn = torch.softmax(scores, dim=-1) |
| x = torch.matmul(attn, value) |
| x = x.transpose(1, 2).reshape(b, -1, self.h * self.d_k) |
| return self.linear_out(x) |
|
|
|
|
| class RelPositionMultiHeadAttention(MultiHeadAttention): |
| """ |
| Relative Position Multi-Head Attention module. |
| """ |
|
|
| def __init__(self, n_head: int, n_feat: int): |
| super().__init__(n_head, n_feat) |
| self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) |
| self.pos_bias_u = nn.Parameter(torch.FloatTensor(self.h, self.d_k)) |
| self.pos_bias_v = nn.Parameter(torch.FloatTensor(self.h, self.d_k)) |
|
|
| def rel_shift(self, x: Tensor) -> Tensor: |
| b, h, qlen, pos_len = x.size() |
| x = torch.nn.functional.pad(x, pad=(1, 0)) |
| x = x.view(b, h, -1, qlen) |
| return x[:, :, 1:].view(b, h, qlen, pos_len) |
|
|
| def forward( |
| self, |
| query: Tensor, |
| key: Tensor, |
| value: Tensor, |
| pos_emb: Tensor, |
| mask: Optional[Tensor] = None, |
| ) -> Tensor: |
| q, k, v = self.forward_qkv(query, key, value) |
| q = q.transpose(1, 2) |
| p = self.linear_pos(pos_emb) |
| p = p.view(pos_emb.shape[0], -1, self.h, self.d_k).transpose(1, 2) |
| q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) |
| q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) |
| matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) |
| matrix_bd = self.rel_shift(matrix_bd) |
| matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) |
| matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)] |
| scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) |
| return self.forward_attention(v, scores, mask) |
|
|
|
|
| class RotaryPositionMultiHeadAttention(MultiHeadAttention): |
| """ |
| Rotary Position Multi-Head Attention module. |
| """ |
|
|
| def forward( |
| self, |
| query: Tensor, |
| key: Tensor, |
| value: Tensor, |
| pos_emb: List[Tensor], |
| mask: Optional[Tensor] = None, |
| ) -> Tensor: |
| b, t, _ = value.size() |
| query = query.transpose(0, 1).view(t, b, self.h, self.d_k) |
| key = key.transpose(0, 1).view(t, b, self.h, self.d_k) |
| value = value.transpose(0, 1).view(t, b, self.h, self.d_k) |
|
|
| cos, sin = pos_emb |
| query, key = apply_rotary_pos_emb(query, key, cos, sin, offset=0) |
|
|
| q, k, v = self.forward_qkv( |
| query.view(t, b, self.h * self.d_k).transpose(0, 1), |
| key.view(t, b, self.h * self.d_k).transpose(0, 1), |
| value.view(t, b, self.h * self.d_k).transpose(0, 1), |
| ) |
|
|
| if not self.flash_attn and not self.torch_sdpa_attn: |
| scores = torch.matmul(q, k.transpose(-2, -1) / math.sqrt(self.d_k)) |
| return self.forward_attention(v, scores, mask) |
| elif self.flash_attn: |
| if mask is None: |
| scores = flash_attn_func(q, k, v) |
| else: |
| scores = apply_masked_flash_attn(q, k, v, mask, self.h, self.d_k) |
| scores = scores.view(b, -1, self.h * self.d_k) |
| return self.linear_out(scores) |
| else: |
| attn_mask = None if mask is None else ~mask.unsqueeze(1) |
| attn_output = F.scaled_dot_product_attention( |
| q, |
| k, |
| v, |
| attn_mask=attn_mask, |
| ) |
| attn_output = attn_output.transpose(1, 2).reshape(b, t, self.h * self.d_k) |
| return self.linear_out(attn_output) |
|
|
|
|
| class PositionalEncoding(nn.Module, ABC): |
| """ |
| Base class of Positional Encodings. |
| """ |
|
|
| def __init__(self, dim: int, base: int): |
| super().__init__() |
| self.dim = dim |
| self.base = base |
|
|
| @abstractmethod |
| def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]: |
| pass |
|
|
| def extend_pe(self, length: int, device: torch.device): |
| """ |
| Extends the positional encoding buffer to process longer sequences. |
| """ |
| pe = self.create_pe(length, device) |
| if pe is None: |
| return |
| if hasattr(self, "pe"): |
| self.pe = pe |
| else: |
| self.register_buffer("pe", pe, persistent=False) |
|
|
|
|
| class RelPositionalEmbedding(PositionalEncoding): |
| """ |
| Relative Positional Embedding module. |
| """ |
|
|
| def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]: |
| """ |
| Creates the relative positional encoding matrix. |
| """ |
| if hasattr(self, "pe") and self.pe.shape[1] >= 2 * length - 1: |
| return None |
| positions = torch.arange(length - 1, -length, -1, device=device).unsqueeze(1) |
| pos_length = positions.size(0) |
| pe = torch.zeros(pos_length, self.dim, device=positions.device) |
| div_term = torch.exp( |
| torch.arange(0, self.dim, 2, device=pe.device) |
| * -(math.log(10000.0) / self.dim) |
| ) |
| pe[:, 0::2] = torch.sin(positions * div_term) |
| pe[:, 1::2] = torch.cos(positions * div_term) |
| return pe.unsqueeze(0) |
|
|
| def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]: |
| input_len = x.size(1) |
| center_pos = self.pe.size(1) // 2 + 1 |
| start_pos = center_pos - input_len |
| end_pos = center_pos + input_len - 1 |
| return x, self.pe[:, start_pos:end_pos] |
|
|
|
|
| class RotaryPositionalEmbedding(PositionalEncoding): |
| """ |
| Rotary Positional Embedding module. |
| """ |
|
|
| def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]: |
| """ |
| Creates or extends the rotary positional encoding matrix. |
| """ |
| if hasattr(self, "pe") and self.pe.size(0) >= 2 * length: |
| return None |
| positions = torch.arange(0, length, dtype=torch.float32, device=device) |
| inv_freq = 1.0 / ( |
| self.base ** (torch.arange(0, self.dim, 2).float() / self.dim) |
| ) |
| t = torch.arange(length, device=positions.device).type_as(inv_freq) |
| freqs = torch.einsum("i,j->ij", t, inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1).to(positions.device) |
| return torch.cat([emb.cos()[:, None, None, :], emb.sin()[:, None, None, :]]) |
|
|
| def forward(self, x: torch.Tensor) -> Tuple[Tensor, List[Tensor]]: |
| cos_emb = self.pe[0 : x.shape[1]] |
| half_pe = self.pe.shape[0] // 2 |
| sin_emb = self.pe[half_pe : half_pe + x.shape[1]] |
| return x, [cos_emb, sin_emb] |
|
|
|
|
| class ConformerConvolution(nn.Module): |
| """ |
| Conformer Convolution module. |
| """ |
|
|
| def __init__( |
| self, |
| d_model: int, |
| kernel_size: int, |
| norm_type: str, |
| ): |
| super().__init__() |
| assert (kernel_size - 1) % 2 == 0 |
| assert norm_type in ["batch_norm", "layer_norm"] |
| self.norm_type = norm_type |
| self.pointwise_conv1 = nn.Conv1d(d_model, d_model * 2, kernel_size=1) |
| self.depthwise_conv = nn.Conv1d( |
| in_channels=d_model, |
| out_channels=d_model, |
| kernel_size=kernel_size, |
| padding=(kernel_size - 1) // 2, |
| groups=d_model, |
| bias=True, |
| ) |
| self.batch_norm = ( |
| nn.BatchNorm1d(d_model) |
| if norm_type == "batch_norm" |
| else nn.LayerNorm(d_model) |
| ) |
| self.activation = nn.SiLU() |
| self.pointwise_conv2 = nn.Conv1d(d_model, d_model, kernel_size=1) |
|
|
| def forward(self, x: Tensor, pad_mask: Optional[Tensor] = None) -> Tensor: |
| x = x.transpose(1, 2) |
| x = self.pointwise_conv1(x) |
| x = nn.functional.glu(x, dim=1) |
| if pad_mask is not None: |
| x = x.masked_fill(pad_mask.unsqueeze(1), 0.0) |
| x = self.depthwise_conv(x) |
| if self.norm_type == "batch_norm": |
| x = self.batch_norm(x) |
| else: |
| x = self.batch_norm(x.transpose(1, 2)).transpose(1, 2) |
| x = self.activation(x) |
| x = self.pointwise_conv2(x) |
| return x.transpose(1, 2) |
|
|
|
|
| class ConformerFeedForward(nn.Module): |
| """ |
| Conformer Feed Forward module. |
| """ |
|
|
| def __init__(self, d_model: int, d_ff: int, use_bias=True): |
| super().__init__() |
| self.linear1 = nn.Linear(d_model, d_ff, bias=use_bias) |
| self.activation = nn.SiLU() |
| self.linear2 = nn.Linear(d_ff, d_model, bias=use_bias) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| return self.linear2(self.activation(self.linear1(x))) |
|
|
|
|
| class ConformerLayer(nn.Module): |
| """ |
| Conformer Layer module. |
| This module combines several submodules including feed forward networks, |
| depthwise separable convolution, and multi-head self-attention |
| to form a single Conformer block. |
| """ |
|
|
| def __init__( |
| self, |
| d_model: int, |
| d_ff: int, |
| self_attention_model: str, |
| n_heads: int = 16, |
| conv_norm_type: str = "batch_norm", |
| conv_kernel_size: int = 31, |
| flash_attn: bool = False, |
| ): |
| super().__init__() |
| self.fc_factor = 0.5 |
| self.norm_feed_forward1 = nn.LayerNorm(d_model) |
| self.feed_forward1 = ConformerFeedForward(d_model=d_model, d_ff=d_ff) |
| self.norm_conv = nn.LayerNorm(d_model) |
| self.conv = ConformerConvolution( |
| d_model=d_model, |
| kernel_size=conv_kernel_size, |
| norm_type=conv_norm_type, |
| ) |
| self.norm_self_att = nn.LayerNorm(d_model) |
| if self_attention_model == "rotary": |
| self.self_attn: nn.Module = RotaryPositionMultiHeadAttention( |
| n_head=n_heads, |
| n_feat=d_model, |
| flash_attn=flash_attn, |
| torch_sdpa_attn=not flash_attn, |
| ) |
| else: |
| assert not flash_attn, "Not supported flash_attn for rel_pos" |
| self.self_attn = RelPositionMultiHeadAttention( |
| n_head=n_heads, |
| n_feat=d_model, |
| ) |
| self.norm_feed_forward2 = nn.LayerNorm(d_model) |
| self.feed_forward2 = ConformerFeedForward(d_model=d_model, d_ff=d_ff) |
| self.norm_out = nn.LayerNorm(d_model) |
|
|
| def forward( |
| self, |
| x: Tensor, |
| pos_emb: Union[Tensor, List[Tensor]], |
| att_mask: Optional[Tensor] = None, |
| pad_mask: Optional[Tensor] = None, |
| ) -> Tensor: |
| residual = x |
| x = self.norm_feed_forward1(x) |
| x = self.feed_forward1(x) |
| residual = residual + x * self.fc_factor |
|
|
| x = self.norm_self_att(residual) |
| x = self.self_attn(x, x, x, pos_emb, mask=att_mask) |
| residual = residual + x |
|
|
| x = self.norm_conv(residual) |
| x = self.conv(x, pad_mask=pad_mask) |
| residual = residual + x |
|
|
| x = self.norm_feed_forward2(residual) |
| x = self.feed_forward2(x) |
| residual = residual + x * self.fc_factor |
|
|
| x = self.norm_out(residual) |
| return x |
|
|
|
|
| class ConformerEncoder(nn.Module): |
| """ |
| Conformer Encoder module. |
| This module encapsulates the entire Conformer encoder architecture, |
| consisting of a StridingSubsampling layer, positional embeddings, and |
| a stack of Conformer Layers. |
| It serves as the main component responsible for processing speech features. |
| """ |
|
|
| def __init__( |
| self, |
| feat_in: int = 64, |
| n_layers: int = 16, |
| d_model: int = 768, |
| subsampling: str = "conv2d", |
| subs_kernel_size: int = 3, |
| subsampling_factor: int = 4, |
| ff_expansion_factor: int = 4, |
| self_attention_model: str = "rotary", |
| n_heads: int = 16, |
| pos_emb_max_len: int = 5000, |
| conv_norm_type: str = "batch_norm", |
| conv_kernel_size: int = 31, |
| flash_attn: bool = False, |
| ): |
| super().__init__() |
| self.feat_in = feat_in |
| assert self_attention_model in [ |
| "rotary", |
| "rel_pos", |
| ], f"Not supported attn = {self_attention_model}" |
|
|
| self.pre_encode = StridingSubsampling( |
| subsampling=subsampling, |
| kernel_size=subs_kernel_size, |
| subsampling_factor=subsampling_factor, |
| feat_in=feat_in, |
| feat_out=d_model, |
| conv_channels=d_model, |
| ) |
|
|
| self.pos_emb_max_len = pos_emb_max_len |
| if self_attention_model == "rotary": |
| self.pos_enc: PositionalEncoding = RotaryPositionalEmbedding( |
| d_model // n_heads, pos_emb_max_len |
| ) |
| else: |
| self.pos_enc = RelPositionalEmbedding(d_model, pos_emb_max_len) |
|
|
| self.layers = nn.ModuleList() |
| for _ in range(n_layers): |
| layer = ConformerLayer( |
| d_model=d_model, |
| d_ff=d_model * ff_expansion_factor, |
| self_attention_model=self_attention_model, |
| n_heads=n_heads, |
| conv_norm_type=conv_norm_type, |
| conv_kernel_size=conv_kernel_size, |
| flash_attn=flash_attn, |
| ) |
| self.layers.append(layer) |
|
|
| def input_example( |
| self, |
| batch_size: int = 1, |
| seqlen: int = 200, |
| ) -> Tuple[Tensor, Tensor]: |
| device = next(self.parameters()).device |
| features = torch.zeros(batch_size, self.feat_in, seqlen) |
| feature_lengths = torch.full([batch_size], features.shape[-1]) |
| return features.float().to(device), feature_lengths.to(device) |
|
|
| def input_names(self) -> List[str]: |
| return ["audio_signal", "length"] |
|
|
| def output_names(self) -> List[str]: |
| return ["encoded", "encoded_len"] |
|
|
| def dynamic_axes(self) -> Dict[str, Dict[int, str]]: |
| return { |
| "audio_signal": {0: "batch_size", 2: "seq_len"}, |
| "length": {0: "batch_size"}, |
| "encoded": {0: "batch_size", 1: "seq_len"}, |
| "encoded_len": {0: "batch_size"}, |
| } |
|
|
| def forward(self, audio_signal: Tensor, length: Tensor) -> Tuple[Tensor, Tensor]: |
| if not hasattr(self.pos_enc, "pe"): |
| self.pos_enc.extend_pe(self.pos_emb_max_len, audio_signal.device) |
|
|
| audio_signal, length = self.pre_encode( |
| x=audio_signal.transpose(1, 2), lengths=length |
| ) |
|
|
| max_len = audio_signal.size(1) |
| audio_signal, pos_emb = self.pos_enc(x=audio_signal) |
|
|
| pad_mask = torch.arange(0, max_len, device=audio_signal.device).expand( |
| length.size(0), -1 |
| ) < length.unsqueeze(-1) |
|
|
| att_mask = None |
| if audio_signal.shape[0] > 1: |
| att_mask = pad_mask.unsqueeze(1).repeat([1, max_len, 1]) |
| att_mask = torch.logical_and(att_mask, att_mask.transpose(1, 2)) |
| att_mask = ~att_mask |
|
|
| pad_mask = ~pad_mask |
|
|
| for layer in self.layers: |
| audio_signal = layer( |
| x=audio_signal, |
| pos_emb=pos_emb, |
| att_mask=att_mask, |
| pad_mask=pad_mask, |
| ) |
|
|
| return audio_signal.transpose(1, 2), length |
|
|
|
|
| |
|
|
|
|
| class CTCHead(nn.Module): |
| """ |
| CTC Head module for Connectionist Temporal Classification. |
| """ |
|
|
| def __init__(self, feat_in: int, num_classes: int): |
| super().__init__() |
| self.decoder_layers = torch.nn.Sequential( |
| torch.nn.Conv1d(feat_in, num_classes, kernel_size=1) |
| ) |
|
|
| def forward(self, encoder_output: Tensor) -> Tensor: |
| return torch.nn.functional.log_softmax( |
| self.decoder_layers(encoder_output).transpose(1, 2), dim=-1 |
| ) |
|
|
|
|
| class RNNTJoint(nn.Module): |
| """ |
| RNN-Transducer Joint Network Module. |
| This module combines the outputs of the encoder and the prediction network using |
| a linear transformation followed by ReLU activation and another linear projection. |
| """ |
|
|
| def __init__( |
| self, enc_hidden: int, pred_hidden: int, joint_hidden: int, num_classes: int |
| ): |
| super().__init__() |
| self.enc_hidden = enc_hidden |
| self.pred_hidden = pred_hidden |
| self.pred = nn.Linear(pred_hidden, joint_hidden) |
| self.enc = nn.Linear(enc_hidden, joint_hidden) |
| self.joint_net = nn.Sequential(nn.ReLU(), nn.Linear(joint_hidden, num_classes)) |
|
|
| def joint(self, encoder_out: Tensor, decoder_out: Tensor) -> Tensor: |
| """ |
| Combine the encoder and prediction network outputs into a joint representation. |
| """ |
| enc = self.enc(encoder_out).unsqueeze(2) |
| pred = self.pred(decoder_out).unsqueeze(1) |
| return self.joint_net(enc + pred).log_softmax(-1) |
|
|
| def input_example(self) -> Tuple[Tensor, Tensor]: |
| device = next(self.parameters()).device |
| enc = torch.zeros(1, self.enc_hidden, 1) |
| dec = torch.zeros(1, self.pred_hidden, 1) |
| return enc.float().to(device), dec.float().to(device) |
|
|
| def input_names(self) -> List[str]: |
| return ["enc", "dec"] |
|
|
| def output_names(self) -> List[str]: |
| return ["joint"] |
|
|
| def forward(self, enc: Tensor, dec: Tensor) -> Tensor: |
| return self.joint(enc.transpose(1, 2), dec.transpose(1, 2)) |
|
|
|
|
| class RNNTDecoder(nn.Module): |
| """ |
| RNN-Transducer Decoder Module. |
| This module handles the prediction network part of the RNN-Transducer architecture. |
| """ |
|
|
| def __init__(self, pred_hidden: int, pred_rnn_layers: int, num_classes: int): |
| super().__init__() |
| self.blank_id = num_classes - 1 |
| self.pred_hidden = pred_hidden |
| self.embed = nn.Embedding(num_classes, pred_hidden, padding_idx=self.blank_id) |
| self.lstm = nn.LSTM(pred_hidden, pred_hidden, pred_rnn_layers) |
|
|
| def predict( |
| self, |
| x: Optional[Tensor], |
| state: Optional[Tensor], |
| batch_size: int = 1, |
| ) -> Tuple[Tensor, Tensor]: |
| """ |
| Make predictions based on the current input and previous states. |
| If no input is provided, use zeros as the initial input. |
| """ |
| if x is not None: |
| emb: Tensor = self.embed(x) |
| else: |
| emb = torch.zeros( |
| (batch_size, 1, self.pred_hidden), device=next(self.parameters()).device |
| ) |
| g, hid = self.lstm(emb.transpose(0, 1), state) |
| return g.transpose(0, 1), hid |
|
|
| def input_example(self) -> Tuple[Tensor, Tensor, Tensor]: |
| device = next(self.parameters()).device |
| label = torch.tensor([[0]]).to(device) |
| hidden_h = torch.zeros(1, 1, self.pred_hidden).to(device) |
| hidden_c = torch.zeros(1, 1, self.pred_hidden).to(device) |
| return label, hidden_h, hidden_c |
|
|
| def input_names(self) -> List[str]: |
| return ["x", "h", "c"] |
|
|
| def output_names(self) -> List[str]: |
| return ["dec", "h", "c"] |
|
|
| def forward(self, x: Tensor, h: Tensor, c: Tensor) -> Tuple[Tensor, Tensor, Tensor]: |
| """ |
| ONNX-specific forward with x, state = (h, c) -> x, h, c. |
| """ |
| emb = self.embed(x) |
| g, (h, c) = self.lstm(emb.transpose(0, 1), (h, c)) |
| return g.transpose(0, 1), h, c |
|
|
|
|
| class RNNTHead(nn.Module): |
| """ |
| RNN-Transducer Head Module. |
| This module combines the decoder and joint network components of the RNN-Transducer architecture. |
| """ |
|
|
| def __init__(self, decoder: Dict[str, int], joint: Dict[str, int]): |
| super().__init__() |
| self.decoder = RNNTDecoder(**decoder) |
| self.joint = RNNTJoint(**joint) |
|
|
|
|
| |
|
|
|
|
| class Tokenizer: |
| """ |
| Tokenizer for converting between text and token IDs. |
| The tokenizer can operate either character-wise or using a pre-trained SentencePiece model. |
| """ |
|
|
| def __init__(self, vocab: List[str], model_path: Optional[str] = None): |
| self.charwise = model_path is None |
| if self.charwise: |
| self.vocab = vocab |
| else: |
| self.model = SentencePieceProcessor() |
| self.model.load(model_path) |
|
|
| def decode(self, tokens: List[int]) -> str: |
| """ |
| Convert a list of token IDs back to a string. |
| """ |
| if self.charwise: |
| return "".join(self.vocab[tok] for tok in tokens) |
| return self.model.decode(tokens) |
|
|
| def __len__(self): |
| """ |
| Get the total number of tokens in the vocabulary. |
| """ |
| return len(self.vocab) if self.charwise else len(self.model) |
|
|
|
|
| class CTCGreedyDecoding: |
| """ |
| Class for performing greedy decoding of CTC outputs. |
| """ |
|
|
| def __init__(self, vocabulary: List[str], model_path: Optional[str] = None): |
| self.tokenizer = Tokenizer(vocabulary, model_path) |
| self.blank_id = len(self.tokenizer) |
|
|
| @torch.inference_mode() |
| def decode(self, head: CTCHead, encoded: Tensor, lengths: Tensor) -> List[str]: |
| """ |
| Decode the output of a CTC model into a list of hypotheses. |
| """ |
| log_probs = head(encoder_output=encoded) |
| assert ( |
| len(log_probs.shape) == 3 |
| ), f"Expected log_probs shape {log_probs.shape} == [B, T, C]" |
| b, _, c = log_probs.shape |
| assert ( |
| c == len(self.tokenizer) + 1 |
| ), f"Num classes {c} != len(vocab) + 1 {len(self.tokenizer) + 1}" |
| labels = log_probs.argmax(dim=-1, keepdim=False) |
|
|
| skip_mask = labels != self.blank_id |
| skip_mask[:, 1:] = torch.logical_and( |
| skip_mask[:, 1:], labels[:, 1:] != labels[:, :-1] |
| ) |
| for i, length in enumerate(lengths): |
| skip_mask[i, length:] = 0 |
|
|
| pred_texts: List[str] = [] |
| for i in range(b): |
| pred_texts.append( |
| "".join(self.tokenizer.decode(labels[i][skip_mask[i]].cpu().tolist())) |
| ) |
| return pred_texts |
|
|
|
|
| class RNNTGreedyDecoding: |
| def __init__( |
| self, |
| vocabulary: List[str], |
| model_path: Optional[str] = None, |
| max_symbols_per_step: int = 10, |
| ): |
| """ |
| Class for performing greedy decoding of RNN-T outputs. |
| """ |
| self.tokenizer = Tokenizer(vocabulary, model_path) |
| self.blank_id = len(self.tokenizer) |
| self.max_symbols = max_symbols_per_step |
|
|
| def _greedy_decode(self, head: RNNTHead, x: Tensor, seqlen: Tensor) -> str: |
| """ |
| Internal helper function for performing greedy decoding on a single sequence. |
| """ |
| hyp: List[int] = [] |
| dec_state: Optional[Tensor] = None |
| last_label: Optional[Tensor] = None |
| for t in range(seqlen): |
| f = x[t, :, :].unsqueeze(1) |
| not_blank = True |
| new_symbols = 0 |
| while not_blank and new_symbols < self.max_symbols: |
| g, hidden = head.decoder.predict(last_label, dec_state) |
| k = head.joint.joint(f, g)[0, 0, 0, :].argmax(0).item() |
| if k == self.blank_id: |
| not_blank = False |
| else: |
| hyp.append(int(k)) |
| dec_state = hidden |
| last_label = torch.tensor([[hyp[-1]]]).to(x.device) |
| new_symbols += 1 |
|
|
| return self.tokenizer.decode(hyp) |
|
|
| @torch.inference_mode() |
| def decode(self, head: RNNTHead, encoded: Tensor, enc_len: Tensor) -> List[str]: |
| """ |
| Decode the output of an RNN-T model into a list of hypotheses. |
| """ |
| b = encoded.shape[0] |
| pred_texts = [] |
| encoded = encoded.transpose(1, 2) |
| for i in range(b): |
| inseq = encoded[i, :, :].unsqueeze(1) |
| pred_texts.append(self._greedy_decode(head, inseq, enc_len[i])) |
| return pred_texts |
|
|
|
|
| |
|
|
|
|
| class GigaAM(nn.Module): |
| """ |
| Giga Acoustic Model: Self-Supervised Model for Speech Tasks |
| """ |
|
|
| def __init__(self, cfg: omegaconf.DictConfig): |
| super().__init__() |
| self.cfg = cfg |
| self.preprocessor = hydra.utils.instantiate(self.cfg.preprocessor) |
| self.encoder = hydra.utils.instantiate(self.cfg.encoder) |
|
|
| def forward( |
| self, features: Tensor, feature_lengths: Tensor |
| ) -> Tuple[Tensor, Tensor]: |
| """ |
| Perform forward pass through the preprocessor and encoder. |
| """ |
| features, feature_lengths = self.preprocessor(features, feature_lengths) |
| if self._device.type == "cpu": |
| return self.encoder(features, feature_lengths) |
| with torch.autocast(device_type=self._device.type, dtype=torch.float16): |
| return self.encoder(features, feature_lengths) |
|
|
| @property |
| def _device(self) -> torch.device: |
| return next(self.parameters()).device |
|
|
| @property |
| def _dtype(self) -> torch.dtype: |
| return next(self.parameters()).dtype |
|
|
| def prepare_wav(self, wav_file: str) -> Tuple[Tensor, Tensor]: |
| """ |
| Prepare an audio file for processing by loading it onto |
| the correct device and converting its format. |
| """ |
| wav = load_audio(wav_file) |
| wav = wav.to(self._device).to(self._dtype).unsqueeze(0) |
| length = torch.full([1], wav.shape[-1], device=self._device) |
| return wav, length |
|
|
| def embed_audio(self, wav_file: str) -> Tuple[Tensor, Tensor]: |
| """ |
| Extract audio representations using the GigaAM model. |
| """ |
| wav, length = self.prepare_wav(wav_file) |
| encoded, encoded_len = self.forward(wav, length) |
| return encoded, encoded_len |
|
|
| def to_onnx(self, dir_path: str = ".") -> None: |
| """ |
| Export onnx model encoder to the specified dir. |
| """ |
| self._to_onnx(dir_path) |
| omegaconf.OmegaConf.save(self.cfg, f"{dir_path}/{self.cfg.model_name}.yaml") |
|
|
| def _to_onnx(self, dir_path: str = ".") -> None: |
| """ |
| Export onnx model encoder to the specified dir. |
| """ |
| onnx_converter( |
| model_name=f"{self.cfg.model_name}_encoder", |
| out_dir=dir_path, |
| module=self.encoder, |
| dynamic_axes=self.encoder.dynamic_axes(), |
| ) |
|
|
|
|
| class GigaAMASR(GigaAM): |
| """ |
| Giga Acoustic Model for Speech Recognition |
| """ |
|
|
| def __init__(self, cfg: omegaconf.DictConfig): |
| super().__init__(cfg) |
| self.head = hydra.utils.instantiate(self.cfg.head) |
| self.decoding = hydra.utils.instantiate(self.cfg.decoding) |
|
|
| @torch.inference_mode() |
| def transcribe(self, wav_file: str) -> str: |
| """ |
| Transcribes a short audio file into text. |
| """ |
| wav, length = self.prepare_wav(wav_file) |
| if length.item() > LONGFORM_THRESHOLD: |
| raise ValueError("Too long wav file, use 'transcribe_longform' method.") |
|
|
| encoded, encoded_len = self.forward(wav, length) |
| return self.decoding.decode(self.head, encoded, encoded_len)[0] |
|
|
| def forward_for_export(self, features: Tensor, feature_lengths: Tensor) -> Tensor: |
| """ |
| Encoder-decoder forward to save model entirely in onnx format. |
| """ |
| return self.head(self.encoder(features, feature_lengths)[0]) |
|
|
| def _to_onnx(self, dir_path: str = ".") -> None: |
| """ |
| Export onnx ASR model. |
| `ctc`: exported entirely in encoder-decoder format. |
| `rnnt`: exported in encoder/decoder/joint parts separately. |
| """ |
| if "ctc" in self.cfg.model_name: |
| saved_forward = self.forward |
| self.forward = self.forward_for_export |
| onnx_converter( |
| model_name=self.cfg.model_name, |
| out_dir=dir_path, |
| module=self, |
| inputs=self.encoder.input_example(), |
| input_names=["features", "feature_lengths"], |
| output_names=["log_probs"], |
| dynamic_axes={ |
| "features": {0: "batch_size", 2: "seq_len"}, |
| "feature_lengths": {0: "batch_size"}, |
| "log_probs": {0: "batch_size", 1: "seq_len"}, |
| }, |
| ) |
| self.forward = saved_forward |
| else: |
| super()._to_onnx(dir_path) |
| onnx_converter( |
| model_name=f"{self.cfg.model_name}_decoder", |
| out_dir=dir_path, |
| module=self.head.decoder, |
| ) |
| onnx_converter( |
| model_name=f"{self.cfg.model_name}_joint", |
| out_dir=dir_path, |
| module=self.head.joint, |
| ) |
|
|
| @torch.inference_mode() |
| def transcribe_longform( |
| self, wav_file: str, **kwargs |
| ) -> List[Dict[str, Union[str, Tuple[float, float]]]]: |
| """ |
| Transcribes a long audio file by splitting it into segments and |
| then transcribing each segment. |
| """ |
| transcribed_segments = [] |
| segments, boundaries = segment_audio_file( |
| wav_file, SAMPLE_RATE, device=self._device, **kwargs |
| ) |
| for segment, segment_boundaries in zip(segments, boundaries): |
| wav = segment.to(self._device).unsqueeze(0).to(self._dtype) |
| length = torch.full([1], wav.shape[-1], device=self._device) |
| encoded, encoded_len = self.forward(wav, length) |
| result = self.decoding.decode(self.head, encoded, encoded_len)[0] |
| transcribed_segments.append( |
| {"transcription": result, "boundaries": segment_boundaries} |
| ) |
| return transcribed_segments |
|
|
|
|
| class GigaAMEmo(GigaAM): |
| """ |
| Giga Acoustic Model for Emotion Recognition |
| """ |
|
|
| def __init__(self, cfg: omegaconf.DictConfig): |
| super().__init__(cfg) |
| self.head = hydra.utils.instantiate(self.cfg.head) |
| self.id2name = cfg.id2name |
|
|
| def get_probs(self, wav_file: str) -> Dict[str, float]: |
| """ |
| Calculate probabilities for each emotion class based on the provided audio file. |
| """ |
| wav, length = self.prepare_wav(wav_file) |
| encoded, _ = self.forward(wav, length) |
| encoded_pooled = nn.functional.avg_pool1d( |
| encoded, kernel_size=encoded.shape[-1] |
| ).squeeze(-1) |
|
|
| logits = self.head(encoded_pooled)[0] |
| probs = nn.functional.softmax(logits, dim=-1).detach().tolist() |
|
|
| return {self.id2name[i]: probs[i] for i in range(len(self.id2name))} |
|
|
| def forward_for_export(self, features: Tensor, feature_lengths: Tensor) -> Tensor: |
| """ |
| Encoder-decoder forward to save model entirely in onnx format. |
| """ |
| encoded, _ = self.encoder(features, feature_lengths) |
| enc_pooled = encoded.mean(dim=-1) |
| return nn.functional.softmax(self.head(enc_pooled), dim=-1) |
|
|
| def _to_onnx(self, dir_path: str = ".") -> None: |
| """ |
| Export onnx Emo model. |
| """ |
| saved_forward = self.forward |
| self.forward = self.forward_for_export |
| onnx_converter( |
| model_name=self.cfg.model_name, |
| out_dir=dir_path, |
| module=self, |
| inputs=self.encoder.input_example(), |
| input_names=["features", "feature_lengths"], |
| output_names=["probs"], |
| dynamic_axes={ |
| "features": {0: "batch_size", 2: "seq_len"}, |
| "feature_lengths": {0: "batch_size"}, |
| "probs": {0: "batch_size", 1: "seq_len"}, |
| }, |
| ) |
| self.forward = saved_forward |
|
|
|
|
| |
|
|
|
|
| class GigaAMConfig(PretrainedConfig): |
| model_type = "gigaam" |
|
|
| def __init__(self, cfg: omegaconf.DictConfig = None, **kwargs): |
| super().__init__(**kwargs) |
| self.cfg = cfg |
|
|
|
|
| class GigaAMModel(PreTrainedModel): |
| config_class = GigaAMConfig |
| base_model_prefix = "gigaam" |
|
|
| def __init__(self, config: GigaAMConfig): |
| super().__init__(config) |
| self.config = config |
| if "decoding" in self.config.cfg["model"]["cfg"] and "model_path" in self.config.cfg["model"]["cfg"]["decoding"]: |
| resolved_tokenizer_path = cached_file( |
| config.name_or_path, |
| "tokenizer.model", |
| revision=getattr(config, "_commit_hash", None), |
| cache_dir=getattr(config, "cache_dir", None), |
| use_auth_token=getattr(config, "use_auth_token", None), |
| ) |
| self.config.cfg["model"]["cfg"]["decoding"]["model_path"] = resolved_tokenizer_path |
|
|
| self.model = instantiate(config.cfg["model"], _recursive_=False) |
|
|
| def forward(self, features: torch.Tensor, feature_lengths: torch.Tensor): |
| return self.model(features, feature_lengths) |
|
|
| def embed_audio(self, wav_file: str) -> torch.Tensor: |
| return self.model.embed_audio(wav_file) |
|
|
| def transcribe(self, wav_file: str) -> str: |
| return self.model.transcribe(wav_file) |
|
|
| def transcribe_longform(self, wav_file: str) -> List[Dict[str, Union[str, Tuple[float, float]]]]: |
| return self.model.transcribe_longform(wav_file) |
|
|
| def get_probs(self, wav_file: str) -> Dict[str, float]: |
| return self.model.get_probs(wav_file) |
|
|
| @torch.no_grad() |
| def to_onnx(self, dir_path: str = ".") -> None: |
| self.model.to_onnx(dir_path) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
| return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) |
|
|