support RNNLM shallow fusion in stateless5

This commit is contained in:
marcoyang 2022-11-02 16:37:29 +08:00
parent de2f5e3e6d
commit 63d0a52dbd
2 changed files with 124 additions and 61 deletions

View File

@ -23,7 +23,6 @@ import sentencepiece as spm
import torch import torch
from model import Transducer from model import Transducer
from icefall import NgramLm, NgramLmStateCost
from icefall.decode import Nbest, one_best_decoding from icefall.decode import Nbest, one_best_decoding
from icefall.rnn_lm.model import RnnLmModel from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import add_eos, add_sos, get_texts from icefall.utils import add_eos, add_sos, get_texts
@ -658,8 +657,6 @@ class Hypothesis:
# It contains only one entry. # It contains only one entry.
log_prob: torch.Tensor log_prob: torch.Tensor
state_cost: Optional[NgramLmStateCost] = None
state: Optional = None
lm_score: Optional=None lm_score: Optional=None
@property @property

View File

@ -19,36 +19,36 @@
""" """
Usage: Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless5/decode.py \ ./lstm_transducer_stateless2/decode.py \
--epoch 28 \ --epoch 35 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \ --exp-dir ./lstm_transducer_stateless2/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (not recommended) (2) beam search (not recommended)
./pruned_transducer_stateless5/decode.py \ ./lstm_transducer_stateless2/decode.py \
--epoch 28 \ --epoch 35 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \ --exp-dir ./lstm_transducer_stateless2/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless5/decode.py \ ./lstm_transducer_stateless2/decode.py \
--epoch 28 \ --epoch 35 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \ --exp-dir ./lstm_transducer_stateless2/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search (one best) (4) fast beam search (one best)
./pruned_transducer_stateless5/decode.py \ ./lstm_transducer_stateless2/decode.py \
--epoch 28 \ --epoch 35 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \ --exp-dir ./lstm_transducer_stateless2/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 20.0 \ --beam 20.0 \
@ -56,10 +56,10 @@ Usage:
--max-states 64 --max-states 64
(5) fast beam search (nbest) (5) fast beam search (nbest)
./pruned_transducer_stateless5/decode.py \ ./lstm_transducer_stateless2/decode.py \
--epoch 28 \ --epoch 30 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search_nbest \ --decoding-method fast_beam_search_nbest \
--beam 20.0 \ --beam 20.0 \
@ -69,10 +69,10 @@ Usage:
--nbest-scale 0.5 --nbest-scale 0.5
(6) fast beam search (nbest oracle WER) (6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless5/decode.py \ ./lstm_transducer_stateless2/decode.py \
--epoch 28 \ --epoch 35 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \ --exp-dir ./lstm_transducer_stateless2/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \ --decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \ --beam 20.0 \
@ -82,10 +82,10 @@ Usage:
--nbest-scale 0.5 --nbest-scale 0.5
(7) fast beam search (with LG) (7) fast beam search (with LG)
./pruned_transducer_stateless5/decode.py \ ./lstm_transducer_stateless2/decode.py \
--epoch 28 \ --epoch 35 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \ --exp-dir ./lstm_transducer_stateless2/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \ --decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \ --beam 20.0 \
@ -115,6 +115,7 @@ from beam_search import (
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
modified_beam_search_rnnlm_shallow_fusion,
) )
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
@ -125,6 +126,7 @@ from icefall.checkpoint import (
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -183,7 +185,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless5/exp", default="lstm_transducer_stateless2/exp",
help="The experiment dir", help="The experiment dir",
) )
@ -213,6 +215,7 @@ def get_parser():
- fast_beam_search_nbest - fast_beam_search_nbest
- fast_beam_search_nbest_oracle - fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG - fast_beam_search_nbest_LG
- modified-beam-search3 # for rnn lm shallow fusion
If you use fast_beam_search_nbest_LG, you have to specify If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`. `--lang-dir`, which should contain `LG.pt`.
""", """,
@ -240,16 +243,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument( parser.add_argument(
"--max-contexts", "--max-contexts",
type=int, type=int,
@ -275,6 +268,7 @@ def get_parser():
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram", "2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
@ -302,28 +296,69 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--simulate-streaming", "--rnn-lm-scale",
type=str2bool, type=float,
default=False, default=0.0,
help="""Whether to simulate streaming in decoding, this is a good way to help="""Used only when --method is modified_beam_search3.
test a streaming model. It specifies the path to RNN LM exp dir.
""", """,
) )
parser.add_argument( parser.add_argument(
"--decode-chunk-size", "--rnn-lm-exp-dir",
type=int, type=str,
default=16, default="rnn_lm/exp",
help="The chunk size for decoding (in frames after subsampling)", help="""Used only when --method is rnn-lm.
It specifies the path to RNN LM exp dir.
""",
) )
parser.add_argument( parser.add_argument(
"--left-context", "--rnn-lm-epoch",
type=int, type=int,
default=64, default=7,
help="left context can be seen during decoding (in frames after subsampling)", help="""Used only when --method is rnn-lm.
It specifies the checkpoint to use.
""",
) )
parser.add_argument(
"--rnn-lm-avg",
type=int,
default=2,
help="""Used only when --method is rnn-lm.
It specifies the number of checkpoints to average.
""",
)
parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -336,6 +371,8 @@ def decode_one_batch(
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
rnnlm: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -361,7 +398,7 @@ def decode_one_batch(
word_table: word_table:
The word symbol table. The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
@ -474,12 +511,21 @@ def decode_one_batch(
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
hyp_tokens = modified_beam_search_rnnlm_shallow_fusion(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
for i in range(batch_size): for i in range(batch_size):
# fmt: off # fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
# fmt: on # fmt: on
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
hyp = greedy_search( hyp = greedy_search(
@ -523,7 +569,9 @@ def decode_dataset(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: rnnlm: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
Args: Args:
@ -538,7 +586,7 @@ def decode_dataset(
word_table: word_table:
The word symbol table. The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
@ -564,6 +612,7 @@ def decode_dataset(
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
logging.info(f"Decoding {batch_idx}-th batch")
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -572,6 +621,8 @@ def decode_dataset(
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
word_table=word_table, word_table=word_table,
batch=batch, batch=batch,
rnnlm=rnnlm,
rnnlm_scale=rnnlm_scale,
) )
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
@ -597,7 +648,7 @@ def decode_dataset(
def save_results( def save_results(
params: AttributeDict, params: AttributeDict,
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
): ):
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
@ -657,6 +708,7 @@ def main():
"fast_beam_search_nbest_LG", "fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
"modified_beam_search_sf_rnnlm",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
@ -665,10 +717,6 @@ def main():
else: else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
@ -686,6 +734,8 @@ def main():
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
if params.use_averaged_model: if params.use_averaged_model:
params.suffix += "-use-averaged-model" params.suffix += "-use-averaged-model"
@ -706,11 +756,6 @@ def main():
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
@ -796,6 +841,25 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
rnn_lm_model = None
rnn_lm_scale = params.rnn_lm_scale
if params.decoding_method == "modified_beam_search3":
rnn_lm_model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
tie_weights=params.rnn_lm_tie_weights,
)
assert params.rnn_lm_avg == 1
load_checkpoint(
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
rnn_lm_model,
)
rnn_lm_model.to(device)
rnn_lm_model.eval()
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG": if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
@ -839,6 +903,8 @@ def main():
sp=sp, sp=sp,
word_table=word_table, word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
rnnlm=rnn_lm_model,
rnnlm_scale=rnn_lm_scale,
) )
save_results( save_results(