Merge 1ebf714fb758942266ef8a8fdcae54c5061f762c into aec222e2fe96bba7b2a7c96bcb2327a2fd45dfdc

This commit is contained in:
LIyong.Guo 2022-07-18 10:17:06 +02:00 committed by GitHub
commit 7bd6be17e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 178 additions and 5 deletions

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional
@ -21,10 +22,11 @@ from typing import Dict, List, Optional
import k2
import sentencepiece as spm
import torch
from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence
from model import Transducer
from icefall.decode import Nbest, one_best_decoding
from icefall.utils import get_texts
from icefall.decode import get_lattice, Nbest, one_best_decoding
from icefall.utils import get_alignments, get_texts
def fast_beam_search_one_best(
@ -553,6 +555,9 @@ def greedy_search_batch(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
decoding_graph: Optional[k2.Fsa] = None,
ngram_rescoring: bool = False,
gamma_blank: float = 1.0,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
@ -602,6 +607,18 @@ def greedy_search_batch(
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
if ngram_rescoring:
vocab_size = model.decoder.vocab_size
total_t = encoder_out.shape[0]
# cached all joiner outputs during greedy search,
# from which non-blank frames are selected before n-gram rescoring.
all_logits = torch.zeros([total_t, vocab_size], device=device)
# A flag indicating a frame is a blank frame or not.
# 0 for blank frame and 1 for non-blank frame.
# Used to select non-blank frames for n-gram rescoring.
non_blank_flag = torch.zeros([total_t], device=device)
offset = 0
for batch_size in batch_size_list:
start = offset
@ -616,10 +633,23 @@ def greedy_search_batch(
logits = model.joiner(
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
)
# logits'shape (batch_size, 1, 1, vocab_size)
# logits'shape (batch_size, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
if ngram_rescoring:
all_logits[start:end] = logits
assert logits.ndim == 2, logits.shape
logits_softmax = logits.softmax(dim=1)
# 0 for blank frame and 1 for non-blank frame.
non_blank_flag[start:end] = torch.where(
logits_softmax[:, 0] >= gamma_blank, 0, 1
)
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
@ -643,6 +673,91 @@ def greedy_search_batch(
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
if not ngram_rescoring:
return ans
assert decoding_graph is not None
# Transform logits to shape [N, T, vocab_size] format to make it easier
# to select non-blank frames.
packed_all_logits = PackedSequence(
all_logits, torch.tensor(batch_size_list)
)
all_logits_unpacked, _ = pad_packed_sequence(
packed_all_logits, batch_first=True
)
# Transform non_blank_flag to shape [N, T]
packed_non_blank_flag = PackedSequence(
non_blank_flag, torch.tensor(batch_size_list)
)
non_blank_flag_unpacked, _ = pad_packed_sequence(
packed_non_blank_flag, batch_first=True
)
non_blank_logits_lens = torch.sum(non_blank_flag_unpacked, dim=1)
max_frame_to_rescore = non_blank_logits_lens.max()
non_blank_logits = torch.zeros(
[N, int(max_frame_to_rescore), vocab_size], device=device
)
# torch.index_select only acceptec a single dimension to index from.
# So we need generate non_blank_logits one by one.
# Maybe there is another efficient way to do this.
for i in range(N):
cur_non_blank_index = torch.where(non_blank_flag_unpacked[i, :] != 0)[0]
assert non_blank_logits_lens[i] == cur_non_blank_index.shape[0]
non_blank_logits[
i, : int(non_blank_logits_lens[i]), :
] = torch.index_select(
all_logits_unpacked[i, :], 0, cur_non_blank_index
)
number_selected_frames = non_blank_flag.sum()
logging.info(f"{number_selected_frames} are selected out of {total_t} frames")
# Split log_softmax into two seperate steps,
# so we cound do blank deweight in probability domain if needed.
logits_to_rescore_softmax = non_blank_logits.softmax(dim=2)
logits_to_rescore = logits_to_rescore_softmax.log()
# In paper: https://arxiv.org/pdf/2101.06856.pdf
# blank deweight is applied before non_blank frames selected.
# However, in current setup, that results in a higher WER.
# So just put this blank deweight before ngram rescoring.
# (TODO): debug this blank deweight issue.
blank_deweight = 0.0
logits_to_rescore[:, :, 0] -= blank_deweight
supervision_segments = torch.zeros([N, 3], dtype=torch.int32)
supervision_segments[:, 0] = torch.arange(0, N, dtype=torch.int32)
supervision_segments[:, 2] = non_blank_logits_lens.to(torch.int32)
lattice = get_lattice(
nnet_output=logits_to_rescore,
decoding_graph=decoding_graph,
supervision_segments=supervision_segments,
search_beam=20,
output_beam=8,
min_active_states=30,
max_active_states=1000,
subsampling_factor=1,
)
best_path = one_best_decoding(
lattice=lattice,
use_double_scores=True,
)
token_ids = get_alignments(best_path, "labels", remove_zero_blank=True)
ans = []
for i in range(N):
usi = unsorted_indices[i]
ans.append(token_ids[usi][: int(non_blank_logits_lens[usi])])
return ans

View File

@ -136,6 +136,28 @@ def get_parser():
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--ngram-rescoring",
type=str2bool,
default=False,
help="Whether to use ngram_rescoring.",
)
parser.add_argument(
"--decoding-graph",
type=str,
default="trivial_graph",
help="one of [trivial_grpah, HLG, Trival_LG, LG]"
"used by greedy_search_batch with ngram-rescoring=True.",
)
parser.add_argument(
"--lang-dir",
type=str,
default="./data/lang_bpe_500/",
help="Path to decoding graphs",
)
parser.add_argument(
"--exp-dir",
type=str,
@ -219,6 +241,12 @@ def get_parser():
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--gamma-blank",
type=int,
default=1.0,
)
return parser
@ -293,6 +321,9 @@ def decode_one_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
decoding_graph=decoding_graph,
ngram_rescoring=params.ngram_rescoring,
gamma_blank=params.gamma_blank,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -498,6 +529,11 @@ def main():
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
if params.ngram_rescoring:
params.suffix += "-ngram-rescoring"
params.suffix += f"-{params.decoding_graph}"
params.suffix += f"-gamma_blank-{params.gamma_blank}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
@ -605,6 +641,24 @@ def main():
else:
decoding_graph = None
if params.ngram_rescoring and params.decoding_method == "greedy_search":
assert params.decoding_graph in [
"trivial_graph",
"L",
], f"Unsupported decoding graph {params.decoding_graph}"
if params.decoding_graph == "trivial_graph":
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else:
decoding_graph = k2.Fsa.from_dict(
torch.load(
f"data/lang_bpe_500/{params.decoding_graph}.pt",
map_location=device,
)
)
decoding_graph = k2.add_epsilon_self_loops(decoding_graph)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")

View File

@ -236,7 +236,9 @@ def get_texts(
return aux_labels.tolist()
def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
def get_alignments(
best_paths: k2.Fsa, kind: str, remove_zero_blank: bool = False
) -> List[List[int]]:
"""Extract labels or aux_labels from the best-path FSAs.
Args:
@ -272,6 +274,8 @@ def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
token_shape, getattr(best_paths, kind).contiguous()
)
tokens = tokens.remove_values_eq(-1)
if remove_zero_blank:
tokens = tokens.remove_values_eq(0)
return tokens.tolist()