Refactoring.

Since FSAs in an Nbest object are linear in structure, we can
add the scores of a path to compute the total scores.
This commit is contained in:
Fangjun Kuang 2021-11-10 11:52:44 +08:00
parent 68cd287626
commit 8d931690ed
2 changed files with 34 additions and 33 deletions

View File

@ -147,13 +147,21 @@ def get_parser():
help="The lang dir",
)
parser.add_argument(
"--lm-dir",
type=str,
default="data/lm",
help="""The LM dir.
It should contain either G_4_gram.pt or G_4_gram.fst.txt
""",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"lm_dir": Path("data/lm"),
# parameters for conformer
"subsampling_factor": 4,
"vgg_frontend": False,
@ -532,6 +540,7 @@ def main():
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
args.lm_dir = Path(args.lm_dir)
params = get_params()
params.update(vars(args))
@ -572,9 +581,8 @@ def main():
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
)
HLG = HLG.to(device)
assert HLG.requires_grad is False
if not hasattr(HLG, "lm_scores"):
@ -609,8 +617,8 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
G = k2.Fsa.from_dict(d).to(device)
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
G = k2.Fsa.from_dict(d)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
# Add epsilon self-loops to G as we will compose

View File

@ -364,23 +364,13 @@ class Nbest(object):
Return a ragged tensor with 2 axes [utt][path_scores].
Its dtype is torch.float64.
"""
# Caution: We need a clone here. `self.fsa.scores` is a
# reference to a tensor representing the last field of an arc
# in the FSA (Remeber that an arc has four fields.) If we later assign
# `self.fsa.scores`, it will also change the scores on every arc, which
# means saved_scores will also be changed if we don't use `clone()`
# here.
saved_scores = self.fsa.scores.clone()
scores_shape = self.fsa.arcs.shape().remove_axis(1)
# scores_shape has axes [path][arc]
am_scores = self.fsa.scores - self.fsa.lm_scores
ragged_am_scores = k2.RaggedTensor(scores_shape, am_scores.contiguous())
tot_scores = ragged_am_scores.sum()
# The `scores` of every arc consists of `am_scores` and `lm_scores`
self.fsa.scores = self.fsa.scores - self.fsa.lm_scores
am_scores = self.fsa.get_tot_scores(
use_double_scores=True, log_semiring=False
)
self.fsa.scores = saved_scores
return k2.RaggedTensor(self.shape, am_scores)
return k2.RaggedTensor(self.shape, tot_scores)
def compute_lm_scores(self) -> k2.RaggedTensor:
"""Compute LM scores of each linear FSA (i.e., each path within
@ -397,17 +387,16 @@ class Nbest(object):
Return a ragged tensor with 2 axes [utt][path_scores].
Its dtype is torch.float64.
"""
saved_scores = self.fsa.scores.clone()
scores_shape = self.fsa.arcs.shape().remove_axis(1)
# scores_shape has axes [path][arc]
# The `scores` of every arc consists of `am_scores` and `lm_scores`
self.fsa.scores = self.fsa.lm_scores.clone()
lm_scores = self.fsa.get_tot_scores(
use_double_scores=True, log_semiring=False
ragged_lm_scores = k2.RaggedTensor(
scores_shape, self.fsa.lm_scores.contiguous()
)
self.fsa.scores = saved_scores
return k2.RaggedTensor(self.shape, lm_scores)
tot_scores = ragged_lm_scores.sum()
return k2.RaggedTensor(self.shape, tot_scores)
def tot_scores(self) -> k2.RaggedTensor:
"""Get total scores of FSAs in this Nbest.
@ -420,10 +409,14 @@ class Nbest(object):
Return a ragged tensor with two axes [utt][path_scores].
Its dtype is torch.float64.
"""
scores = self.fsa.get_tot_scores(
use_double_scores=True, log_semiring=False
)
return k2.RaggedTensor(self.shape, scores)
scores_shape = self.fsa.arcs.shape().remove_axis(1)
# scores_shape has axes [path][arc]
ragged_scores = k2.RaggedTensor(scores_shape, self.scores.contiguous())
tot_scores = ragged_scores.sum()
return k2.RaggedTensor(self.shape, tot_scores)
def build_levenshtein_graphs(self) -> k2.Fsa:
"""Return an FsaVec with axes [utt][state][arc]."""