replace file

This commit is contained in:
Kinan Martin 2025-04-14 08:27:50 +09:00
parent 1e9bb87305
commit a4be3cb3db
2 changed files with 67 additions and 1149 deletions

View File

@ -103,6 +103,9 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import k2 import k2
# import sentencepiece as spm
from tokenizer import Tokenizer
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import MLSEnglishHFAsrDataModule from asr_datamodule import MLSEnglishHFAsrDataModule
@ -120,7 +123,7 @@ from beam_search import (
modified_beam_search_lm_shallow_fusion, modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR, modified_beam_search_LODR,
) )
from tokenizer import Tokenizer # from gigaspeech_scoring import asr_text_post_processing
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
from icefall import ContextGraph, LmScorer, NgramLm from icefall import ContextGraph, LmScorer, NgramLm
@ -194,18 +197,25 @@ def get_parser():
help="The experiment dir", help="The experiment dir",
) )
parser.add_argument( # parser.add_argument(
"--bpe-model", # "--bpe-model",
type=str, # type=str,
default="data/lang_bpe_500/bpe.model", # default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model", # help="Path to the BPE model",
) # )
# parser.add_argument(
# "--lang-dir",
# type=Path,
# default="data/lang_bpe_500",
# help="The lang dir containing word table and LG graph",
# )
parser.add_argument( parser.add_argument(
"--lang-dir", "--lang-dir",
type=Path, type=str,
default="data/lang_char", default="data/lang_char",
help="The lang dir containing word table and LG graph", help="Path to the lang dir with the BPE model (`bpe.model`)",
) )
parser.add_argument( parser.add_argument(
@ -370,23 +380,24 @@ def get_parser():
modified_beam_search_LODR. modified_beam_search_LODR.
""", """,
) )
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
def asr_text_post_processing(inp):
return inp
def post_processing(
results: List[Tuple[str, List[str], List[str]]],
) -> List[Tuple[str, List[str], List[str]]]:
new_results = []
for key, ref, hyp in results:
new_ref = asr_text_post_processing(" ".join(ref)).split()
new_hyp = asr_text_post_processing(" ".join(hyp)).split()
new_results.append((key, new_ref, new_hyp))
return new_results
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
@ -470,10 +481,9 @@ def decode_one_batch(
beam=params.beam, beam=params.beam,
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
blank_penalty=params.blank_penalty,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp)) hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG": elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG( hyp_tokens = fast_beam_search_nbest_LG(
model=model, model=model,
@ -485,7 +495,6 @@ def decode_one_batch(
max_states=params.max_states, max_states=params.max_states,
num_paths=params.num_paths, num_paths=params.num_paths,
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
) )
for hyp in hyp_tokens: for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp]) hyps.append([word_table[i] for i in hyp])
@ -500,10 +509,9 @@ def decode_one_batch(
max_states=params.max_states, max_states=params.max_states,
num_paths=params.num_paths, num_paths=params.num_paths,
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp)) hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle": elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle( hyp_tokens = fast_beam_search_nbest_oracle(
model=model, model=model,
@ -516,19 +524,17 @@ def decode_one_batch(
num_paths=params.num_paths, num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]), ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp)) hyps.append(hyp.split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
blank_penalty=params.blank_penalty,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp)) hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
@ -536,10 +542,9 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
context_graph=context_graph, context_graph=context_graph,
blank_penalty=params.blank_penalty,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp)) hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
hyp_tokens = modified_beam_search_lm_shallow_fusion( hyp_tokens = modified_beam_search_lm_shallow_fusion(
model=model, model=model,
@ -549,7 +554,7 @@ def decode_one_batch(
LM=LM, LM=LM,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp)) hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_LODR": elif params.decoding_method == "modified_beam_search_LODR":
hyp_tokens = modified_beam_search_LODR( hyp_tokens = modified_beam_search_LODR(
model=model, model=model,
@ -562,7 +567,7 @@ def decode_one_batch(
context_graph=context_graph, context_graph=context_graph,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp)) hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_lm_rescore": elif params.decoding_method == "modified_beam_search_lm_rescore":
lm_scale_list = [0.01 * i for i in range(10, 50)] lm_scale_list = [0.01 * i for i in range(10, 50)]
ans_dict = modified_beam_search_lm_rescore( ans_dict = modified_beam_search_lm_rescore(
@ -608,9 +613,8 @@ def decode_one_batch(
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
) )
hyps.append(sp.text2word(sp.decode(hyp))) hyps.append(sp.decode(hyp).split())
key = f"blank_penalty_{params.blank_penalty}"
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method: elif "fast_beam_search" in params.decoding_method:
@ -714,7 +718,7 @@ def decode_dataset(
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = sp.text2word(ref_text) ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -738,6 +742,7 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = post_processing(results)
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -776,7 +781,7 @@ def save_results(
def main(): def main():
parser = get_parser() parser = get_parser()
MLSEnglishHFAsrDataModule.add_arguments(parser) MLSEnglishHFAsrDataModule.add_arguments(parser)
Tokenizer.add_arguments(parser) LmScorer.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
@ -847,8 +852,6 @@ def main():
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
) )
params.suffix += f"-blank-penalty-{params.blank_penalty}"
if params.use_averaged_model: if params.use_averaged_model:
params.suffix += "-use-averaged-model" params.suffix += "-use-averaged-model"
@ -861,9 +864,13 @@ def main():
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
sp = Tokenizer.load(Path(args.lang_dir), "bpe") # sp = spm.SentencePieceProcessor()
# sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/prepare_lang_char.py sp = Tokenizer.load(Path(args.lang_dir), "bpe") # force bpe model
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
@ -1022,9 +1029,9 @@ def main():
if os.path.exists(params.context_file): if os.path.exists(params.context_file):
contexts = [] contexts = []
for line in open(params.context_file).readlines(): for line in open(params.context_file).readlines():
contexts.append((sp.encode(line.strip()), 0.0)) contexts.append(line.strip())
context_graph = ContextGraph(params.context_score) context_graph = ContextGraph(params.context_score)
context_graph.build(contexts) context_graph.build(sp.encode(contexts))
else: else:
context_graph = None context_graph = None
else: else:
@ -1038,16 +1045,21 @@ def main():
mls_english_corpus = MLSEnglishHFAsrDataModule(args) mls_english_corpus = MLSEnglishHFAsrDataModule(args)
mls_english_corpus.load_hf_dataset("/root/datasets/parler-tts--mls_eng") mls_english_corpus.load_hf_dataset("/root/datasets/parler-tts--mls_eng")
# dev_cuts = mls_english_corpus.dev_cuts()
test_cuts = mls_english_corpus.test_cuts() test_cuts = mls_english_corpus.test_cuts()
for subdir in ["valid"]: # dev_dl = mls_english_corpus.test_dataloaders(dev_cuts)
test_dl = mls_english_corpus.test_dataloaders(test_cuts)
test_sets = ["test"]
test_dls = [test_dl]
# test_sets = ["dev", "test"]
# test_dls = [dev_dl, test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset( results_dict = decode_dataset(
dl = mls_english_corpus.test_dataloaders( dl=test_dl,
test_cuts,
),
# dl=mls_english_corpus.test_dataloaders(
# getattr(mls_english_corpus, f"{subdir}_cuts")()
# ),
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
@ -1058,22 +1070,12 @@ def main():
ngram_lm=ngram_lm, ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale, ngram_lm_scale=ngram_lm_scale,
) )
tot_err = save_results(
save_results(
params=params, params=params,
test_set_name=subdir, test_set_name=test_set,
results_dict=results_dict, results_dict=results_dict,
) )
# with (
# params.res_dir
# / (
# f"{subdir}-{params.decode_chunk_len}_{params.beam_size}"
# f"_{params.avg}_{params.epoch}.cer"
# )
# ).open("w") as fout:
# if len(tot_err) == 1:
# fout.write(f"{tot_err[0][1]}")
# else:
# fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err)
logging.info("Done!") logging.info("Done!")

File diff suppressed because it is too large Load Diff