mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
revert changes in decode.py and utils.py
This commit is contained in:
parent
85f95db6f9
commit
95b2408ed1
@ -26,7 +26,7 @@ import torch
|
||||
from icefall.context_graph import ContextGraph, ContextState
|
||||
from icefall.lm_wrapper import LmScorer
|
||||
from icefall.ngram_lm import NgramLm, NgramLmStateCost
|
||||
from icefall.utils import DecodingResults, add_eos, add_sos, get_texts
|
||||
from icefall.utils import add_eos, add_sos, get_texts
|
||||
|
||||
DEFAULT_LM_SCALE = [
|
||||
0.01,
|
||||
@ -1485,52 +1485,25 @@ def ctc_greedy_search(
|
||||
ctc_output: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
blank_id: int = 0,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
) -> List[List[int]]:
|
||||
"""CTC greedy search.
|
||||
|
||||
Args:
|
||||
ctc_output: (batch, seq_len, vocab_size)
|
||||
encoder_out_lens: (batch,)
|
||||
Returns:
|
||||
Union[List[List[int]], DecodingResults]: greedy search result
|
||||
List[List[int]]: greedy search result
|
||||
"""
|
||||
batch = ctc_output.shape[0]
|
||||
index = ctc_output.argmax(dim=-1) # (batch, seq_len)
|
||||
|
||||
hyps = [[] for _ in range(batch)]
|
||||
# timestamps[n][i] is the frame index after subsampling
|
||||
# on which hyp[n][i] is decoded
|
||||
timestamps = [[] for _ in range(batch)]
|
||||
# scores[n][i] is the logits on which hyp[n][i] is decoded
|
||||
scores = [[] for _ in range(batch)]
|
||||
|
||||
for i in range(batch):
|
||||
cur_pos = -1
|
||||
last = -1
|
||||
next_other = torch.where(index[i,0 : encoder_out_lens[i]] != last)[0]
|
||||
flag = False # whether decoding words
|
||||
while next_other.size(0) > 0:
|
||||
cur_pos = cur_pos + 1 + next_other[0].item()
|
||||
last = index[i,cur_pos].item()
|
||||
if flag: # last word decoding finished
|
||||
timestamps[i][-1] = timestamps[i][-1] + (cur_pos - 1, )
|
||||
if last != blank_id:
|
||||
hyps[i].append(last)
|
||||
timestamps[i].append((cur_pos,))
|
||||
scores[i].append(ctc_output[i, cur_pos, last].item())
|
||||
flag = True # Decoding words
|
||||
else:
|
||||
flag = False
|
||||
|
||||
next_other = torch.where(index[i,cur_pos + 1 : encoder_out_lens[i]] != last)[0]
|
||||
if flag:
|
||||
timestamps[i][-1] = timestamps[i][-1] + (encoder_out_lens[i].item() - 1, )
|
||||
|
||||
if not return_timestamps:
|
||||
return hyps
|
||||
else:
|
||||
return DecodingResults(hyps = hyps, timestamps = timestamps, scores = scores)
|
||||
hyps = [
|
||||
torch.unique_consecutive(index[i, : encoder_out_lens[i]]) for i in range(batch)
|
||||
]
|
||||
|
||||
hyps = [h[h != blank_id].tolist() for h in hyps]
|
||||
|
||||
return hyps
|
||||
|
||||
|
||||
@dataclass
|
||||
class Hypothesis:
|
||||
|
@ -360,7 +360,7 @@ class KeywordResult:
|
||||
class DecodingResults:
|
||||
# timestamps[i][k] contains the frame number on which tokens[i][k]
|
||||
# is decoded
|
||||
timestamps: List[List[Union[int, Tuple[int, int]]]]
|
||||
timestamps: List[List[int]]
|
||||
|
||||
# hyps[i] is the recognition results, i.e., word IDs or token IDs
|
||||
# for the i-th utterance with fast_beam_search_nbest_LG.
|
||||
@ -583,40 +583,6 @@ def store_transcripts_and_timestamps(
|
||||
s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]"
|
||||
print(f"{cut_id}:\ttimestamp_hyp={s}", file=f)
|
||||
|
||||
def store_transcripts_and_timestamps_withoutref(
|
||||
filename: Pathlike,
|
||||
texts: Iterable[Tuple[str, List[str], List[str], List[Tuple[float]]]],
|
||||
) -> None:
|
||||
"""Save predicted results and reference transcripts as well as their timestamps
|
||||
to a file.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
File to save the results to.
|
||||
texts:
|
||||
An iterable of tuples. The first element is the cur_id, the second is
|
||||
the reference transcript and the third element is the predicted result.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
with open(filename, "w", encoding="utf8") as f:
|
||||
for cut_id, ref, hyp, time_hyp in texts:
|
||||
print(f"{cut_id}:\tref={ref}", file=f)
|
||||
print(f"{cut_id}:\thyp={hyp}", file=f)
|
||||
|
||||
if len(time_hyp) > 0:
|
||||
if isinstance(time_hyp[0], tuple):
|
||||
# each element is <start, end> pair
|
||||
s = (
|
||||
"["
|
||||
+ ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_hyp])
|
||||
+ "]"
|
||||
)
|
||||
else:
|
||||
# each element is a float number
|
||||
s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]"
|
||||
print(f"{cut_id}:\ttimestamp_hyp={s}", file=f)
|
||||
|
||||
|
||||
def write_error_stats(
|
||||
f: TextIO,
|
||||
@ -1873,30 +1839,6 @@ def convert_timestamp(
|
||||
|
||||
return time
|
||||
|
||||
def convert_timestamp_duration(
|
||||
frames: List[Tuple[int, int]],
|
||||
subsampling_factor: int,
|
||||
frame_shift_ms: float = 10,
|
||||
) -> List[Tuple[float, float]]:
|
||||
"""Convert frame numbers to time (in seconds) given subsampling factor
|
||||
and frame shift (in milliseconds).
|
||||
|
||||
Args:
|
||||
frames:
|
||||
A list of frame numbers after subsampling.
|
||||
subsampling_factor:
|
||||
The subsampling factor of the model.
|
||||
frame_shift_ms:
|
||||
Frame shift in milliseconds between two contiguous frames.
|
||||
Return:
|
||||
Return the time in seconds corresponding to each given frame.
|
||||
"""
|
||||
frame_shift = frame_shift_ms / 1000.0
|
||||
time = []
|
||||
for f_start, f_end in frames:
|
||||
time.append((round(f_start * subsampling_factor * frame_shift, ndigits=3), round(f_end * subsampling_factor * frame_shift, ndigits=3)))
|
||||
|
||||
return time
|
||||
|
||||
def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]:
|
||||
"""
|
||||
@ -1982,43 +1924,6 @@ def parse_hyp_and_timestamp(
|
||||
|
||||
return hyps, timestamps
|
||||
|
||||
def parse_hyp_and_timestamp_ch(
|
||||
res: DecodingResults,
|
||||
subsampling_factor: int,
|
||||
word_table: k2.SymbolTable,
|
||||
frame_shift_ms: float = 10,
|
||||
) -> Tuple[List[List[str]], List[List[float]]]:
|
||||
"""Parse hypothesis and timestamps of Chinese characters.
|
||||
|
||||
Args:
|
||||
res:
|
||||
A DecodingResults object.
|
||||
subsampling_factor:
|
||||
The integer subsampling factor.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
frame_shift_ms:
|
||||
The float frame shift used for feature extraction.
|
||||
|
||||
Returns:
|
||||
Return a list of hypothesis and timestamp.
|
||||
"""
|
||||
hyps = []
|
||||
timestamps = []
|
||||
|
||||
N = len(res.hyps)
|
||||
assert len(res.timestamps) == N, (len(res.timestamps), N)
|
||||
assert word_table is not None
|
||||
|
||||
for i in range(N):
|
||||
time = convert_timestamp_duration(res.timestamps[i], subsampling_factor, frame_shift_ms)
|
||||
words_decoded = [word_table[idx] for idx in res.hyps[i]]
|
||||
hyps.append(words_decoded)
|
||||
timestamps.append(time)
|
||||
assert len(time) == len(words_decoded), (len(time), len(words_decoded))
|
||||
|
||||
return hyps, timestamps
|
||||
|
||||
|
||||
# `is_module_available` is copied from
|
||||
# https://github.com/pytorch/audio/blob/6bad3a66a7a1c7cc05755e9ee5931b7391d2b94c/torchaudio/_internal/module_utils.py#L9
|
||||
|
Loading…
x
Reference in New Issue
Block a user