mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Refactoring.
Since FSAs in an Nbest object are linear in structure, we can add the scores of a path to compute the total scores.
This commit is contained in:
parent
68cd287626
commit
8d931690ed
@ -147,13 +147,21 @@ def get_parser():
|
|||||||
help="The lang dir",
|
help="The lang dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/lm",
|
||||||
|
help="""The LM dir.
|
||||||
|
It should contain either G_4_gram.pt or G_4_gram.fst.txt
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def get_params() -> AttributeDict:
|
def get_params() -> AttributeDict:
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
"lm_dir": Path("data/lm"),
|
|
||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
@ -532,6 +540,7 @@ def main():
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
args.lang_dir = Path(args.lang_dir)
|
args.lang_dir = Path(args.lang_dir)
|
||||||
|
args.lm_dir = Path(args.lm_dir)
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
@ -572,9 +581,8 @@ def main():
|
|||||||
H = None
|
H = None
|
||||||
bpe_model = None
|
bpe_model = None
|
||||||
HLG = k2.Fsa.from_dict(
|
HLG = k2.Fsa.from_dict(
|
||||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
|
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
||||||
)
|
)
|
||||||
HLG = HLG.to(device)
|
|
||||||
assert HLG.requires_grad is False
|
assert HLG.requires_grad is False
|
||||||
|
|
||||||
if not hasattr(HLG, "lm_scores"):
|
if not hasattr(HLG, "lm_scores"):
|
||||||
@ -609,8 +617,8 @@ def main():
|
|||||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||||
else:
|
else:
|
||||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
|
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
||||||
G = k2.Fsa.from_dict(d).to(device)
|
G = k2.Fsa.from_dict(d)
|
||||||
|
|
||||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||||
# Add epsilon self-loops to G as we will compose
|
# Add epsilon self-loops to G as we will compose
|
||||||
|
@ -364,23 +364,13 @@ class Nbest(object):
|
|||||||
Return a ragged tensor with 2 axes [utt][path_scores].
|
Return a ragged tensor with 2 axes [utt][path_scores].
|
||||||
Its dtype is torch.float64.
|
Its dtype is torch.float64.
|
||||||
"""
|
"""
|
||||||
# Caution: We need a clone here. `self.fsa.scores` is a
|
scores_shape = self.fsa.arcs.shape().remove_axis(1)
|
||||||
# reference to a tensor representing the last field of an arc
|
# scores_shape has axes [path][arc]
|
||||||
# in the FSA (Remeber that an arc has four fields.) If we later assign
|
am_scores = self.fsa.scores - self.fsa.lm_scores
|
||||||
# `self.fsa.scores`, it will also change the scores on every arc, which
|
ragged_am_scores = k2.RaggedTensor(scores_shape, am_scores.contiguous())
|
||||||
# means saved_scores will also be changed if we don't use `clone()`
|
tot_scores = ragged_am_scores.sum()
|
||||||
# here.
|
|
||||||
saved_scores = self.fsa.scores.clone()
|
|
||||||
|
|
||||||
# The `scores` of every arc consists of `am_scores` and `lm_scores`
|
return k2.RaggedTensor(self.shape, tot_scores)
|
||||||
self.fsa.scores = self.fsa.scores - self.fsa.lm_scores
|
|
||||||
|
|
||||||
am_scores = self.fsa.get_tot_scores(
|
|
||||||
use_double_scores=True, log_semiring=False
|
|
||||||
)
|
|
||||||
self.fsa.scores = saved_scores
|
|
||||||
|
|
||||||
return k2.RaggedTensor(self.shape, am_scores)
|
|
||||||
|
|
||||||
def compute_lm_scores(self) -> k2.RaggedTensor:
|
def compute_lm_scores(self) -> k2.RaggedTensor:
|
||||||
"""Compute LM scores of each linear FSA (i.e., each path within
|
"""Compute LM scores of each linear FSA (i.e., each path within
|
||||||
@ -397,17 +387,16 @@ class Nbest(object):
|
|||||||
Return a ragged tensor with 2 axes [utt][path_scores].
|
Return a ragged tensor with 2 axes [utt][path_scores].
|
||||||
Its dtype is torch.float64.
|
Its dtype is torch.float64.
|
||||||
"""
|
"""
|
||||||
saved_scores = self.fsa.scores.clone()
|
scores_shape = self.fsa.arcs.shape().remove_axis(1)
|
||||||
|
# scores_shape has axes [path][arc]
|
||||||
|
|
||||||
# The `scores` of every arc consists of `am_scores` and `lm_scores`
|
ragged_lm_scores = k2.RaggedTensor(
|
||||||
self.fsa.scores = self.fsa.lm_scores.clone()
|
scores_shape, self.fsa.lm_scores.contiguous()
|
||||||
|
|
||||||
lm_scores = self.fsa.get_tot_scores(
|
|
||||||
use_double_scores=True, log_semiring=False
|
|
||||||
)
|
)
|
||||||
self.fsa.scores = saved_scores
|
|
||||||
|
|
||||||
return k2.RaggedTensor(self.shape, lm_scores)
|
tot_scores = ragged_lm_scores.sum()
|
||||||
|
|
||||||
|
return k2.RaggedTensor(self.shape, tot_scores)
|
||||||
|
|
||||||
def tot_scores(self) -> k2.RaggedTensor:
|
def tot_scores(self) -> k2.RaggedTensor:
|
||||||
"""Get total scores of FSAs in this Nbest.
|
"""Get total scores of FSAs in this Nbest.
|
||||||
@ -420,10 +409,14 @@ class Nbest(object):
|
|||||||
Return a ragged tensor with two axes [utt][path_scores].
|
Return a ragged tensor with two axes [utt][path_scores].
|
||||||
Its dtype is torch.float64.
|
Its dtype is torch.float64.
|
||||||
"""
|
"""
|
||||||
scores = self.fsa.get_tot_scores(
|
scores_shape = self.fsa.arcs.shape().remove_axis(1)
|
||||||
use_double_scores=True, log_semiring=False
|
# scores_shape has axes [path][arc]
|
||||||
)
|
|
||||||
return k2.RaggedTensor(self.shape, scores)
|
ragged_scores = k2.RaggedTensor(scores_shape, self.scores.contiguous())
|
||||||
|
|
||||||
|
tot_scores = ragged_scores.sum()
|
||||||
|
|
||||||
|
return k2.RaggedTensor(self.shape, tot_scores)
|
||||||
|
|
||||||
def build_levenshtein_graphs(self) -> k2.Fsa:
|
def build_levenshtein_graphs(self) -> k2.Fsa:
|
||||||
"""Return an FsaVec with axes [utt][state][arc]."""
|
"""Return an FsaVec with axes [utt][state][arc]."""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user