mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
add ctc-greedy-search with timestamps (#905)
This commit is contained in:
parent
6a8b649e56
commit
25ee50e27c
@ -92,7 +92,10 @@ from icefall.decode import (
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
convert_timestamp,
|
||||
get_texts,
|
||||
make_pad_mask,
|
||||
parse_bpe_start_end_pairs,
|
||||
parse_fsa_timestamps_and_texts,
|
||||
setup_logger,
|
||||
store_transcripts_and_timestamps,
|
||||
@ -167,21 +170,24 @@ def get_parser():
|
||||
default="ctc-decoding",
|
||||
help="""Decoding method.
|
||||
Supported values are:
|
||||
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
|
||||
- (0) ctc-greedy-search. It uses a sentence piece model,
|
||||
i.e., lang_dir/bpe.model, to convert word pieces to words.
|
||||
It needs neither a lexicon nor an n-gram LM.
|
||||
- (1) ctc-decoding. Use CTC decoding. It uses a sentence piece
|
||||
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
||||
It needs neither a lexicon nor an n-gram LM.
|
||||
- (1) 1best. Extract the best path from the decoding lattice as the
|
||||
- (2) 1best. Extract the best path from the decoding lattice as the
|
||||
decoding result.
|
||||
- (2) nbest. Extract n paths from the decoding lattice; the path
|
||||
- (3) nbest. Extract n paths from the decoding lattice; the path
|
||||
with the highest score is the decoding result.
|
||||
- (3) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||
- (4) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
|
||||
the highest score is the decoding result.
|
||||
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an
|
||||
- (5) whole-lattice-rescoring. Rescore the decoding lattice with an
|
||||
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
|
||||
is the decoding result.
|
||||
you have trained an RNN LM using ./rnn_lm/train.py
|
||||
- (5) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
- (6) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
rescoring method can achieve. Useful for debugging n-best
|
||||
rescoring method.
|
||||
""",
|
||||
@ -269,6 +275,101 @@ def get_decoding_params() -> AttributeDict:
|
||||
return params
|
||||
|
||||
|
||||
def ctc_greedy_search(
|
||||
ctc_probs: torch.Tensor,
|
||||
nnet_output_lens: torch.Tensor,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
subsampling_factor: int = 4,
|
||||
frame_shift_ms: float = 10,
|
||||
) -> Tuple[List[Tuple[float, float]], List[List[str]]]:
|
||||
"""Apply CTC greedy search
|
||||
Args:
|
||||
ctc_probs (torch.Tensor):
|
||||
(batch, max_len, feat_dim)
|
||||
nnet_output_lens (torch.Tensor):
|
||||
(batch, )
|
||||
sp:
|
||||
The BPE model.
|
||||
subsampling_factor:
|
||||
The subsampling factor of the model.
|
||||
frame_shift_ms:
|
||||
Frame shift in milliseconds between two contiguous frames.
|
||||
|
||||
Returns:
|
||||
utt_time_pairs:
|
||||
A list of pair list. utt_time_pairs[i] is a list of
|
||||
(start-time, end-time) pairs for each word in
|
||||
utterance-i.
|
||||
utt_words:
|
||||
A list of str list. utt_words[i] is a word list of utterence-i.
|
||||
"""
|
||||
topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1)
|
||||
topk_index = topk_index.squeeze(2) # (B, maxlen)
|
||||
mask = make_pad_mask(nnet_output_lens)
|
||||
topk_index = topk_index.masked_fill_(mask, 0) # (B, maxlen)
|
||||
hyps = [hyp.tolist() for hyp in topk_index]
|
||||
|
||||
def get_first_tokens(tokens: List[int]) -> List[bool]:
|
||||
is_first_token = []
|
||||
first_tokens = []
|
||||
for t in range(len(tokens)):
|
||||
if tokens[t] != 0 and (t == 0 or tokens[t - 1] != tokens[t]):
|
||||
is_first_token.append(True)
|
||||
first_tokens.append(tokens[t])
|
||||
else:
|
||||
is_first_token.append(False)
|
||||
return first_tokens, is_first_token
|
||||
|
||||
utt_time_pairs = []
|
||||
utt_words = []
|
||||
for utt in range(len(hyps)):
|
||||
first_tokens, is_first_token = get_first_tokens(hyps[utt])
|
||||
all_tokens = sp.id_to_piece(hyps[utt])
|
||||
index_pairs = parse_bpe_start_end_pairs(all_tokens, is_first_token)
|
||||
words = sp.decode(first_tokens).split()
|
||||
assert len(index_pairs) == len(words), (
|
||||
len(index_pairs),
|
||||
len(words),
|
||||
all_tokens,
|
||||
)
|
||||
start = convert_timestamp(
|
||||
frames=[i[0] for i in index_pairs],
|
||||
subsampling_factor=subsampling_factor,
|
||||
frame_shift_ms=frame_shift_ms,
|
||||
)
|
||||
end = convert_timestamp(
|
||||
# The duration in frames is (end_frame_index - start_frame_index + 1)
|
||||
frames=[i[1] + 1 for i in index_pairs],
|
||||
subsampling_factor=subsampling_factor,
|
||||
frame_shift_ms=frame_shift_ms,
|
||||
)
|
||||
utt_time_pairs.append(list(zip(start, end)))
|
||||
utt_words.append(words)
|
||||
|
||||
return utt_time_pairs, utt_words
|
||||
|
||||
|
||||
def remove_duplicates_and_blank(hyp: List[int]) -> Tuple[List[int], List[int]]:
|
||||
# modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
|
||||
new_hyp: List[int] = []
|
||||
time: List[Tuple[int, int]] = []
|
||||
cur = 0
|
||||
start, end = -1, -1
|
||||
while cur < len(hyp):
|
||||
if hyp[cur] != 0:
|
||||
new_hyp.append(hyp[cur])
|
||||
start = cur
|
||||
prev = cur
|
||||
while cur < len(hyp) and hyp[cur] == hyp[prev]:
|
||||
if start != -1:
|
||||
end = cur
|
||||
cur += 1
|
||||
if start != -1 and end != -1:
|
||||
time.append((start, end))
|
||||
start, end = -1, -1
|
||||
return new_hyp, time
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
@ -360,6 +461,17 @@ def decode_one_batch(
|
||||
nnet_output = model.get_ctc_output(encoder_out)
|
||||
# nnet_output is (N, T, C)
|
||||
|
||||
if params.decoding_method == "ctc-greedy-search":
|
||||
timestamps, hyps = ctc_greedy_search(
|
||||
ctc_probs=nnet_output,
|
||||
nnet_output_lens=encoder_out_lens,
|
||||
sp=bpe_model,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
frame_shift_ms=params.frame_shift_ms,
|
||||
)
|
||||
key = "ctc-greedy-search"
|
||||
return {key: (hyps, timestamps)}
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
supervisions["sequence_idx"],
|
||||
@ -696,6 +808,7 @@ def main():
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"ctc-greedy-search",
|
||||
"ctc-decoding",
|
||||
"1best",
|
||||
"nbest",
|
||||
@ -749,7 +862,7 @@ def main():
|
||||
params.sos_id = sos_id
|
||||
params.eos_id = eos_id
|
||||
|
||||
if params.decoding_method == "ctc-decoding":
|
||||
if params.decoding_method in ["ctc-decoding", "ctc-greedy-search"]:
|
||||
HLG = None
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
|
Loading…
x
Reference in New Issue
Block a user