add ctc-greedy-search with timestamps (#905)

This commit is contained in:
Zengwei Yao 2023-02-13 19:45:09 +08:00 committed by GitHub
parent 6a8b649e56
commit 25ee50e27c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -92,7 +92,10 @@ from icefall.decode import (
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
convert_timestamp,
get_texts,
make_pad_mask,
parse_bpe_start_end_pairs,
parse_fsa_timestamps_and_texts,
setup_logger,
store_transcripts_and_timestamps,
@ -167,21 +170,24 @@ def get_parser():
default="ctc-decoding",
help="""Decoding method.
Supported values are:
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
- (0) ctc-greedy-search. It uses a sentence piece model,
i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
- (1) ctc-decoding. Use CTC decoding. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
- (1) 1best. Extract the best path from the decoding lattice as the
- (2) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path
- (3) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (3) nbest-rescoring. Extract n paths from the decoding lattice,
- (4) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
the highest score is the decoding result.
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an
- (5) whole-lattice-rescoring. Rescore the decoding lattice with an
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
is the decoding result.
you have trained an RNN LM using ./rnn_lm/train.py
- (5) nbest-oracle. Its WER is the lower bound of any n-best
- (6) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
""",
@ -269,6 +275,101 @@ def get_decoding_params() -> AttributeDict:
return params
def ctc_greedy_search(
ctc_probs: torch.Tensor,
nnet_output_lens: torch.Tensor,
sp: spm.SentencePieceProcessor,
subsampling_factor: int = 4,
frame_shift_ms: float = 10,
) -> Tuple[List[Tuple[float, float]], List[List[str]]]:
"""Apply CTC greedy search
Args:
ctc_probs (torch.Tensor):
(batch, max_len, feat_dim)
nnet_output_lens (torch.Tensor):
(batch, )
sp:
The BPE model.
subsampling_factor:
The subsampling factor of the model.
frame_shift_ms:
Frame shift in milliseconds between two contiguous frames.
Returns:
utt_time_pairs:
A list of pair list. utt_time_pairs[i] is a list of
(start-time, end-time) pairs for each word in
utterance-i.
utt_words:
A list of str list. utt_words[i] is a word list of utterence-i.
"""
topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1)
topk_index = topk_index.squeeze(2) # (B, maxlen)
mask = make_pad_mask(nnet_output_lens)
topk_index = topk_index.masked_fill_(mask, 0) # (B, maxlen)
hyps = [hyp.tolist() for hyp in topk_index]
def get_first_tokens(tokens: List[int]) -> List[bool]:
is_first_token = []
first_tokens = []
for t in range(len(tokens)):
if tokens[t] != 0 and (t == 0 or tokens[t - 1] != tokens[t]):
is_first_token.append(True)
first_tokens.append(tokens[t])
else:
is_first_token.append(False)
return first_tokens, is_first_token
utt_time_pairs = []
utt_words = []
for utt in range(len(hyps)):
first_tokens, is_first_token = get_first_tokens(hyps[utt])
all_tokens = sp.id_to_piece(hyps[utt])
index_pairs = parse_bpe_start_end_pairs(all_tokens, is_first_token)
words = sp.decode(first_tokens).split()
assert len(index_pairs) == len(words), (
len(index_pairs),
len(words),
all_tokens,
)
start = convert_timestamp(
frames=[i[0] for i in index_pairs],
subsampling_factor=subsampling_factor,
frame_shift_ms=frame_shift_ms,
)
end = convert_timestamp(
# The duration in frames is (end_frame_index - start_frame_index + 1)
frames=[i[1] + 1 for i in index_pairs],
subsampling_factor=subsampling_factor,
frame_shift_ms=frame_shift_ms,
)
utt_time_pairs.append(list(zip(start, end)))
utt_words.append(words)
return utt_time_pairs, utt_words
def remove_duplicates_and_blank(hyp: List[int]) -> Tuple[List[int], List[int]]:
# modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
new_hyp: List[int] = []
time: List[Tuple[int, int]] = []
cur = 0
start, end = -1, -1
while cur < len(hyp):
if hyp[cur] != 0:
new_hyp.append(hyp[cur])
start = cur
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
if start != -1:
end = cur
cur += 1
if start != -1 and end != -1:
time.append((start, end))
start, end = -1, -1
return new_hyp, time
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
@ -360,6 +461,17 @@ def decode_one_batch(
nnet_output = model.get_ctc_output(encoder_out)
# nnet_output is (N, T, C)
if params.decoding_method == "ctc-greedy-search":
timestamps, hyps = ctc_greedy_search(
ctc_probs=nnet_output,
nnet_output_lens=encoder_out_lens,
sp=bpe_model,
subsampling_factor=params.subsampling_factor,
frame_shift_ms=params.frame_shift_ms,
)
key = "ctc-greedy-search"
return {key: (hyps, timestamps)}
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
@ -696,6 +808,7 @@ def main():
params.update(vars(args))
assert params.decoding_method in (
"ctc-greedy-search",
"ctc-decoding",
"1best",
"nbest",
@ -749,7 +862,7 @@ def main():
params.sos_id = sos_id
params.eos_id = eos_id
if params.decoding_method == "ctc-decoding":
if params.decoding_method in ["ctc-decoding", "ctc-greedy-search"]:
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,