| |
|
|
| """ |
| This file gives a sample demonstration of how to use the given functions in Python, for the Voice Safety Classifier v2 model. |
| """ |
|
|
| import torch |
| import librosa |
| import numpy as np |
| import argparse |
| from transformers import WavLMForSequenceClassification |
|
|
|
|
| def feature_extract_simple( |
| wav, |
| sr=16_000, |
| win_len=15.0, |
| win_stride=15.0, |
| do_normalize=False, |
| ): |
| """simple feature extraction for wavLM |
| Parameters |
| ---------- |
| wav : str or array-like |
| path to the wav file, or array-like |
| sr : int, optional |
| sample rate, by default 16_000 |
| win_len : float, optional |
| window length, by default 15.0 |
| win_stride : float, optional |
| window stride, by default 15.0 |
| do_normalize: bool, optional |
| whether to normalize the input, by default False. |
| Returns |
| ------- |
| np.ndarray |
| batched input to wavLM |
| """ |
| if type(wav) == str: |
| signal, _ = librosa.core.load(wav, sr=sr) |
| else: |
| try: |
| signal = np.array(wav).squeeze() |
| except Exception as e: |
| print(e) |
| raise RuntimeError |
| batched_input = [] |
| stride = int(win_stride * sr) |
| l = int(win_len * sr) |
| if len(signal) / sr > win_len: |
| for i in range(0, len(signal), stride): |
| if i + int(win_len * sr) > len(signal): |
| |
| chunked = np.pad(signal[i:], (0, l - len(signal[i:]))) |
| else: |
| chunked = signal[i : i + l] |
| if do_normalize: |
| chunked = (chunked - np.mean(chunked)) / (np.std(chunked) + 1e-7) |
| batched_input.append(chunked) |
| if i + int(win_len * sr) > len(signal): |
| break |
| else: |
| if do_normalize: |
| signal = (signal - np.mean(signal)) / (np.std(signal) + 1e-7) |
| batched_input.append(signal) |
| return np.stack(batched_input) |
|
|
|
|
| def infer(model, inputs): |
| output = model(inputs) |
| probs = torch.sigmoid(torch.Tensor(output.logits)) |
| return probs |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--audio_file", |
| type=str, |
| help="File to run inference", |
| ) |
| parser.add_argument( |
| "--model_path", |
| type=str, |
| default="roblox/voice-safety-classifier-v2", |
| help="checkpoint file of model", |
| ) |
| args = parser.parse_args() |
| labels_name_list = [ |
| "Discrimination", |
| "Harassment", |
| "Sexual", |
| "IllegalAndRegulated", |
| "DatingAndRomantic", |
| "Profanity", |
| ] |
|
|
| |
| audio, _ = librosa.core.load(args.audio_file, sr=16000) |
| input_np = feature_extract_simple(audio, sr=16000) |
| input_pt = torch.Tensor(input_np) |
| model = WavLMForSequenceClassification.from_pretrained( |
| args.model_path, num_labels=len(labels_name_list) |
| ) |
| probs = infer(model, input_pt) |
| probs = probs.reshape(-1, 6).detach().tolist() |
| print(f"Probabilities for {args.audio_file}:") |
| for chunk_idx in range(len(probs)): |
| print(f"\nSegment {chunk_idx}:") |
| for label_idx, label in enumerate(labels_name_list): |
| print(f"{label} : {probs[chunk_idx][label_idx]}") |
|
|