mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 00:54:18 +00:00
support showing WERs of different books
This commit is contained in:
parent
f23882b9f6
commit
80c54c05e2
@ -233,6 +233,12 @@ class LibriHeavyAsrDataModule:
|
|||||||
type=str,
|
type=str,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--long-audio-cuts",
|
||||||
|
type=str,
|
||||||
|
default="data/manifest_npr/npr1_cuts_all_guids_0.jsonl.gz",
|
||||||
|
)
|
||||||
|
|
||||||
def train_dataloaders(
|
def train_dataloaders(
|
||||||
self,
|
self,
|
||||||
cuts_train: CutSet,
|
cuts_train: CutSet,
|
||||||
@ -510,8 +516,16 @@ class LibriHeavyAsrDataModule:
|
|||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def long_audio_cuts(self) -> CutSet:
|
def long_audio_cuts(self) -> CutSet:
|
||||||
logging.info("About to get medium test cuts")
|
logging.info("About to get long audio cuts")
|
||||||
cuts = load_manifest_lazy(
|
cuts = load_manifest_lazy(
|
||||||
"data/long_audios/long_audio_pomonastravels_combined.jsonl.gz"
|
self.args.long_audio_cuts,
|
||||||
|
)
|
||||||
|
return cuts
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_dev_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get test dev cuts")
|
||||||
|
cuts = load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libriheavy_cuts_test_dev.jsonl.gz"
|
||||||
)
|
)
|
||||||
return cuts
|
return cuts
|
@ -249,6 +249,12 @@ def get_parser():
|
|||||||
help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ",
|
help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--long-audio-recog",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-ls-test-set",
|
"--use-ls-test-set",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -434,6 +440,10 @@ def decode_dataset(
|
|||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
if not params.use_ls_test_set:
|
||||||
|
book_names = [cut.text_path.split('/')[-2] for cut in batch["supervisions"]["cut"]]
|
||||||
|
else:
|
||||||
|
book_names = ["" for _ in cut_ids]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -447,13 +457,14 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
for cut_id, book_name, hyp_words, ref_text in zip(cut_ids, book_names, hyps, texts):
|
||||||
ref_text = ref_text_normalization(
|
ref_text = ref_text_normalization(
|
||||||
ref_text
|
ref_text
|
||||||
)
|
)
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((cut_id, ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
if not params.use_ls_test_set:
|
||||||
|
results[name + " " + book_name].extend(this_batch)
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
num_cuts += len(texts)
|
num_cuts += len(texts)
|
||||||
@ -556,6 +567,10 @@ def main():
|
|||||||
"greedy_search",
|
"greedy_search",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if params.long_audio_recog:
|
||||||
|
params.res_dir = params.exp_dir / (params.decoding_method + "long_audio")
|
||||||
|
else:
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
@ -704,11 +719,16 @@ def main():
|
|||||||
test_other_cuts = libriheavy.test_other_cuts()
|
test_other_cuts = libriheavy.test_other_cuts()
|
||||||
ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts()
|
ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts()
|
||||||
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
|
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
|
||||||
|
long_audio_cuts = libriheavy.long_audio_cuts()
|
||||||
|
test_dev_cuts = libriheavy.test_dev_cuts()
|
||||||
|
#test_clean_cuts = test_clean_cuts.filter(lambda c: "Brain Twister" not in c.text_path)
|
||||||
|
|
||||||
test_clean_dl = libriheavy.valid_dataloaders(test_clean_cuts,)
|
test_clean_dl = libriheavy.valid_dataloaders(test_clean_cuts,)
|
||||||
test_other_dl = libriheavy.valid_dataloaders(test_other_cuts,)
|
test_other_dl = libriheavy.valid_dataloaders(test_other_cuts,)
|
||||||
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
|
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
|
||||||
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
|
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
|
||||||
|
long_audio_dl = libriheavy.valid_dataloaders(long_audio_cuts,)
|
||||||
|
test_dev_dl = libriheavy.valid_dataloaders(test_dev_cuts)
|
||||||
|
|
||||||
if params.use_ls_test_set:
|
if params.use_ls_test_set:
|
||||||
test_sets = ["ls-test-clean", "ls-test-other"]
|
test_sets = ["ls-test-clean", "ls-test-other"]
|
||||||
@ -717,6 +737,13 @@ def main():
|
|||||||
test_sets = ["test-clean", "test-other"]
|
test_sets = ["test-clean", "test-other"]
|
||||||
test_dl = [test_clean_dl, test_other_dl]
|
test_dl = [test_clean_dl, test_other_dl]
|
||||||
|
|
||||||
|
if params.long_audio_recog:
|
||||||
|
test_sets = ["long-audio"]
|
||||||
|
test_dl = [long_audio_dl]
|
||||||
|
|
||||||
|
# test_sets = ["test-dev", ]
|
||||||
|
# test_dl = [test_dev_dl, ]
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
|
@ -270,6 +270,12 @@ def get_parser():
|
|||||||
help="Use style prompt when evaluation"
|
help="Use style prompt when evaluation"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-prompt-lens",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-context-embedding",
|
"--use-context-embedding",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -284,6 +290,12 @@ def get_parser():
|
|||||||
help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ",
|
help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--long-audio-recog",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--compute-CER",
|
"--compute-CER",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -434,6 +446,7 @@ def decode_one_batch(
|
|||||||
|
|
||||||
# apply style transform to the pre_text and style_text
|
# apply style transform to the pre_text and style_text
|
||||||
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
||||||
|
pre_texts = [t[:params.max_prompt_lens] for t in pre_texts]
|
||||||
#pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0)
|
#pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0)
|
||||||
if params.use_style_prompt:
|
if params.use_style_prompt:
|
||||||
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
|
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
|
||||||
@ -605,6 +618,10 @@ def decode_dataset(
|
|||||||
texts = _apply_style_transform(texts, params.style_text_transform)
|
texts = _apply_style_transform(texts, params.style_text_transform)
|
||||||
|
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
if not params.use_ls_test_set:
|
||||||
|
book_names = [cut.text_path.split('/')[-2] for cut in batch["supervisions"]["cut"]]
|
||||||
|
else:
|
||||||
|
book_names = ["" for _ in cut_ids]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -620,13 +637,14 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
for cut_id, book_name, hyp_words, ref_text in zip(cut_ids, book_names, hyps, texts):
|
||||||
ref_text = ref_text_normalization(
|
ref_text = ref_text_normalization(
|
||||||
ref_text
|
ref_text
|
||||||
) # remove full-width symbols & some book marks
|
) # remove full-width symbols & some book marks
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((cut_id, ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
if not params.use_ls_test_set:
|
||||||
|
results[name + "_" + book_name].extend(this_batch)
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
num_cuts += len(texts)
|
num_cuts += len(texts)
|
||||||
@ -731,6 +749,10 @@ def main():
|
|||||||
"greedy_search",
|
"greedy_search",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if params.long_audio_recog:
|
||||||
|
params.res_dir = params.exp_dir / (params.decoding_method + "long_audio")
|
||||||
|
else:
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
@ -757,7 +779,7 @@ def main():
|
|||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
|
||||||
if params.use_pre_text:
|
if params.use_pre_text:
|
||||||
params.suffix += f"-pre-text-{params.pre_text_transform}"
|
params.suffix += f"-pre-text-{params.pre_text_transform}-len-{params.max_prompt_lens}"
|
||||||
|
|
||||||
if params.use_style_prompt:
|
if params.use_style_prompt:
|
||||||
params.suffix += f"-style-prompt-{params.style_text_transform}"
|
params.suffix += f"-style-prompt-{params.style_text_transform}"
|
||||||
@ -892,11 +914,16 @@ def main():
|
|||||||
test_other_cuts = libriheavy.test_other_cuts()
|
test_other_cuts = libriheavy.test_other_cuts()
|
||||||
ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts()
|
ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts()
|
||||||
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
|
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
|
||||||
|
long_audio_cuts = libriheavy.long_audio_cuts()
|
||||||
|
test_dev_cuts = libriheavy.test_dev_cuts()
|
||||||
|
#test_clean_cuts = test_clean_cuts.filter(lambda c: "Brain Twister" not in c.text_path)
|
||||||
|
|
||||||
test_clean_dl = libriheavy.valid_dataloaders(test_clean_cuts, text_sampling_func=naive_triplet_text_sampling)
|
test_clean_dl = libriheavy.valid_dataloaders(test_clean_cuts, text_sampling_func=naive_triplet_text_sampling)
|
||||||
test_other_dl = libriheavy.valid_dataloaders(test_other_cuts, text_sampling_func=naive_triplet_text_sampling)
|
test_other_dl = libriheavy.valid_dataloaders(test_other_cuts, text_sampling_func=naive_triplet_text_sampling)
|
||||||
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
|
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
|
||||||
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
|
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
|
||||||
|
long_audio_dl = libriheavy.valid_dataloaders(long_audio_cuts, text_sampling_func=naive_triplet_text_sampling)
|
||||||
|
test_dev_dl = libriheavy.valid_dataloaders(test_dev_cuts)
|
||||||
|
|
||||||
if params.use_ls_test_set:
|
if params.use_ls_test_set:
|
||||||
test_sets = ["ls-test-clean", "ls-test-other"]
|
test_sets = ["ls-test-clean", "ls-test-other"]
|
||||||
@ -905,6 +932,13 @@ def main():
|
|||||||
test_sets = ["test-clean", "test-other"]
|
test_sets = ["test-clean", "test-other"]
|
||||||
test_dl = [test_clean_dl, test_other_dl]
|
test_dl = [test_clean_dl, test_other_dl]
|
||||||
|
|
||||||
|
if params.long_audio_recog:
|
||||||
|
test_sets = ["long-audio"]
|
||||||
|
test_dl = [long_audio_dl]
|
||||||
|
|
||||||
|
test_sets = ["test-dev", ]
|
||||||
|
test_dl = [test_dev_dl, ]
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
if test_set == "ls-test-clean":
|
if test_set == "ls-test-clean":
|
||||||
biasing_dict = get_facebook_biasing_list("test-clean", use_distractors=params.ls_distractors)
|
biasing_dict = get_facebook_biasing_list("test-clean", use_distractors=params.ls_distractors)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user