revert changes in decode.py and utils.py

This commit is contained in:
hhzzff 2025-07-04 15:33:47 +08:00
parent 85f95db6f9
commit 95b2408ed1
2 changed files with 12 additions and 134 deletions

View File

@ -26,7 +26,7 @@ import torch
from icefall.context_graph import ContextGraph, ContextState
from icefall.lm_wrapper import LmScorer
from icefall.ngram_lm import NgramLm, NgramLmStateCost
from icefall.utils import DecodingResults, add_eos, add_sos, get_texts
from icefall.utils import add_eos, add_sos, get_texts
DEFAULT_LM_SCALE = [
0.01,
@ -1485,52 +1485,25 @@ def ctc_greedy_search(
ctc_output: torch.Tensor,
encoder_out_lens: torch.Tensor,
blank_id: int = 0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
) -> List[List[int]]:
"""CTC greedy search.
Args:
ctc_output: (batch, seq_len, vocab_size)
encoder_out_lens: (batch,)
Returns:
Union[List[List[int]], DecodingResults]: greedy search result
List[List[int]]: greedy search result
"""
batch = ctc_output.shape[0]
index = ctc_output.argmax(dim=-1) # (batch, seq_len)
hyps = [[] for _ in range(batch)]
# timestamps[n][i] is the frame index after subsampling
# on which hyp[n][i] is decoded
timestamps = [[] for _ in range(batch)]
# scores[n][i] is the logits on which hyp[n][i] is decoded
scores = [[] for _ in range(batch)]
for i in range(batch):
cur_pos = -1
last = -1
next_other = torch.where(index[i,0 : encoder_out_lens[i]] != last)[0]
flag = False # whether decoding words
while next_other.size(0) > 0:
cur_pos = cur_pos + 1 + next_other[0].item()
last = index[i,cur_pos].item()
if flag: # last word decoding finished
timestamps[i][-1] = timestamps[i][-1] + (cur_pos - 1, )
if last != blank_id:
hyps[i].append(last)
timestamps[i].append((cur_pos,))
scores[i].append(ctc_output[i, cur_pos, last].item())
flag = True # Decoding words
else:
flag = False
next_other = torch.where(index[i,cur_pos + 1 : encoder_out_lens[i]] != last)[0]
if flag:
timestamps[i][-1] = timestamps[i][-1] + (encoder_out_lens[i].item() - 1, )
if not return_timestamps:
return hyps
else:
return DecodingResults(hyps = hyps, timestamps = timestamps, scores = scores)
hyps = [
torch.unique_consecutive(index[i, : encoder_out_lens[i]]) for i in range(batch)
]
hyps = [h[h != blank_id].tolist() for h in hyps]
return hyps
@dataclass
class Hypothesis:

View File

@ -360,7 +360,7 @@ class KeywordResult:
class DecodingResults:
# timestamps[i][k] contains the frame number on which tokens[i][k]
# is decoded
timestamps: List[List[Union[int, Tuple[int, int]]]]
timestamps: List[List[int]]
# hyps[i] is the recognition results, i.e., word IDs or token IDs
# for the i-th utterance with fast_beam_search_nbest_LG.
@ -583,40 +583,6 @@ def store_transcripts_and_timestamps(
s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]"
print(f"{cut_id}:\ttimestamp_hyp={s}", file=f)
def store_transcripts_and_timestamps_withoutref(
filename: Pathlike,
texts: Iterable[Tuple[str, List[str], List[str], List[Tuple[float]]]],
) -> None:
"""Save predicted results and reference transcripts as well as their timestamps
to a file.
Args:
filename:
File to save the results to.
texts:
An iterable of tuples. The first element is the cur_id, the second is
the reference transcript and the third element is the predicted result.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf8") as f:
for cut_id, ref, hyp, time_hyp in texts:
print(f"{cut_id}:\tref={ref}", file=f)
print(f"{cut_id}:\thyp={hyp}", file=f)
if len(time_hyp) > 0:
if isinstance(time_hyp[0], tuple):
# each element is <start, end> pair
s = (
"["
+ ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_hyp])
+ "]"
)
else:
# each element is a float number
s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]"
print(f"{cut_id}:\ttimestamp_hyp={s}", file=f)
def write_error_stats(
f: TextIO,
@ -1873,30 +1839,6 @@ def convert_timestamp(
return time
def convert_timestamp_duration(
frames: List[Tuple[int, int]],
subsampling_factor: int,
frame_shift_ms: float = 10,
) -> List[Tuple[float, float]]:
"""Convert frame numbers to time (in seconds) given subsampling factor
and frame shift (in milliseconds).
Args:
frames:
A list of frame numbers after subsampling.
subsampling_factor:
The subsampling factor of the model.
frame_shift_ms:
Frame shift in milliseconds between two contiguous frames.
Return:
Return the time in seconds corresponding to each given frame.
"""
frame_shift = frame_shift_ms / 1000.0
time = []
for f_start, f_end in frames:
time.append((round(f_start * subsampling_factor * frame_shift, ndigits=3), round(f_end * subsampling_factor * frame_shift, ndigits=3)))
return time
def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]:
"""
@ -1982,43 +1924,6 @@ def parse_hyp_and_timestamp(
return hyps, timestamps
def parse_hyp_and_timestamp_ch(
res: DecodingResults,
subsampling_factor: int,
word_table: k2.SymbolTable,
frame_shift_ms: float = 10,
) -> Tuple[List[List[str]], List[List[float]]]:
"""Parse hypothesis and timestamps of Chinese characters.
Args:
res:
A DecodingResults object.
subsampling_factor:
The integer subsampling factor.
word_table:
The word symbol table.
frame_shift_ms:
The float frame shift used for feature extraction.
Returns:
Return a list of hypothesis and timestamp.
"""
hyps = []
timestamps = []
N = len(res.hyps)
assert len(res.timestamps) == N, (len(res.timestamps), N)
assert word_table is not None
for i in range(N):
time = convert_timestamp_duration(res.timestamps[i], subsampling_factor, frame_shift_ms)
words_decoded = [word_table[idx] for idx in res.hyps[i]]
hyps.append(words_decoded)
timestamps.append(time)
assert len(time) == len(words_decoded), (len(time), len(words_decoded))
return hyps, timestamps
# `is_module_available` is copied from
# https://github.com/pytorch/audio/blob/6bad3a66a7a1c7cc05755e9ee5931b7391d2b94c/torchaudio/_internal/module_utils.py#L9