diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py index 635272e17..0c082c801 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py @@ -232,6 +232,12 @@ class LibriHeavyAsrDataModule: "--rare-word-file", type=str, ) + + group.add_argument( + "--long-audio-cuts", + type=str, + default="data/manifest_npr/npr1_cuts_all_guids_0.jsonl.gz", + ) def train_dataloaders( self, @@ -510,8 +516,16 @@ class LibriHeavyAsrDataModule: @lru_cache() def long_audio_cuts(self) -> CutSet: - logging.info("About to get medium test cuts") + logging.info("About to get long audio cuts") cuts = load_manifest_lazy( - "data/long_audios/long_audio_pomonastravels_combined.jsonl.gz" + self.args.long_audio_cuts, + ) + return cuts + + @lru_cache() + def test_dev_cuts(self) -> CutSet: + logging.info("About to get test dev cuts") + cuts = load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_test_dev.jsonl.gz" ) return cuts \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py index cf59e7503..a992b5747 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py @@ -249,6 +249,12 @@ def get_parser(): help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ", ) + parser.add_argument( + "--long-audio-recog", + type=str2bool, + default=False, + ) + parser.add_argument( "--use-ls-test-set", type=str2bool, @@ -434,6 +440,10 @@ def decode_dataset( for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + if not params.use_ls_test_set: + book_names = [cut.text_path.split('/')[-2] for cut in batch["supervisions"]["cut"]] + else: + book_names = ["" for _ in cut_ids] hyps_dict = decode_one_batch( params=params, @@ -447,13 +457,14 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + for cut_id, book_name, hyp_words, ref_text in zip(cut_ids, book_names, hyps, texts): ref_text = ref_text_normalization( ref_text ) ref_words = ref_text.split() this_batch.append((cut_id, ref_words, hyp_words)) - + if not params.use_ls_test_set: + results[name + " " + book_name].extend(this_batch) results[name].extend(this_batch) num_cuts += len(texts) @@ -556,7 +567,11 @@ def main(): "greedy_search", "modified_beam_search", ) - params.res_dir = params.exp_dir / params.decoding_method + + if params.long_audio_recog: + params.res_dir = params.exp_dir / (params.decoding_method + "long_audio") + else: + params.res_dir = params.exp_dir / params.decoding_method if params.iter > 0: params.suffix = f"iter-{params.iter}-avg-{params.avg}" @@ -704,11 +719,16 @@ def main(): test_other_cuts = libriheavy.test_other_cuts() ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts() ls_test_other_cuts = libriheavy.librispeech_test_other_cuts() + long_audio_cuts = libriheavy.long_audio_cuts() + test_dev_cuts = libriheavy.test_dev_cuts() + #test_clean_cuts = test_clean_cuts.filter(lambda c: "Brain Twister" not in c.text_path) test_clean_dl = libriheavy.valid_dataloaders(test_clean_cuts,) test_other_dl = libriheavy.valid_dataloaders(test_other_cuts,) ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts) ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts) + long_audio_dl = libriheavy.valid_dataloaders(long_audio_cuts,) + test_dev_dl = libriheavy.valid_dataloaders(test_dev_cuts) if params.use_ls_test_set: test_sets = ["ls-test-clean", "ls-test-other"] @@ -716,6 +736,13 @@ def main(): else: test_sets = ["test-clean", "test-other"] test_dl = [test_clean_dl, test_other_dl] + + if params.long_audio_recog: + test_sets = ["long-audio"] + test_dl = [long_audio_dl] + + # test_sets = ["test-dev", ] + # test_dl = [test_dev_dl, ] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style.py index 90203a954..df31621b7 100755 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style.py @@ -270,6 +270,12 @@ def get_parser(): help="Use style prompt when evaluation" ) + parser.add_argument( + "--max-prompt-lens", + type=int, + default=500, + ) + parser.add_argument( "--use-context-embedding", type=str2bool, @@ -283,6 +289,12 @@ def get_parser(): default=True, help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ", ) + + parser.add_argument( + "--long-audio-recog", + type=str2bool, + default=False, + ) parser.add_argument( "--compute-CER", @@ -434,6 +446,7 @@ def decode_one_batch( # apply style transform to the pre_text and style_text pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform) + pre_texts = [t[:params.max_prompt_lens] for t in pre_texts] #pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0) if params.use_style_prompt: style_texts = _apply_style_transform(style_texts, params.style_text_transform) @@ -605,6 +618,10 @@ def decode_dataset( texts = _apply_style_transform(texts, params.style_text_transform) cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + if not params.use_ls_test_set: + book_names = [cut.text_path.split('/')[-2] for cut in batch["supervisions"]["cut"]] + else: + book_names = ["" for _ in cut_ids] hyps_dict = decode_one_batch( params=params, @@ -620,13 +637,14 @@ def decode_dataset( for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + for cut_id, book_name, hyp_words, ref_text in zip(cut_ids, book_names, hyps, texts): ref_text = ref_text_normalization( ref_text ) # remove full-width symbols & some book marks ref_words = ref_text.split() this_batch.append((cut_id, ref_words, hyp_words)) - + if not params.use_ls_test_set: + results[name + "_" + book_name].extend(this_batch) results[name].extend(this_batch) num_cuts += len(texts) @@ -731,7 +749,11 @@ def main(): "greedy_search", "modified_beam_search", ) - params.res_dir = params.exp_dir / params.decoding_method + + if params.long_audio_recog: + params.res_dir = params.exp_dir / (params.decoding_method + "long_audio") + else: + params.res_dir = params.exp_dir / params.decoding_method if params.iter > 0: params.suffix = f"iter-{params.iter}-avg-{params.avg}" @@ -757,7 +779,7 @@ def main(): params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" if params.use_pre_text: - params.suffix += f"-pre-text-{params.pre_text_transform}" + params.suffix += f"-pre-text-{params.pre_text_transform}-len-{params.max_prompt_lens}" if params.use_style_prompt: params.suffix += f"-style-prompt-{params.style_text_transform}" @@ -892,11 +914,16 @@ def main(): test_other_cuts = libriheavy.test_other_cuts() ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts() ls_test_other_cuts = libriheavy.librispeech_test_other_cuts() + long_audio_cuts = libriheavy.long_audio_cuts() + test_dev_cuts = libriheavy.test_dev_cuts() + #test_clean_cuts = test_clean_cuts.filter(lambda c: "Brain Twister" not in c.text_path) test_clean_dl = libriheavy.valid_dataloaders(test_clean_cuts, text_sampling_func=naive_triplet_text_sampling) test_other_dl = libriheavy.valid_dataloaders(test_other_cuts, text_sampling_func=naive_triplet_text_sampling) ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts) ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts) + long_audio_dl = libriheavy.valid_dataloaders(long_audio_cuts, text_sampling_func=naive_triplet_text_sampling) + test_dev_dl = libriheavy.valid_dataloaders(test_dev_cuts) if params.use_ls_test_set: test_sets = ["ls-test-clean", "ls-test-other"] @@ -904,6 +931,13 @@ def main(): else: test_sets = ["test-clean", "test-other"] test_dl = [test_clean_dl, test_other_dl] + + if params.long_audio_recog: + test_sets = ["long-audio"] + test_dl = [long_audio_dl] + + test_sets = ["test-dev", ] + test_dl = [test_dev_dl, ] for test_set, test_dl in zip(test_sets, test_dl): if test_set == "ls-test-clean":