mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Refactor getting timestamps in fsa-based decoding (#660)
* refactor getting timestamps for fsa-based decoding * fix doc * fix bug
This commit is contained in:
parent
3600ce1b5f
commit
32de2766d5
@ -487,7 +487,7 @@ def decode_one_batch(
|
||||
)
|
||||
tokens.extend(res.tokens)
|
||||
timestamps.extend(res.timestamps)
|
||||
res = DecodingResults(tokens=tokens, timestamps=timestamps)
|
||||
res = DecodingResults(hyps=tokens, timestamps=timestamps)
|
||||
|
||||
hyps, timestamps = parse_hyp_and_timestamp(
|
||||
decoding_method=params.decoding_method,
|
||||
|
@ -598,7 +598,7 @@ def greedy_search(
|
||||
return hyp
|
||||
else:
|
||||
return DecodingResults(
|
||||
tokens=[hyp],
|
||||
hyps=[hyp],
|
||||
timestamps=[timestamp],
|
||||
)
|
||||
|
||||
@ -712,7 +712,7 @@ def greedy_search_batch(
|
||||
return ans
|
||||
else:
|
||||
return DecodingResults(
|
||||
tokens=ans,
|
||||
hyps=ans,
|
||||
timestamps=ans_timestamps,
|
||||
)
|
||||
|
||||
@ -1049,7 +1049,7 @@ def modified_beam_search(
|
||||
return ans
|
||||
else:
|
||||
return DecodingResults(
|
||||
tokens=ans,
|
||||
hyps=ans,
|
||||
timestamps=ans_timestamps,
|
||||
)
|
||||
|
||||
@ -1176,7 +1176,7 @@ def _deprecated_modified_beam_search(
|
||||
if not return_timestamps:
|
||||
return ys
|
||||
else:
|
||||
return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp])
|
||||
return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp])
|
||||
|
||||
|
||||
def beam_search(
|
||||
@ -1336,7 +1336,7 @@ def beam_search(
|
||||
if not return_timestamps:
|
||||
return ys
|
||||
else:
|
||||
return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp])
|
||||
return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp])
|
||||
|
||||
|
||||
def fast_beam_search_with_nbest_rescoring(
|
||||
|
@ -531,7 +531,7 @@ def decode_one_batch(
|
||||
)
|
||||
tokens.extend(res.tokens)
|
||||
timestamps.extend(res.timestamps)
|
||||
res = DecodingResults(tokens=tokens, timestamps=timestamps)
|
||||
res = DecodingResults(hyps=tokens, timestamps=timestamps)
|
||||
|
||||
hyps, timestamps = parse_hyp_and_timestamp(
|
||||
decoding_method=params.decoding_method,
|
||||
|
@ -251,33 +251,20 @@ def get_texts(
|
||||
|
||||
@dataclass
|
||||
class DecodingResults:
|
||||
# Decoded token IDs for each utterance in the batch
|
||||
tokens: List[List[int]]
|
||||
|
||||
# timestamps[i][k] contains the frame number on which tokens[i][k]
|
||||
# is decoded
|
||||
timestamps: List[List[int]]
|
||||
|
||||
# hyps[i] is the recognition results, i.e., word IDs
|
||||
# hyps[i] is the recognition results, i.e., word IDs or token IDs
|
||||
# for the i-th utterance with fast_beam_search_nbest_LG.
|
||||
hyps: Union[List[List[int]], k2.RaggedTensor] = None
|
||||
|
||||
|
||||
def get_tokens_and_timestamps(labels: List[int]) -> Tuple[List[int], List[int]]:
|
||||
tokens = []
|
||||
timestamps = []
|
||||
for i, v in enumerate(labels):
|
||||
if v != 0:
|
||||
tokens.append(v)
|
||||
timestamps.append(i)
|
||||
|
||||
return tokens, timestamps
|
||||
hyps: Union[List[List[int]], k2.RaggedTensor]
|
||||
|
||||
|
||||
def get_texts_with_timestamp(
|
||||
best_paths: k2.Fsa, return_ragged: bool = False
|
||||
) -> DecodingResults:
|
||||
"""Extract the texts (as word IDs) and timestamps from the best-path FSAs.
|
||||
"""Extract the texts (as word IDs) and timestamps (as frame indexes)
|
||||
from the best-path FSAs.
|
||||
Args:
|
||||
best_paths:
|
||||
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
||||
@ -292,11 +279,18 @@ def get_texts_with_timestamp(
|
||||
decoded.
|
||||
"""
|
||||
if isinstance(best_paths.aux_labels, k2.RaggedTensor):
|
||||
all_aux_shape = (
|
||||
best_paths.arcs.shape()
|
||||
.remove_axis(1)
|
||||
.compose(best_paths.aux_labels.shape)
|
||||
)
|
||||
all_aux_labels = k2.RaggedTensor(
|
||||
all_aux_shape, best_paths.aux_labels.values
|
||||
)
|
||||
# remove 0's and -1's.
|
||||
aux_labels = best_paths.aux_labels.remove_values_leq(0)
|
||||
# TODO: change arcs.shape() to arcs.shape
|
||||
aux_shape = best_paths.arcs.shape().compose(aux_labels.shape)
|
||||
|
||||
# remove the states and arcs axes.
|
||||
aux_shape = aux_shape.remove_axis(1)
|
||||
aux_shape = aux_shape.remove_axis(1)
|
||||
@ -304,26 +298,26 @@ def get_texts_with_timestamp(
|
||||
else:
|
||||
# remove axis corresponding to states.
|
||||
aux_shape = best_paths.arcs.shape().remove_axis(1)
|
||||
aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels)
|
||||
all_aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels)
|
||||
# remove 0's and -1's.
|
||||
aux_labels = aux_labels.remove_values_leq(0)
|
||||
aux_labels = all_aux_labels.remove_values_leq(0)
|
||||
|
||||
assert aux_labels.num_axes == 2
|
||||
|
||||
labels_shape = best_paths.arcs.shape().remove_axis(1)
|
||||
labels_list = k2.RaggedTensor(
|
||||
labels_shape, best_paths.labels.contiguous()
|
||||
).tolist()
|
||||
|
||||
tokens = []
|
||||
timestamps = []
|
||||
for labels in labels_list:
|
||||
token, time = get_tokens_and_timestamps(labels[:-1])
|
||||
tokens.append(token)
|
||||
timestamps.append(time)
|
||||
if isinstance(best_paths.aux_labels, k2.RaggedTensor):
|
||||
for p in range(all_aux_labels.dim0):
|
||||
time = []
|
||||
for i, arc in enumerate(all_aux_labels[p].tolist()):
|
||||
if len(arc) == 1 and arc[0] > 0:
|
||||
time.append(i)
|
||||
timestamps.append(time)
|
||||
else:
|
||||
for labels in all_aux_labels.tolist():
|
||||
time = [i for i, v in enumerate(labels) if v > 0]
|
||||
timestamps.append(time)
|
||||
|
||||
return DecodingResults(
|
||||
tokens=tokens,
|
||||
timestamps=timestamps,
|
||||
hyps=aux_labels if return_ragged else aux_labels.tolist(),
|
||||
)
|
||||
@ -1399,8 +1393,8 @@ def parse_hyp_and_timestamp(
|
||||
hyps = []
|
||||
timestamps = []
|
||||
|
||||
N = len(res.tokens)
|
||||
assert len(res.timestamps) == N
|
||||
N = len(res.hyps)
|
||||
assert len(res.timestamps) == N, (len(res.timestamps), N)
|
||||
use_word_table = False
|
||||
if (
|
||||
decoding_method == "fast_beam_search_nbest_LG"
|
||||
@ -1410,16 +1404,16 @@ def parse_hyp_and_timestamp(
|
||||
use_word_table = True
|
||||
|
||||
for i in range(N):
|
||||
tokens = sp.id_to_piece(res.tokens[i])
|
||||
if use_word_table:
|
||||
words = [word_table[i] for i in res.hyps[i]]
|
||||
else:
|
||||
words = sp.decode_pieces(tokens).split()
|
||||
time = convert_timestamp(
|
||||
res.timestamps[i], subsampling_factor, frame_shift_ms
|
||||
)
|
||||
time = parse_timestamp(tokens, time)
|
||||
assert len(time) == len(words), (tokens, words)
|
||||
if use_word_table:
|
||||
words = [word_table[i] for i in res.hyps[i]]
|
||||
else:
|
||||
tokens = sp.id_to_piece(res.hyps[i])
|
||||
words = sp.decode_pieces(tokens).split()
|
||||
time = parse_timestamp(tokens, time)
|
||||
assert len(time) == len(words), (len(time), len(words))
|
||||
|
||||
hyps.append(words)
|
||||
timestamps.append(time)
|
||||
|
Loading…
x
Reference in New Issue
Block a user