mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
Add attention rescore pipeline
This commit is contained in:
parent
286dce7b0f
commit
0669aa8ab9
@ -100,6 +100,7 @@ def decode_one_batch(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
HLG: k2.Fsa,
|
HLG: k2.Fsa,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
|
batch_idx: int,
|
||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
sos_id: int,
|
sos_id: int,
|
||||||
eos_id: int,
|
eos_id: int,
|
||||||
@ -201,6 +202,7 @@ def decode_one_batch(
|
|||||||
"nbest-rescoring",
|
"nbest-rescoring",
|
||||||
"whole-lattice-rescoring",
|
"whole-lattice-rescoring",
|
||||||
"attention-decoder",
|
"attention-decoder",
|
||||||
|
"attention-decoder-v2",
|
||||||
]
|
]
|
||||||
|
|
||||||
lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
||||||
@ -232,6 +234,23 @@ def decode_one_batch(
|
|||||||
sos_id=sos_id,
|
sos_id=sos_id,
|
||||||
eos_id=eos_id,
|
eos_id=eos_id,
|
||||||
)
|
)
|
||||||
|
elif params.method == "attention-decoder-v2":
|
||||||
|
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||||
|
rescored_lattice = rescore_with_whole_lattice(
|
||||||
|
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
|
||||||
|
)
|
||||||
|
best_path_dict = rescore_with_attention_decoder_v2(
|
||||||
|
lattice=rescored_lattice,
|
||||||
|
batch_idx=batch_idx,
|
||||||
|
dump_best_matching_feature=params.dump_feature,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
top_k=params.top_k,
|
||||||
|
model=model,
|
||||||
|
memory=memory,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert False, f"Unsupported decoding method: {params.method}"
|
assert False, f"Unsupported decoding method: {params.method}"
|
||||||
|
|
||||||
@ -295,6 +314,7 @@ def decode_dataset(
|
|||||||
model=model,
|
model=model,
|
||||||
HLG=HLG,
|
HLG=HLG,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
|
batch_idx,
|
||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
G=G,
|
G=G,
|
||||||
sos_id=sos_id,
|
sos_id=sos_id,
|
||||||
|
@ -25,7 +25,7 @@ stop_stage=100
|
|||||||
# - librispeech-vocab.txt
|
# - librispeech-vocab.txt
|
||||||
# - librispeech-lexicon.txt
|
# - librispeech-lexicon.txt
|
||||||
#
|
#
|
||||||
# - $do_dir/musan
|
# - $dl_dir/musan
|
||||||
# This directory contains the following directories downloaded from
|
# This directory contains the following directories downloaded from
|
||||||
# http://www.openslr.org/17/
|
# http://www.openslr.org/17/
|
||||||
#
|
#
|
||||||
|
@ -721,3 +721,248 @@ def rescore_with_attention_decoder(
|
|||||||
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
|
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
|
||||||
ans[key] = best_path_fsa
|
ans[key] = best_path_fsa
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def rescore_nbest_with_attention_decoder(
|
||||||
|
nbest: Nbest,
|
||||||
|
model: nn.Module,
|
||||||
|
memory: torch.Tensor,
|
||||||
|
memory_key_padding_mask: torch.Tensor,
|
||||||
|
sos_id: int,
|
||||||
|
eos_id: int,
|
||||||
|
) -> Nbest:
|
||||||
|
"""This function rescores an nbest list with an attention decoder. The paths
|
||||||
|
with rescored scores are returned as a new nbest.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nbest:
|
||||||
|
An Nbest, the nbest path of given sequences.
|
||||||
|
It can be the return value of :func:`generate_nbest_list`.
|
||||||
|
model:
|
||||||
|
A transformer model. See the class "Transformer" in
|
||||||
|
conformer_ctc/transformer.py for its interface.
|
||||||
|
memory:
|
||||||
|
The encoder memory of the given model. It is the output of
|
||||||
|
the last torch.nn.TransformerEncoder layer in the given model.
|
||||||
|
Its shape is `[T, N, C]`.
|
||||||
|
memory_key_padding_mask:
|
||||||
|
The padding mask for memory with shape [N, T].
|
||||||
|
sos_id:
|
||||||
|
The token ID for SOS.
|
||||||
|
eos_id:
|
||||||
|
The token ID for EOS.
|
||||||
|
Returns:
|
||||||
|
A dict of FsaVec, whose key contains a string
|
||||||
|
ngram_lm_scale_attention_scale and the value is the
|
||||||
|
best decoding path for each sequence in the lattice.
|
||||||
|
"""
|
||||||
|
num_seqs = nbest.shape.Dim0()
|
||||||
|
token_seq = k2.RaggedInt(nbest.shape, nbest.fsas.labels().contiguous())
|
||||||
|
|
||||||
|
# Remove -1 from token_seq, there is no epsilon tokens in token_seq, we
|
||||||
|
# removed it when generating nbest list
|
||||||
|
token_seq = k2.ragged.remove_values_leq(token_seq, -1)
|
||||||
|
|
||||||
|
token_ids = k2.ragged.to_list(token_seq)
|
||||||
|
|
||||||
|
path_to_seq_map_long = token_seq.shape.row_ids(1).to(torch.long)
|
||||||
|
expanded_memory = memory.index_select(1, path_to_seq_map_long)
|
||||||
|
|
||||||
|
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
|
||||||
|
0, path_to_seq_map_long
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: pass the sos_token_id and eos_token_id via function arguments
|
||||||
|
nll = model.decoder_nll(
|
||||||
|
memory=expanded_memory,
|
||||||
|
memory_key_padding_mask=expanded_memory_key_padding_mask,
|
||||||
|
token_ids=token_ids,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
)
|
||||||
|
assert nll.ndim == 2
|
||||||
|
assert nll.shape[0] == num_seqs
|
||||||
|
|
||||||
|
attention_scores = torch.zeros(
|
||||||
|
nbest.fsas.labels().size()[0],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=nbest.device
|
||||||
|
)
|
||||||
|
start_index = 0
|
||||||
|
for i in range(num_seqs):
|
||||||
|
# Plus 1 to fill the score of final arc
|
||||||
|
tokens_num = len(tokens_ids[i]) + 1
|
||||||
|
attention_scores[start_index: start_index + tokens_num] =
|
||||||
|
nll[i][0: tokens_num]
|
||||||
|
start_index += tokens_num
|
||||||
|
|
||||||
|
fsas = nbest.fsas.clone()
|
||||||
|
fsas.score = attention_scores
|
||||||
|
return Nbest(fsas, nbest.shape.clone())
|
||||||
|
|
||||||
|
|
||||||
|
def rescore_with_attention_decoder_v2(
|
||||||
|
lattice: k2.Fsa,
|
||||||
|
batch_idx: int,
|
||||||
|
dump_best_matching_feature: bool,
|
||||||
|
num_paths: int,
|
||||||
|
top_k: int,
|
||||||
|
model: nn.Module,
|
||||||
|
memory: torch.Tensor,
|
||||||
|
memory_key_padding_mask: torch.Tensor,
|
||||||
|
sos_id: int,
|
||||||
|
eos_id: int,
|
||||||
|
) -> Dict[str, k2.Fsa]:
|
||||||
|
"""This function extracts n paths from the given lattice and uses
|
||||||
|
an attention decoder to rescore them. The path with the highest
|
||||||
|
score is used as the decoding output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lattice:
|
||||||
|
An FsaVec. It can be the return value of :func:`get_lattice`.
|
||||||
|
num_paths:
|
||||||
|
Number of paths to extract from the given lattice for rescoring.
|
||||||
|
model:
|
||||||
|
A transformer model. See the class "Transformer" in
|
||||||
|
conformer_ctc/transformer.py for its interface.
|
||||||
|
memory:
|
||||||
|
The encoder memory of the given model. It is the output of
|
||||||
|
the last torch.nn.TransformerEncoder layer in the given model.
|
||||||
|
Its shape is `[T, N, C]`.
|
||||||
|
memory_key_padding_mask:
|
||||||
|
The padding mask for memory with shape [N, T].
|
||||||
|
sos_id:
|
||||||
|
The token ID for SOS.
|
||||||
|
eos_id:
|
||||||
|
The token ID for EOS.
|
||||||
|
Returns:
|
||||||
|
A dict of FsaVec, whose key contains a string
|
||||||
|
ngram_lm_scale_attention_scale and the value is the
|
||||||
|
best decoding path for each sequence in the lattice.
|
||||||
|
"""
|
||||||
|
nbest = generate_nbest_list(lattice, num_paths)
|
||||||
|
# Now we have nbest with scores
|
||||||
|
nbest = nbest.intersect(lattice)
|
||||||
|
|
||||||
|
if dump_best_matching_feature:
|
||||||
|
nbest_k, nbest_q = nbest.split(k=top_k, sort=False)
|
||||||
|
rescored_nbest_k = rescore_nbest_with_attention_decoder(
|
||||||
|
nbest=nbest_k,
|
||||||
|
model=model,
|
||||||
|
memory=memory,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
)
|
||||||
|
stats_tensor = get_best_matching_stats(
|
||||||
|
rescored_nbest_k,
|
||||||
|
nbest_q,
|
||||||
|
max_order=3
|
||||||
|
)
|
||||||
|
rescored_nbest_q = rescore_nbest_with_attention_decoder(
|
||||||
|
nbest=nbest_q,
|
||||||
|
model=model,
|
||||||
|
memory=memory,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
# return feature & label or dump to file
|
||||||
|
|
||||||
|
nbest_topk, nbest_remain = nbest.split(k=top_k)
|
||||||
|
|
||||||
|
rescored_nbest_topk = rescore_nbest_with_attention_decoder(
|
||||||
|
nbest=nbest_topk,
|
||||||
|
model=model,
|
||||||
|
memory=memory,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
)
|
||||||
|
stats_tensor = get_best_matching_stats(
|
||||||
|
rescored_nbest_topk,
|
||||||
|
nbest_remain,
|
||||||
|
max_order=3
|
||||||
|
)
|
||||||
|
# run rescore estimation model to get the mean and var of each token
|
||||||
|
mean, var = rescore_est_model(stats_tensor)
|
||||||
|
# calculate nbest_remain estimated score and select topk
|
||||||
|
nbest_remain_topk = nbest_remain.top_k(k=top_k)
|
||||||
|
rescored_nbest_remain_topk = rescore_nbest_with_attention_decoder(
|
||||||
|
nbest=nbest_remain_topk,
|
||||||
|
model=model,
|
||||||
|
memory=memory,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
)
|
||||||
|
best_path_dict=get_best_path_from_nbests(
|
||||||
|
rescored_nbest_topk,
|
||||||
|
rescored_nbest_remain_topk,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def generate_nbest_list(
|
||||||
|
lats: k2.Fsa,
|
||||||
|
num_paths: int,
|
||||||
|
aux_labels: bool = False
|
||||||
|
) -> Nbest:
|
||||||
|
'''Generate an n-best list from a lattice.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lats:
|
||||||
|
The decoding lattice from the first pass after LM rescoring.
|
||||||
|
lats is an FsaVec. It can be the return value of
|
||||||
|
:func:`rescore_with_whole_lattice`
|
||||||
|
num_paths:
|
||||||
|
Size of n for n-best list. CAUTION: After removing paths
|
||||||
|
that represent the same word sequences, the number of paths
|
||||||
|
in different sequences may not be equal.
|
||||||
|
Return:
|
||||||
|
Return an Nbest object. Note the returned FSAs don't have epsilon
|
||||||
|
self-loops.
|
||||||
|
'''
|
||||||
|
assert len(lats.shape) == 3
|
||||||
|
|
||||||
|
# First, extract `num_paths` paths for each sequence.
|
||||||
|
# paths is a k2.RaggedInt with axes [seq][path][arc_pos]
|
||||||
|
paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)
|
||||||
|
|
||||||
|
# Seqs is a k2.RaggedInt sharing the same shape as `paths`.
|
||||||
|
# Note that it also contains 0s and -1s.
|
||||||
|
# The last entry in each sublist is -1.
|
||||||
|
# Its axes are [seq][path][word_id]
|
||||||
|
if aux_labels:
|
||||||
|
# if aux_labels enable, seqs contains word_id
|
||||||
|
assert hasattr(lats, "aux_labels")
|
||||||
|
seqs = k2.index(lats.aux_labels, paths)
|
||||||
|
else:
|
||||||
|
# CAUTION: We use `phones` instead of `tokens` here because
|
||||||
|
# :func:`compile_HLG` uses `phones`
|
||||||
|
#
|
||||||
|
# Note: compile_HLG is from k2-fsa/snowfall
|
||||||
|
assert hasattr(lats, 'phones')
|
||||||
|
|
||||||
|
assert not hasattr(lats, 'tokens')
|
||||||
|
lats.tokens = lats.phones
|
||||||
|
seqs = k2.index(lats.tokens, paths)
|
||||||
|
|
||||||
|
# Remove epsilons (0s) and -1 from word_seqs
|
||||||
|
seqs = k2.ragged.remove_values_leq(seqs, 0)
|
||||||
|
|
||||||
|
# unique_word_seqs is still a k2.RaggedInt with axes [seq][path][word_id].
|
||||||
|
# But then number of pathsin each sequence may be different.
|
||||||
|
unique_seqs, _, _ = k2.ragged.unique_sequences(
|
||||||
|
seqs, need_num_repeats=False, need_new2old_indexes=False)
|
||||||
|
|
||||||
|
seq_to_path_shape = k2.ragged.get_layer(unique_seqs.shape(), 0)
|
||||||
|
|
||||||
|
# Remove the seq axis.
|
||||||
|
# Now unique_word_seqs has only two axes [path][word_id]
|
||||||
|
unique_seqs = k2.ragged.remove_axis(unique_seqs, 0)
|
||||||
|
|
||||||
|
fsas = k2.linear_fsa(unique_seqs)
|
||||||
|
|
||||||
|
return Nbest(fsa=fsas, shape=seq_to_path_shape)
|
||||||
|
|
||||||
|
156
icefall/nbest.py
156
icefall/nbest.py
@ -5,10 +5,9 @@
|
|||||||
# See https://github.com/k2-fsa/snowfall/issues/232 for more details
|
# See https://github.com/k2-fsa/snowfall/issues/232 for more details
|
||||||
#
|
#
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import _k2
|
|
||||||
import k2
|
import k2
|
||||||
|
|
||||||
# Note: We use `utterance` and `sequence` interchangeably in the comment
|
# Note: We use `utterance` and `sequence` interchangeably in the comment
|
||||||
@ -19,7 +18,7 @@ class Nbest(object):
|
|||||||
An Nbest object contains two fields:
|
An Nbest object contains two fields:
|
||||||
|
|
||||||
(1) fsa, its type is k2.Fsa
|
(1) fsa, its type is k2.Fsa
|
||||||
(2) shape, its type is k2.RaggedShape (alias to _k2.RaggedShape)
|
(2) shape, its type is k2.RaggedShape
|
||||||
|
|
||||||
The field `fsa` is an FsaVec containing a vector of **linear** FSAs.
|
The field `fsa` is an FsaVec containing a vector of **linear** FSAs.
|
||||||
|
|
||||||
@ -29,7 +28,7 @@ class Nbest(object):
|
|||||||
of paths, which is also the number of FSAs in `fsa`.
|
of paths, which is also the number of FSAs in `fsa`.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, fsa: k2.Fsa, shape: _k2.RaggedShape) -> None:
|
def __init__(self, fsa: k2.Fsa, shape: k2.RaggedShape) -> None:
|
||||||
assert len(fsa.shape) == 3, f'fsa.shape: {fsa.shape}'
|
assert len(fsa.shape) == 3, f'fsa.shape: {fsa.shape}'
|
||||||
assert shape.num_axes() == 2, f'num_axes: {shape.num_axes()}'
|
assert shape.num_axes() == 2, f'num_axes: {shape.num_axes()}'
|
||||||
|
|
||||||
@ -85,7 +84,7 @@ class Nbest(object):
|
|||||||
|
|
||||||
return Nbest(fsa=one_best, shape=self.shape)
|
return Nbest(fsa=one_best, shape=self.shape)
|
||||||
|
|
||||||
def total_scores(self) -> _k2.RaggedFloat:
|
def total_scores(self) -> k2.RaggedFloat:
|
||||||
'''Get total scores of the FSAs in this Nbest.
|
'''Get total scores of the FSAs in this Nbest.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
@ -99,7 +98,7 @@ class Nbest(object):
|
|||||||
log_semiring=False)
|
log_semiring=False)
|
||||||
# We use single precision here since we only wrap k2.RaggedFloat.
|
# We use single precision here since we only wrap k2.RaggedFloat.
|
||||||
# If k2.RaggedDouble is wrapped, we can use double precision here.
|
# If k2.RaggedDouble is wrapped, we can use double precision here.
|
||||||
return _k2.RaggedFloat(self.shape, scores.float())
|
return k2.RaggedFloat(self.shape, scores.float())
|
||||||
|
|
||||||
def top_k(self, k: int) -> 'Nbest':
|
def top_k(self, k: int) -> 'Nbest':
|
||||||
'''Get a subset of paths in the Nbest. The resulting Nbest is regular
|
'''Get a subset of paths in the Nbest. The resulting Nbest is regular
|
||||||
@ -144,121 +143,66 @@ class Nbest(object):
|
|||||||
return Nbest(top_k_fsas, top_k_shape)
|
return Nbest(top_k_fsas, top_k_shape)
|
||||||
|
|
||||||
|
|
||||||
def whole_lattice_rescoring(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa) -> k2.Fsa:
|
def split(self, k: int, sort: bool = True) -> Tuple['Nbest', 'Nbest']:
|
||||||
'''Rescore the 1st pass lattice with an LM.
|
'''Split the paths in the Nbest into two parts, the first part is the
|
||||||
|
first k paths for each sequence in the Nbest, the second part is the
|
||||||
|
remaining paths.
|
||||||
|
There may be less than k paths for the responding sequence in the part,
|
||||||
|
|
||||||
In general, the G in HLG used to obtain `lats` is a 3-gram LM.
|
If the sort flag is true, we select the top-k paths according to the
|
||||||
This function replaces the 3-gram LM in `lats` with a 4-gram LM.
|
total_scores of each path in descending order, If a utterance has less
|
||||||
|
than k paths, then the first part will have the really number of paths
|
||||||
|
and leaving the second part empty.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
lats:
|
k:
|
||||||
The decoding lattice from the 1st pass. We assume it is the result
|
Number of paths in the first part of each utterance.
|
||||||
of intersecting HLG with the network output.
|
Returns:
|
||||||
G_with_epsilon_loops:
|
Return a tuple of new Nbest.
|
||||||
An LM. It is usually a 4-gram LM with epsilon self-loops.
|
'''
|
||||||
It should be arc sorted.
|
# indexes contains idx01's for self.shape
|
||||||
Returns:
|
indexes = torch.arange(
|
||||||
Return a new lattice rescored with a given G.
|
self.shape.num_elements(), dtype=torch.int32,
|
||||||
'''
|
device=self.shape.device
|
||||||
assert len(lats.shape) == 3, f'{lats.shape}'
|
)
|
||||||
assert hasattr(lats, 'lm_scores')
|
|
||||||
assert G_with_epsilon_loops.shape == (1, None, None), \
|
|
||||||
f'{G_with_epsilon_loops.shape}'
|
|
||||||
|
|
||||||
device = lats.device
|
if sort:
|
||||||
lats.scores = lats.scores - lats.lm_scores
|
ragged_scores = self.total_scores()
|
||||||
# Now lats contains only acoustic scores
|
|
||||||
|
|
||||||
# We will use lm_scores from the given G, so remove lats.lm_scores here
|
# ragged_scores.values()[indexes] is sorted
|
||||||
del lats.lm_scores
|
indexes = k2.ragged.sort_sublist(
|
||||||
assert hasattr(lats, 'lm_scores') is False
|
ragged_scores, descending=True, need_new2old_indexes=True
|
||||||
|
)
|
||||||
|
|
||||||
# inverted_lats has word IDs as labels.
|
ragged_indexes = k2.RaggedInt(self.shape, indexes)
|
||||||
# Its aux_labels are token IDs, which is a ragged tensor k2.RaggedInt
|
|
||||||
# if lats.aux_labels is a ragged tensor
|
|
||||||
inverted_lats = k2.invert(lats)
|
|
||||||
num_seqs = lats.shape[0]
|
|
||||||
|
|
||||||
b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
|
padded_indexes = k2.ragged.pad(ragged_indexes, value=-1)
|
||||||
|
|
||||||
while True:
|
# Select the idx01's of top-k paths of each utterance
|
||||||
try:
|
first_indexes = padded_indexes[:, :k].flatten().contiguous()
|
||||||
rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
|
|
||||||
inverted_lats,
|
|
||||||
b_to_a_map,
|
|
||||||
sorted_match_a=True)
|
|
||||||
break
|
|
||||||
except RuntimeError as e:
|
|
||||||
logging.info(f'Caught exception:\n{e}\n')
|
|
||||||
# Usually, this is an OOM exception. We reduce
|
|
||||||
# the size of the lattice and redo k2.intersect_device()
|
|
||||||
|
|
||||||
# NOTE(fangjun): The choice of the threshold 1e-5 is arbitrary here
|
# Remove the padding elements
|
||||||
# to avoid OOM. We may need to fine tune it.
|
first_indexes = first_indexes[first_indexes >= 0]
|
||||||
logging.info(f'num_arcs before: {inverted_lats.num_arcs}')
|
|
||||||
inverted_lats = k2.prune_on_arc_post(inverted_lats, 1e-5, True)
|
|
||||||
logging.info(f'num_arcs after: {inverted_lats.num_arcs}')
|
|
||||||
|
|
||||||
rescoring_lats = k2.top_sort(k2.connect(rescoring_lats))
|
first_fsas = k2.index_fsa(self.fsa, first_indexes)
|
||||||
|
|
||||||
# inv_rescoring_lats has token IDs as labels
|
first_row_ids = k2.index(self.shape.row_ids(1), first_indexes)
|
||||||
# and word IDs as aux_labels.
|
first_shape = k2.ragged.create_ragged_shape2(row_ids=first_row_ids)
|
||||||
inv_rescoring_lats = k2.invert(rescoring_lats)
|
|
||||||
return inv_rescoring_lats
|
|
||||||
|
|
||||||
|
first_nbest = Nbest(first_fsas, first_shape)
|
||||||
|
|
||||||
def generate_nbest_list(lats: k2.Fsa, num_paths: int) -> Nbest:
|
# Select the idx01's of remaining paths of each utterance
|
||||||
'''Generate an n-best list from a lattice.
|
second_indexes = padded_indexes[:, k:].flatten().contiguous()
|
||||||
|
|
||||||
Args:
|
# Remove the padding elements
|
||||||
lats:
|
second_indexes = second_indexes[second_indexes >= 0]
|
||||||
The decoding lattice from the first pass after LM rescoring.
|
|
||||||
lats is an FsaVec. It can be the return value of
|
|
||||||
:func:`whole_lattice_rescoring`
|
|
||||||
num_paths:
|
|
||||||
Size of n for n-best list. CAUTION: After removing paths
|
|
||||||
that represent the same token sequences, the number of paths
|
|
||||||
in different sequences may not be equal.
|
|
||||||
Return:
|
|
||||||
Return an Nbest object. Note the returned FSAs don't have epsilon
|
|
||||||
self-loops.
|
|
||||||
'''
|
|
||||||
assert len(lats.shape) == 3
|
|
||||||
|
|
||||||
# CAUTION: We use `phones` instead of `tokens` here because
|
second_fsas = k2.index_fsa(self.fsa, second_indexes)
|
||||||
# :func:`compile_HLG` uses `phones`
|
|
||||||
#
|
|
||||||
# Note: compile_HLG is from k2-fsa/snowfall
|
|
||||||
assert hasattr(lats, 'phones')
|
|
||||||
|
|
||||||
assert not hasattr(lats, 'tokens')
|
second_row_ids = k2.index(self.shape.row_ids(1), second_indexes)
|
||||||
lats.tokens = lats.phones
|
second_shape = k2.ragged.create_ragged_shape2(row_ids=second_row_ids)
|
||||||
# we use tokens instead of phones in the following code
|
|
||||||
|
|
||||||
# First, extract `num_paths` paths for each sequence.
|
second_nbest = Nbest(second_fsas, second_shape)
|
||||||
# paths is a k2.RaggedInt with axes [seq][path][arc_pos]
|
|
||||||
paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)
|
|
||||||
|
|
||||||
# token_seqs is a k2.RaggedInt sharing the same shape as `paths`
|
return first_nbest, second_nbest
|
||||||
# but it contains token IDs. Note that it also contains 0s and -1s.
|
|
||||||
# The last entry in each sublist is -1.
|
|
||||||
# Its axes are [seq][path][token_id]
|
|
||||||
token_seqs = k2.index(lats.tokens, paths)
|
|
||||||
|
|
||||||
# Remove epsilons (0s) and -1 from token_seqs
|
|
||||||
token_seqs = k2.ragged.remove_values_leq(token_seqs, 0)
|
|
||||||
|
|
||||||
# unique_token_seqs is still a k2.RaggedInt with axes [seq][path]token_id].
|
|
||||||
# But then number of pathsin each sequence may be different.
|
|
||||||
unique_token_seqs, _, _ = k2.ragged.unique_sequences(
|
|
||||||
token_seqs, need_num_repeats=False, need_new2old_indexes=False)
|
|
||||||
|
|
||||||
seq_to_path_shape = k2.ragged.get_layer(unique_token_seqs.shape(), 0)
|
|
||||||
|
|
||||||
# Remove the seq axis.
|
|
||||||
# Now unique_token_seqs has only two axes [path][token_id]
|
|
||||||
unique_token_seqs = k2.ragged.remove_axis(unique_token_seqs, 0)
|
|
||||||
|
|
||||||
token_fsas = k2.linear_fsa(unique_token_seqs)
|
|
||||||
|
|
||||||
return Nbest(fsa=token_fsas, shape=seq_to_path_shape)
|
|
||||||
|
@ -5,7 +5,7 @@ import subprocess
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from nbest import Nbest
|
from icefall.nbest import Nbest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user