add rnn lm decoding

This commit is contained in:
Quandwang 2022-07-21 17:49:51 +08:00
parent 40eb8c43c9
commit 4634911dc2

View File

@ -46,13 +46,16 @@ from icefall.decode import (
one_best_decoding, one_best_decoding,
rescore_with_attention_decoder, rescore_with_attention_decoder,
rescore_with_n_best_list, rescore_with_n_best_list,
rescore_with_rnn_lm,
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_texts, get_texts,
load_averaged_model,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
str2bool, str2bool,
@ -116,7 +119,9 @@ def get_parser():
is the decoding result. is the decoding result.
- (6) attention-decoder. Extract n paths from the LM rescored - (6) attention-decoder. Extract n paths from the LM rescored
lattice, the path with the highest score is the decoding result. lattice, the path with the highest score is the decoding result.
- (7) nbest-oracle. Its WER is the lower bound of any n-best - (7) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume
you have trained an RNN LM using ./rnn_lm/train.py
- (8) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best rescoring method can achieve. Useful for debugging n-best
rescoring method. rescoring method.
""", """,
@ -148,7 +153,7 @@ def get_parser():
default=100, default=100,
help="""Number of paths for n-best based decoding method. help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values: Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, and nbest-oracle nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
""", """,
) )
@ -159,7 +164,7 @@ def get_parser():
help="""The scale to be applied to `lattice.scores`. help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring. It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values: Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, and nbest-oracle nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
A smaller value results in more unique paths. A smaller value results in more unique paths.
""", """,
) )
@ -182,11 +187,67 @@ def get_parser():
"--lm-dir", "--lm-dir",
type=str, type=str,
default="data/lm", default="data/lm",
help="""The LM dir. help="""The n-gram LM dir.
It should contain either G_4_gram.pt or G_4_gram.fst.txt It should contain either G_4_gram.pt or G_4_gram.fst.txt
""", """,
) )
parser.add_argument(
"--rnn-lm-exp-dir",
type=str,
default="rnn_lm/exp",
help="""Used only when --method is rnn-lm.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--rnn-lm-epoch",
type=int,
default=7,
help="""Used only when --method is rnn-lm.
It specifies the checkpoint to use.
""",
)
parser.add_argument(
"--rnn-lm-avg",
type=int,
default=2,
help="""Used only when --method is rnn-lm.
It specifies the number of checkpoints to average.
""",
)
parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
return parser return parser
@ -242,6 +303,7 @@ def ctc_greedy_search(
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
new_hyp: List[int] = [] new_hyp: List[int] = []
cur = 0 cur = 0
while cur < len(hyp): while cur < len(hyp):
@ -256,6 +318,7 @@ def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
rnn_lm_model: Optional[nn.Module],
HLG: Optional[k2.Fsa], HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa], H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor], bpe_model: Optional[spm.SentencePieceProcessor],
@ -288,6 +351,8 @@ def decode_one_batch(
model: model:
The neural model. The neural model.
rnn_lm_model:
The neural model for RNN LM.
HLG: HLG:
The decoding graph. Used only when params.method is NOT ctc-decoding. The decoding graph. Used only when params.method is NOT ctc-decoding.
H: H:
@ -436,6 +501,7 @@ def decode_one_batch(
"nbest-rescoring", "nbest-rescoring",
"whole-lattice-rescoring", "whole-lattice-rescoring",
"attention-decoder", "attention-decoder",
"rnn-lm",
] ]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
@ -476,6 +542,26 @@ def decode_one_batch(
eos_id=eos_id, eos_id=eos_id,
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
) )
elif params.method == "rnn-lm":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=None,
)
best_path_dict = rescore_with_rnn_lm(
lattice=rescored_lattice,
num_paths=params.num_paths,
rnn_lm_model=rnn_lm_model,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
blank_id=0,
nbest_scale=params.nbest_scale,
)
else: else:
assert False, f"Unsupported decoding method: {params.method}" assert False, f"Unsupported decoding method: {params.method}"
@ -494,6 +580,7 @@ def decode_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
rnn_lm_model: Optional[nn.Module],
HLG: Optional[k2.Fsa], HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa], H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor], bpe_model: Optional[spm.SentencePieceProcessor],
@ -511,6 +598,8 @@ def decode_dataset(
It is returned by :func:`get_params`. It is returned by :func:`get_params`.
model: model:
The neural model. The neural model.
rnn_lm_model:
The neural model for RNN LM.
HLG: HLG:
The decoding graph. Used only when params.method is NOT ctc-decoding. The decoding graph. Used only when params.method is NOT ctc-decoding.
H: H:
@ -548,6 +637,7 @@ def decode_dataset(
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
model=model, model=model,
rnn_lm_model=rnn_lm_model,
HLG=HLG, HLG=HLG,
H=H, H=H,
bpe_model=bpe_model, bpe_model=bpe_model,
@ -596,7 +686,7 @@ def save_results(
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]], results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
): ):
if params.method == "attention-decoder": if params.method in ("attention-decoder", "rnn-lm"):
# Set it to False since there are too many logs. # Set it to False since there are too many logs.
enable_log = False enable_log = False
else: else:
@ -700,6 +790,7 @@ def main():
"nbest-rescoring", "nbest-rescoring",
"whole-lattice-rescoring", "whole-lattice-rescoring",
"attention-decoder", "attention-decoder",
"rnn-lm",
): ):
if not (params.lm_dir / "G_4_gram.pt").is_file(): if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt") logging.info("Loading G_4_gram.fst.txt")
@ -731,7 +822,11 @@ def main():
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
G = k2.Fsa.from_dict(d) G = k2.Fsa.from_dict(d)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]: if params.method in [
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
]:
# Add epsilon self-loops to G as we will compose # Add epsilon self-loops to G as we will compose
# it with the whole lattice later # it with the whole lattice later
G = k2.add_epsilon_self_loops(G) G = k2.add_epsilon_self_loops(G)
@ -836,6 +931,31 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
rnn_lm_model = None
if params.method == "rnn-lm":
rnn_lm_model = RnnLmModel(
vocab_size=params.num_classes,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
tie_weights=params.rnn_lm_tie_weights,
)
if params.rnn_lm_avg == 1:
load_checkpoint(
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
rnn_lm_model,
)
rnn_lm_model.to(device)
else:
rnn_lm_model = load_averaged_model(
params.rnn_lm_exp_dir,
rnn_lm_model,
params.rnn_lm_epoch,
params.rnn_lm_avg,
device,
)
rnn_lm_model.eval()
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts() test_clean_cuts = librispeech.test_clean_cuts()
@ -852,6 +972,7 @@ def main():
dl=test_dl, dl=test_dl,
params=params, params=params,
model=model, model=model,
rnn_lm_model=rnn_lm_model,
HLG=HLG, HLG=HLG,
H=H, H=H,
bpe_model=bpe_model, bpe_model=bpe_model,