remove decoding method under test in decode.py

This commit is contained in:
Dongji Gao 2023-09-21 12:19:02 -04:00
parent 01bae96151
commit a76c46caea

View File

@ -2,6 +2,7 @@
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, # Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
# Fangjun Kuang, # Fangjun Kuang,
# Quandong Wang) # Quandong Wang)
# 2023 Johns Hopkins University (Author: Dongji Gao)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -40,17 +41,10 @@ from icefall.checkpoint import (
) )
from icefall.decode import ( from icefall.decode import (
get_lattice, get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding, one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
rescore_with_rnn_lm,
rescore_with_whole_lattice,
) )
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_texts, get_texts,
@ -119,22 +113,7 @@ def get_parser():
model for decoding. It produces the same results with ctc-decoding. model for decoding. It produces the same results with ctc-decoding.
- (2) 1best. Extract the best path from the decoding lattice as the - (2) 1best. Extract the best path from the decoding lattice as the
decoding result. decoding result.
- (3) nbest. Extract n paths from the decoding lattice; the path """,
with the highest score is the decoding result.
- (4) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
the highest score is the decoding result.
- (5) whole-lattice-rescoring. Rescore the decoding lattice with an
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
is the decoding result.
- (6) attention-decoder. Extract n paths from the LM rescored
lattice, the path with the highest score is the decoding result.
- (7) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume
you have trained an RNN LM using ./rnn_lm/train.py
- (8) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
""",
) )
parser.add_argument( parser.add_argument(
@ -157,28 +136,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
A smaller value results in more unique paths.
""",
)
parser.add_argument( parser.add_argument(
"--exp-dir", type=str, default="conformer_ctc2/exp", help="The experiment dir", "--exp-dir", type=str, default="conformer_ctc2/exp", help="The experiment dir",
) )
@ -196,59 +153,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--rnn-lm-exp-dir",
type=str,
default="rnn_lm/exp",
help="""Used only when --method is rnn-lm.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--rnn-lm-epoch",
type=int,
default=7,
help="""Used only when --method is rnn-lm.
It specifies the checkpoint to use.
""",
)
parser.add_argument(
"--rnn-lm-avg",
type=int,
default=2,
help="""Used only when --method is rnn-lm.
It specifies the number of checkpoints to average.
""",
)
parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--rnn-lm-hidden-dim", type=int, default=2048, help="Hidden dim of the model",
)
parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
return parser return parser
@ -256,8 +160,8 @@ def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
# parameters for conformer # parameters for conformer
"subsampling_factor": 2, "subsampling_factor": 4,
"feature_dim": 768, "feature_dim": 80,
"nhead": 8, "nhead": 8,
"dim_feedforward": 2048, "dim_feedforward": 2048,
"encoder_dim": 512, "encoder_dim": 512,
@ -319,7 +223,6 @@ def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
rnn_lm_model: Optional[nn.Module],
HLG: Optional[k2.Fsa], HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa], H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor], bpe_model: Optional[spm.SentencePieceProcessor],
@ -345,15 +248,9 @@ def decode_one_batch(
It's the return value of :func:`get_params`. It's the return value of :func:`get_params`.
- params.method is "1best", it uses 1best decoding without LM rescoring. - params.method is "1best", it uses 1best decoding without LM rescoring.
- params.method is "nbest", it uses nbest decoding without LM rescoring.
- params.method is "nbest-rescoring", it uses nbest LM rescoring.
- params.method is "whole-lattice-rescoring", it uses whole lattice LM
rescoring.
model: model:
The neural model. The neural model.
rnn_lm_model:
The neural model for RNN LM.
HLG: HLG:
The decoding graph. Used only when params.method is NOT ctc-decoding. The decoding graph. Used only when params.method is NOT ctc-decoding.
H: H:
@ -458,121 +355,23 @@ def decode_one_batch(
key = "ctc-greedy-search" key = "ctc-greedy-search"
return {key: hyps} return {key: hyps}
if params.method == "nbest-oracle": if params.method in ["1best"]:
# Note: You can also pass rescored lattices to it. best_path = one_best_decoding(
# We choose the HLG decoded lattice for speed reasons lattice=lattice, use_double_scores=params.use_double_scores
# as HLG decoding is faster and the oracle WER
# is only slightly worse than that of rescored lattices.
best_path = nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
word_table=word_table,
nbest_scale=params.nbest_scale,
oov="<UNK>",
) )
hyps = get_texts(best_path) key = "no_rescore"
hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
return {key: hyps}
if params.method in ["1best", "nbest"]:
if params.method == "1best":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
key = "no_rescore"
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
nbest_scale=params.nbest_scale,
)
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path) hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps] hyps = [[word_table[i] for i in ids] for ids in hyps]
return {key: hyps} return {key: hyps}
assert params.method in [
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if params.method == "nbest-rescoring":
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
nbest_scale=params.nbest_scale,
)
elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list,
)
elif params.method == "attention-decoder":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None,
)
# TODO: pass `lattice` instead of `rescored_lattice` to
# `rescore_with_attention_decoder`
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
nbest_scale=params.nbest_scale,
)
elif params.method == "rnn-lm":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None,
)
best_path_dict = rescore_with_rnn_lm(
lattice=rescored_lattice,
num_paths=params.num_paths,
rnn_lm_model=rnn_lm_model,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
blank_id=0,
nbest_scale=params.nbest_scale,
)
else: else:
assert False, f"Unsupported decoding method: {params.method}" assert False, f"Unsupported decoding method: {params.method}"
ans = dict()
if best_path_dict is not None:
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
else:
ans = None
return ans
def decode_dataset( def decode_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
rnn_lm_model: Optional[nn.Module],
HLG: Optional[k2.Fsa], HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa], H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor], bpe_model: Optional[spm.SentencePieceProcessor],
@ -590,8 +389,6 @@ def decode_dataset(
It is returned by :func:`get_params`. It is returned by :func:`get_params`.
model: model:
The neural model. The neural model.
rnn_lm_model:
The neural model for RNN LM.
HLG: HLG:
The decoding graph. Used only when params.method is NOT ctc-decoding. The decoding graph. Used only when params.method is NOT ctc-decoding.
H: H:
@ -630,7 +427,6 @@ def decode_dataset(
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
model=model, model=model,
rnn_lm_model=rnn_lm_model,
HLG=HLG, HLG=HLG,
H=H, H=H,
bpe_model=bpe_model, bpe_model=bpe_model,
@ -774,58 +570,7 @@ def main():
if not hasattr(HLG, "lm_scores"): if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone() HLG.lm_scores = HLG.scores.clone()
if params.method in ( G = None
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
):
if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt")
logging.warning("It may take 8 minutes.")
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
first_word_disambig_id = lexicon.word_table["#0"]
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
# remove it here.
del G.aux_labels
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
G.labels[G.labels >= first_word_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set G.properties to None
G.__dict__["_properties"] = None
G = k2.Fsa.from_fsas([G]).to(device)
G = k2.arc_sort(G)
# Save a dummy value so that it can be loaded in C++.
# See https://github.com/pytorch/pytorch/issues/67902
# for why we need to do this.
G.dummy = 1
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
G = k2.Fsa.from_dict(d)
if params.method in [
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
]:
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G = G.to(device)
# G.lm_scores is used to replace HLG.lm_scores during
# LM rescoring.
G.lm_scores = G.scores.clone()
else:
G = None
model = Conformer( model = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
@ -919,30 +664,6 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
rnn_lm_model = None
if params.method == "rnn-lm":
rnn_lm_model = RnnLmModel(
vocab_size=params.num_classes,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
tie_weights=params.rnn_lm_tie_weights,
)
if params.rnn_lm_avg == 1:
load_checkpoint(
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", rnn_lm_model,
)
rnn_lm_model.to(device)
else:
rnn_lm_model = load_averaged_model(
params.rnn_lm_exp_dir,
rnn_lm_model,
params.rnn_lm_epoch,
params.rnn_lm_avg,
device,
)
rnn_lm_model.eval()
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
@ -961,7 +682,6 @@ def main():
dl=test_dl, dl=test_dl,
params=params, params=params,
model=model, model=model,
rnn_lm_model=rnn_lm_model,
HLG=HLG, HLG=HLG,
H=H, H=H,
bpe_model=bpe_model, bpe_model=bpe_model,