mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
add rnn lm decoding
This commit is contained in:
parent
40eb8c43c9
commit
4634911dc2
@ -46,13 +46,16 @@ from icefall.decode import (
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_rnn_lm,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.rnn_lm.model import RnnLmModel
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
load_averaged_model,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
@ -116,7 +119,9 @@ def get_parser():
|
||||
is the decoding result.
|
||||
- (6) attention-decoder. Extract n paths from the LM rescored
|
||||
lattice, the path with the highest score is the decoding result.
|
||||
- (7) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
- (7) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume
|
||||
you have trained an RNN LM using ./rnn_lm/train.py
|
||||
- (8) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
rescoring method can achieve. Useful for debugging n-best
|
||||
rescoring method.
|
||||
""",
|
||||
@ -148,7 +153,7 @@ def get_parser():
|
||||
default=100,
|
||||
help="""Number of paths for n-best based decoding method.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
||||
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
|
||||
""",
|
||||
)
|
||||
|
||||
@ -159,7 +164,7 @@ def get_parser():
|
||||
help="""The scale to be applied to `lattice.scores`.
|
||||
It's needed if you use any kinds of n-best based rescoring.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
||||
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
|
||||
A smaller value results in more unique paths.
|
||||
""",
|
||||
)
|
||||
@ -182,11 +187,67 @@ def get_parser():
|
||||
"--lm-dir",
|
||||
type=str,
|
||||
default="data/lm",
|
||||
help="""The LM dir.
|
||||
help="""The n-gram LM dir.
|
||||
It should contain either G_4_gram.pt or G_4_gram.fst.txt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-exp-dir",
|
||||
type=str,
|
||||
default="rnn_lm/exp",
|
||||
help="""Used only when --method is rnn-lm.
|
||||
It specifies the path to RNN LM exp dir.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-epoch",
|
||||
type=int,
|
||||
default=7,
|
||||
help="""Used only when --method is rnn-lm.
|
||||
It specifies the checkpoint to use.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-avg",
|
||||
type=int,
|
||||
default=2,
|
||||
help="""Used only when --method is rnn-lm.
|
||||
It specifies the number of checkpoints to average.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-embedding-dim",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Embedding dim of the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-hidden-dim",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Hidden dim of the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-num-layers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of RNN layers the model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rnn-lm-tie-weights",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to share the weights between the input embedding layer and the
|
||||
last output linear layer
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -242,6 +303,7 @@ def ctc_greedy_search(
|
||||
|
||||
|
||||
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
|
||||
# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
|
||||
new_hyp: List[int] = []
|
||||
cur = 0
|
||||
while cur < len(hyp):
|
||||
@ -256,6 +318,7 @@ def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
rnn_lm_model: Optional[nn.Module],
|
||||
HLG: Optional[k2.Fsa],
|
||||
H: Optional[k2.Fsa],
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
@ -288,6 +351,8 @@ def decode_one_batch(
|
||||
|
||||
model:
|
||||
The neural model.
|
||||
rnn_lm_model:
|
||||
The neural model for RNN LM.
|
||||
HLG:
|
||||
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
||||
H:
|
||||
@ -436,6 +501,7 @@ def decode_one_batch(
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
"rnn-lm",
|
||||
]
|
||||
|
||||
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
||||
@ -476,6 +542,26 @@ def decode_one_batch(
|
||||
eos_id=eos_id,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
elif params.method == "rnn-lm":
|
||||
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||
rescored_lattice = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=None,
|
||||
)
|
||||
|
||||
best_path_dict = rescore_with_rnn_lm(
|
||||
lattice=rescored_lattice,
|
||||
num_paths=params.num_paths,
|
||||
rnn_lm_model=rnn_lm_model,
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
blank_id=0,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
else:
|
||||
assert False, f"Unsupported decoding method: {params.method}"
|
||||
|
||||
@ -494,6 +580,7 @@ def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
rnn_lm_model: Optional[nn.Module],
|
||||
HLG: Optional[k2.Fsa],
|
||||
H: Optional[k2.Fsa],
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
@ -511,6 +598,8 @@ def decode_dataset(
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
rnn_lm_model:
|
||||
The neural model for RNN LM.
|
||||
HLG:
|
||||
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
||||
H:
|
||||
@ -548,6 +637,7 @@ def decode_dataset(
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
rnn_lm_model=rnn_lm_model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
bpe_model=bpe_model,
|
||||
@ -596,7 +686,7 @@ def save_results(
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||
):
|
||||
if params.method == "attention-decoder":
|
||||
if params.method in ("attention-decoder", "rnn-lm"):
|
||||
# Set it to False since there are too many logs.
|
||||
enable_log = False
|
||||
else:
|
||||
@ -700,6 +790,7 @@ def main():
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
"rnn-lm",
|
||||
):
|
||||
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
||||
logging.info("Loading G_4_gram.fst.txt")
|
||||
@ -731,7 +822,11 @@ def main():
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
|
||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||
if params.method in [
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
"rnn-lm",
|
||||
]:
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
@ -836,6 +931,31 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
rnn_lm_model = None
|
||||
if params.method == "rnn-lm":
|
||||
rnn_lm_model = RnnLmModel(
|
||||
vocab_size=params.num_classes,
|
||||
embedding_dim=params.rnn_lm_embedding_dim,
|
||||
hidden_dim=params.rnn_lm_hidden_dim,
|
||||
num_layers=params.rnn_lm_num_layers,
|
||||
tie_weights=params.rnn_lm_tie_weights,
|
||||
)
|
||||
if params.rnn_lm_avg == 1:
|
||||
load_checkpoint(
|
||||
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
|
||||
rnn_lm_model,
|
||||
)
|
||||
rnn_lm_model.to(device)
|
||||
else:
|
||||
rnn_lm_model = load_averaged_model(
|
||||
params.rnn_lm_exp_dir,
|
||||
rnn_lm_model,
|
||||
params.rnn_lm_epoch,
|
||||
params.rnn_lm_avg,
|
||||
device,
|
||||
)
|
||||
rnn_lm_model.eval()
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
@ -852,6 +972,7 @@ def main():
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
rnn_lm_model=rnn_lm_model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
bpe_model=bpe_model,
|
||||
|
Loading…
x
Reference in New Issue
Block a user