mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Refactoring.
Since FSAs in an Nbest object are linear in structure, we can add the scores of a path to compute the total scores.
This commit is contained in:
parent
68cd287626
commit
8d931690ed
@ -147,13 +147,21 @@ def get_parser():
|
||||
help="The lang dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-dir",
|
||||
type=str,
|
||||
default="data/lm",
|
||||
help="""The LM dir.
|
||||
It should contain either G_4_gram.pt or G_4_gram.fst.txt
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"lm_dir": Path("data/lm"),
|
||||
# parameters for conformer
|
||||
"subsampling_factor": 4,
|
||||
"vgg_frontend": False,
|
||||
@ -532,6 +540,7 @@ def main():
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
args.lm_dir = Path(args.lm_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
@ -572,9 +581,8 @@ def main():
|
||||
H = None
|
||||
bpe_model = None
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
||||
)
|
||||
HLG = HLG.to(device)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
@ -609,8 +617,8 @@ def main():
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
|
||||
G = k2.Fsa.from_dict(d).to(device)
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
|
||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
|
@ -364,23 +364,13 @@ class Nbest(object):
|
||||
Return a ragged tensor with 2 axes [utt][path_scores].
|
||||
Its dtype is torch.float64.
|
||||
"""
|
||||
# Caution: We need a clone here. `self.fsa.scores` is a
|
||||
# reference to a tensor representing the last field of an arc
|
||||
# in the FSA (Remeber that an arc has four fields.) If we later assign
|
||||
# `self.fsa.scores`, it will also change the scores on every arc, which
|
||||
# means saved_scores will also be changed if we don't use `clone()`
|
||||
# here.
|
||||
saved_scores = self.fsa.scores.clone()
|
||||
scores_shape = self.fsa.arcs.shape().remove_axis(1)
|
||||
# scores_shape has axes [path][arc]
|
||||
am_scores = self.fsa.scores - self.fsa.lm_scores
|
||||
ragged_am_scores = k2.RaggedTensor(scores_shape, am_scores.contiguous())
|
||||
tot_scores = ragged_am_scores.sum()
|
||||
|
||||
# The `scores` of every arc consists of `am_scores` and `lm_scores`
|
||||
self.fsa.scores = self.fsa.scores - self.fsa.lm_scores
|
||||
|
||||
am_scores = self.fsa.get_tot_scores(
|
||||
use_double_scores=True, log_semiring=False
|
||||
)
|
||||
self.fsa.scores = saved_scores
|
||||
|
||||
return k2.RaggedTensor(self.shape, am_scores)
|
||||
return k2.RaggedTensor(self.shape, tot_scores)
|
||||
|
||||
def compute_lm_scores(self) -> k2.RaggedTensor:
|
||||
"""Compute LM scores of each linear FSA (i.e., each path within
|
||||
@ -397,17 +387,16 @@ class Nbest(object):
|
||||
Return a ragged tensor with 2 axes [utt][path_scores].
|
||||
Its dtype is torch.float64.
|
||||
"""
|
||||
saved_scores = self.fsa.scores.clone()
|
||||
scores_shape = self.fsa.arcs.shape().remove_axis(1)
|
||||
# scores_shape has axes [path][arc]
|
||||
|
||||
# The `scores` of every arc consists of `am_scores` and `lm_scores`
|
||||
self.fsa.scores = self.fsa.lm_scores.clone()
|
||||
|
||||
lm_scores = self.fsa.get_tot_scores(
|
||||
use_double_scores=True, log_semiring=False
|
||||
ragged_lm_scores = k2.RaggedTensor(
|
||||
scores_shape, self.fsa.lm_scores.contiguous()
|
||||
)
|
||||
self.fsa.scores = saved_scores
|
||||
|
||||
return k2.RaggedTensor(self.shape, lm_scores)
|
||||
tot_scores = ragged_lm_scores.sum()
|
||||
|
||||
return k2.RaggedTensor(self.shape, tot_scores)
|
||||
|
||||
def tot_scores(self) -> k2.RaggedTensor:
|
||||
"""Get total scores of FSAs in this Nbest.
|
||||
@ -420,10 +409,14 @@ class Nbest(object):
|
||||
Return a ragged tensor with two axes [utt][path_scores].
|
||||
Its dtype is torch.float64.
|
||||
"""
|
||||
scores = self.fsa.get_tot_scores(
|
||||
use_double_scores=True, log_semiring=False
|
||||
)
|
||||
return k2.RaggedTensor(self.shape, scores)
|
||||
scores_shape = self.fsa.arcs.shape().remove_axis(1)
|
||||
# scores_shape has axes [path][arc]
|
||||
|
||||
ragged_scores = k2.RaggedTensor(scores_shape, self.scores.contiguous())
|
||||
|
||||
tot_scores = ragged_scores.sum()
|
||||
|
||||
return k2.RaggedTensor(self.shape, tot_scores)
|
||||
|
||||
def build_levenshtein_graphs(self) -> k2.Fsa:
|
||||
"""Return an FsaVec with axes [utt][state][arc]."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user