merge with master

This commit is contained in:
pkufool 2024-06-21 18:05:52 +08:00
parent 3059eb4511
commit 19f88482be
5 changed files with 73 additions and 43 deletions

View File

@ -556,18 +556,14 @@ def save_results(
): ):
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = ( recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs. # ref/hyp pairs.
errs_filename = ( errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True f, f"{test_set_name}-{key}", results, enable_log=True
@ -577,9 +573,7 @@ def save_results(
logging.info("Wrote detailed error stats to {}".format(errs_filename)) logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
for key, val in test_set_wers: for key, val in test_set_wers:

View File

@ -239,12 +239,22 @@ def add_model_arguments(parser: argparse.ArgumentParser):
chunk left-context frames will be chosen randomly from this list; else not relevant.""", chunk left-context frames will be chosen randomly from this list; else not relevant.""",
) )
parser.add_argument(
def get_parser(): "--use-transducer",
parser = argparse.ArgumentParser( type=str2bool,
formatter_class=argparse.ArgumentDefaultsHelpFormatter default=True,
help="If True, use Transducer head.",
) )
parser.add_argument(
"--use-ctc",
type=str2bool,
default=False,
help="If True, use CTC head.",
)
def add_training_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--world-size", "--world-size",
type=int, type=int,
@ -302,16 +312,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument( parser.add_argument(
"--base-lr", type=float, default=0.045, help="The base learning rate." "--base-lr", type=float, default=0.045, help="The base learning rate."
) )
@ -379,6 +379,13 @@ def get_parser():
with this parameter before adding to the final loss.""", with this parameter before adding to the final loss.""",
) )
parser.add_argument(
"--ctc-loss-scale",
type=float,
default=0.2,
help="Scale for CTC loss.",
)
parser.add_argument( parser.add_argument(
"--seed", "--seed",
type=int, type=int,
@ -444,6 +451,24 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
add_training_arguments(parser)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -1154,26 +1179,6 @@ def run(rank, world_size, args):
# ) # )
return False return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 7) // 2 + 1) // 2
tokens = graph_compiler.texts_to_ids([c.supervisions[0].text])[0]
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
return False
return True return True
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)

View File

@ -28,5 +28,6 @@ multi_quantization
onnx onnx
onnxmltools onnxmltools
onnxruntime onnxruntime
pypinyin
kaldifst kaldifst
kaldi-decoder kaldi-decoder

View File

@ -3,6 +3,7 @@ kaldilm
kaldialign kaldialign
num2words num2words
kaldi-decoder kaldi-decoder
pypinyin
sentencepiece>=0.1.96 sentencepiece>=0.1.96
pypinyin==0.50.0 pypinyin==0.50.0
tensorboard tensorboard

View File

@ -28,6 +28,7 @@ from icefall.utils import (
encode_supervisions, encode_supervisions,
get_texts, get_texts,
make_pad_mask, make_pad_mask,
text_to_pinyin,
) )
@ -163,3 +164,31 @@ def test_add_eos():
[[1, 2, eos_id], [3, eos_id], [eos_id], [5, 8, 9, eos_id]] [[1, 2, eos_id], [3, eos_id], [eos_id], [5, 8, 9, eos_id]]
) )
assert str(ragged_eos) == str(expected) assert str(ragged_eos) == str(expected)
def test_text_to_pinyin():
txt = "想吃KFC"
r = text_to_pinyin(txt, mode="full_with_tone")
assert " ".join(r) == "xiǎng chī KFC"
r = text_to_pinyin(txt, mode="full_with_tone", errors="split")
assert " ".join(r) == "xiǎng chī K F C"
r = text_to_pinyin(txt, mode="full_no_tone", errors="default")
assert " ".join(r) == "xiang chi KFC"
r = text_to_pinyin(txt, mode="full_no_tone", errors="split")
assert " ".join(r) == "xiang chi K F C"
r = text_to_pinyin(txt, mode="partial_with_tone")
assert " ".join(r) == "x iǎng ch ī KFC"
r = text_to_pinyin(txt, mode="partial_with_tone", errors="split")
assert " ".join(r) == "x iǎng ch ī K F C"
r = text_to_pinyin(txt, mode="partial_no_tone", errors="default")
assert " ".join(r) == "x iang ch i KFC"
r = text_to_pinyin(txt, mode="partial_no_tone", errors="split")
assert " ".join(r) == "x iang ch i K F C"