File size: 3,923 Bytes
cb9ad2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import torch.nn as nn
from transformers import HubertModel, WavLMModel, PreTrainedModel, PreTrainedTokenizerFast
from .configuration_dual_ctc import DualCTCConfig

class DualCTCModel(PreTrainedModel):
    """Wav2Vec2 encoder with dual CTC heads for kana and phoneme output.

    Implements the Apple Diverse Modeling Units approach (Interspeech 2024)
    adapted for Japanese: phoneme CTC at an intermediate layer, kana CTC
    at the final layer.
    """
    
    config_class = DualCTCConfig
    base_model_prefix = "encoder"

    def __init__(self, config: DualCTCConfig):
        super().__init__(config)
        self.kana_ctc_layer = config.kana_ctc_layer
        self.phoneme_ctc_layer = config.phoneme_ctc_layer

        # Use HubertModel (Base for WavLM/Hubert in Transformers)
        self.encoder = WavLMModel(config)
        # self.encoder = HubertModel(config)

        # Final CTC head for kana output
        self.kana_head = nn.Sequential(
            nn.LayerNorm(config.hidden_size),
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.GELU(),
            nn.Dropout(config.hidden_dropout),
            nn.Linear(config.hidden_size, config.kana_vocab_size)
        )

        # Intermediate CTC head for phoneme output
        self.phoneme_head = nn.Sequential(
            nn.LayerNorm(config.hidden_size),
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.GELU(),
            nn.Dropout(config.hidden_dropout),
            nn.Linear(config.hidden_size, config.kana_vocab_size)
        )
        
        # Initialize weights defined in PreTrainedModel
        self.post_init()

    def forward(
        self, input_values: torch.Tensor, attention_mask: torch.Tensor | None = None,
    ) -> dict[str, torch.Tensor]:
        """Forward pass returning both kana and phoneme logits.

        Args:
            input_values: (B, T) raw audio waveform.
            attention_mask: Optional attention mask.

        Returns:
            Dict with keys:
                - kana_logits: (B, T', kana_vocab_size)
                - phoneme_logits: (B, T', phoneme_vocab_size)
        """
        outputs = self.encoder(
            input_values,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True,
        )

        # kana CTC
        hidden_for_kana = outputs.hidden_states[self.kana_ctc_layer]
        kana_logits = self.kana_head(hidden_for_kana)

        # Intermediate layer → phoneme CTC
        hidden_for_phoneme = outputs.hidden_states[self.phoneme_ctc_layer]
        phoneme_logits = self.phoneme_head(hidden_for_phoneme)

        return {
            "kana_logits": kana_logits,
            "phoneme_logits": phoneme_logits,
        }
    
    @staticmethod
    def ctc_decode(indices: list[int], tokenizer: PreTrainedTokenizerFast, is_kana: bool = False) -> str:
        """Performs CTC greedy collapse and decodes to text."""
        collapsed_ids = []
        prev = None
        blank_idx = tokenizer.pad_token_id  # <blank> is set as the pad token (0)

        # CTC Collapse
        for idx in indices:
            if idx == blank_idx:
                prev = idx
                continue
            if idx == prev:
                continue
            collapsed_ids.append(idx)
            prev = idx
        
        # Decode using HF Tokenizer
        text = tokenizer.decode(collapsed_ids)

        # Format output
        if is_kana:
            # Kana should be joined without spaces
            return text.replace(" ", "")
        else:
            # Phonemes should remain space-separated
            return text

    def get_feat_extract_output_lengths(self, input_lengths: torch.Tensor) -> torch.Tensor:
        """Compute output sequence lengths after wav2vec2 CNN downsampling."""
        return self.encoder._get_feat_extract_output_lengths(input_lengths)