mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
add ctc-greedy-search with timestamps (#905)
This commit is contained in:
parent
6a8b649e56
commit
25ee50e27c
@ -92,7 +92,10 @@ from icefall.decode import (
|
|||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
|
convert_timestamp,
|
||||||
get_texts,
|
get_texts,
|
||||||
|
make_pad_mask,
|
||||||
|
parse_bpe_start_end_pairs,
|
||||||
parse_fsa_timestamps_and_texts,
|
parse_fsa_timestamps_and_texts,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts_and_timestamps,
|
store_transcripts_and_timestamps,
|
||||||
@ -167,21 +170,24 @@ def get_parser():
|
|||||||
default="ctc-decoding",
|
default="ctc-decoding",
|
||||||
help="""Decoding method.
|
help="""Decoding method.
|
||||||
Supported values are:
|
Supported values are:
|
||||||
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
|
- (0) ctc-greedy-search. It uses a sentence piece model,
|
||||||
|
i.e., lang_dir/bpe.model, to convert word pieces to words.
|
||||||
|
It needs neither a lexicon nor an n-gram LM.
|
||||||
|
- (1) ctc-decoding. Use CTC decoding. It uses a sentence piece
|
||||||
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
||||||
It needs neither a lexicon nor an n-gram LM.
|
It needs neither a lexicon nor an n-gram LM.
|
||||||
- (1) 1best. Extract the best path from the decoding lattice as the
|
- (2) 1best. Extract the best path from the decoding lattice as the
|
||||||
decoding result.
|
decoding result.
|
||||||
- (2) nbest. Extract n paths from the decoding lattice; the path
|
- (3) nbest. Extract n paths from the decoding lattice; the path
|
||||||
with the highest score is the decoding result.
|
with the highest score is the decoding result.
|
||||||
- (3) nbest-rescoring. Extract n paths from the decoding lattice,
|
- (4) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||||
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
|
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
|
||||||
the highest score is the decoding result.
|
the highest score is the decoding result.
|
||||||
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an
|
- (5) whole-lattice-rescoring. Rescore the decoding lattice with an
|
||||||
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
|
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
|
||||||
is the decoding result.
|
is the decoding result.
|
||||||
you have trained an RNN LM using ./rnn_lm/train.py
|
you have trained an RNN LM using ./rnn_lm/train.py
|
||||||
- (5) nbest-oracle. Its WER is the lower bound of any n-best
|
- (6) nbest-oracle. Its WER is the lower bound of any n-best
|
||||||
rescoring method can achieve. Useful for debugging n-best
|
rescoring method can achieve. Useful for debugging n-best
|
||||||
rescoring method.
|
rescoring method.
|
||||||
""",
|
""",
|
||||||
@ -269,6 +275,101 @@ def get_decoding_params() -> AttributeDict:
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def ctc_greedy_search(
|
||||||
|
ctc_probs: torch.Tensor,
|
||||||
|
nnet_output_lens: torch.Tensor,
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
|
subsampling_factor: int = 4,
|
||||||
|
frame_shift_ms: float = 10,
|
||||||
|
) -> Tuple[List[Tuple[float, float]], List[List[str]]]:
|
||||||
|
"""Apply CTC greedy search
|
||||||
|
Args:
|
||||||
|
ctc_probs (torch.Tensor):
|
||||||
|
(batch, max_len, feat_dim)
|
||||||
|
nnet_output_lens (torch.Tensor):
|
||||||
|
(batch, )
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
subsampling_factor:
|
||||||
|
The subsampling factor of the model.
|
||||||
|
frame_shift_ms:
|
||||||
|
Frame shift in milliseconds between two contiguous frames.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
utt_time_pairs:
|
||||||
|
A list of pair list. utt_time_pairs[i] is a list of
|
||||||
|
(start-time, end-time) pairs for each word in
|
||||||
|
utterance-i.
|
||||||
|
utt_words:
|
||||||
|
A list of str list. utt_words[i] is a word list of utterence-i.
|
||||||
|
"""
|
||||||
|
topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1)
|
||||||
|
topk_index = topk_index.squeeze(2) # (B, maxlen)
|
||||||
|
mask = make_pad_mask(nnet_output_lens)
|
||||||
|
topk_index = topk_index.masked_fill_(mask, 0) # (B, maxlen)
|
||||||
|
hyps = [hyp.tolist() for hyp in topk_index]
|
||||||
|
|
||||||
|
def get_first_tokens(tokens: List[int]) -> List[bool]:
|
||||||
|
is_first_token = []
|
||||||
|
first_tokens = []
|
||||||
|
for t in range(len(tokens)):
|
||||||
|
if tokens[t] != 0 and (t == 0 or tokens[t - 1] != tokens[t]):
|
||||||
|
is_first_token.append(True)
|
||||||
|
first_tokens.append(tokens[t])
|
||||||
|
else:
|
||||||
|
is_first_token.append(False)
|
||||||
|
return first_tokens, is_first_token
|
||||||
|
|
||||||
|
utt_time_pairs = []
|
||||||
|
utt_words = []
|
||||||
|
for utt in range(len(hyps)):
|
||||||
|
first_tokens, is_first_token = get_first_tokens(hyps[utt])
|
||||||
|
all_tokens = sp.id_to_piece(hyps[utt])
|
||||||
|
index_pairs = parse_bpe_start_end_pairs(all_tokens, is_first_token)
|
||||||
|
words = sp.decode(first_tokens).split()
|
||||||
|
assert len(index_pairs) == len(words), (
|
||||||
|
len(index_pairs),
|
||||||
|
len(words),
|
||||||
|
all_tokens,
|
||||||
|
)
|
||||||
|
start = convert_timestamp(
|
||||||
|
frames=[i[0] for i in index_pairs],
|
||||||
|
subsampling_factor=subsampling_factor,
|
||||||
|
frame_shift_ms=frame_shift_ms,
|
||||||
|
)
|
||||||
|
end = convert_timestamp(
|
||||||
|
# The duration in frames is (end_frame_index - start_frame_index + 1)
|
||||||
|
frames=[i[1] + 1 for i in index_pairs],
|
||||||
|
subsampling_factor=subsampling_factor,
|
||||||
|
frame_shift_ms=frame_shift_ms,
|
||||||
|
)
|
||||||
|
utt_time_pairs.append(list(zip(start, end)))
|
||||||
|
utt_words.append(words)
|
||||||
|
|
||||||
|
return utt_time_pairs, utt_words
|
||||||
|
|
||||||
|
|
||||||
|
def remove_duplicates_and_blank(hyp: List[int]) -> Tuple[List[int], List[int]]:
|
||||||
|
# modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
|
||||||
|
new_hyp: List[int] = []
|
||||||
|
time: List[Tuple[int, int]] = []
|
||||||
|
cur = 0
|
||||||
|
start, end = -1, -1
|
||||||
|
while cur < len(hyp):
|
||||||
|
if hyp[cur] != 0:
|
||||||
|
new_hyp.append(hyp[cur])
|
||||||
|
start = cur
|
||||||
|
prev = cur
|
||||||
|
while cur < len(hyp) and hyp[cur] == hyp[prev]:
|
||||||
|
if start != -1:
|
||||||
|
end = cur
|
||||||
|
cur += 1
|
||||||
|
if start != -1 and end != -1:
|
||||||
|
time.append((start, end))
|
||||||
|
start, end = -1, -1
|
||||||
|
return new_hyp, time
|
||||||
|
|
||||||
|
|
||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -360,6 +461,17 @@ def decode_one_batch(
|
|||||||
nnet_output = model.get_ctc_output(encoder_out)
|
nnet_output = model.get_ctc_output(encoder_out)
|
||||||
# nnet_output is (N, T, C)
|
# nnet_output is (N, T, C)
|
||||||
|
|
||||||
|
if params.decoding_method == "ctc-greedy-search":
|
||||||
|
timestamps, hyps = ctc_greedy_search(
|
||||||
|
ctc_probs=nnet_output,
|
||||||
|
nnet_output_lens=encoder_out_lens,
|
||||||
|
sp=bpe_model,
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
frame_shift_ms=params.frame_shift_ms,
|
||||||
|
)
|
||||||
|
key = "ctc-greedy-search"
|
||||||
|
return {key: (hyps, timestamps)}
|
||||||
|
|
||||||
supervision_segments = torch.stack(
|
supervision_segments = torch.stack(
|
||||||
(
|
(
|
||||||
supervisions["sequence_idx"],
|
supervisions["sequence_idx"],
|
||||||
@ -696,6 +808,7 @@ def main():
|
|||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
assert params.decoding_method in (
|
assert params.decoding_method in (
|
||||||
|
"ctc-greedy-search",
|
||||||
"ctc-decoding",
|
"ctc-decoding",
|
||||||
"1best",
|
"1best",
|
||||||
"nbest",
|
"nbest",
|
||||||
@ -749,7 +862,7 @@ def main():
|
|||||||
params.sos_id = sos_id
|
params.sos_id = sos_id
|
||||||
params.eos_id = eos_id
|
params.eos_id = eos_id
|
||||||
|
|
||||||
if params.decoding_method == "ctc-decoding":
|
if params.decoding_method in ["ctc-decoding", "ctc-greedy-search"]:
|
||||||
HLG = None
|
HLG = None
|
||||||
H = k2.ctc_topo(
|
H = k2.ctc_topo(
|
||||||
max_token=max_token_id,
|
max_token=max_token_id,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user