mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
add rnn lm decoding
This commit is contained in:
parent
40eb8c43c9
commit
4634911dc2
@ -46,13 +46,16 @@ from icefall.decode import (
|
|||||||
one_best_decoding,
|
one_best_decoding,
|
||||||
rescore_with_attention_decoder,
|
rescore_with_attention_decoder,
|
||||||
rescore_with_n_best_list,
|
rescore_with_n_best_list,
|
||||||
|
rescore_with_rnn_lm,
|
||||||
rescore_with_whole_lattice,
|
rescore_with_whole_lattice,
|
||||||
)
|
)
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.rnn_lm.model import RnnLmModel
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
get_texts,
|
get_texts,
|
||||||
|
load_averaged_model,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
str2bool,
|
str2bool,
|
||||||
@ -116,7 +119,9 @@ def get_parser():
|
|||||||
is the decoding result.
|
is the decoding result.
|
||||||
- (6) attention-decoder. Extract n paths from the LM rescored
|
- (6) attention-decoder. Extract n paths from the LM rescored
|
||||||
lattice, the path with the highest score is the decoding result.
|
lattice, the path with the highest score is the decoding result.
|
||||||
- (7) nbest-oracle. Its WER is the lower bound of any n-best
|
- (7) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume
|
||||||
|
you have trained an RNN LM using ./rnn_lm/train.py
|
||||||
|
- (8) nbest-oracle. Its WER is the lower bound of any n-best
|
||||||
rescoring method can achieve. Useful for debugging n-best
|
rescoring method can achieve. Useful for debugging n-best
|
||||||
rescoring method.
|
rescoring method.
|
||||||
""",
|
""",
|
||||||
@ -148,7 +153,7 @@ def get_parser():
|
|||||||
default=100,
|
default=100,
|
||||||
help="""Number of paths for n-best based decoding method.
|
help="""Number of paths for n-best based decoding method.
|
||||||
Used only when "method" is one of the following values:
|
Used only when "method" is one of the following values:
|
||||||
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -159,7 +164,7 @@ def get_parser():
|
|||||||
help="""The scale to be applied to `lattice.scores`.
|
help="""The scale to be applied to `lattice.scores`.
|
||||||
It's needed if you use any kinds of n-best based rescoring.
|
It's needed if you use any kinds of n-best based rescoring.
|
||||||
Used only when "method" is one of the following values:
|
Used only when "method" is one of the following values:
|
||||||
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
|
||||||
A smaller value results in more unique paths.
|
A smaller value results in more unique paths.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -182,11 +187,67 @@ def get_parser():
|
|||||||
"--lm-dir",
|
"--lm-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lm",
|
default="data/lm",
|
||||||
help="""The LM dir.
|
help="""The n-gram LM dir.
|
||||||
It should contain either G_4_gram.pt or G_4_gram.fst.txt
|
It should contain either G_4_gram.pt or G_4_gram.fst.txt
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--rnn-lm-exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="rnn_lm/exp",
|
||||||
|
help="""Used only when --method is rnn-lm.
|
||||||
|
It specifies the path to RNN LM exp dir.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--rnn-lm-epoch",
|
||||||
|
type=int,
|
||||||
|
default=7,
|
||||||
|
help="""Used only when --method is rnn-lm.
|
||||||
|
It specifies the checkpoint to use.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--rnn-lm-avg",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="""Used only when --method is rnn-lm.
|
||||||
|
It specifies the number of checkpoints to average.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--rnn-lm-embedding-dim",
|
||||||
|
type=int,
|
||||||
|
default=2048,
|
||||||
|
help="Embedding dim of the model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--rnn-lm-hidden-dim",
|
||||||
|
type=int,
|
||||||
|
default=2048,
|
||||||
|
help="Hidden dim of the model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--rnn-lm-num-layers",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Number of RNN layers the model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rnn-lm-tie-weights",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""True to share the weights between the input embedding layer and the
|
||||||
|
last output linear layer
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -242,6 +303,7 @@ def ctc_greedy_search(
|
|||||||
|
|
||||||
|
|
||||||
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
|
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
|
||||||
|
# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
|
||||||
new_hyp: List[int] = []
|
new_hyp: List[int] = []
|
||||||
cur = 0
|
cur = 0
|
||||||
while cur < len(hyp):
|
while cur < len(hyp):
|
||||||
@ -256,6 +318,7 @@ def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
|
|||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
rnn_lm_model: Optional[nn.Module],
|
||||||
HLG: Optional[k2.Fsa],
|
HLG: Optional[k2.Fsa],
|
||||||
H: Optional[k2.Fsa],
|
H: Optional[k2.Fsa],
|
||||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||||
@ -288,6 +351,8 @@ def decode_one_batch(
|
|||||||
|
|
||||||
model:
|
model:
|
||||||
The neural model.
|
The neural model.
|
||||||
|
rnn_lm_model:
|
||||||
|
The neural model for RNN LM.
|
||||||
HLG:
|
HLG:
|
||||||
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
||||||
H:
|
H:
|
||||||
@ -436,6 +501,7 @@ def decode_one_batch(
|
|||||||
"nbest-rescoring",
|
"nbest-rescoring",
|
||||||
"whole-lattice-rescoring",
|
"whole-lattice-rescoring",
|
||||||
"attention-decoder",
|
"attention-decoder",
|
||||||
|
"rnn-lm",
|
||||||
]
|
]
|
||||||
|
|
||||||
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
||||||
@ -476,6 +542,26 @@ def decode_one_batch(
|
|||||||
eos_id=eos_id,
|
eos_id=eos_id,
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
)
|
)
|
||||||
|
elif params.method == "rnn-lm":
|
||||||
|
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||||
|
rescored_lattice = rescore_with_whole_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
G_with_epsilon_loops=G,
|
||||||
|
lm_scale_list=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
best_path_dict = rescore_with_rnn_lm(
|
||||||
|
lattice=rescored_lattice,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
rnn_lm_model=rnn_lm_model,
|
||||||
|
model=model,
|
||||||
|
memory=memory,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
blank_id=0,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert False, f"Unsupported decoding method: {params.method}"
|
assert False, f"Unsupported decoding method: {params.method}"
|
||||||
|
|
||||||
@ -494,6 +580,7 @@ def decode_dataset(
|
|||||||
dl: torch.utils.data.DataLoader,
|
dl: torch.utils.data.DataLoader,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
rnn_lm_model: Optional[nn.Module],
|
||||||
HLG: Optional[k2.Fsa],
|
HLG: Optional[k2.Fsa],
|
||||||
H: Optional[k2.Fsa],
|
H: Optional[k2.Fsa],
|
||||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||||
@ -511,6 +598,8 @@ def decode_dataset(
|
|||||||
It is returned by :func:`get_params`.
|
It is returned by :func:`get_params`.
|
||||||
model:
|
model:
|
||||||
The neural model.
|
The neural model.
|
||||||
|
rnn_lm_model:
|
||||||
|
The neural model for RNN LM.
|
||||||
HLG:
|
HLG:
|
||||||
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
||||||
H:
|
H:
|
||||||
@ -548,6 +637,7 @@ def decode_dataset(
|
|||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
rnn_lm_model=rnn_lm_model,
|
||||||
HLG=HLG,
|
HLG=HLG,
|
||||||
H=H,
|
H=H,
|
||||||
bpe_model=bpe_model,
|
bpe_model=bpe_model,
|
||||||
@ -596,7 +686,7 @@ def save_results(
|
|||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||||
):
|
):
|
||||||
if params.method == "attention-decoder":
|
if params.method in ("attention-decoder", "rnn-lm"):
|
||||||
# Set it to False since there are too many logs.
|
# Set it to False since there are too many logs.
|
||||||
enable_log = False
|
enable_log = False
|
||||||
else:
|
else:
|
||||||
@ -700,6 +790,7 @@ def main():
|
|||||||
"nbest-rescoring",
|
"nbest-rescoring",
|
||||||
"whole-lattice-rescoring",
|
"whole-lattice-rescoring",
|
||||||
"attention-decoder",
|
"attention-decoder",
|
||||||
|
"rnn-lm",
|
||||||
):
|
):
|
||||||
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
||||||
logging.info("Loading G_4_gram.fst.txt")
|
logging.info("Loading G_4_gram.fst.txt")
|
||||||
@ -731,7 +822,11 @@ def main():
|
|||||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
||||||
G = k2.Fsa.from_dict(d)
|
G = k2.Fsa.from_dict(d)
|
||||||
|
|
||||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
if params.method in [
|
||||||
|
"whole-lattice-rescoring",
|
||||||
|
"attention-decoder",
|
||||||
|
"rnn-lm",
|
||||||
|
]:
|
||||||
# Add epsilon self-loops to G as we will compose
|
# Add epsilon self-loops to G as we will compose
|
||||||
# it with the whole lattice later
|
# it with the whole lattice later
|
||||||
G = k2.add_epsilon_self_loops(G)
|
G = k2.add_epsilon_self_loops(G)
|
||||||
@ -836,6 +931,31 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
rnn_lm_model = None
|
||||||
|
if params.method == "rnn-lm":
|
||||||
|
rnn_lm_model = RnnLmModel(
|
||||||
|
vocab_size=params.num_classes,
|
||||||
|
embedding_dim=params.rnn_lm_embedding_dim,
|
||||||
|
hidden_dim=params.rnn_lm_hidden_dim,
|
||||||
|
num_layers=params.rnn_lm_num_layers,
|
||||||
|
tie_weights=params.rnn_lm_tie_weights,
|
||||||
|
)
|
||||||
|
if params.rnn_lm_avg == 1:
|
||||||
|
load_checkpoint(
|
||||||
|
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
|
||||||
|
rnn_lm_model,
|
||||||
|
)
|
||||||
|
rnn_lm_model.to(device)
|
||||||
|
else:
|
||||||
|
rnn_lm_model = load_averaged_model(
|
||||||
|
params.rnn_lm_exp_dir,
|
||||||
|
rnn_lm_model,
|
||||||
|
params.rnn_lm_epoch,
|
||||||
|
params.rnn_lm_avg,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
rnn_lm_model.eval()
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
@ -852,6 +972,7 @@ def main():
|
|||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
rnn_lm_model=rnn_lm_model,
|
||||||
HLG=HLG,
|
HLG=HLG,
|
||||||
H=H,
|
H=H,
|
||||||
bpe_model=bpe_model,
|
bpe_model=bpe_model,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user