mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
include previous added decoding method
This commit is contained in:
parent
6c8d1f9ef5
commit
9a01b9098d
@ -131,11 +131,13 @@ from beam_search import (
|
|||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
|
modified_beam_search_ngram_rescoring,
|
||||||
modified_beam_search_rnnlm_shallow_fusion,
|
modified_beam_search_rnnlm_shallow_fusion,
|
||||||
)
|
)
|
||||||
from librispeech import LibriSpeech
|
from librispeech import LibriSpeech
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall import NgramLm
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
@ -232,6 +234,7 @@ def get_parser():
|
|||||||
- fast_beam_search_nbest
|
- fast_beam_search_nbest
|
||||||
- fast_beam_search_nbest_oracle
|
- fast_beam_search_nbest_oracle
|
||||||
- fast_beam_search_nbest_LG
|
- fast_beam_search_nbest_LG
|
||||||
|
- modified_beam_search_ngram_rescoring
|
||||||
- modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion
|
- modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion
|
||||||
If you use fast_beam_search_nbest_LG, you have to specify
|
If you use fast_beam_search_nbest_LG, you have to specify
|
||||||
`--lang-dir`, which should contain `LG.pt`.
|
`--lang-dir`, which should contain `LG.pt`.
|
||||||
@ -386,7 +389,23 @@ def get_parser():
|
|||||||
last output linear layer
|
last output linear layer
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
parser.add_argument("--ilm-scale", type=float, default=-0.1)
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens-ngram",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="""Token Ngram used for rescoring.
|
||||||
|
Used only when the decoding method is modified_beam_search_ngram_rescoring""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--backoff-id",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="""ID of the backoff symbol.
|
||||||
|
Used only when the decoding method is modified_beam_search_ngram_rescoring""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -399,6 +418,8 @@ def decode_one_batch(
|
|||||||
batch: dict,
|
batch: dict,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
ngram_lm: Optional[NgramLm] = None,
|
||||||
|
ngram_lm_scale: float = 1.0,
|
||||||
rnnlm: Optional[RnnLmModel] = None,
|
rnnlm: Optional[RnnLmModel] = None,
|
||||||
rnnlm_scale: float = 1.0,
|
rnnlm_scale: float = 1.0,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
@ -534,6 +555,17 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "modified_beam_search_ngram_rescoring":
|
||||||
|
hyp_tokens = modified_beam_search_ngram_rescoring(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
ngram_lm=ngram_lm,
|
||||||
|
ngram_lm_scale=ngram_lm_scale,
|
||||||
|
beam=params.beam_size,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
|
elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
|
||||||
hyp_tokens = modified_beam_search_rnnlm_shallow_fusion(
|
hyp_tokens = modified_beam_search_rnnlm_shallow_fusion(
|
||||||
model=model,
|
model=model,
|
||||||
@ -595,9 +627,11 @@ def decode_dataset(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
ngram_lm: Optional[NgramLm] = None,
|
||||||
|
ngram_lm_scale: float = 1.0,
|
||||||
rnnlm: Optional[RnnLmModel] = None,
|
rnnlm: Optional[RnnLmModel] = None,
|
||||||
rnnlm_scale: float = 1.0,
|
rnnlm_scale: float = 1.0,
|
||||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -638,13 +672,6 @@ def decode_dataset(
|
|||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
total_duration = sum(
|
|
||||||
[cut.duration for cut in batch["supervisions"]["cut"]]
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
f"Decoding {batch_idx}-th batch, batch size is {len(cut_ids)}, total duration is {total_duration}"
|
|
||||||
)
|
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -653,6 +680,8 @@ def decode_dataset(
|
|||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
word_table=word_table,
|
word_table=word_table,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
|
ngram_lm=ngram_lm,
|
||||||
|
ngram_lm_scale=ngram_lm_scale,
|
||||||
rnnlm=rnnlm,
|
rnnlm=rnnlm,
|
||||||
rnnlm_scale=rnnlm_scale,
|
rnnlm_scale=rnnlm_scale,
|
||||||
)
|
)
|
||||||
@ -680,7 +709,7 @@ def decode_dataset(
|
|||||||
def save_results(
|
def save_results(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
@ -740,6 +769,7 @@ def main():
|
|||||||
"fast_beam_search_nbest_LG",
|
"fast_beam_search_nbest_LG",
|
||||||
"fast_beam_search_nbest_oracle",
|
"fast_beam_search_nbest_oracle",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
|
"modified_beam_search_ngram_rescoring",
|
||||||
"modified_beam_search_rnnlm_shallow_fusion",
|
"modified_beam_search_rnnlm_shallow_fusion",
|
||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
@ -765,13 +795,10 @@ def main():
|
|||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||||
if "rnnlm" in params.decoding_method:
|
if "rnnlm" in params.decoding_method:
|
||||||
params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
|
params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
|
||||||
|
|
||||||
if "ILME" in params.decoding_method:
|
|
||||||
params.suffix += f"-ILME-scale={params.ilm_scale}"
|
|
||||||
|
|
||||||
if params.use_averaged_model:
|
if params.use_averaged_model:
|
||||||
params.suffix += "-use-averaged-model"
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
@ -884,6 +911,14 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
|
||||||
|
logging.info(f"lm filename: {lm_filename}")
|
||||||
|
ngram_lm = NgramLm(
|
||||||
|
str(params.lang_dir / lm_filename),
|
||||||
|
backoff_id=params.backoff_id,
|
||||||
|
is_binary=False,
|
||||||
|
)
|
||||||
|
logging.info(f"num states: {ngram_lm.lm.num_states}")
|
||||||
# only load rnnlm if used
|
# only load rnnlm if used
|
||||||
if "rnnlm" in params.decoding_method:
|
if "rnnlm" in params.decoding_method:
|
||||||
rnn_lm_scale = params.rnn_lm_scale
|
rnn_lm_scale = params.rnn_lm_scale
|
||||||
@ -951,6 +986,8 @@ def main():
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
word_table=word_table,
|
word_table=word_table,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
|
ngram_lm=ngram_lm,
|
||||||
|
ngram_lm_scale=params.ngram_lm_scale,
|
||||||
rnnlm=rnn_lm_model,
|
rnnlm=rnn_lm_model,
|
||||||
rnnlm_scale=rnn_lm_scale,
|
rnnlm_scale=rnn_lm_scale,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user