import torch import torch.nn as nn from transformers import HubertModel, WavLMModel, PreTrainedModel, PreTrainedTokenizerFast from .configuration_dual_ctc import DualCTCConfig class DualCTCModel(PreTrainedModel): """Wav2Vec2 encoder with dual CTC heads for kana and phoneme output. Implements the Apple Diverse Modeling Units approach (Interspeech 2024) adapted for Japanese: phoneme CTC at an intermediate layer, kana CTC at the final layer. """ config_class = DualCTCConfig base_model_prefix = "encoder" def __init__(self, config: DualCTCConfig): super().__init__(config) self.kana_ctc_layer = config.kana_ctc_layer self.phoneme_ctc_layer = config.phoneme_ctc_layer # Use HubertModel (Base for WavLM/Hubert in Transformers) self.encoder = WavLMModel(config) # self.encoder = HubertModel(config) # Final CTC head for kana output self.kana_head = nn.Sequential( nn.LayerNorm(config.hidden_size), nn.Linear(config.hidden_size, config.hidden_size), nn.GELU(), nn.Dropout(config.hidden_dropout), nn.Linear(config.hidden_size, config.kana_vocab_size) ) # Intermediate CTC head for phoneme output self.phoneme_head = nn.Sequential( nn.LayerNorm(config.hidden_size), nn.Linear(config.hidden_size, config.hidden_size), nn.GELU(), nn.Dropout(config.hidden_dropout), nn.Linear(config.hidden_size, config.kana_vocab_size) ) # Initialize weights defined in PreTrainedModel self.post_init() def forward( self, input_values: torch.Tensor, attention_mask: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: """Forward pass returning both kana and phoneme logits. Args: input_values: (B, T) raw audio waveform. attention_mask: Optional attention mask. Returns: Dict with keys: - kana_logits: (B, T', kana_vocab_size) - phoneme_logits: (B, T', phoneme_vocab_size) """ outputs = self.encoder( input_values, attention_mask=attention_mask, output_hidden_states=True, return_dict=True, ) # kana CTC hidden_for_kana = outputs.hidden_states[self.kana_ctc_layer] kana_logits = self.kana_head(hidden_for_kana) # Intermediate layer → phoneme CTC hidden_for_phoneme = outputs.hidden_states[self.phoneme_ctc_layer] phoneme_logits = self.phoneme_head(hidden_for_phoneme) return { "kana_logits": kana_logits, "phoneme_logits": phoneme_logits, } @staticmethod def ctc_decode(indices: list[int], tokenizer: PreTrainedTokenizerFast, is_kana: bool = False) -> str: """Performs CTC greedy collapse and decodes to text.""" collapsed_ids = [] prev = None blank_idx = tokenizer.pad_token_id # is set as the pad token (0) # CTC Collapse for idx in indices: if idx == blank_idx: prev = idx continue if idx == prev: continue collapsed_ids.append(idx) prev = idx # Decode using HF Tokenizer text = tokenizer.decode(collapsed_ids) # Format output if is_kana: # Kana should be joined without spaces return text.replace(" ", "") else: # Phonemes should remain space-separated return text def get_feat_extract_output_lengths(self, input_lengths: torch.Tensor) -> torch.Tensor: """Compute output sequence lengths after wav2vec2 CNN downsampling.""" return self.encoder._get_feat_extract_output_lengths(input_lengths)