support showing WERs of different books

This commit is contained in:
marcoyang1998 2023-08-17 23:59:37 +08:00
parent f23882b9f6
commit 80c54c05e2
3 changed files with 84 additions and 9 deletions

View File

@ -232,6 +232,12 @@ class LibriHeavyAsrDataModule:
"--rare-word-file",
type=str,
)
group.add_argument(
"--long-audio-cuts",
type=str,
default="data/manifest_npr/npr1_cuts_all_guids_0.jsonl.gz",
)
def train_dataloaders(
self,
@ -510,8 +516,16 @@ class LibriHeavyAsrDataModule:
@lru_cache()
def long_audio_cuts(self) -> CutSet:
logging.info("About to get medium test cuts")
logging.info("About to get long audio cuts")
cuts = load_manifest_lazy(
"data/long_audios/long_audio_pomonastravels_combined.jsonl.gz"
self.args.long_audio_cuts,
)
return cuts
@lru_cache()
def test_dev_cuts(self) -> CutSet:
logging.info("About to get test dev cuts")
cuts = load_manifest_lazy(
self.args.manifest_dir / "libriheavy_cuts_test_dev.jsonl.gz"
)
return cuts

View File

@ -249,6 +249,12 @@ def get_parser():
help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ",
)
parser.add_argument(
"--long-audio-recog",
type=str2bool,
default=False,
)
parser.add_argument(
"--use-ls-test-set",
type=str2bool,
@ -434,6 +440,10 @@ def decode_dataset(
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
if not params.use_ls_test_set:
book_names = [cut.text_path.split('/')[-2] for cut in batch["supervisions"]["cut"]]
else:
book_names = ["" for _ in cut_ids]
hyps_dict = decode_one_batch(
params=params,
@ -447,13 +457,14 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
for cut_id, book_name, hyp_words, ref_text in zip(cut_ids, book_names, hyps, texts):
ref_text = ref_text_normalization(
ref_text
)
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
if not params.use_ls_test_set:
results[name + " " + book_name].extend(this_batch)
results[name].extend(this_batch)
num_cuts += len(texts)
@ -556,7 +567,11 @@ def main():
"greedy_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.long_audio_recog:
params.res_dir = params.exp_dir / (params.decoding_method + "long_audio")
else:
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
@ -704,11 +719,16 @@ def main():
test_other_cuts = libriheavy.test_other_cuts()
ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts()
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
long_audio_cuts = libriheavy.long_audio_cuts()
test_dev_cuts = libriheavy.test_dev_cuts()
#test_clean_cuts = test_clean_cuts.filter(lambda c: "Brain Twister" not in c.text_path)
test_clean_dl = libriheavy.valid_dataloaders(test_clean_cuts,)
test_other_dl = libriheavy.valid_dataloaders(test_other_cuts,)
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
long_audio_dl = libriheavy.valid_dataloaders(long_audio_cuts,)
test_dev_dl = libriheavy.valid_dataloaders(test_dev_cuts)
if params.use_ls_test_set:
test_sets = ["ls-test-clean", "ls-test-other"]
@ -716,6 +736,13 @@ def main():
else:
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
if params.long_audio_recog:
test_sets = ["long-audio"]
test_dl = [long_audio_dl]
# test_sets = ["test-dev", ]
# test_dl = [test_dev_dl, ]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(

View File

@ -270,6 +270,12 @@ def get_parser():
help="Use style prompt when evaluation"
)
parser.add_argument(
"--max-prompt-lens",
type=int,
default=500,
)
parser.add_argument(
"--use-context-embedding",
type=str2bool,
@ -283,6 +289,12 @@ def get_parser():
default=True,
help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ",
)
parser.add_argument(
"--long-audio-recog",
type=str2bool,
default=False,
)
parser.add_argument(
"--compute-CER",
@ -434,6 +446,7 @@ def decode_one_batch(
# apply style transform to the pre_text and style_text
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
pre_texts = [t[:params.max_prompt_lens] for t in pre_texts]
#pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0)
if params.use_style_prompt:
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
@ -605,6 +618,10 @@ def decode_dataset(
texts = _apply_style_transform(texts, params.style_text_transform)
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
if not params.use_ls_test_set:
book_names = [cut.text_path.split('/')[-2] for cut in batch["supervisions"]["cut"]]
else:
book_names = ["" for _ in cut_ids]
hyps_dict = decode_one_batch(
params=params,
@ -620,13 +637,14 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
for cut_id, book_name, hyp_words, ref_text in zip(cut_ids, book_names, hyps, texts):
ref_text = ref_text_normalization(
ref_text
) # remove full-width symbols & some book marks
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
if not params.use_ls_test_set:
results[name + "_" + book_name].extend(this_batch)
results[name].extend(this_batch)
num_cuts += len(texts)
@ -731,7 +749,11 @@ def main():
"greedy_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.long_audio_recog:
params.res_dir = params.exp_dir / (params.decoding_method + "long_audio")
else:
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
@ -757,7 +779,7 @@ def main():
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_pre_text:
params.suffix += f"-pre-text-{params.pre_text_transform}"
params.suffix += f"-pre-text-{params.pre_text_transform}-len-{params.max_prompt_lens}"
if params.use_style_prompt:
params.suffix += f"-style-prompt-{params.style_text_transform}"
@ -892,11 +914,16 @@ def main():
test_other_cuts = libriheavy.test_other_cuts()
ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts()
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
long_audio_cuts = libriheavy.long_audio_cuts()
test_dev_cuts = libriheavy.test_dev_cuts()
#test_clean_cuts = test_clean_cuts.filter(lambda c: "Brain Twister" not in c.text_path)
test_clean_dl = libriheavy.valid_dataloaders(test_clean_cuts, text_sampling_func=naive_triplet_text_sampling)
test_other_dl = libriheavy.valid_dataloaders(test_other_cuts, text_sampling_func=naive_triplet_text_sampling)
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
long_audio_dl = libriheavy.valid_dataloaders(long_audio_cuts, text_sampling_func=naive_triplet_text_sampling)
test_dev_dl = libriheavy.valid_dataloaders(test_dev_cuts)
if params.use_ls_test_set:
test_sets = ["ls-test-clean", "ls-test-other"]
@ -904,6 +931,13 @@ def main():
else:
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
if params.long_audio_recog:
test_sets = ["long-audio"]
test_dl = [long_audio_dl]
test_sets = ["test-dev", ]
test_dl = [test_dev_dl, ]
for test_set, test_dl in zip(test_sets, test_dl):
if test_set == "ls-test-clean":