mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
revert changes in decode.py and utils.py
This commit is contained in:
parent
85f95db6f9
commit
95b2408ed1
@ -26,7 +26,7 @@ import torch
|
|||||||
from icefall.context_graph import ContextGraph, ContextState
|
from icefall.context_graph import ContextGraph, ContextState
|
||||||
from icefall.lm_wrapper import LmScorer
|
from icefall.lm_wrapper import LmScorer
|
||||||
from icefall.ngram_lm import NgramLm, NgramLmStateCost
|
from icefall.ngram_lm import NgramLm, NgramLmStateCost
|
||||||
from icefall.utils import DecodingResults, add_eos, add_sos, get_texts
|
from icefall.utils import add_eos, add_sos, get_texts
|
||||||
|
|
||||||
DEFAULT_LM_SCALE = [
|
DEFAULT_LM_SCALE = [
|
||||||
0.01,
|
0.01,
|
||||||
@ -1485,52 +1485,25 @@ def ctc_greedy_search(
|
|||||||
ctc_output: torch.Tensor,
|
ctc_output: torch.Tensor,
|
||||||
encoder_out_lens: torch.Tensor,
|
encoder_out_lens: torch.Tensor,
|
||||||
blank_id: int = 0,
|
blank_id: int = 0,
|
||||||
return_timestamps: bool = False,
|
) -> List[List[int]]:
|
||||||
) -> Union[List[List[int]], DecodingResults]:
|
|
||||||
"""CTC greedy search.
|
"""CTC greedy search.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ctc_output: (batch, seq_len, vocab_size)
|
ctc_output: (batch, seq_len, vocab_size)
|
||||||
encoder_out_lens: (batch,)
|
encoder_out_lens: (batch,)
|
||||||
Returns:
|
Returns:
|
||||||
Union[List[List[int]], DecodingResults]: greedy search result
|
List[List[int]]: greedy search result
|
||||||
"""
|
"""
|
||||||
batch = ctc_output.shape[0]
|
batch = ctc_output.shape[0]
|
||||||
index = ctc_output.argmax(dim=-1) # (batch, seq_len)
|
index = ctc_output.argmax(dim=-1) # (batch, seq_len)
|
||||||
|
hyps = [
|
||||||
|
torch.unique_consecutive(index[i, : encoder_out_lens[i]]) for i in range(batch)
|
||||||
|
]
|
||||||
|
|
||||||
hyps = [[] for _ in range(batch)]
|
hyps = [h[h != blank_id].tolist() for h in hyps]
|
||||||
# timestamps[n][i] is the frame index after subsampling
|
|
||||||
# on which hyp[n][i] is decoded
|
|
||||||
timestamps = [[] for _ in range(batch)]
|
|
||||||
# scores[n][i] is the logits on which hyp[n][i] is decoded
|
|
||||||
scores = [[] for _ in range(batch)]
|
|
||||||
|
|
||||||
for i in range(batch):
|
|
||||||
cur_pos = -1
|
|
||||||
last = -1
|
|
||||||
next_other = torch.where(index[i,0 : encoder_out_lens[i]] != last)[0]
|
|
||||||
flag = False # whether decoding words
|
|
||||||
while next_other.size(0) > 0:
|
|
||||||
cur_pos = cur_pos + 1 + next_other[0].item()
|
|
||||||
last = index[i,cur_pos].item()
|
|
||||||
if flag: # last word decoding finished
|
|
||||||
timestamps[i][-1] = timestamps[i][-1] + (cur_pos - 1, )
|
|
||||||
if last != blank_id:
|
|
||||||
hyps[i].append(last)
|
|
||||||
timestamps[i].append((cur_pos,))
|
|
||||||
scores[i].append(ctc_output[i, cur_pos, last].item())
|
|
||||||
flag = True # Decoding words
|
|
||||||
else:
|
|
||||||
flag = False
|
|
||||||
|
|
||||||
next_other = torch.where(index[i,cur_pos + 1 : encoder_out_lens[i]] != last)[0]
|
|
||||||
if flag:
|
|
||||||
timestamps[i][-1] = timestamps[i][-1] + (encoder_out_lens[i].item() - 1, )
|
|
||||||
|
|
||||||
if not return_timestamps:
|
|
||||||
return hyps
|
return hyps
|
||||||
else:
|
|
||||||
return DecodingResults(hyps = hyps, timestamps = timestamps, scores = scores)
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Hypothesis:
|
class Hypothesis:
|
||||||
|
@ -360,7 +360,7 @@ class KeywordResult:
|
|||||||
class DecodingResults:
|
class DecodingResults:
|
||||||
# timestamps[i][k] contains the frame number on which tokens[i][k]
|
# timestamps[i][k] contains the frame number on which tokens[i][k]
|
||||||
# is decoded
|
# is decoded
|
||||||
timestamps: List[List[Union[int, Tuple[int, int]]]]
|
timestamps: List[List[int]]
|
||||||
|
|
||||||
# hyps[i] is the recognition results, i.e., word IDs or token IDs
|
# hyps[i] is the recognition results, i.e., word IDs or token IDs
|
||||||
# for the i-th utterance with fast_beam_search_nbest_LG.
|
# for the i-th utterance with fast_beam_search_nbest_LG.
|
||||||
@ -583,40 +583,6 @@ def store_transcripts_and_timestamps(
|
|||||||
s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]"
|
s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]"
|
||||||
print(f"{cut_id}:\ttimestamp_hyp={s}", file=f)
|
print(f"{cut_id}:\ttimestamp_hyp={s}", file=f)
|
||||||
|
|
||||||
def store_transcripts_and_timestamps_withoutref(
|
|
||||||
filename: Pathlike,
|
|
||||||
texts: Iterable[Tuple[str, List[str], List[str], List[Tuple[float]]]],
|
|
||||||
) -> None:
|
|
||||||
"""Save predicted results and reference transcripts as well as their timestamps
|
|
||||||
to a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filename:
|
|
||||||
File to save the results to.
|
|
||||||
texts:
|
|
||||||
An iterable of tuples. The first element is the cur_id, the second is
|
|
||||||
the reference transcript and the third element is the predicted result.
|
|
||||||
Returns:
|
|
||||||
Return None.
|
|
||||||
"""
|
|
||||||
with open(filename, "w", encoding="utf8") as f:
|
|
||||||
for cut_id, ref, hyp, time_hyp in texts:
|
|
||||||
print(f"{cut_id}:\tref={ref}", file=f)
|
|
||||||
print(f"{cut_id}:\thyp={hyp}", file=f)
|
|
||||||
|
|
||||||
if len(time_hyp) > 0:
|
|
||||||
if isinstance(time_hyp[0], tuple):
|
|
||||||
# each element is <start, end> pair
|
|
||||||
s = (
|
|
||||||
"["
|
|
||||||
+ ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_hyp])
|
|
||||||
+ "]"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# each element is a float number
|
|
||||||
s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]"
|
|
||||||
print(f"{cut_id}:\ttimestamp_hyp={s}", file=f)
|
|
||||||
|
|
||||||
|
|
||||||
def write_error_stats(
|
def write_error_stats(
|
||||||
f: TextIO,
|
f: TextIO,
|
||||||
@ -1873,30 +1839,6 @@ def convert_timestamp(
|
|||||||
|
|
||||||
return time
|
return time
|
||||||
|
|
||||||
def convert_timestamp_duration(
|
|
||||||
frames: List[Tuple[int, int]],
|
|
||||||
subsampling_factor: int,
|
|
||||||
frame_shift_ms: float = 10,
|
|
||||||
) -> List[Tuple[float, float]]:
|
|
||||||
"""Convert frame numbers to time (in seconds) given subsampling factor
|
|
||||||
and frame shift (in milliseconds).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
frames:
|
|
||||||
A list of frame numbers after subsampling.
|
|
||||||
subsampling_factor:
|
|
||||||
The subsampling factor of the model.
|
|
||||||
frame_shift_ms:
|
|
||||||
Frame shift in milliseconds between two contiguous frames.
|
|
||||||
Return:
|
|
||||||
Return the time in seconds corresponding to each given frame.
|
|
||||||
"""
|
|
||||||
frame_shift = frame_shift_ms / 1000.0
|
|
||||||
time = []
|
|
||||||
for f_start, f_end in frames:
|
|
||||||
time.append((round(f_start * subsampling_factor * frame_shift, ndigits=3), round(f_end * subsampling_factor * frame_shift, ndigits=3)))
|
|
||||||
|
|
||||||
return time
|
|
||||||
|
|
||||||
def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]:
|
def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]:
|
||||||
"""
|
"""
|
||||||
@ -1982,43 +1924,6 @@ def parse_hyp_and_timestamp(
|
|||||||
|
|
||||||
return hyps, timestamps
|
return hyps, timestamps
|
||||||
|
|
||||||
def parse_hyp_and_timestamp_ch(
|
|
||||||
res: DecodingResults,
|
|
||||||
subsampling_factor: int,
|
|
||||||
word_table: k2.SymbolTable,
|
|
||||||
frame_shift_ms: float = 10,
|
|
||||||
) -> Tuple[List[List[str]], List[List[float]]]:
|
|
||||||
"""Parse hypothesis and timestamps of Chinese characters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
res:
|
|
||||||
A DecodingResults object.
|
|
||||||
subsampling_factor:
|
|
||||||
The integer subsampling factor.
|
|
||||||
word_table:
|
|
||||||
The word symbol table.
|
|
||||||
frame_shift_ms:
|
|
||||||
The float frame shift used for feature extraction.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Return a list of hypothesis and timestamp.
|
|
||||||
"""
|
|
||||||
hyps = []
|
|
||||||
timestamps = []
|
|
||||||
|
|
||||||
N = len(res.hyps)
|
|
||||||
assert len(res.timestamps) == N, (len(res.timestamps), N)
|
|
||||||
assert word_table is not None
|
|
||||||
|
|
||||||
for i in range(N):
|
|
||||||
time = convert_timestamp_duration(res.timestamps[i], subsampling_factor, frame_shift_ms)
|
|
||||||
words_decoded = [word_table[idx] for idx in res.hyps[i]]
|
|
||||||
hyps.append(words_decoded)
|
|
||||||
timestamps.append(time)
|
|
||||||
assert len(time) == len(words_decoded), (len(time), len(words_decoded))
|
|
||||||
|
|
||||||
return hyps, timestamps
|
|
||||||
|
|
||||||
|
|
||||||
# `is_module_available` is copied from
|
# `is_module_available` is copied from
|
||||||
# https://github.com/pytorch/audio/blob/6bad3a66a7a1c7cc05755e9ee5931b7391d2b94c/torchaudio/_internal/module_utils.py#L9
|
# https://github.com/pytorch/audio/blob/6bad3a66a7a1c7cc05755e9ee5931b7391d2b94c/torchaudio/_internal/module_utils.py#L9
|
||||||
|
Loading…
x
Reference in New Issue
Block a user