Fix a bug in Nbest.compute_am_scores and Nbest.compute_lm_scores. (#111)

This commit is contained in:
Fangjun Kuang 2021-11-09 13:44:51 +08:00 committed by GitHub
parent 91cfecebf2
commit 04029871b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -364,7 +364,13 @@ class Nbest(object):
Return a ragged tensor with 2 axes [utt][path_scores]. Return a ragged tensor with 2 axes [utt][path_scores].
Its dtype is torch.float64. Its dtype is torch.float64.
""" """
saved_scores = self.fsa.scores # Caution: We need a clone here. `self.fsa.scores` is a
# reference to a tensor representing the last field of an arc
# in the FSA (Remeber that an arc has four fields.) If we later assign
# `self.fsa.scores`, it will also change the scores on every arc, which
# means saved_scores will also be changed if we don't use `clone()`
# here.
saved_scores = self.fsa.scores.clone()
# The `scores` of every arc consists of `am_scores` and `lm_scores` # The `scores` of every arc consists of `am_scores` and `lm_scores`
self.fsa.scores = self.fsa.scores - self.fsa.lm_scores self.fsa.scores = self.fsa.scores - self.fsa.lm_scores
@ -391,10 +397,10 @@ class Nbest(object):
Return a ragged tensor with 2 axes [utt][path_scores]. Return a ragged tensor with 2 axes [utt][path_scores].
Its dtype is torch.float64. Its dtype is torch.float64.
""" """
saved_scores = self.fsa.scores saved_scores = self.fsa.scores.clone()
# The `scores` of every arc consists of `am_scores` and `lm_scores` # The `scores` of every arc consists of `am_scores` and `lm_scores`
self.fsa.scores = self.fsa.lm_scores self.fsa.scores = self.fsa.lm_scores.clone()
lm_scores = self.fsa.get_tot_scores( lm_scores = self.fsa.get_tot_scores(
use_double_scores=True, log_semiring=False use_double_scores=True, log_semiring=False