| import torch |
| import torch.nn as nn |
| from transformers import HubertModel, WavLMModel, PreTrainedModel, PreTrainedTokenizerFast |
| from .configuration_dual_ctc import DualCTCConfig |
|
|
| class DualCTCModel(PreTrainedModel): |
| """Wav2Vec2 encoder with dual CTC heads for kana and phoneme output. |
| |
| Implements the Apple Diverse Modeling Units approach (Interspeech 2024) |
| adapted for Japanese: phoneme CTC at an intermediate layer, kana CTC |
| at the final layer. |
| """ |
| |
| config_class = DualCTCConfig |
| base_model_prefix = "encoder" |
|
|
| def __init__(self, config: DualCTCConfig): |
| super().__init__(config) |
| self.kana_ctc_layer = config.kana_ctc_layer |
| self.phoneme_ctc_layer = config.phoneme_ctc_layer |
|
|
| |
| self.encoder = WavLMModel(config) |
| |
|
|
| |
| self.kana_head = nn.Sequential( |
| nn.LayerNorm(config.hidden_size), |
| nn.Linear(config.hidden_size, config.hidden_size), |
| nn.GELU(), |
| nn.Dropout(config.hidden_dropout), |
| nn.Linear(config.hidden_size, config.kana_vocab_size) |
| ) |
|
|
| |
| self.phoneme_head = nn.Sequential( |
| nn.LayerNorm(config.hidden_size), |
| nn.Linear(config.hidden_size, config.hidden_size), |
| nn.GELU(), |
| nn.Dropout(config.hidden_dropout), |
| nn.Linear(config.hidden_size, config.kana_vocab_size) |
| ) |
| |
| |
| self.post_init() |
|
|
| def forward( |
| self, input_values: torch.Tensor, attention_mask: torch.Tensor | None = None, |
| ) -> dict[str, torch.Tensor]: |
| """Forward pass returning both kana and phoneme logits. |
| |
| Args: |
| input_values: (B, T) raw audio waveform. |
| attention_mask: Optional attention mask. |
| |
| Returns: |
| Dict with keys: |
| - kana_logits: (B, T', kana_vocab_size) |
| - phoneme_logits: (B, T', phoneme_vocab_size) |
| """ |
| outputs = self.encoder( |
| input_values, |
| attention_mask=attention_mask, |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
|
|
| |
| hidden_for_kana = outputs.hidden_states[self.kana_ctc_layer] |
| kana_logits = self.kana_head(hidden_for_kana) |
|
|
| |
| hidden_for_phoneme = outputs.hidden_states[self.phoneme_ctc_layer] |
| phoneme_logits = self.phoneme_head(hidden_for_phoneme) |
|
|
| return { |
| "kana_logits": kana_logits, |
| "phoneme_logits": phoneme_logits, |
| } |
| |
| @staticmethod |
| def ctc_decode(indices: list[int], tokenizer: PreTrainedTokenizerFast, is_kana: bool = False) -> str: |
| """Performs CTC greedy collapse and decodes to text.""" |
| collapsed_ids = [] |
| prev = None |
| blank_idx = tokenizer.pad_token_id |
|
|
| |
| for idx in indices: |
| if idx == blank_idx: |
| prev = idx |
| continue |
| if idx == prev: |
| continue |
| collapsed_ids.append(idx) |
| prev = idx |
| |
| |
| text = tokenizer.decode(collapsed_ids) |
|
|
| |
| if is_kana: |
| |
| return text.replace(" ", "") |
| else: |
| |
| return text |
|
|
| def get_feat_extract_output_lengths(self, input_lengths: torch.Tensor) -> torch.Tensor: |
| """Compute output sequence lengths after wav2vec2 CNN downsampling.""" |
| return self.encoder._get_feat_extract_output_lengths(input_lengths) |