Refactor getting timestamps in fsa-based decoding (#660)

* refactor getting timestamps for fsa-based decoding

* fix doc

* fix bug
This commit is contained in:
Zengwei Yao 2022-11-05 22:36:06 +08:00 committed by GitHub
parent 3600ce1b5f
commit 32de2766d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 41 additions and 47 deletions

View File

@ -487,7 +487,7 @@ def decode_one_batch(
) )
tokens.extend(res.tokens) tokens.extend(res.tokens)
timestamps.extend(res.timestamps) timestamps.extend(res.timestamps)
res = DecodingResults(tokens=tokens, timestamps=timestamps) res = DecodingResults(hyps=tokens, timestamps=timestamps)
hyps, timestamps = parse_hyp_and_timestamp( hyps, timestamps = parse_hyp_and_timestamp(
decoding_method=params.decoding_method, decoding_method=params.decoding_method,

View File

@ -598,7 +598,7 @@ def greedy_search(
return hyp return hyp
else: else:
return DecodingResults( return DecodingResults(
tokens=[hyp], hyps=[hyp],
timestamps=[timestamp], timestamps=[timestamp],
) )
@ -712,7 +712,7 @@ def greedy_search_batch(
return ans return ans
else: else:
return DecodingResults( return DecodingResults(
tokens=ans, hyps=ans,
timestamps=ans_timestamps, timestamps=ans_timestamps,
) )
@ -1049,7 +1049,7 @@ def modified_beam_search(
return ans return ans
else: else:
return DecodingResults( return DecodingResults(
tokens=ans, hyps=ans,
timestamps=ans_timestamps, timestamps=ans_timestamps,
) )
@ -1176,7 +1176,7 @@ def _deprecated_modified_beam_search(
if not return_timestamps: if not return_timestamps:
return ys return ys
else: else:
return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp]) return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp])
def beam_search( def beam_search(
@ -1336,7 +1336,7 @@ def beam_search(
if not return_timestamps: if not return_timestamps:
return ys return ys
else: else:
return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp]) return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp])
def fast_beam_search_with_nbest_rescoring( def fast_beam_search_with_nbest_rescoring(

View File

@ -531,7 +531,7 @@ def decode_one_batch(
) )
tokens.extend(res.tokens) tokens.extend(res.tokens)
timestamps.extend(res.timestamps) timestamps.extend(res.timestamps)
res = DecodingResults(tokens=tokens, timestamps=timestamps) res = DecodingResults(hyps=tokens, timestamps=timestamps)
hyps, timestamps = parse_hyp_and_timestamp( hyps, timestamps = parse_hyp_and_timestamp(
decoding_method=params.decoding_method, decoding_method=params.decoding_method,

View File

@ -251,33 +251,20 @@ def get_texts(
@dataclass @dataclass
class DecodingResults: class DecodingResults:
# Decoded token IDs for each utterance in the batch
tokens: List[List[int]]
# timestamps[i][k] contains the frame number on which tokens[i][k] # timestamps[i][k] contains the frame number on which tokens[i][k]
# is decoded # is decoded
timestamps: List[List[int]] timestamps: List[List[int]]
# hyps[i] is the recognition results, i.e., word IDs # hyps[i] is the recognition results, i.e., word IDs or token IDs
# for the i-th utterance with fast_beam_search_nbest_LG. # for the i-th utterance with fast_beam_search_nbest_LG.
hyps: Union[List[List[int]], k2.RaggedTensor] = None hyps: Union[List[List[int]], k2.RaggedTensor]
def get_tokens_and_timestamps(labels: List[int]) -> Tuple[List[int], List[int]]:
tokens = []
timestamps = []
for i, v in enumerate(labels):
if v != 0:
tokens.append(v)
timestamps.append(i)
return tokens, timestamps
def get_texts_with_timestamp( def get_texts_with_timestamp(
best_paths: k2.Fsa, return_ragged: bool = False best_paths: k2.Fsa, return_ragged: bool = False
) -> DecodingResults: ) -> DecodingResults:
"""Extract the texts (as word IDs) and timestamps from the best-path FSAs. """Extract the texts (as word IDs) and timestamps (as frame indexes)
from the best-path FSAs.
Args: Args:
best_paths: best_paths:
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
@ -292,11 +279,18 @@ def get_texts_with_timestamp(
decoded. decoded.
""" """
if isinstance(best_paths.aux_labels, k2.RaggedTensor): if isinstance(best_paths.aux_labels, k2.RaggedTensor):
all_aux_shape = (
best_paths.arcs.shape()
.remove_axis(1)
.compose(best_paths.aux_labels.shape)
)
all_aux_labels = k2.RaggedTensor(
all_aux_shape, best_paths.aux_labels.values
)
# remove 0's and -1's. # remove 0's and -1's.
aux_labels = best_paths.aux_labels.remove_values_leq(0) aux_labels = best_paths.aux_labels.remove_values_leq(0)
# TODO: change arcs.shape() to arcs.shape # TODO: change arcs.shape() to arcs.shape
aux_shape = best_paths.arcs.shape().compose(aux_labels.shape) aux_shape = best_paths.arcs.shape().compose(aux_labels.shape)
# remove the states and arcs axes. # remove the states and arcs axes.
aux_shape = aux_shape.remove_axis(1) aux_shape = aux_shape.remove_axis(1)
aux_shape = aux_shape.remove_axis(1) aux_shape = aux_shape.remove_axis(1)
@ -304,26 +298,26 @@ def get_texts_with_timestamp(
else: else:
# remove axis corresponding to states. # remove axis corresponding to states.
aux_shape = best_paths.arcs.shape().remove_axis(1) aux_shape = best_paths.arcs.shape().remove_axis(1)
aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels) all_aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels)
# remove 0's and -1's. # remove 0's and -1's.
aux_labels = aux_labels.remove_values_leq(0) aux_labels = all_aux_labels.remove_values_leq(0)
assert aux_labels.num_axes == 2 assert aux_labels.num_axes == 2
labels_shape = best_paths.arcs.shape().remove_axis(1)
labels_list = k2.RaggedTensor(
labels_shape, best_paths.labels.contiguous()
).tolist()
tokens = []
timestamps = [] timestamps = []
for labels in labels_list: if isinstance(best_paths.aux_labels, k2.RaggedTensor):
token, time = get_tokens_and_timestamps(labels[:-1]) for p in range(all_aux_labels.dim0):
tokens.append(token) time = []
timestamps.append(time) for i, arc in enumerate(all_aux_labels[p].tolist()):
if len(arc) == 1 and arc[0] > 0:
time.append(i)
timestamps.append(time)
else:
for labels in all_aux_labels.tolist():
time = [i for i, v in enumerate(labels) if v > 0]
timestamps.append(time)
return DecodingResults( return DecodingResults(
tokens=tokens,
timestamps=timestamps, timestamps=timestamps,
hyps=aux_labels if return_ragged else aux_labels.tolist(), hyps=aux_labels if return_ragged else aux_labels.tolist(),
) )
@ -1399,8 +1393,8 @@ def parse_hyp_and_timestamp(
hyps = [] hyps = []
timestamps = [] timestamps = []
N = len(res.tokens) N = len(res.hyps)
assert len(res.timestamps) == N assert len(res.timestamps) == N, (len(res.timestamps), N)
use_word_table = False use_word_table = False
if ( if (
decoding_method == "fast_beam_search_nbest_LG" decoding_method == "fast_beam_search_nbest_LG"
@ -1410,16 +1404,16 @@ def parse_hyp_and_timestamp(
use_word_table = True use_word_table = True
for i in range(N): for i in range(N):
tokens = sp.id_to_piece(res.tokens[i])
if use_word_table:
words = [word_table[i] for i in res.hyps[i]]
else:
words = sp.decode_pieces(tokens).split()
time = convert_timestamp( time = convert_timestamp(
res.timestamps[i], subsampling_factor, frame_shift_ms res.timestamps[i], subsampling_factor, frame_shift_ms
) )
time = parse_timestamp(tokens, time) if use_word_table:
assert len(time) == len(words), (tokens, words) words = [word_table[i] for i in res.hyps[i]]
else:
tokens = sp.id_to_piece(res.hyps[i])
words = sp.decode_pieces(tokens).split()
time = parse_timestamp(tokens, time)
assert len(time) == len(words), (len(time), len(words))
hyps.append(words) hyps.append(words)
timestamps.append(time) timestamps.append(time)