wavlm-base-plus-hiragana-ctc-v2 / modeling_dual_ctc.py
TylorShine's picture
Upload folder using huggingface_hub
cb9ad2c verified
import torch
import torch.nn as nn
from transformers import HubertModel, WavLMModel, PreTrainedModel, PreTrainedTokenizerFast
from .configuration_dual_ctc import DualCTCConfig
class DualCTCModel(PreTrainedModel):
"""Wav2Vec2 encoder with dual CTC heads for kana and phoneme output.
Implements the Apple Diverse Modeling Units approach (Interspeech 2024)
adapted for Japanese: phoneme CTC at an intermediate layer, kana CTC
at the final layer.
"""
config_class = DualCTCConfig
base_model_prefix = "encoder"
def __init__(self, config: DualCTCConfig):
super().__init__(config)
self.kana_ctc_layer = config.kana_ctc_layer
self.phoneme_ctc_layer = config.phoneme_ctc_layer
# Use HubertModel (Base for WavLM/Hubert in Transformers)
self.encoder = WavLMModel(config)
# self.encoder = HubertModel(config)
# Final CTC head for kana output
self.kana_head = nn.Sequential(
nn.LayerNorm(config.hidden_size),
nn.Linear(config.hidden_size, config.hidden_size),
nn.GELU(),
nn.Dropout(config.hidden_dropout),
nn.Linear(config.hidden_size, config.kana_vocab_size)
)
# Intermediate CTC head for phoneme output
self.phoneme_head = nn.Sequential(
nn.LayerNorm(config.hidden_size),
nn.Linear(config.hidden_size, config.hidden_size),
nn.GELU(),
nn.Dropout(config.hidden_dropout),
nn.Linear(config.hidden_size, config.kana_vocab_size)
)
# Initialize weights defined in PreTrainedModel
self.post_init()
def forward(
self, input_values: torch.Tensor, attention_mask: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""Forward pass returning both kana and phoneme logits.
Args:
input_values: (B, T) raw audio waveform.
attention_mask: Optional attention mask.
Returns:
Dict with keys:
- kana_logits: (B, T', kana_vocab_size)
- phoneme_logits: (B, T', phoneme_vocab_size)
"""
outputs = self.encoder(
input_values,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True,
)
# kana CTC
hidden_for_kana = outputs.hidden_states[self.kana_ctc_layer]
kana_logits = self.kana_head(hidden_for_kana)
# Intermediate layer → phoneme CTC
hidden_for_phoneme = outputs.hidden_states[self.phoneme_ctc_layer]
phoneme_logits = self.phoneme_head(hidden_for_phoneme)
return {
"kana_logits": kana_logits,
"phoneme_logits": phoneme_logits,
}
@staticmethod
def ctc_decode(indices: list[int], tokenizer: PreTrainedTokenizerFast, is_kana: bool = False) -> str:
"""Performs CTC greedy collapse and decodes to text."""
collapsed_ids = []
prev = None
blank_idx = tokenizer.pad_token_id # <blank> is set as the pad token (0)
# CTC Collapse
for idx in indices:
if idx == blank_idx:
prev = idx
continue
if idx == prev:
continue
collapsed_ids.append(idx)
prev = idx
# Decode using HF Tokenizer
text = tokenizer.decode(collapsed_ids)
# Format output
if is_kana:
# Kana should be joined without spaces
return text.replace(" ", "")
else:
# Phonemes should remain space-separated
return text
def get_feat_extract_output_lengths(self, input_lengths: torch.Tensor) -> torch.Tensor:
"""Compute output sequence lengths after wav2vec2 CNN downsampling."""
return self.encoder._get_feat_extract_output_lengths(input_lengths)