Refactoring.

Since FSAs in an Nbest object are linear in structure, we can
add the scores of a path to compute the total scores.
This commit is contained in:
Fangjun Kuang 2021-11-10 11:52:44 +08:00
parent 68cd287626
commit 8d931690ed
2 changed files with 34 additions and 33 deletions

View File

@ -147,13 +147,21 @@ def get_parser():
help="The lang dir", help="The lang dir",
) )
parser.add_argument(
"--lm-dir",
type=str,
default="data/lm",
help="""The LM dir.
It should contain either G_4_gram.pt or G_4_gram.fst.txt
""",
)
return parser return parser
def get_params() -> AttributeDict: def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
"lm_dir": Path("data/lm"),
# parameters for conformer # parameters for conformer
"subsampling_factor": 4, "subsampling_factor": 4,
"vgg_frontend": False, "vgg_frontend": False,
@ -532,6 +540,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir) args.lang_dir = Path(args.lang_dir)
args.lm_dir = Path(args.lm_dir)
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
@ -572,9 +581,8 @@ def main():
H = None H = None
bpe_model = None bpe_model = None
HLG = k2.Fsa.from_dict( HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
) )
HLG = HLG.to(device)
assert HLG.requires_grad is False assert HLG.requires_grad is False
if not hasattr(HLG, "lm_scores"): if not hasattr(HLG, "lm_scores"):
@ -609,8 +617,8 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else: else:
logging.info("Loading pre-compiled G_4_gram.pt") logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
G = k2.Fsa.from_dict(d).to(device) G = k2.Fsa.from_dict(d)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]: if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
# Add epsilon self-loops to G as we will compose # Add epsilon self-loops to G as we will compose

View File

@ -364,23 +364,13 @@ class Nbest(object):
Return a ragged tensor with 2 axes [utt][path_scores]. Return a ragged tensor with 2 axes [utt][path_scores].
Its dtype is torch.float64. Its dtype is torch.float64.
""" """
# Caution: We need a clone here. `self.fsa.scores` is a scores_shape = self.fsa.arcs.shape().remove_axis(1)
# reference to a tensor representing the last field of an arc # scores_shape has axes [path][arc]
# in the FSA (Remeber that an arc has four fields.) If we later assign am_scores = self.fsa.scores - self.fsa.lm_scores
# `self.fsa.scores`, it will also change the scores on every arc, which ragged_am_scores = k2.RaggedTensor(scores_shape, am_scores.contiguous())
# means saved_scores will also be changed if we don't use `clone()` tot_scores = ragged_am_scores.sum()
# here.
saved_scores = self.fsa.scores.clone()
# The `scores` of every arc consists of `am_scores` and `lm_scores` return k2.RaggedTensor(self.shape, tot_scores)
self.fsa.scores = self.fsa.scores - self.fsa.lm_scores
am_scores = self.fsa.get_tot_scores(
use_double_scores=True, log_semiring=False
)
self.fsa.scores = saved_scores
return k2.RaggedTensor(self.shape, am_scores)
def compute_lm_scores(self) -> k2.RaggedTensor: def compute_lm_scores(self) -> k2.RaggedTensor:
"""Compute LM scores of each linear FSA (i.e., each path within """Compute LM scores of each linear FSA (i.e., each path within
@ -397,17 +387,16 @@ class Nbest(object):
Return a ragged tensor with 2 axes [utt][path_scores]. Return a ragged tensor with 2 axes [utt][path_scores].
Its dtype is torch.float64. Its dtype is torch.float64.
""" """
saved_scores = self.fsa.scores.clone() scores_shape = self.fsa.arcs.shape().remove_axis(1)
# scores_shape has axes [path][arc]
# The `scores` of every arc consists of `am_scores` and `lm_scores` ragged_lm_scores = k2.RaggedTensor(
self.fsa.scores = self.fsa.lm_scores.clone() scores_shape, self.fsa.lm_scores.contiguous()
lm_scores = self.fsa.get_tot_scores(
use_double_scores=True, log_semiring=False
) )
self.fsa.scores = saved_scores
return k2.RaggedTensor(self.shape, lm_scores) tot_scores = ragged_lm_scores.sum()
return k2.RaggedTensor(self.shape, tot_scores)
def tot_scores(self) -> k2.RaggedTensor: def tot_scores(self) -> k2.RaggedTensor:
"""Get total scores of FSAs in this Nbest. """Get total scores of FSAs in this Nbest.
@ -420,10 +409,14 @@ class Nbest(object):
Return a ragged tensor with two axes [utt][path_scores]. Return a ragged tensor with two axes [utt][path_scores].
Its dtype is torch.float64. Its dtype is torch.float64.
""" """
scores = self.fsa.get_tot_scores( scores_shape = self.fsa.arcs.shape().remove_axis(1)
use_double_scores=True, log_semiring=False # scores_shape has axes [path][arc]
)
return k2.RaggedTensor(self.shape, scores) ragged_scores = k2.RaggedTensor(scores_shape, self.scores.contiguous())
tot_scores = ragged_scores.sum()
return k2.RaggedTensor(self.shape, tot_scores)
def build_levenshtein_graphs(self) -> k2.Fsa: def build_levenshtein_graphs(self) -> k2.Fsa:
"""Return an FsaVec with axes [utt][state][arc].""" """Return an FsaVec with axes [utt][state][arc]."""