mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Refactor getting timestamps in fsa-based decoding (#660)
* refactor getting timestamps for fsa-based decoding * fix doc * fix bug
This commit is contained in:
parent
3600ce1b5f
commit
32de2766d5
@ -487,7 +487,7 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
tokens.extend(res.tokens)
|
tokens.extend(res.tokens)
|
||||||
timestamps.extend(res.timestamps)
|
timestamps.extend(res.timestamps)
|
||||||
res = DecodingResults(tokens=tokens, timestamps=timestamps)
|
res = DecodingResults(hyps=tokens, timestamps=timestamps)
|
||||||
|
|
||||||
hyps, timestamps = parse_hyp_and_timestamp(
|
hyps, timestamps = parse_hyp_and_timestamp(
|
||||||
decoding_method=params.decoding_method,
|
decoding_method=params.decoding_method,
|
||||||
|
@ -598,7 +598,7 @@ def greedy_search(
|
|||||||
return hyp
|
return hyp
|
||||||
else:
|
else:
|
||||||
return DecodingResults(
|
return DecodingResults(
|
||||||
tokens=[hyp],
|
hyps=[hyp],
|
||||||
timestamps=[timestamp],
|
timestamps=[timestamp],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -712,7 +712,7 @@ def greedy_search_batch(
|
|||||||
return ans
|
return ans
|
||||||
else:
|
else:
|
||||||
return DecodingResults(
|
return DecodingResults(
|
||||||
tokens=ans,
|
hyps=ans,
|
||||||
timestamps=ans_timestamps,
|
timestamps=ans_timestamps,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1049,7 +1049,7 @@ def modified_beam_search(
|
|||||||
return ans
|
return ans
|
||||||
else:
|
else:
|
||||||
return DecodingResults(
|
return DecodingResults(
|
||||||
tokens=ans,
|
hyps=ans,
|
||||||
timestamps=ans_timestamps,
|
timestamps=ans_timestamps,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1176,7 +1176,7 @@ def _deprecated_modified_beam_search(
|
|||||||
if not return_timestamps:
|
if not return_timestamps:
|
||||||
return ys
|
return ys
|
||||||
else:
|
else:
|
||||||
return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp])
|
return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp])
|
||||||
|
|
||||||
|
|
||||||
def beam_search(
|
def beam_search(
|
||||||
@ -1336,7 +1336,7 @@ def beam_search(
|
|||||||
if not return_timestamps:
|
if not return_timestamps:
|
||||||
return ys
|
return ys
|
||||||
else:
|
else:
|
||||||
return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp])
|
return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp])
|
||||||
|
|
||||||
|
|
||||||
def fast_beam_search_with_nbest_rescoring(
|
def fast_beam_search_with_nbest_rescoring(
|
||||||
|
@ -531,7 +531,7 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
tokens.extend(res.tokens)
|
tokens.extend(res.tokens)
|
||||||
timestamps.extend(res.timestamps)
|
timestamps.extend(res.timestamps)
|
||||||
res = DecodingResults(tokens=tokens, timestamps=timestamps)
|
res = DecodingResults(hyps=tokens, timestamps=timestamps)
|
||||||
|
|
||||||
hyps, timestamps = parse_hyp_and_timestamp(
|
hyps, timestamps = parse_hyp_and_timestamp(
|
||||||
decoding_method=params.decoding_method,
|
decoding_method=params.decoding_method,
|
||||||
|
@ -251,33 +251,20 @@ def get_texts(
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DecodingResults:
|
class DecodingResults:
|
||||||
# Decoded token IDs for each utterance in the batch
|
|
||||||
tokens: List[List[int]]
|
|
||||||
|
|
||||||
# timestamps[i][k] contains the frame number on which tokens[i][k]
|
# timestamps[i][k] contains the frame number on which tokens[i][k]
|
||||||
# is decoded
|
# is decoded
|
||||||
timestamps: List[List[int]]
|
timestamps: List[List[int]]
|
||||||
|
|
||||||
# hyps[i] is the recognition results, i.e., word IDs
|
# hyps[i] is the recognition results, i.e., word IDs or token IDs
|
||||||
# for the i-th utterance with fast_beam_search_nbest_LG.
|
# for the i-th utterance with fast_beam_search_nbest_LG.
|
||||||
hyps: Union[List[List[int]], k2.RaggedTensor] = None
|
hyps: Union[List[List[int]], k2.RaggedTensor]
|
||||||
|
|
||||||
|
|
||||||
def get_tokens_and_timestamps(labels: List[int]) -> Tuple[List[int], List[int]]:
|
|
||||||
tokens = []
|
|
||||||
timestamps = []
|
|
||||||
for i, v in enumerate(labels):
|
|
||||||
if v != 0:
|
|
||||||
tokens.append(v)
|
|
||||||
timestamps.append(i)
|
|
||||||
|
|
||||||
return tokens, timestamps
|
|
||||||
|
|
||||||
|
|
||||||
def get_texts_with_timestamp(
|
def get_texts_with_timestamp(
|
||||||
best_paths: k2.Fsa, return_ragged: bool = False
|
best_paths: k2.Fsa, return_ragged: bool = False
|
||||||
) -> DecodingResults:
|
) -> DecodingResults:
|
||||||
"""Extract the texts (as word IDs) and timestamps from the best-path FSAs.
|
"""Extract the texts (as word IDs) and timestamps (as frame indexes)
|
||||||
|
from the best-path FSAs.
|
||||||
Args:
|
Args:
|
||||||
best_paths:
|
best_paths:
|
||||||
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
||||||
@ -292,11 +279,18 @@ def get_texts_with_timestamp(
|
|||||||
decoded.
|
decoded.
|
||||||
"""
|
"""
|
||||||
if isinstance(best_paths.aux_labels, k2.RaggedTensor):
|
if isinstance(best_paths.aux_labels, k2.RaggedTensor):
|
||||||
|
all_aux_shape = (
|
||||||
|
best_paths.arcs.shape()
|
||||||
|
.remove_axis(1)
|
||||||
|
.compose(best_paths.aux_labels.shape)
|
||||||
|
)
|
||||||
|
all_aux_labels = k2.RaggedTensor(
|
||||||
|
all_aux_shape, best_paths.aux_labels.values
|
||||||
|
)
|
||||||
# remove 0's and -1's.
|
# remove 0's and -1's.
|
||||||
aux_labels = best_paths.aux_labels.remove_values_leq(0)
|
aux_labels = best_paths.aux_labels.remove_values_leq(0)
|
||||||
# TODO: change arcs.shape() to arcs.shape
|
# TODO: change arcs.shape() to arcs.shape
|
||||||
aux_shape = best_paths.arcs.shape().compose(aux_labels.shape)
|
aux_shape = best_paths.arcs.shape().compose(aux_labels.shape)
|
||||||
|
|
||||||
# remove the states and arcs axes.
|
# remove the states and arcs axes.
|
||||||
aux_shape = aux_shape.remove_axis(1)
|
aux_shape = aux_shape.remove_axis(1)
|
||||||
aux_shape = aux_shape.remove_axis(1)
|
aux_shape = aux_shape.remove_axis(1)
|
||||||
@ -304,26 +298,26 @@ def get_texts_with_timestamp(
|
|||||||
else:
|
else:
|
||||||
# remove axis corresponding to states.
|
# remove axis corresponding to states.
|
||||||
aux_shape = best_paths.arcs.shape().remove_axis(1)
|
aux_shape = best_paths.arcs.shape().remove_axis(1)
|
||||||
aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels)
|
all_aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels)
|
||||||
# remove 0's and -1's.
|
# remove 0's and -1's.
|
||||||
aux_labels = aux_labels.remove_values_leq(0)
|
aux_labels = all_aux_labels.remove_values_leq(0)
|
||||||
|
|
||||||
assert aux_labels.num_axes == 2
|
assert aux_labels.num_axes == 2
|
||||||
|
|
||||||
labels_shape = best_paths.arcs.shape().remove_axis(1)
|
|
||||||
labels_list = k2.RaggedTensor(
|
|
||||||
labels_shape, best_paths.labels.contiguous()
|
|
||||||
).tolist()
|
|
||||||
|
|
||||||
tokens = []
|
|
||||||
timestamps = []
|
timestamps = []
|
||||||
for labels in labels_list:
|
if isinstance(best_paths.aux_labels, k2.RaggedTensor):
|
||||||
token, time = get_tokens_and_timestamps(labels[:-1])
|
for p in range(all_aux_labels.dim0):
|
||||||
tokens.append(token)
|
time = []
|
||||||
|
for i, arc in enumerate(all_aux_labels[p].tolist()):
|
||||||
|
if len(arc) == 1 and arc[0] > 0:
|
||||||
|
time.append(i)
|
||||||
|
timestamps.append(time)
|
||||||
|
else:
|
||||||
|
for labels in all_aux_labels.tolist():
|
||||||
|
time = [i for i, v in enumerate(labels) if v > 0]
|
||||||
timestamps.append(time)
|
timestamps.append(time)
|
||||||
|
|
||||||
return DecodingResults(
|
return DecodingResults(
|
||||||
tokens=tokens,
|
|
||||||
timestamps=timestamps,
|
timestamps=timestamps,
|
||||||
hyps=aux_labels if return_ragged else aux_labels.tolist(),
|
hyps=aux_labels if return_ragged else aux_labels.tolist(),
|
||||||
)
|
)
|
||||||
@ -1399,8 +1393,8 @@ def parse_hyp_and_timestamp(
|
|||||||
hyps = []
|
hyps = []
|
||||||
timestamps = []
|
timestamps = []
|
||||||
|
|
||||||
N = len(res.tokens)
|
N = len(res.hyps)
|
||||||
assert len(res.timestamps) == N
|
assert len(res.timestamps) == N, (len(res.timestamps), N)
|
||||||
use_word_table = False
|
use_word_table = False
|
||||||
if (
|
if (
|
||||||
decoding_method == "fast_beam_search_nbest_LG"
|
decoding_method == "fast_beam_search_nbest_LG"
|
||||||
@ -1410,16 +1404,16 @@ def parse_hyp_and_timestamp(
|
|||||||
use_word_table = True
|
use_word_table = True
|
||||||
|
|
||||||
for i in range(N):
|
for i in range(N):
|
||||||
tokens = sp.id_to_piece(res.tokens[i])
|
|
||||||
if use_word_table:
|
|
||||||
words = [word_table[i] for i in res.hyps[i]]
|
|
||||||
else:
|
|
||||||
words = sp.decode_pieces(tokens).split()
|
|
||||||
time = convert_timestamp(
|
time = convert_timestamp(
|
||||||
res.timestamps[i], subsampling_factor, frame_shift_ms
|
res.timestamps[i], subsampling_factor, frame_shift_ms
|
||||||
)
|
)
|
||||||
|
if use_word_table:
|
||||||
|
words = [word_table[i] for i in res.hyps[i]]
|
||||||
|
else:
|
||||||
|
tokens = sp.id_to_piece(res.hyps[i])
|
||||||
|
words = sp.decode_pieces(tokens).split()
|
||||||
time = parse_timestamp(tokens, time)
|
time = parse_timestamp(tokens, time)
|
||||||
assert len(time) == len(words), (tokens, words)
|
assert len(time) == len(words), (len(time), len(words))
|
||||||
|
|
||||||
hyps.append(words)
|
hyps.append(words)
|
||||||
timestamps.append(time)
|
timestamps.append(time)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user