From 8d931690edfb7fefca53e8abeefbe25cdd96bbbd Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 10 Nov 2021 11:52:44 +0800 Subject: [PATCH] Refactoring. Since FSAs in an Nbest object are linear in structure, we can add the scores of a path to compute the total scores. --- egs/librispeech/ASR/conformer_ctc/decode.py | 18 +++++--- icefall/decode.py | 49 +++++++++------------ 2 files changed, 34 insertions(+), 33 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index ad936b6e4..b4b0d7f37 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -147,13 +147,21 @@ def get_parser(): help="The lang dir", ) + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + return parser def get_params() -> AttributeDict: params = AttributeDict( { - "lm_dir": Path("data/lm"), # parameters for conformer "subsampling_factor": 4, "vgg_frontend": False, @@ -532,6 +540,7 @@ def main(): args = parser.parse_args() args.exp_dir = Path(args.exp_dir) args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) params = get_params() params.update(vars(args)) @@ -572,9 +581,8 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) ) - HLG = HLG.to(device) assert HLG.requires_grad is False if not hasattr(HLG, "lm_scores"): @@ -609,8 +617,8 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") - G = k2.Fsa.from_dict(d).to(device) + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) if params.method in ["whole-lattice-rescoring", "attention-decoder"]: # Add epsilon self-loops to G as we will compose diff --git a/icefall/decode.py b/icefall/decode.py index 8b7bdd27f..98f792783 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -364,23 +364,13 @@ class Nbest(object): Return a ragged tensor with 2 axes [utt][path_scores]. Its dtype is torch.float64. """ - # Caution: We need a clone here. `self.fsa.scores` is a - # reference to a tensor representing the last field of an arc - # in the FSA (Remeber that an arc has four fields.) If we later assign - # `self.fsa.scores`, it will also change the scores on every arc, which - # means saved_scores will also be changed if we don't use `clone()` - # here. - saved_scores = self.fsa.scores.clone() + scores_shape = self.fsa.arcs.shape().remove_axis(1) + # scores_shape has axes [path][arc] + am_scores = self.fsa.scores - self.fsa.lm_scores + ragged_am_scores = k2.RaggedTensor(scores_shape, am_scores.contiguous()) + tot_scores = ragged_am_scores.sum() - # The `scores` of every arc consists of `am_scores` and `lm_scores` - self.fsa.scores = self.fsa.scores - self.fsa.lm_scores - - am_scores = self.fsa.get_tot_scores( - use_double_scores=True, log_semiring=False - ) - self.fsa.scores = saved_scores - - return k2.RaggedTensor(self.shape, am_scores) + return k2.RaggedTensor(self.shape, tot_scores) def compute_lm_scores(self) -> k2.RaggedTensor: """Compute LM scores of each linear FSA (i.e., each path within @@ -397,17 +387,16 @@ class Nbest(object): Return a ragged tensor with 2 axes [utt][path_scores]. Its dtype is torch.float64. """ - saved_scores = self.fsa.scores.clone() + scores_shape = self.fsa.arcs.shape().remove_axis(1) + # scores_shape has axes [path][arc] - # The `scores` of every arc consists of `am_scores` and `lm_scores` - self.fsa.scores = self.fsa.lm_scores.clone() - - lm_scores = self.fsa.get_tot_scores( - use_double_scores=True, log_semiring=False + ragged_lm_scores = k2.RaggedTensor( + scores_shape, self.fsa.lm_scores.contiguous() ) - self.fsa.scores = saved_scores - return k2.RaggedTensor(self.shape, lm_scores) + tot_scores = ragged_lm_scores.sum() + + return k2.RaggedTensor(self.shape, tot_scores) def tot_scores(self) -> k2.RaggedTensor: """Get total scores of FSAs in this Nbest. @@ -420,10 +409,14 @@ class Nbest(object): Return a ragged tensor with two axes [utt][path_scores]. Its dtype is torch.float64. """ - scores = self.fsa.get_tot_scores( - use_double_scores=True, log_semiring=False - ) - return k2.RaggedTensor(self.shape, scores) + scores_shape = self.fsa.arcs.shape().remove_axis(1) + # scores_shape has axes [path][arc] + + ragged_scores = k2.RaggedTensor(scores_shape, self.scores.contiguous()) + + tot_scores = ragged_scores.sum() + + return k2.RaggedTensor(self.shape, tot_scores) def build_levenshtein_graphs(self) -> k2.Fsa: """Return an FsaVec with axes [utt][state][arc]."""