File size: 10,010 Bytes
ae4543e f3c2e37 f073b16 f3c2e37 ae4543e 9754376 ae4543e 9754376 ae4543e f073b16 ae4543e d87b0f5 ae4543e 9754376 ae4543e a705b55 ae4543e 9754376 c8cb7f7 f073b16 9754376 ae4543e da9e96d f3c2e37 da9e96d f3c2e37 ae4543e 9754376 f073b16 ae4543e 83bd703 ae4543e f3c2e37 f073b16 ae4543e f073b16 83bd703 f073b16 83bd703 f073b16 e63e063 f073b16 83bd703 f073b16 83bd703 f073b16 83bd703 f073b16 83bd703 f073b16 83bd703 f073b16 83bd703 da9e96d 83bd703 d87b0f5 83bd703 9754376 83bd703 9754376 83bd703 9754376 e63e063 9754376 e63e063 9754376 c8cb7f7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 | from huggingface_hub import hf_hub_download
from torch import nn
from transformers import Wav2Vec2ConformerModel
from safetensors.torch import load_file
from torch_state_bridge import state_bridge
import torch
import torch.nn.functional as F
import torchaudio
import librosa
class Op(nn.Module):
def __init__(self, func,allow_self=False):
super().__init__()
self.func = func
self.allow_self = allow_self
def forward(self, x):
if self.allow_self:
return self.func(self,x)
return self.func(x)
class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
def init_weights(self):
del self.encoder.pos_conv_embed
config = self.config
self.cache_length = None
self.enc = nn.Linear(config.hidden_size, config.joint_hidden)
self.pred = nn.Linear(config.pred_hidden, config.joint_hidden)
self.joint = nn.Linear(config.joint_hidden, config.vocab_size // len(config.languages) + 1)
self.embed = nn.Embedding(config.vocab_size+1, config.pred_hidden, padding_idx=config.vocab_size)
self.lstm = nn.LSTM(config.pred_hidden, config.pred_hidden, config.lstm_layer, batch_first=True)
self.act = nn.ReLU()
self.spec = torchaudio.transforms.Spectrogram(n_fft=512, hop_length=160, win_length=400, center=False)
self.mask_layer = Op(lambda self_obj,x : x.masked_fill(self_obj.cache_pad_mask.unsqueeze(1), 0),True)
self.register_buffer("mel_fb",torch.tensor(librosa.filters.mel(sr=config.sampling_rate,n_fft=self.spec.n_fft,n_mels=80)))
for idx,l in enumerate(self.feature_extractor.conv_layers):
if not(config.multilingual) or idx == 0:
l.conv = nn.Conv2d(l.conv.in_channels,l.conv.out_channels,l.conv.kernel_size[0],l.conv.stride,1)
l.layer_norm = nn.Identity()
else:
l.conv = nn.Sequential(nn.Conv2d(l.conv.in_channels,l.conv.out_channels,l.conv.kernel_size[0],l.conv.stride,1,groups=l.conv.out_channels),nn.Conv2d(l.conv.in_channels,l.conv.out_channels, 1))
self.feature_extractor.conv_layers.append(Op(lambda x : x.transpose(1, 2)))
self.feature_projection.projection = nn.Linear(config.conv_dim[-1] * self.calc_length(80,repeat_num=config.num_feat_extract_layers),config.hidden_size)
self.feature_projection.layer_norm = Op(lambda x:x.permute(0, 2, 1, 3).flatten(2))
for l in self.encoder.layers:
l.conv_module.glu = nn.Sequential(l.conv_module.glu,self.mask_layer)
l.conv_module.pointwise_conv1.bias = nn.Parameter(torch.empty(l.conv_module.pointwise_conv1.out_channels))
l.conv_module.pointwise_conv2.bias = nn.Parameter(torch.empty(l.conv_module.pointwise_conv2.out_channels))
l.conv_module.depthwise_conv.bias = nn.Parameter(torch.empty(l.conv_module.depthwise_conv.out_channels))
self.encoder.layer_norm = nn.Identity()
if config.multilingual:
self.lang_joint_net = nn.ModuleDict({l: nn.Linear(config.joint_hidden, config.vocab_size // len(config.languages) + 1) for l in config.languages.values()})
self.eps = 2**-24
self.denorm = (2 ** config.num_feat_extract_layers) * self.spec.hop_length / config.sampling_rate
self.scaler = config.hidden_size ** (1/2)
return super().init_weights()
def _mask_hidden_states(self, hidden_states, mask_time_indices = None, attention_mask = None):
hidden_states = hidden_states * self.scaler
self.mask_layer.cache_pad_mask = (torch.arange(hidden_states.size(1), device=hidden_states.device).unsqueeze(0) >= self.cache_length.unsqueeze(1))
return super()._mask_hidden_states(hidden_states, mask_time_indices, attention_mask)
def calc_length(self, lengths, padding=1, kernel_size=3, stride=2, repeat_num=1):
for _ in range(repeat_num):
lengths = (lengths + 2 * padding - kernel_size) // stride + 1
return lengths
def preprocessing(self, x):
x, l = x
l = (l // self.spec.hop_length + 1).long()
x = torch.cat((x[:, :1], x[:, 1:] - self.config.preemph * x[:, :-1]), 1)
x = (self.mel_fb @ self.spec(x) + self.eps).log()
T = x.size(-1)
m = torch.arange(T, device=x.device)[None] >= l[:, None]
x = x.masked_fill(m[:, None], 0)
μ = x.sum(-1) / l[:, None]
denom = torch.clamp(l[:, None] - 1, min=1)
σ = (((x - μ[..., None])**2).sum(-1) / denom + 1e-5).sqrt()
x = ((x - μ[..., None]) / σ[..., None]).masked_fill(m[:, None], 0)
self.cache_length = self.calc_length(l, repeat_num=self.config.num_feat_extract_layers).long()
return F.pad(x, (0, (-T) % self.config.pad_to)).transpose(1, 2)
def forward(self, input_values):
return self.postprocessing(super().forward(self.preprocessing(input_values)).last_hidden_state)
def postprocessing(self, enc_out):
B, T, _ = enc_out.shape
H = self.lstm.hidden_size
blank = self.config.blank_id
pad = self.config.pad_id
max_len = T * self.config.max_symbols_per_step
tokens = torch.full((B, max_len), pad, dtype=torch.long, device=enc_out.device)
starts = torch.full((B, max_len), -1.0, dtype=enc_out.dtype, device=enc_out.device)
lengths = torch.zeros(B, dtype=torch.long, device=enc_out.device)
hx = torch.zeros(self.config.lstm_layer, B, H, dtype=enc_out.dtype, device=enc_out.device)
cx = torch.zeros_like(hx)
last = torch.full((B, 1), blank, dtype=torch.long, device=enc_out.device)
enc_proj = self.enc(enc_out) # (B, T, D)
for t in range(T):
e = enc_proj[:, t:t+1]
t_sec = torch.full((B, 1), t * self.denorm, dtype=enc_out.dtype, device=enc_out.device)
for _ in range(self.config.max_symbols_per_step):
hx_prev, cx_prev = hx, cx
p, (hx, cx) = self.lstm(self.embed(last), (hx, cx))
n = self.joint(self.act(e + self.pred(p))).squeeze(1).argmax(-1) # (B,)
emitted = n.ne(blank)
# revert hidden for blanks
mask = emitted.view(1, B, 1)
hx = torch.where(mask, hx, hx_prev)
cx = torch.where(mask, cx, cx_prev)
last = torch.where(emitted.unsqueeze(1), n.unsqueeze(1), last)
if emitted.any():
idx = lengths[emitted].unsqueeze(1).clamp(max=max_len - 1)
tokens[emitted] = tokens[emitted].scatter(1, idx, n[emitted].unsqueeze(1))
starts[emitted] = starts[emitted].scatter(1, idx, t_sec[emitted])
lengths[emitted] += 1
return tokens, starts, lengths
def change_language(self,language):
self.joint.load_state_dict(self.lang_joint_net[language].state_dict())
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, config = None, cache_dir = None, ignore_mismatched_sizes = False, force_download = False, local_files_only = False, token = None, revision = "main", use_safetensors = None, weights_only = True, **kwargs):
config.language = kwargs.pop("language",None)
config.multilingual = not(config.language)
if config.multilingual:
config.hidden_size = 1024
config.num_hidden_layers = 24
config.conv_depthwise_kernel_size = 9
config.conv_stride = [2,2,2]
config.conv_kernel = [3,3,3]
config.conv_dim = [256,256,256]
config.feat_extract_norm = "group"
config.intermediate_size = config.hidden_size * 4
config.num_feat_extract_layers = len(config.conv_dim)
config.lstm_layer = 2
kwargs['state_dict'] = load_file(hf_hub_download(pretrained_model_name_or_path,f"{config.language or 'all'}.safetensors"))
return super().from_pretrained(None, *model_args, config=config, cache_dir=cache_dir, ignore_mismatched_sizes=ignore_mismatched_sizes, force_download=force_download, local_files_only=local_files_only, token=token, revision=revision, use_safetensors=use_safetensors, weights_only=weights_only, **kwargs)
@staticmethod
def _load_pretrained_model(model, state_dict, checkpoint_files, load_config):
changes = """
preprocessor.featurizer.fb,mel_fb
preprocessor.featurizer.window,spec.window
norm_feed_forward1,ffn1_layer_norm
norm_feed_forward2,ffn2_layer_norm
feed_forward1.linear1,ffn1.intermediate_dense
feed_forward1.linear2,ffn1.output_dense
feed_forward2.linear1,ffn2.intermediate_dense
feed_forward2.linear2,ffn2.output_dense
norm_self_att,self_attn_layer_norm
norm_out,final_layer_norm
norm_conv,conv_module.layer_norm
.conv.,.conv_module.
decoder.prediction.dec_rnn.lstm,lstm
decoder.prediction.embed,embed
joint.enc,enc
joint.pred,pred
joint.joint_net.2,lang_joint_net
encoder.pre_encode.conv_module.0,feature_extractor.conv_layers.0.conv
encoder.pre_encode.out,feature_projection.projection
"""
if not model.config.multilingual:
changes += "encoder.pre_encode.conv_module.{n},feature_extractor.conv_layers.{(n/2)}.conv\n"
changes += f"lang_joint_net.{model.config.language},joint\n"
else:
changes += "encoder.pre_encode.conv_module.{n},encoder.pre_encode.conv_module.{(n-2)}\n"
changes += "encoder.pre_encode.conv_module.{n},feature_extractor.conv_layers.{(n//3+1)}.conv.{(n%3)}\n"
state_dict = state_bridge(state_dict, changes)
if not model.config.multilingual:
state_dict = {k: v for k, v in state_dict.items() if "lang_joint_net" not in k}
state_dict['mel_fb'] = state_dict['mel_fb'].squeeze(0)
state_dict.pop('ctc_decoder.decoder_layers.0.bias', None)
state_dict.pop('ctc_decoder.decoder_layers.0.weight', None)
return super()._load_pretrained_model(model, state_dict, checkpoint_files, load_config) |