add ctc-greedy-search with timestamps (#905)

This commit is contained in:
Zengwei Yao 2023-02-13 19:45:09 +08:00 committed by GitHub
parent 6a8b649e56
commit 25ee50e27c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -92,7 +92,10 @@ from icefall.decode import (
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
convert_timestamp,
get_texts, get_texts,
make_pad_mask,
parse_bpe_start_end_pairs,
parse_fsa_timestamps_and_texts, parse_fsa_timestamps_and_texts,
setup_logger, setup_logger,
store_transcripts_and_timestamps, store_transcripts_and_timestamps,
@ -167,21 +170,24 @@ def get_parser():
default="ctc-decoding", default="ctc-decoding",
help="""Decoding method. help="""Decoding method.
Supported values are: Supported values are:
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece - (0) ctc-greedy-search. It uses a sentence piece model,
i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
- (1) ctc-decoding. Use CTC decoding. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words. model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM. It needs neither a lexicon nor an n-gram LM.
- (1) 1best. Extract the best path from the decoding lattice as the - (2) 1best. Extract the best path from the decoding lattice as the
decoding result. decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path - (3) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result. with the highest score is the decoding result.
- (3) nbest-rescoring. Extract n paths from the decoding lattice, - (4) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
the highest score is the decoding result. the highest score is the decoding result.
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an - (5) whole-lattice-rescoring. Rescore the decoding lattice with an
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
is the decoding result. is the decoding result.
you have trained an RNN LM using ./rnn_lm/train.py you have trained an RNN LM using ./rnn_lm/train.py
- (5) nbest-oracle. Its WER is the lower bound of any n-best - (6) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best rescoring method can achieve. Useful for debugging n-best
rescoring method. rescoring method.
""", """,
@ -269,6 +275,101 @@ def get_decoding_params() -> AttributeDict:
return params return params
def ctc_greedy_search(
ctc_probs: torch.Tensor,
nnet_output_lens: torch.Tensor,
sp: spm.SentencePieceProcessor,
subsampling_factor: int = 4,
frame_shift_ms: float = 10,
) -> Tuple[List[Tuple[float, float]], List[List[str]]]:
"""Apply CTC greedy search
Args:
ctc_probs (torch.Tensor):
(batch, max_len, feat_dim)
nnet_output_lens (torch.Tensor):
(batch, )
sp:
The BPE model.
subsampling_factor:
The subsampling factor of the model.
frame_shift_ms:
Frame shift in milliseconds between two contiguous frames.
Returns:
utt_time_pairs:
A list of pair list. utt_time_pairs[i] is a list of
(start-time, end-time) pairs for each word in
utterance-i.
utt_words:
A list of str list. utt_words[i] is a word list of utterence-i.
"""
topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1)
topk_index = topk_index.squeeze(2) # (B, maxlen)
mask = make_pad_mask(nnet_output_lens)
topk_index = topk_index.masked_fill_(mask, 0) # (B, maxlen)
hyps = [hyp.tolist() for hyp in topk_index]
def get_first_tokens(tokens: List[int]) -> List[bool]:
is_first_token = []
first_tokens = []
for t in range(len(tokens)):
if tokens[t] != 0 and (t == 0 or tokens[t - 1] != tokens[t]):
is_first_token.append(True)
first_tokens.append(tokens[t])
else:
is_first_token.append(False)
return first_tokens, is_first_token
utt_time_pairs = []
utt_words = []
for utt in range(len(hyps)):
first_tokens, is_first_token = get_first_tokens(hyps[utt])
all_tokens = sp.id_to_piece(hyps[utt])
index_pairs = parse_bpe_start_end_pairs(all_tokens, is_first_token)
words = sp.decode(first_tokens).split()
assert len(index_pairs) == len(words), (
len(index_pairs),
len(words),
all_tokens,
)
start = convert_timestamp(
frames=[i[0] for i in index_pairs],
subsampling_factor=subsampling_factor,
frame_shift_ms=frame_shift_ms,
)
end = convert_timestamp(
# The duration in frames is (end_frame_index - start_frame_index + 1)
frames=[i[1] + 1 for i in index_pairs],
subsampling_factor=subsampling_factor,
frame_shift_ms=frame_shift_ms,
)
utt_time_pairs.append(list(zip(start, end)))
utt_words.append(words)
return utt_time_pairs, utt_words
def remove_duplicates_and_blank(hyp: List[int]) -> Tuple[List[int], List[int]]:
# modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
new_hyp: List[int] = []
time: List[Tuple[int, int]] = []
cur = 0
start, end = -1, -1
while cur < len(hyp):
if hyp[cur] != 0:
new_hyp.append(hyp[cur])
start = cur
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
if start != -1:
end = cur
cur += 1
if start != -1 and end != -1:
time.append((start, end))
start, end = -1, -1
return new_hyp, time
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -360,6 +461,17 @@ def decode_one_batch(
nnet_output = model.get_ctc_output(encoder_out) nnet_output = model.get_ctc_output(encoder_out)
# nnet_output is (N, T, C) # nnet_output is (N, T, C)
if params.decoding_method == "ctc-greedy-search":
timestamps, hyps = ctc_greedy_search(
ctc_probs=nnet_output,
nnet_output_lens=encoder_out_lens,
sp=bpe_model,
subsampling_factor=params.subsampling_factor,
frame_shift_ms=params.frame_shift_ms,
)
key = "ctc-greedy-search"
return {key: (hyps, timestamps)}
supervision_segments = torch.stack( supervision_segments = torch.stack(
( (
supervisions["sequence_idx"], supervisions["sequence_idx"],
@ -696,6 +808,7 @@ def main():
params.update(vars(args)) params.update(vars(args))
assert params.decoding_method in ( assert params.decoding_method in (
"ctc-greedy-search",
"ctc-decoding", "ctc-decoding",
"1best", "1best",
"nbest", "nbest",
@ -749,7 +862,7 @@ def main():
params.sos_id = sos_id params.sos_id = sos_id
params.eos_id = eos_id params.eos_id = eos_id
if params.decoding_method == "ctc-decoding": if params.decoding_method in ["ctc-decoding", "ctc-greedy-search"]:
HLG = None HLG = None
H = k2.ctc_topo( H = k2.ctc_topo(
max_token=max_token_id, max_token=max_token_id,