fix style

This commit is contained in:
marcoyang 2023-10-10 16:48:17 +08:00
parent 7a9c18fc79
commit 9ef5dfe380

View File

@ -52,18 +52,15 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriHeavyAsrDataModule from asr_datamodule import LibriHeavyAsrDataModule
from beam_search import ( from beam_search import greedy_search, greedy_search_batch, modified_beam_search
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from ls_text_normalization import word_normalization from ls_text_normalization import word_normalization
from text_normalization import ref_text_normalization, remove_non_alphabetic, upper_only_alpha from text_normalization import (
from train_baseline import ( ref_text_normalization,
add_model_arguments, remove_non_alphabetic,
get_params, upper_only_alpha,
get_transducer_model,
) )
from train_baseline import add_model_arguments, get_params, get_transducer_model
from utils import write_error_stats
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
@ -72,13 +69,7 @@ from icefall.checkpoint import (
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import AttributeDict, setup_logger, store_transcripts, str2bool
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
)
from utils import write_error_stats
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)
@ -260,7 +251,7 @@ def get_parser():
"--use-ls-test-set", "--use-ls-test-set",
type=str2bool, type=str2bool,
default=False, default=False,
help="Use librispeech test set for evaluation." help="Use librispeech test set for evaluation.",
) )
parser.add_argument( parser.add_argument(
@ -373,12 +364,6 @@ def decode_one_batch(
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame, max_sym_per_frame=params.max_sym_per_frame,
) )
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
@ -442,7 +427,9 @@ def decode_dataset(
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
if not params.use_ls_test_set: if not params.use_ls_test_set:
book_names = [cut.text_path.split('/')[-2] for cut in batch["supervisions"]["cut"]] book_names = [
cut.text_path.split("/")[-2] for cut in batch["supervisions"]["cut"]
]
else: else:
book_names = ["" for _ in cut_ids] book_names = ["" for _ in cut_ids]
@ -458,10 +445,10 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for cut_id, book_name, hyp_words, ref_text in zip(cut_ids, book_names, hyps, texts): for cut_id, book_name, hyp_words, ref_text in zip(
ref_text = ref_text_normalization( cut_ids, book_names, hyps, texts
ref_text ):
) ref_text = ref_text_normalization(ref_text)
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
# if not params.use_ls_test_set: # if not params.use_ls_test_set:
@ -496,7 +483,11 @@ def save_results(
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True, biasing_words=biasing_words, f,
f"{test_set_name}-{key}",
results,
enable_log=True,
biasing_words=biasing_words,
) )
test_set_wers[key] = wer test_set_wers[key] = wer
@ -504,7 +495,9 @@ def save_results(
if params.compute_CER: if params.compute_CER:
# Write CER statistics # Write CER statistics
recog_path = params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt" recog_path = (
params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results, char_level=True) store_transcripts(filename=recog_path, texts=results, char_level=True)
errs_filename = ( errs_filename = (
params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt" params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt"
@ -522,9 +515,7 @@ def save_results(
logging.info("Wrote detailed CER stats to {}".format(errs_filename)) logging.info("Wrote detailed CER stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
)
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
for key, val in test_set_wers: for key, val in test_set_wers:
@ -539,9 +530,7 @@ def save_results(
if params.compute_CER: if params.compute_CER:
test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1]) test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1])
errs_info = ( errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
)
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tcER", file=f) print("settings\tcER", file=f)
for key, val in test_set_cers: for key, val in test_set_cers:
@ -580,9 +569,6 @@ def main():
else: else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.random_left_padding:
params.suffix += f"random-left-padding"
if "beam_search" in params.decoding_method: if "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else: else:
@ -717,11 +703,17 @@ def main():
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts() ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
long_audio_cuts = libriheavy.long_audio_cuts() long_audio_cuts = libriheavy.long_audio_cuts()
test_clean_dl = libriheavy.valid_dataloaders(test_clean_cuts,) test_clean_dl = libriheavy.valid_dataloaders(
test_other_dl = libriheavy.valid_dataloaders(test_other_cuts,) test_clean_cuts,
)
test_other_dl = libriheavy.valid_dataloaders(
test_other_cuts,
)
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts) ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts) ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
long_audio_dl = libriheavy.valid_dataloaders(long_audio_cuts,) long_audio_dl = libriheavy.valid_dataloaders(
long_audio_cuts,
)
if params.use_ls_test_set: if params.use_ls_test_set:
test_sets = ["ls-test-clean", "ls-test-other"] test_sets = ["ls-test-clean", "ls-test-other"]
@ -736,7 +728,9 @@ def main():
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
if params.use_ls_test_set: if params.use_ls_test_set:
f = open("data/context_biasing/LibriSpeechBiasingLists/all_rare_words.txt", 'r') f = open(
"data/context_biasing/LibriSpeechBiasingLists/all_rare_words.txt", "r"
)
biasing_words = f.read().strip().split() biasing_words = f.read().strip().split()
f.close() f.close()
else: else:
@ -766,7 +760,9 @@ def main():
for item in results_dict[k]: for item in results_dict[k]:
id, ref, hyp = item id, ref, hyp = item
if params.use_ls_test_set: if params.use_ls_test_set:
hyp = " ".join(hyp).replace("-", " ").split() # handle the hypens hyp = (
" ".join(hyp).replace("-", " ").split()
) # handle the hypens
hyp = upper_only_alpha(" ".join(hyp)).split() hyp = upper_only_alpha(" ".join(hyp)).split()
hyp = [word_normalization(w.upper()) for w in hyp] hyp = [word_normalization(w.upper()) for w in hyp]
hyp = " ".join(hyp).split() hyp = " ".join(hyp).split()
@ -788,7 +784,6 @@ def main():
if params.suffix.endswith("-post-normalization"): if params.suffix.endswith("-post-normalization"):
params.suffix = params.suffix.replace("-post-normalization", "") params.suffix = params.suffix.replace("-post-normalization", "")
logging.info("Done!") logging.info("Done!")