mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Merge 1ebf714fb758942266ef8a8fdcae54c5061f762c into aec222e2fe96bba7b2a7c96bcb2327a2fd45dfdc
This commit is contained in:
commit
7bd6be17e8
@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
@ -21,10 +22,11 @@ from typing import Dict, List, Optional
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence
|
||||
from model import Transducer
|
||||
|
||||
from icefall.decode import Nbest, one_best_decoding
|
||||
from icefall.utils import get_texts
|
||||
from icefall.decode import get_lattice, Nbest, one_best_decoding
|
||||
from icefall.utils import get_alignments, get_texts
|
||||
|
||||
|
||||
def fast_beam_search_one_best(
|
||||
@ -553,6 +555,9 @@ def greedy_search_batch(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
ngram_rescoring: bool = False,
|
||||
gamma_blank: float = 1.0,
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
@ -602,6 +607,18 @@ def greedy_search_batch(
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||
|
||||
if ngram_rescoring:
|
||||
vocab_size = model.decoder.vocab_size
|
||||
total_t = encoder_out.shape[0]
|
||||
# cached all joiner outputs during greedy search,
|
||||
# from which non-blank frames are selected before n-gram rescoring.
|
||||
all_logits = torch.zeros([total_t, vocab_size], device=device)
|
||||
|
||||
# A flag indicating a frame is a blank frame or not.
|
||||
# 0 for blank frame and 1 for non-blank frame.
|
||||
# Used to select non-blank frames for n-gram rescoring.
|
||||
non_blank_flag = torch.zeros([total_t], device=device)
|
||||
|
||||
offset = 0
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
@ -616,10 +633,23 @@ def greedy_search_batch(
|
||||
logits = model.joiner(
|
||||
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
|
||||
)
|
||||
# logits'shape (batch_size, 1, 1, vocab_size)
|
||||
|
||||
# logits'shape (batch_size, 1, 1, vocab_size)
|
||||
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
||||
assert logits.ndim == 2, logits.shape
|
||||
|
||||
if ngram_rescoring:
|
||||
all_logits[start:end] = logits
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
logits_softmax = logits.softmax(dim=1)
|
||||
|
||||
|
||||
# 0 for blank frame and 1 for non-blank frame.
|
||||
non_blank_flag[start:end] = torch.where(
|
||||
logits_softmax[:, 0] >= gamma_blank, 0, 1
|
||||
)
|
||||
|
||||
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
@ -643,6 +673,91 @@ def greedy_search_batch(
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
|
||||
if not ngram_rescoring:
|
||||
return ans
|
||||
|
||||
assert decoding_graph is not None
|
||||
|
||||
# Transform logits to shape [N, T, vocab_size] format to make it easier
|
||||
# to select non-blank frames.
|
||||
packed_all_logits = PackedSequence(
|
||||
all_logits, torch.tensor(batch_size_list)
|
||||
)
|
||||
all_logits_unpacked, _ = pad_packed_sequence(
|
||||
packed_all_logits, batch_first=True
|
||||
)
|
||||
|
||||
# Transform non_blank_flag to shape [N, T]
|
||||
packed_non_blank_flag = PackedSequence(
|
||||
non_blank_flag, torch.tensor(batch_size_list)
|
||||
)
|
||||
non_blank_flag_unpacked, _ = pad_packed_sequence(
|
||||
packed_non_blank_flag, batch_first=True
|
||||
)
|
||||
|
||||
non_blank_logits_lens = torch.sum(non_blank_flag_unpacked, dim=1)
|
||||
max_frame_to_rescore = non_blank_logits_lens.max()
|
||||
|
||||
non_blank_logits = torch.zeros(
|
||||
[N, int(max_frame_to_rescore), vocab_size], device=device
|
||||
)
|
||||
|
||||
# torch.index_select only acceptec a single dimension to index from.
|
||||
# So we need generate non_blank_logits one by one.
|
||||
# Maybe there is another efficient way to do this.
|
||||
for i in range(N):
|
||||
cur_non_blank_index = torch.where(non_blank_flag_unpacked[i, :] != 0)[0]
|
||||
assert non_blank_logits_lens[i] == cur_non_blank_index.shape[0]
|
||||
non_blank_logits[
|
||||
i, : int(non_blank_logits_lens[i]), :
|
||||
] = torch.index_select(
|
||||
all_logits_unpacked[i, :], 0, cur_non_blank_index
|
||||
)
|
||||
|
||||
|
||||
|
||||
number_selected_frames = non_blank_flag.sum()
|
||||
logging.info(f"{number_selected_frames} are selected out of {total_t} frames")
|
||||
# Split log_softmax into two seperate steps,
|
||||
# so we cound do blank deweight in probability domain if needed.
|
||||
logits_to_rescore_softmax = non_blank_logits.softmax(dim=2)
|
||||
logits_to_rescore = logits_to_rescore_softmax.log()
|
||||
|
||||
# In paper: https://arxiv.org/pdf/2101.06856.pdf
|
||||
# blank deweight is applied before non_blank frames selected.
|
||||
# However, in current setup, that results in a higher WER.
|
||||
# So just put this blank deweight before ngram rescoring.
|
||||
# (TODO): debug this blank deweight issue.
|
||||
|
||||
blank_deweight = 0.0
|
||||
logits_to_rescore[:, :, 0] -= blank_deweight
|
||||
|
||||
supervision_segments = torch.zeros([N, 3], dtype=torch.int32)
|
||||
supervision_segments[:, 0] = torch.arange(0, N, dtype=torch.int32)
|
||||
supervision_segments[:, 2] = non_blank_logits_lens.to(torch.int32)
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=logits_to_rescore,
|
||||
decoding_graph=decoding_graph,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=20,
|
||||
output_beam=8,
|
||||
min_active_states=30,
|
||||
max_active_states=1000,
|
||||
subsampling_factor=1,
|
||||
)
|
||||
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice,
|
||||
use_double_scores=True,
|
||||
)
|
||||
|
||||
token_ids = get_alignments(best_path, "labels", remove_zero_blank=True)
|
||||
|
||||
ans = []
|
||||
for i in range(N):
|
||||
usi = unsorted_indices[i]
|
||||
ans.append(token_ids[usi][: int(non_blank_logits_lens[usi])])
|
||||
return ans
|
||||
|
||||
|
||||
|
@ -136,6 +136,28 @@ def get_parser():
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-rescoring",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use ngram_rescoring.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-graph",
|
||||
type=str,
|
||||
default="trivial_graph",
|
||||
help="one of [trivial_grpah, HLG, Trival_LG, LG]"
|
||||
"used by greedy_search_batch with ngram-rescoring=True.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="./data/lang_bpe_500/",
|
||||
help="Path to decoding graphs",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
@ -219,6 +241,12 @@ def get_parser():
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--gamma-blank",
|
||||
type=int,
|
||||
default=1.0,
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -293,6 +321,9 @@ def decode_one_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
decoding_graph=decoding_graph,
|
||||
ngram_rescoring=params.ngram_rescoring,
|
||||
gamma_blank=params.gamma_blank,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
@ -498,6 +529,11 @@ def main():
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
if params.ngram_rescoring:
|
||||
params.suffix += "-ngram-rescoring"
|
||||
params.suffix += f"-{params.decoding_graph}"
|
||||
params.suffix += f"-gamma_blank-{params.gamma_blank}"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
@ -605,6 +641,24 @@ def main():
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
||||
if params.ngram_rescoring and params.decoding_method == "greedy_search":
|
||||
assert params.decoding_graph in [
|
||||
"trivial_graph",
|
||||
"L",
|
||||
], f"Unsupported decoding graph {params.decoding_graph}"
|
||||
if params.decoding_graph == "trivial_graph":
|
||||
decoding_graph = k2.trivial_graph(
|
||||
params.vocab_size - 1, device=device
|
||||
)
|
||||
else:
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(
|
||||
f"data/lang_bpe_500/{params.decoding_graph}.pt",
|
||||
map_location=device,
|
||||
)
|
||||
)
|
||||
decoding_graph = k2.add_epsilon_self_loops(decoding_graph)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
|
@ -236,7 +236,9 @@ def get_texts(
|
||||
return aux_labels.tolist()
|
||||
|
||||
|
||||
def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
|
||||
def get_alignments(
|
||||
best_paths: k2.Fsa, kind: str, remove_zero_blank: bool = False
|
||||
) -> List[List[int]]:
|
||||
"""Extract labels or aux_labels from the best-path FSAs.
|
||||
|
||||
Args:
|
||||
@ -272,6 +274,8 @@ def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
|
||||
token_shape, getattr(best_paths, kind).contiguous()
|
||||
)
|
||||
tokens = tokens.remove_values_eq(-1)
|
||||
if remove_zero_blank:
|
||||
tokens = tokens.remove_values_eq(0)
|
||||
return tokens.tolist()
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user