mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Minor fixes for streaming decoding
This commit is contained in:
parent
88ed814197
commit
6a937ffa40
@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
@ -45,6 +46,7 @@ class DecodeStream(object):
|
||||
assert device == decoding_graph.device
|
||||
|
||||
self.params = params
|
||||
self.LOG_EPS = math.log(1e-10)
|
||||
|
||||
self.states = initial_states
|
||||
|
||||
@ -52,6 +54,7 @@ class DecodeStream(object):
|
||||
self.features: torch.Tensor = None
|
||||
# how many frames have been processed. (before subsampling).
|
||||
# we only modify this value in `func:get_feature_frames`.
|
||||
self.num_frames = 0
|
||||
self.num_processed_frames: int = 0
|
||||
self._done: bool = False
|
||||
# The transcript of current utterance.
|
||||
@ -62,7 +65,11 @@ class DecodeStream(object):
|
||||
# how many frames have been processed, after subsampling (i.e. a
|
||||
# cumulative sum of the second return value of
|
||||
# encoder.streaming_forward
|
||||
self.feature_len: int = 0
|
||||
self.done_frames: int = 0
|
||||
|
||||
self.pad_length = (
|
||||
params.right_context + 2
|
||||
) * params.subsampling_factor + 3
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
self.hyp = [params.blank_id] * params.context_size
|
||||
@ -86,27 +93,32 @@ class DecodeStream(object):
|
||||
features: torch.Tensor,
|
||||
) -> None:
|
||||
"""Set features tensor of current utterance."""
|
||||
self.features = features
|
||||
assert features.dim() == 2, features.dim()
|
||||
self.num_frames = features.size(0)
|
||||
# tail padding
|
||||
self.features = torch.nn.functional.pad(
|
||||
features,
|
||||
(0, 0, 0, self.pad_length),
|
||||
mode="constant",
|
||||
value=self.LOG_EPS,
|
||||
)
|
||||
|
||||
def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]:
|
||||
"""Consume chunk_size frames of features"""
|
||||
# plus 3 here because we subsampling features with
|
||||
# lengths = ((x_lens - 1) // 2 - 1) // 2
|
||||
ret_chunk_size = min(
|
||||
self.features.size(0) - self.num_processed_frames, chunk_size + 3
|
||||
update_length = min(
|
||||
self.num_frames - self.num_processed_frames, chunk_size
|
||||
)
|
||||
ret_length = update_length + self.pad_length
|
||||
|
||||
ret_features = self.features[
|
||||
self.num_processed_frames : self.num_processed_frames # noqa
|
||||
+ ret_chunk_size,
|
||||
:,
|
||||
+ ret_length
|
||||
]
|
||||
self.num_processed_frames += (
|
||||
chunk_size
|
||||
- 2 * self.params.subsampling_factor
|
||||
- self.params.right_context * self.params.subsampling_factor
|
||||
)
|
||||
|
||||
if self.num_processed_frames >= self.features.size(0):
|
||||
self.num_processed_frames += update_length
|
||||
if self.num_processed_frames >= self.num_frames:
|
||||
self._done = True
|
||||
|
||||
return ret_features, ret_chunk_size
|
||||
return ret_features, ret_length
|
||||
|
@ -341,20 +341,16 @@ def decode_one_chunk(
|
||||
states = []
|
||||
|
||||
rnnt_stream_list = []
|
||||
processed_feature_lens = []
|
||||
processed_lens = []
|
||||
|
||||
for stream in decode_streams:
|
||||
# we plus 2 here because we will cut off one frame on each size of
|
||||
# encoder_embed output as they see invalid paddings. so we need extra 2
|
||||
# frames.
|
||||
feat, feat_len = stream.get_feature_frames(
|
||||
(params.decode_chunk_size + 2 + params.right_context)
|
||||
* params.subsampling_factor
|
||||
params.decode_chunk_size * params.subsampling_factor
|
||||
)
|
||||
features.append(feat)
|
||||
feature_lens.append(feat_len)
|
||||
states.append(stream.states)
|
||||
processed_feature_lens.append(stream.feature_len)
|
||||
processed_lens.append(stream.done_frames)
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
rnnt_stream_list.append(stream.rnnt_decoding_stream)
|
||||
|
||||
@ -388,16 +384,15 @@ def decode_one_chunk(
|
||||
torch.stack([x[1] for x in states], dim=2),
|
||||
]
|
||||
|
||||
processed_feature_lens = torch.tensor(processed_feature_lens, device=device)
|
||||
processed_lens = torch.tensor(processed_lens, device=device)
|
||||
|
||||
# Note: states will be modified in streaming_forward.
|
||||
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
|
||||
x=features,
|
||||
x_lens=feature_lens,
|
||||
states=states,
|
||||
left_context=params.left_context,
|
||||
right_context=params.right_context,
|
||||
processed_lens=processed_feature_lens,
|
||||
processed_lens=processed_lens,
|
||||
)
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
@ -411,7 +406,7 @@ def decode_one_chunk(
|
||||
max_states=params.max_states,
|
||||
)
|
||||
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
|
||||
processed_lens = processed_feature_lens + encoder_out_lens
|
||||
processed_lens = processed_lens + encoder_out_lens
|
||||
hyp_tokens = fast_beam_search(
|
||||
model, encoder_out, processed_lens, decoding_streams
|
||||
)
|
||||
@ -423,7 +418,7 @@ def decode_one_chunk(
|
||||
finished_streams = []
|
||||
for i in range(len(decode_streams)):
|
||||
decode_streams[i].states = [states[0][i], states[1][i]]
|
||||
decode_streams[i].feature_len += encoder_out_lens[i]
|
||||
decode_streams[i].done_frames += encoder_out_lens[i]
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decode_streams[i].hyp = hyp_tokens[i]
|
||||
if decode_streams[i].done:
|
||||
@ -469,7 +464,7 @@ def decode_dataset(
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
log_interval = 50
|
||||
log_interval = 100
|
||||
|
||||
decode_results = []
|
||||
# Contain decode streams currently running.
|
||||
@ -557,6 +552,9 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
# sort results so we can easily compare the difference between two
|
||||
# recognition results
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
|
@ -347,20 +347,16 @@ def decode_one_chunk(
|
||||
states = []
|
||||
|
||||
rnnt_stream_list = []
|
||||
processed_feature_lens = []
|
||||
processed_lens = []
|
||||
|
||||
for stream in decode_streams:
|
||||
# we plus 2 here because we will cut off one frame on each size of
|
||||
# encoder_embed output as they see invalid paddings. so we need extra 2
|
||||
# frames.
|
||||
feat, feat_len = stream.get_feature_frames(
|
||||
(params.decode_chunk_size + 2 + params.right_context)
|
||||
* params.subsampling_factor
|
||||
params.decode_chunk_size * params.subsampling_factor
|
||||
)
|
||||
features.append(feat)
|
||||
feature_lens.append(feat_len)
|
||||
states.append(stream.states)
|
||||
processed_feature_lens.append(stream.feature_len)
|
||||
processed_lens.append(stream.done_frames)
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
rnnt_stream_list.append(stream.rnnt_decoding_stream)
|
||||
|
||||
@ -369,6 +365,9 @@ def decode_one_chunk(
|
||||
|
||||
# if T is less than 7 there will be an error in time reduction layer,
|
||||
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
|
||||
# we plus 2 here because we will cut off one frame on each size of
|
||||
# encoder_embed output as they see invalid paddings. so we need extra 2
|
||||
# frames.
|
||||
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
|
||||
if features.size(1) < tail_length:
|
||||
feature_lens += tail_length - features.size(1)
|
||||
@ -390,7 +389,7 @@ def decode_one_chunk(
|
||||
torch.stack([x[0] for x in states], dim=2),
|
||||
torch.stack([x[1] for x in states], dim=2),
|
||||
]
|
||||
processed_feature_lens = torch.tensor(processed_feature_lens, device=device)
|
||||
processed_lens = torch.tensor(processed_lens, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
|
||||
x=features,
|
||||
@ -398,7 +397,7 @@ def decode_one_chunk(
|
||||
states=states,
|
||||
left_context=params.left_context,
|
||||
right_context=params.right_context,
|
||||
processed_lens=processed_feature_lens,
|
||||
processed_lens=processed_lens,
|
||||
)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
@ -414,7 +413,7 @@ def decode_one_chunk(
|
||||
max_states=params.max_states,
|
||||
)
|
||||
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
|
||||
processed_lens = processed_feature_lens + encoder_out_lens
|
||||
processed_lens = processed_lens + encoder_out_lens
|
||||
hyp_tokens = fast_beam_search(
|
||||
model, encoder_out, processed_lens, decoding_streams
|
||||
)
|
||||
@ -426,7 +425,7 @@ def decode_one_chunk(
|
||||
finished_streams = []
|
||||
for i in range(len(decode_streams)):
|
||||
decode_streams[i].states = [states[0][i], states[1][i]]
|
||||
decode_streams[i].feature_len += encoder_out_lens[i]
|
||||
decode_streams[i].done_frames += encoder_out_lens[i]
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decode_streams[i].hyp = hyp_tokens[i]
|
||||
if decode_streams[i].done:
|
||||
@ -561,7 +560,10 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
store_transcripts(filename=recog_path, texts=sorted(results))
|
||||
# sort results so we can easily compare the difference between two
|
||||
# recognition results
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
|
@ -348,20 +348,16 @@ def decode_one_chunk(
|
||||
states = []
|
||||
|
||||
rnnt_stream_list = []
|
||||
processed_feature_lens = []
|
||||
processed_lens = []
|
||||
|
||||
for stream in decode_streams:
|
||||
# we plus 2 here because we will cut off one frame on each size of
|
||||
# encoder_embed output as they see invalid paddings. so we need extra 2
|
||||
# frames.
|
||||
feat, feat_len = stream.get_feature_frames(
|
||||
(params.decode_chunk_size + 2 + params.right_context)
|
||||
* params.subsampling_factor
|
||||
params.decode_chunk_size * params.subsampling_factor
|
||||
)
|
||||
features.append(feat)
|
||||
feature_lens.append(feat_len)
|
||||
states.append(stream.states)
|
||||
processed_feature_lens.append(stream.feature_len)
|
||||
processed_lens.append(stream.done_frames)
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
rnnt_stream_list.append(stream.rnnt_decoding_stream)
|
||||
|
||||
@ -394,7 +390,7 @@ def decode_one_chunk(
|
||||
torch.stack([x[0] for x in states], dim=2),
|
||||
torch.stack([x[1] for x in states], dim=2),
|
||||
]
|
||||
processed_feature_lens = torch.tensor(processed_feature_lens, device=device)
|
||||
processed_lens = torch.tensor(processed_lens, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
|
||||
x=features,
|
||||
@ -402,7 +398,7 @@ def decode_one_chunk(
|
||||
states=states,
|
||||
left_context=params.left_context,
|
||||
right_context=params.right_context,
|
||||
processed_lens=processed_feature_lens,
|
||||
processed_lens=processed_lens,
|
||||
)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
@ -418,7 +414,7 @@ def decode_one_chunk(
|
||||
max_states=params.max_states,
|
||||
)
|
||||
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
|
||||
processed_lens = processed_feature_lens + encoder_out_lens
|
||||
processed_lens = processed_lens + encoder_out_lens
|
||||
hyp_tokens = fast_beam_search(
|
||||
model, encoder_out, processed_lens, decoding_streams
|
||||
)
|
||||
@ -430,7 +426,7 @@ def decode_one_chunk(
|
||||
finished_streams = []
|
||||
for i in range(len(decode_streams)):
|
||||
decode_streams[i].states = [states[0][i], states[1][i]]
|
||||
decode_streams[i].feature_len += encoder_out_lens[i]
|
||||
decode_streams[i].done_frames += encoder_out_lens[i]
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decode_streams[i].hyp = hyp_tokens[i]
|
||||
if decode_streams[i].done:
|
||||
@ -565,7 +561,8 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
store_transcripts(filename=recog_path, texts=sorted(results))
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
|
@ -359,20 +359,16 @@ def decode_one_chunk(
|
||||
states = []
|
||||
|
||||
rnnt_stream_list = []
|
||||
processed_feature_lens = []
|
||||
processed_lens = []
|
||||
|
||||
for stream in decode_streams:
|
||||
# we plus 2 here because we will cut off one frame on each size of
|
||||
# encoder_embed output as they see invalid paddings. so we need extra 2
|
||||
# frames.
|
||||
feat, feat_len = stream.get_feature_frames(
|
||||
(params.decode_chunk_size + 2 + params.right_context)
|
||||
* params.subsampling_factor
|
||||
params.decode_chunk_size * params.subsampling_factor
|
||||
)
|
||||
features.append(feat)
|
||||
feature_lens.append(feat_len)
|
||||
states.append(stream.states)
|
||||
processed_feature_lens.append(stream.feature_len)
|
||||
processed_lens.append(stream.done_frames)
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
rnnt_stream_list.append(stream.rnnt_decoding_stream)
|
||||
|
||||
@ -405,7 +401,7 @@ def decode_one_chunk(
|
||||
torch.stack([x[0] for x in states], dim=2),
|
||||
torch.stack([x[1] for x in states], dim=2),
|
||||
]
|
||||
processed_feature_lens = torch.tensor(processed_feature_lens, device=device)
|
||||
processed_lens = torch.tensor(processed_lens, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
|
||||
x=features,
|
||||
@ -413,7 +409,7 @@ def decode_one_chunk(
|
||||
states=states,
|
||||
left_context=params.left_context,
|
||||
right_context=params.right_context,
|
||||
processed_lens=processed_feature_lens,
|
||||
processed_lens=processed_lens,
|
||||
)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
@ -429,7 +425,7 @@ def decode_one_chunk(
|
||||
max_states=params.max_states,
|
||||
)
|
||||
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
|
||||
processed_lens = processed_feature_lens + encoder_out_lens
|
||||
processed_lens = processed_lens + encoder_out_lens
|
||||
hyp_tokens = fast_beam_search(
|
||||
model, encoder_out, processed_lens, decoding_streams
|
||||
)
|
||||
@ -441,7 +437,7 @@ def decode_one_chunk(
|
||||
finished_streams = []
|
||||
for i in range(len(decode_streams)):
|
||||
decode_streams[i].states = [states[0][i], states[1][i]]
|
||||
decode_streams[i].feature_len += encoder_out_lens[i]
|
||||
decode_streams[i].done_frames += encoder_out_lens[i]
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decode_streams[i].hyp = hyp_tokens[i]
|
||||
if decode_streams[i].done:
|
||||
@ -576,7 +572,8 @@ def save_results(
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
store_transcripts(filename=recog_path, texts=sorted(results))
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
|
Loading…
x
Reference in New Issue
Block a user