diff --git a/icefall/decode.py b/icefall/decode.py index 2cdcbf491..5d836bd48 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -26,7 +26,7 @@ import torch from icefall.context_graph import ContextGraph, ContextState from icefall.lm_wrapper import LmScorer from icefall.ngram_lm import NgramLm, NgramLmStateCost -from icefall.utils import DecodingResults, add_eos, add_sos, get_texts +from icefall.utils import add_eos, add_sos, get_texts DEFAULT_LM_SCALE = [ 0.01, @@ -1485,52 +1485,25 @@ def ctc_greedy_search( ctc_output: torch.Tensor, encoder_out_lens: torch.Tensor, blank_id: int = 0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: +) -> List[List[int]]: """CTC greedy search. Args: ctc_output: (batch, seq_len, vocab_size) encoder_out_lens: (batch,) Returns: - Union[List[List[int]], DecodingResults]: greedy search result + List[List[int]]: greedy search result """ batch = ctc_output.shape[0] index = ctc_output.argmax(dim=-1) # (batch, seq_len) - - hyps = [[] for _ in range(batch)] - # timestamps[n][i] is the frame index after subsampling - # on which hyp[n][i] is decoded - timestamps = [[] for _ in range(batch)] - # scores[n][i] is the logits on which hyp[n][i] is decoded - scores = [[] for _ in range(batch)] - - for i in range(batch): - cur_pos = -1 - last = -1 - next_other = torch.where(index[i,0 : encoder_out_lens[i]] != last)[0] - flag = False # whether decoding words - while next_other.size(0) > 0: - cur_pos = cur_pos + 1 + next_other[0].item() - last = index[i,cur_pos].item() - if flag: # last word decoding finished - timestamps[i][-1] = timestamps[i][-1] + (cur_pos - 1, ) - if last != blank_id: - hyps[i].append(last) - timestamps[i].append((cur_pos,)) - scores[i].append(ctc_output[i, cur_pos, last].item()) - flag = True # Decoding words - else: - flag = False - - next_other = torch.where(index[i,cur_pos + 1 : encoder_out_lens[i]] != last)[0] - if flag: - timestamps[i][-1] = timestamps[i][-1] + (encoder_out_lens[i].item() - 1, ) - - if not return_timestamps: - return hyps - else: - return DecodingResults(hyps = hyps, timestamps = timestamps, scores = scores) + hyps = [ + torch.unique_consecutive(index[i, : encoder_out_lens[i]]) for i in range(batch) + ] + + hyps = [h[h != blank_id].tolist() for h in hyps] + + return hyps + @dataclass class Hypothesis: diff --git a/icefall/utils.py b/icefall/utils.py index 5510f0476..ffb926566 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -360,7 +360,7 @@ class KeywordResult: class DecodingResults: # timestamps[i][k] contains the frame number on which tokens[i][k] # is decoded - timestamps: List[List[Union[int, Tuple[int, int]]]] + timestamps: List[List[int]] # hyps[i] is the recognition results, i.e., word IDs or token IDs # for the i-th utterance with fast_beam_search_nbest_LG. @@ -583,40 +583,6 @@ def store_transcripts_and_timestamps( s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]" print(f"{cut_id}:\ttimestamp_hyp={s}", file=f) -def store_transcripts_and_timestamps_withoutref( - filename: Pathlike, - texts: Iterable[Tuple[str, List[str], List[str], List[Tuple[float]]]], -) -> None: - """Save predicted results and reference transcripts as well as their timestamps - to a file. - - Args: - filename: - File to save the results to. - texts: - An iterable of tuples. The first element is the cur_id, the second is - the reference transcript and the third element is the predicted result. - Returns: - Return None. - """ - with open(filename, "w", encoding="utf8") as f: - for cut_id, ref, hyp, time_hyp in texts: - print(f"{cut_id}:\tref={ref}", file=f) - print(f"{cut_id}:\thyp={hyp}", file=f) - - if len(time_hyp) > 0: - if isinstance(time_hyp[0], tuple): - # each element is pair - s = ( - "[" - + ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_hyp]) - + "]" - ) - else: - # each element is a float number - s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]" - print(f"{cut_id}:\ttimestamp_hyp={s}", file=f) - def write_error_stats( f: TextIO, @@ -1873,30 +1839,6 @@ def convert_timestamp( return time -def convert_timestamp_duration( - frames: List[Tuple[int, int]], - subsampling_factor: int, - frame_shift_ms: float = 10, -) -> List[Tuple[float, float]]: - """Convert frame numbers to time (in seconds) given subsampling factor - and frame shift (in milliseconds). - - Args: - frames: - A list of frame numbers after subsampling. - subsampling_factor: - The subsampling factor of the model. - frame_shift_ms: - Frame shift in milliseconds between two contiguous frames. - Return: - Return the time in seconds corresponding to each given frame. - """ - frame_shift = frame_shift_ms / 1000.0 - time = [] - for f_start, f_end in frames: - time.append((round(f_start * subsampling_factor * frame_shift, ndigits=3), round(f_end * subsampling_factor * frame_shift, ndigits=3))) - - return time def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]: """ @@ -1982,43 +1924,6 @@ def parse_hyp_and_timestamp( return hyps, timestamps -def parse_hyp_and_timestamp_ch( - res: DecodingResults, - subsampling_factor: int, - word_table: k2.SymbolTable, - frame_shift_ms: float = 10, -) -> Tuple[List[List[str]], List[List[float]]]: - """Parse hypothesis and timestamps of Chinese characters. - - Args: - res: - A DecodingResults object. - subsampling_factor: - The integer subsampling factor. - word_table: - The word symbol table. - frame_shift_ms: - The float frame shift used for feature extraction. - - Returns: - Return a list of hypothesis and timestamp. - """ - hyps = [] - timestamps = [] - - N = len(res.hyps) - assert len(res.timestamps) == N, (len(res.timestamps), N) - assert word_table is not None - - for i in range(N): - time = convert_timestamp_duration(res.timestamps[i], subsampling_factor, frame_shift_ms) - words_decoded = [word_table[idx] for idx in res.hyps[i]] - hyps.append(words_decoded) - timestamps.append(time) - assert len(time) == len(words_decoded), (len(time), len(words_decoded)) - - return hyps, timestamps - # `is_module_available` is copied from # https://github.com/pytorch/audio/blob/6bad3a66a7a1c7cc05755e9ee5931b7391d2b94c/torchaudio/_internal/module_utils.py#L9