| from tqdm import tqdm |
| from typing import Dict, List |
| from pydiardecode import build_diardecoder |
| import numpy as np |
| import copy |
| import os |
| import json |
| import concurrent.futures |
| import kenlm |
|
|
| __INFO_TAG__ = "[BeamSearchUtil INFO]" |
|
|
| class SpeakerTaggingBeamSearchDecoder: |
| def __init__(self, loaded_kenlm_model: kenlm, cfg: dict): |
| self.realigning_lm_params = cfg |
| self.realigning_lm = self._load_realigning_LM(loaded_kenlm_model=loaded_kenlm_model) |
| self._SPLITSYM = "@" |
|
|
| def _load_realigning_LM(self, loaded_kenlm_model: kenlm): |
| """ |
| Load ARPA language model for realigning speaker labels for words. |
| """ |
| diar_decoder = build_diardecoder( |
| loaded_kenlm_model=loaded_kenlm_model, |
| kenlm_model_path=self.realigning_lm_params['arpa_language_model'], |
| alpha=self.realigning_lm_params['alpha'], |
| beta=self.realigning_lm_params['beta'], |
| word_window=self.realigning_lm_params['word_window'], |
| use_ngram=self.realigning_lm_params['use_ngram'], |
| ) |
| return diar_decoder |
|
|
| def realign_words_with_lm(self, word_dict_seq_list: List[Dict[str, float]], speaker_count: int = None, port_num=None) -> List[Dict[str, float]]: |
| if speaker_count is None: |
| spk_list = [] |
| for k, line_dict in enumerate(word_dict_seq_list): |
| _, spk_label = line_dict['word'], line_dict['speaker'] |
| spk_list.append(spk_label) |
| else: |
| spk_list = [ f"speaker_{k}" for k in range(speaker_count)] |
|
|
| realigned_list = self.realigning_lm.decode_beams(beam_width=self.realigning_lm_params['beam_width'], |
| speaker_list=sorted(list(set(spk_list))), |
| word_dict_seq_list=word_dict_seq_list, |
| port_num=port_num) |
| return realigned_list |
|
|
| def beam_search_diarization( |
| self, |
| trans_info_dict: Dict[str, Dict[str, list]], |
| port_num: List[int] = None, |
| ) -> Dict[str, Dict[str, float]]: |
| """ |
| Match the diarization result with the ASR output. |
| The words and the timestamps for the corresponding words are matched in a for loop. |
| |
| Args: |
| |
| Returns: |
| trans_info_dict (dict): |
| Dictionary containing word timestamps, speaker labels and words from all sessions. |
| Each session is indexed by a unique ID. |
| """ |
| for uniq_id, session_dict in tqdm(trans_info_dict.items(), total=len(trans_info_dict), disable=True): |
| word_dict_seq_list = session_dict['words'] |
| output_beams = self.realign_words_with_lm(word_dict_seq_list=word_dict_seq_list, speaker_count=session_dict['speaker_count'], port_num=port_num) |
| word_dict_seq_list = output_beams[0][2] |
| trans_info_dict[uniq_id]['words'] = word_dict_seq_list |
| return trans_info_dict |
|
|
| def merge_div_inputs(self, div_trans_info_dict, org_trans_info_dict, win_len=250, word_window=16): |
| """ |
| Merge the outputs of parallel processing. |
| """ |
| uniq_id_list = list(org_trans_info_dict.keys()) |
| sub_div_dict = {} |
| for seq_id in div_trans_info_dict.keys(): |
| div_info = seq_id.split(self._SPLITSYM) |
| uniq_id, sub_idx, total_count = div_info[0], int(div_info[1]), int(div_info[2]) |
| if uniq_id not in sub_div_dict: |
| sub_div_dict[uniq_id] = [None] * total_count |
| sub_div_dict[uniq_id][sub_idx] = div_trans_info_dict[seq_id]['words'] |
| |
| for uniq_id in uniq_id_list: |
| org_trans_info_dict[uniq_id]['words'] = [] |
| for k, div_words in enumerate(sub_div_dict[uniq_id]): |
| if k == 0: |
| div_words = div_words[:win_len] |
| else: |
| div_words = div_words[word_window:] |
| org_trans_info_dict[uniq_id]['words'].extend(div_words) |
| return org_trans_info_dict |
| |
| def divide_chunks(self, trans_info_dict, win_len, word_window, port): |
| """ |
| Divide word sequence into chunks of length `win_len` for parallel processing. |
| |
| Args: |
| trans_info_dict (_type_): _description_ |
| diar_logits (_type_): _description_ |
| win_len (int, optional): _description_. Defaults to 250. |
| """ |
| if len(port) > 1: |
| num_workers = len(port) |
| else: |
| num_workers = 1 |
| div_trans_info_dict = {} |
| for uniq_id in trans_info_dict.keys(): |
| uniq_trans = trans_info_dict[uniq_id] |
| del uniq_trans['status'] |
| del uniq_trans['transcription'] |
| del uniq_trans['sentences'] |
| word_seq = uniq_trans['words'] |
|
|
| div_word_seq = [] |
| if win_len is None: |
| win_len = int(np.ceil(len(word_seq)/num_workers)) |
| n_chunks = int(np.ceil(len(word_seq)/win_len)) |
| |
| for k in range(n_chunks): |
| div_word_seq.append(word_seq[max(k*win_len - word_window, 0):(k+1)*win_len]) |
| |
| total_count = len(div_word_seq) |
| for k, w_seq in enumerate(div_word_seq): |
| seq_id = uniq_id + f"{self._SPLITSYM}{k}{self._SPLITSYM}{total_count}" |
| div_trans_info_dict[seq_id] = dict(uniq_trans) |
| div_trans_info_dict[seq_id]['words'] = w_seq |
| return div_trans_info_dict |
|
|
| def run_mp_beam_search_decoding( |
| speaker_beam_search_decoder, |
| loaded_kenlm_model, |
| div_trans_info_dict, |
| org_trans_info_dict, |
| div_mp, |
| win_len, |
| word_window, |
| port=None, |
| use_ngram=False |
| ): |
| if len(port) > 1: |
| port = [int(p) for p in port] |
| if use_ngram: |
| port = [None] |
| num_workers = 36 |
| else: |
| num_workers = len(port) |
| |
| uniq_id_list = sorted(list(div_trans_info_dict.keys() )) |
| tp = concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) |
| futures = [] |
|
|
| count = 0 |
| for uniq_id in uniq_id_list: |
| print(f"{__INFO_TAG__} Running beam search decoding for {uniq_id}...") |
| if port is not None: |
| port_num = port[count % len(port)] |
| else: |
| port_num = None |
| count += 1 |
| uniq_trans_info_dict = {uniq_id: div_trans_info_dict[uniq_id]} |
| futures.append(tp.submit(speaker_beam_search_decoder.beam_search_diarization, uniq_trans_info_dict, port_num=port_num)) |
|
|
| pbar = tqdm(total=len(uniq_id_list), desc="Running beam search decoding", unit="files") |
| count = 0 |
| output_trans_info_dict = {} |
| for done_future in concurrent.futures.as_completed(futures): |
| count += 1 |
| pbar.update() |
| output_trans_info_dict.update(done_future.result()) |
| pbar.close() |
| tp.shutdown() |
| if div_mp: |
| output_trans_info_dict = speaker_beam_search_decoder.merge_div_inputs(div_trans_info_dict=output_trans_info_dict, |
| org_trans_info_dict=org_trans_info_dict, |
| win_len=win_len, |
| word_window=word_window) |
| return output_trans_info_dict |
|
|
| def count_num_of_spks(json_trans_list): |
| spk_set = set() |
| for sentence_dict in json_trans_list: |
| spk_set.add(sentence_dict['speaker']) |
| speaker_map = { spk_str: idx for idx, spk_str in enumerate(spk_set)} |
| return speaker_map |
|
|
| def add_placeholder_speaker_softmax(json_trans_list, peak_prob=0.94 ,max_spks=4): |
| nemo_json_dict = {} |
| word_dict_seq_list = [] |
| if peak_prob > 1 or peak_prob < 0: |
| raise ValueError(f"peak_prob must be between 0 and 1 but got {peak_prob}") |
| speaker_map = count_num_of_spks(json_trans_list) |
| base_array = np.ones(max_spks) * (1 - peak_prob)/(max_spks-1) |
| stt_sec, end_sec = None, None |
| for sentence_dict in json_trans_list: |
| word_list = sentence_dict['words'].split() |
| speaker = sentence_dict['speaker'] |
| for word in word_list: |
| speaker_softmax = copy.deepcopy(base_array) |
| speaker_softmax[speaker_map[speaker]] = peak_prob |
| word_dict_seq_list.append({'word': word, |
| 'start_time': stt_sec, |
| 'end_time': end_sec, |
| 'speaker': speaker_map[speaker], |
| 'speaker_softmax': speaker_softmax} |
| ) |
| nemo_json_dict.update({'words': word_dict_seq_list, |
| 'status': "success", |
| 'sentences': json_trans_list, |
| 'speaker_count': len(speaker_map), |
| 'transcription': None} |
| ) |
| return nemo_json_dict |
|
|
| def convert_nemo_json_to_seglst(trans_info_dict): |
| seglst_seq_list = [] |
| seg_lst_dict, spk_wise_trans_sessions = {}, {} |
| for uniq_id in trans_info_dict.keys(): |
| spk_wise_trans_sessions[uniq_id] = {} |
| seglst_seq_list = [] |
| word_seq_list = trans_info_dict[uniq_id]['words'] |
| prev_speaker, sentence = None, '' |
| for widx, word_dict in enumerate(word_seq_list): |
| curr_speaker = word_dict['speaker'] |
|
|
| |
| word = word_dict['word'] |
| if curr_speaker not in spk_wise_trans_sessions[uniq_id]: |
| spk_wise_trans_sessions[uniq_id][curr_speaker] = word |
| elif curr_speaker in spk_wise_trans_sessions[uniq_id]: |
| spk_wise_trans_sessions[uniq_id][curr_speaker] = f"{spk_wise_trans_sessions[uniq_id][curr_speaker]} {word_dict['word']}" |
|
|
| |
| if curr_speaker!= prev_speaker and prev_speaker is not None: |
| seglst_seq_list.append({'session_id': uniq_id, |
| 'words': sentence.strip(), |
| 'start_time': 0.0, |
| 'end_time': 0.0, |
| 'speaker': prev_speaker, |
| }) |
| sentence = word_dict['word'] |
| else: |
| sentence = f"{sentence} {word_dict['word']}" |
| prev_speaker = curr_speaker |
|
|
| |
| |
| |
| if widx == len(word_seq_list) - 1: |
| seglst_seq_list.append({'session_id': uniq_id, |
| 'words': sentence.strip(), |
| 'start_time': 0.0, |
| 'end_time': 0.0, |
| 'speaker': curr_speaker, |
| }) |
| seg_lst_dict[uniq_id] = seglst_seq_list |
| return seg_lst_dict |
|
|
| def load_input_jsons(input_error_src_list_path, ext_str=".seglst.json", peak_prob=0.94, max_spks=4): |
| trans_info_dict = {} |
| json_filepath_list = open(input_error_src_list_path).readlines() |
| for json_path in json_filepath_list: |
| json_path = json_path.strip() |
| uniq_id = os.path.split(json_path)[-1].split(ext_str)[0] |
| if os.path.exists(json_path): |
| with open(json_path, "r") as file: |
| json_trans = json.load(file) |
| else: |
| raise FileNotFoundError(f"{json_path} does not exist. Aborting.") |
| nemo_json_dict = add_placeholder_speaker_softmax(json_trans, peak_prob=peak_prob, max_spks=max_spks) |
| trans_info_dict[uniq_id] = nemo_json_dict |
| return trans_info_dict |
|
|
| def load_reference_jsons(reference_seglst_list_path, ext_str=".seglst.json"): |
| reference_info_dict = {} |
| json_filepath_list = open(reference_seglst_list_path).readlines() |
| for json_path in json_filepath_list: |
| json_path = json_path.strip() |
| uniq_id = os.path.split(json_path)[-1].split(ext_str)[0] |
| if os.path.exists(json_path): |
| with open(json_path, "r") as file: |
| json_trans = json.load(file) |
| else: |
| raise FileNotFoundError(f"{json_path} does not exist. Aborting.") |
| json_trans_uniq_id = [] |
| for sentence_dict in json_trans: |
| sentence_dict['session_id'] = uniq_id |
| json_trans_uniq_id.append(sentence_dict) |
| reference_info_dict[uniq_id] = json_trans_uniq_id |
| return reference_info_dict |
|
|
| def write_seglst_jsons( |
| seg_lst_sessions_dict: dict, |
| input_error_src_list_path: str, |
| diar_out_path: str, |
| ext_str: str, |
| write_individual_seglst_jsons=True |
| ): |
| """ |
| Writes the segment list (seglst) JSON files to the output directory. |
| |
| Parameters: |
| seg_lst_sessions_dict (dict): A dictionary containing session IDs as keys and their corresponding segment lists as values. |
| input_error_src_list_path (str): The path to the input error source list file. |
| diar_out_path (str): The path to the output directory where the seglst JSON files will be written. |
| type_string (str): A string representing the type of the seglst JSON files (e.g., 'hyp' for hypothesis or 'ef' for reference). |
| write_individual_seglst_jsons (bool, optional): A flag indicating whether to write individual seglst JSON files for each session. Defaults to True. |
| |
| Returns: |
| None |
| """ |
| total_infer_list = [] |
| total_output_filename = os.path.split(input_error_src_list_path)[-1].replace(".list", "") |
| for session_id, seg_lst_list in seg_lst_sessions_dict.items(): |
| total_infer_list.extend(seg_lst_list) |
| if write_individual_seglst_jsons: |
| print(f"{__INFO_TAG__} Writing {diar_out_path}/{session_id}.seglst.json") |
| with open(f'{diar_out_path}/{session_id}.seglst.json', 'w') as file: |
| json.dump(seg_lst_list, file, indent=4) |
|
|
| print(f"{__INFO_TAG__} Writing {diar_out_path}/{session_id}.seglst.json") |
| total_output_filename = total_output_filename.replace("src", ext_str).replace("ref", ext_str) |
| with open(f'{diar_out_path}/{total_output_filename}.seglst.json', 'w') as file: |
| json.dump(total_infer_list, file, indent=4) |
|
|