diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py index 28e1472af..a75a79fef 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py @@ -57,6 +57,7 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) +from ls_text_normalization import word_normalization from text_normalization import ref_text_normalization, remove_non_alphabetic, upper_only_alpha from train_baseline import ( add_model_arguments, @@ -76,8 +77,8 @@ from icefall.utils import ( setup_logger, store_transcripts, str2bool, - write_error_stats, ) +from utils import write_error_stats LOG_EPS = math.log(1e-10) @@ -480,6 +481,7 @@ def save_results( params: AttributeDict, test_set_name: str, results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + biasing_words: List[str] = None, ): test_set_wers = dict() test_set_cers = dict() @@ -494,7 +496,7 @@ def save_results( errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True + f, f"{test_set_name}-{key}", results, enable_log=True, biasing_words=biasing_words, ) test_set_wers[key] = wer @@ -740,6 +742,12 @@ def main(): test_dl = [long_audio_dl] for test_set, test_dl in zip(test_sets, test_dl): + if params.use_ls_test_set: + f = open("data/context_biasing/LibriSpeechBiasingLists/all_rare_words.txt", 'r') + biasing_words = f.read().strip().split() + f.close() + else: + biasing_words = None results_dict = decode_dataset( dl=test_dl, params=params, @@ -781,6 +789,7 @@ def main(): params=params, test_set_name=test_set, results_dict=new_res, + biasing_words=biasing_words, ) if params.suffix.endswith("-post-normalization"): diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style.py index 59daf0bc4..0aa23d49a 100755 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style.py @@ -61,7 +61,7 @@ from beam_search import ( modified_beam_search, ) from dataset import naive_triplet_text_sampling, random_shuffle_subset -from utils import get_facebook_biasing_list +from utils import get_facebook_biasing_list, brian_biasing_list, write_error_stats from ls_text_normalization import word_normalization from text_normalization import ( ref_text_normalization, @@ -92,7 +92,6 @@ from icefall.utils import ( setup_logger, store_transcripts, str2bool, - write_error_stats, ) LOG_EPS = math.log(1e-10) @@ -334,11 +333,18 @@ def get_parser(): help="If use a fixed context list for LibriSpeech decoding" ) + parser.add_argument( + "--biasing-level", + type=str, + default="utterance", + choices=["utterance", "Book", "Chapter"], + ) + parser.add_argument( "--ls-distractors", - type=str2bool, - default=True, - help="If add distractors into context list for LibriSpeech decoding" + type=int, + default=0, + help="The number of distractors into context list for LibriSpeech decoding" ) add_model_arguments(parser) @@ -430,13 +436,20 @@ def decode_one_batch( pre_texts = ["" for _ in range(batch_size)] if params.use_ls_context_list: - pre_texts = [biasing_dict[id] for id in cut_ids] + if params.biasing_level == "utterance": + pre_texts = [biasing_dict[id] for id in cut_ids] + elif params.biasing_level == "Chapter": + chapter_ids = [c.split('-')[1] for c in cut_ids] + pre_texts = [biasing_dict[id] for id in chapter_ids] + elif params.biasing_level == "Book": + chapter_ids = [c.split('-')[1] for c in cut_ids] + pre_texts = [biasing_dict[id] for id in chapter_ids] if params.pre_text_transform == "mixed-punc": pre_texts = [t.lower() for t in pre_texts] # get style_text if params.use_style_prompt: - fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it is fully not related." + fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related." style_texts = batch["supervisions"].get("style_text", [fixed_sentence for _ in range(batch_size)]) style_texts = [train_text_normalization(t) for t in style_texts] else: @@ -447,7 +460,8 @@ def decode_one_batch( # apply style transform to the pre_text and style_text pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform) - pre_texts = [t[:params.max_prompt_lens] for t in pre_texts] + if not params.use_ls_context_list: + pre_texts = [t[:params.max_prompt_lens] for t in pre_texts] #pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0) if params.use_style_prompt: style_texts = _apply_style_transform(style_texts, params.style_text_transform) @@ -461,7 +475,9 @@ def decode_one_batch( style_texts=style_texts, tokenizer=tokenizer, device=device, + no_limit=True, ) + logging.info(f"Shape of the encoded prompts: {encoded_inputs['input_ids'].shape}") memory, memory_key_padding_mask = model.encode_text( encoded_inputs=encoded_inputs, @@ -666,6 +682,7 @@ def save_results( params: AttributeDict, test_set_name: str, results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + biasing_words: List[str] = None, ): test_set_wers = dict() test_set_cers = dict() @@ -792,9 +809,9 @@ def main(): params.suffix += f"-use-context-fuser" if params.use_ls_context_list: - params.suffix += f"-use-ls-context-list" - if params.ls_distractors: - params.suffix += f"-add-ls-context-distractors" + params.suffix += f"-use-{params.biasing_level}-level-ls-context-list" + if params.biasing_level == "utterance" and params.ls_distractors: + params.suffix += f"-ls-context-distractors-{params.ls_distractors}" if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -919,6 +936,7 @@ def main(): ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts() ls_test_other_cuts = libriheavy.librispeech_test_other_cuts() long_audio_cuts = libriheavy.long_audio_cuts() + npr1_dev_cuts = libriheavy.npr1_dev_cuts() npr1_test_cuts = libriheavy.npr1_test_cuts() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_subformer_with_style.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_subformer_with_style.py index 75f570fa4..956da517a 100755 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_subformer_with_style.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_subformer_with_style.py @@ -60,7 +60,8 @@ from beam_search import ( modified_beam_search, ) from dataset import naive_triplet_text_sampling, random_shuffle_subset, get_substring -from utils import get_facebook_biasing_list +from utils import get_facebook_biasing_list, brian_biasing_list, write_error_stats +from ls_text_normalization import word_normalization from text_normalization import ref_text_normalization, remove_non_alphabetic, upper_only_alpha, upper_all_char, train_text_normalization from train_subformer_with_style import ( add_model_arguments, @@ -82,7 +83,6 @@ from icefall.utils import ( setup_logger, store_transcripts, str2bool, - write_error_stats, ) LOG_EPS = math.log(1e-10) @@ -324,11 +324,18 @@ def get_parser(): help="If use a fixed context list for LibriSpeech decoding" ) + parser.add_argument( + "--biasing-level", + type=str, + default="utterance", + choices=["utterance", "Book", "Chapter"], + ) + parser.add_argument( "--ls-distractors", - type=str2bool, - default=True, - help="If add distractors into context list for LibriSpeech decoding" + type=int, + default=0, + help="The number of distractors into context list for LibriSpeech decoding" ) add_model_arguments(parser) @@ -414,11 +421,19 @@ def decode_one_batch( if "pre_text" in batch["supervisions"] and params.use_pre_text: pre_texts = batch["supervisions"]["pre_text"] + pre_texts = [train_text_normalization(t) for t in pre_texts] else: pre_texts = ["" for _ in range(batch_size)] if params.use_ls_context_list: - pre_texts = [biasing_dict[id] for id in cut_ids] + if params.biasing_level == "utterance": + pre_texts = [biasing_dict[id] for id in cut_ids] + elif params.biasing_level == "Chapter": + chapter_ids = [c.split('-')[1] for c in cut_ids] + pre_texts = [biasing_dict[id] for id in chapter_ids] + elif params.biasing_level == "Book": + chapter_ids = [c.split('-')[1] for c in cut_ids] + pre_texts = [biasing_dict[id] for id in chapter_ids] if params.pre_text_transform == "mixed-punc": pre_texts = [t.lower() for t in pre_texts] @@ -434,7 +449,8 @@ def decode_one_batch( # apply style transform to the pre_text and style_text pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform) - pre_texts = [t[:params.max_prompt_lens] for t in pre_texts] + if not params.use_ls_context_list: + pre_texts = [t[:params.max_prompt_lens] for t in pre_texts] #pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0) if params.use_style_prompt: style_texts = _apply_style_transform(style_texts, params.style_text_transform) @@ -448,8 +464,9 @@ def decode_one_batch( style_texts=style_texts, bpe_model=text_encoder_bpe_model, device=device, - max_tokens=1000, + max_tokens=8000, ) + logging.info(f"Shape of the encoded prompts: {pre_texts.shape}") memory, memory_key_padding_mask = model.encode_text( text=pre_texts, @@ -608,6 +625,13 @@ def decode_dataset( texts = _apply_style_transform(texts, params.style_text_transform) cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + if not params.use_ls_test_set: + try: + book_names = [cut.text_path.split('/')[-2] for cut in batch["supervisions"]["cut"]] + except: + book_names = [cut.id.split('/')[0] for cut in batch["supervisions"]["cut"]] + else: + book_names = ["" for _ in cut_ids] hyps_dict = decode_one_batch( params=params, @@ -623,13 +647,14 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + for cut_id, book_name, hyp_words, ref_text in zip(cut_ids, book_names, hyps, texts): ref_text = ref_text_normalization( ref_text ) # remove full-width symbols & some book marks ref_words = ref_text.split() this_batch.append((cut_id, ref_words, hyp_words)) - + # if not params.use_ls_test_set: + # results[name + "_" + book_name].extend(this_batch) results[name].extend(this_batch) num_cuts += len(texts) @@ -647,6 +672,7 @@ def save_results( params: AttributeDict, test_set_name: str, results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + biasing_words: List[str] = None, ): test_set_wers = dict() test_set_cers = dict() @@ -661,7 +687,7 @@ def save_results( errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True + f, f"{test_set_name}-{key}", results, enable_log=True, biasing_words=biasing_words, ) test_set_wers[key] = wer @@ -769,9 +795,9 @@ def main(): params.suffix += f"-use-context-fuser" if params.use_ls_context_list: - params.suffix += f"-use-ls-context-list" - if params.ls_distractors: - params.suffix += f"-add-ls-context-distractors" + params.suffix += f"-use-{params.biasing_level}-level-ls-context-list" + if params.biasing_level == "utterance" and params.ls_distractors: + params.suffix += f"-ls-context-distractors-{params.ls_distractors}" if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -921,16 +947,28 @@ def main(): test_sets = ["long-audio"] test_dl = [long_audio_dl] - #test_sets = ["npr1-dev", "npr1-test"] - #test_dl = [npr1_dev_dl, npr1_test_dl] + if params.long_audio_recog: + test_sets = ["long-audio"] + test_dl = [long_audio_dl] for test_set, test_dl in zip(test_sets, test_dl): - if test_set == "ls-test-clean": - biasing_dict = get_facebook_biasing_list("test-clean", use_distractors=params.ls_distractors) - elif test_set == "ls-test-other": - biasing_dict = get_facebook_biasing_list("test-other", use_distractors=params.ls_distractors) + if params.biasing_level == "utterance": + if test_set == "ls-test-clean": + biasing_dict = get_facebook_biasing_list("test-clean", use_distractors=params.ls_distractors, num_distractors=params.ls_distractors) + elif test_set == "ls-test-other": + biasing_dict = get_facebook_biasing_list("test-other", use_distractors=params.ls_distractors, num_distractors=params.ls_distractors) + else: + biasing_dict = None + f = open("data/context_biasing/LibriSpeechBiasingLists/all_rare_words.txt", 'r') + biasing_words = f.read().strip().split() + f.close() + else: - biasing_dict = None + if params.use_ls_test_set: + biasing_dict = brian_biasing_list(params.biasing_level) + else: + biasing_dict = None + biasing_words = None results_dict = decode_dataset( dl=test_dl, @@ -961,7 +999,7 @@ def main(): if params.use_ls_test_set: hyp = " ".join(hyp).replace("-", " ").split() # handle the hypens hyp = upper_only_alpha(" ".join(hyp)).split() - hyp = [word_normalization(w.upper()) for w in hyp] + hyp = [word_normalization(str(w).upper()) for w in hyp] hyp = " ".join(hyp).split() hyp = [w for w in hyp if w != ""] ref = upper_only_alpha(" ".join(ref)).split() @@ -975,6 +1013,7 @@ def main(): params=params, test_set_name=test_set, results_dict=new_res, + biasing_words=biasing_words, ) if params.suffix.endswith("-post-normalization"):