Minor fixes for streaming decoding

This commit is contained in:
pkufool 2022-06-18 08:41:52 +08:00
parent 88ed814197
commit 6a937ffa40
5 changed files with 68 additions and 62 deletions

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import k2 import k2
@ -45,6 +46,7 @@ class DecodeStream(object):
assert device == decoding_graph.device assert device == decoding_graph.device
self.params = params self.params = params
self.LOG_EPS = math.log(1e-10)
self.states = initial_states self.states = initial_states
@ -52,6 +54,7 @@ class DecodeStream(object):
self.features: torch.Tensor = None self.features: torch.Tensor = None
# how many frames have been processed. (before subsampling). # how many frames have been processed. (before subsampling).
# we only modify this value in `func:get_feature_frames`. # we only modify this value in `func:get_feature_frames`.
self.num_frames = 0
self.num_processed_frames: int = 0 self.num_processed_frames: int = 0
self._done: bool = False self._done: bool = False
# The transcript of current utterance. # The transcript of current utterance.
@ -62,7 +65,11 @@ class DecodeStream(object):
# how many frames have been processed, after subsampling (i.e. a # how many frames have been processed, after subsampling (i.e. a
# cumulative sum of the second return value of # cumulative sum of the second return value of
# encoder.streaming_forward # encoder.streaming_forward
self.feature_len: int = 0 self.done_frames: int = 0
self.pad_length = (
params.right_context + 2
) * params.subsampling_factor + 3
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
self.hyp = [params.blank_id] * params.context_size self.hyp = [params.blank_id] * params.context_size
@ -86,27 +93,32 @@ class DecodeStream(object):
features: torch.Tensor, features: torch.Tensor,
) -> None: ) -> None:
"""Set features tensor of current utterance.""" """Set features tensor of current utterance."""
self.features = features assert features.dim() == 2, features.dim()
self.num_frames = features.size(0)
# tail padding
self.features = torch.nn.functional.pad(
features,
(0, 0, 0, self.pad_length),
mode="constant",
value=self.LOG_EPS,
)
def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]:
"""Consume chunk_size frames of features""" """Consume chunk_size frames of features"""
# plus 3 here because we subsampling features with # plus 3 here because we subsampling features with
# lengths = ((x_lens - 1) // 2 - 1) // 2 # lengths = ((x_lens - 1) // 2 - 1) // 2
ret_chunk_size = min( update_length = min(
self.features.size(0) - self.num_processed_frames, chunk_size + 3 self.num_frames - self.num_processed_frames, chunk_size
) )
ret_length = update_length + self.pad_length
ret_features = self.features[ ret_features = self.features[
self.num_processed_frames : self.num_processed_frames # noqa self.num_processed_frames : self.num_processed_frames # noqa
+ ret_chunk_size, + ret_length
:,
] ]
self.num_processed_frames += (
chunk_size
- 2 * self.params.subsampling_factor
- self.params.right_context * self.params.subsampling_factor
)
if self.num_processed_frames >= self.features.size(0): self.num_processed_frames += update_length
if self.num_processed_frames >= self.num_frames:
self._done = True self._done = True
return ret_features, ret_chunk_size return ret_features, ret_length

View File

@ -341,20 +341,16 @@ def decode_one_chunk(
states = [] states = []
rnnt_stream_list = [] rnnt_stream_list = []
processed_feature_lens = [] processed_lens = []
for stream in decode_streams: for stream in decode_streams:
# we plus 2 here because we will cut off one frame on each size of
# encoder_embed output as they see invalid paddings. so we need extra 2
# frames.
feat, feat_len = stream.get_feature_frames( feat, feat_len = stream.get_feature_frames(
(params.decode_chunk_size + 2 + params.right_context) params.decode_chunk_size * params.subsampling_factor
* params.subsampling_factor
) )
features.append(feat) features.append(feat)
feature_lens.append(feat_len) feature_lens.append(feat_len)
states.append(stream.states) states.append(stream.states)
processed_feature_lens.append(stream.feature_len) processed_lens.append(stream.done_frames)
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
rnnt_stream_list.append(stream.rnnt_decoding_stream) rnnt_stream_list.append(stream.rnnt_decoding_stream)
@ -388,16 +384,15 @@ def decode_one_chunk(
torch.stack([x[1] for x in states], dim=2), torch.stack([x[1] for x in states], dim=2),
] ]
processed_feature_lens = torch.tensor(processed_feature_lens, device=device) processed_lens = torch.tensor(processed_lens, device=device)
# Note: states will be modified in streaming_forward.
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward( encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
x=features, x=features,
x_lens=feature_lens, x_lens=feature_lens,
states=states, states=states,
left_context=params.left_context, left_context=params.left_context,
right_context=params.right_context, right_context=params.right_context,
processed_lens=processed_feature_lens, processed_lens=processed_lens,
) )
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
@ -411,7 +406,7 @@ def decode_one_chunk(
max_states=params.max_states, max_states=params.max_states,
) )
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
processed_lens = processed_feature_lens + encoder_out_lens processed_lens = processed_lens + encoder_out_lens
hyp_tokens = fast_beam_search( hyp_tokens = fast_beam_search(
model, encoder_out, processed_lens, decoding_streams model, encoder_out, processed_lens, decoding_streams
) )
@ -423,7 +418,7 @@ def decode_one_chunk(
finished_streams = [] finished_streams = []
for i in range(len(decode_streams)): for i in range(len(decode_streams)):
decode_streams[i].states = [states[0][i], states[1][i]] decode_streams[i].states = [states[0][i], states[1][i]]
decode_streams[i].feature_len += encoder_out_lens[i] decode_streams[i].done_frames += encoder_out_lens[i]
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
decode_streams[i].hyp = hyp_tokens[i] decode_streams[i].hyp = hyp_tokens[i]
if decode_streams[i].done: if decode_streams[i].done:
@ -469,7 +464,7 @@ def decode_dataset(
opts.frame_opts.samp_freq = 16000 opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80 opts.mel_opts.num_bins = 80
log_interval = 50 log_interval = 100
decode_results = [] decode_results = []
# Contain decode streams currently running. # Contain decode streams currently running.
@ -557,6 +552,9 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
# sort results so we can easily compare the difference between two
# recognition results
results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")

View File

@ -347,20 +347,16 @@ def decode_one_chunk(
states = [] states = []
rnnt_stream_list = [] rnnt_stream_list = []
processed_feature_lens = [] processed_lens = []
for stream in decode_streams: for stream in decode_streams:
# we plus 2 here because we will cut off one frame on each size of
# encoder_embed output as they see invalid paddings. so we need extra 2
# frames.
feat, feat_len = stream.get_feature_frames( feat, feat_len = stream.get_feature_frames(
(params.decode_chunk_size + 2 + params.right_context) params.decode_chunk_size * params.subsampling_factor
* params.subsampling_factor
) )
features.append(feat) features.append(feat)
feature_lens.append(feat_len) feature_lens.append(feat_len)
states.append(stream.states) states.append(stream.states)
processed_feature_lens.append(stream.feature_len) processed_lens.append(stream.done_frames)
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
rnnt_stream_list.append(stream.rnnt_decoding_stream) rnnt_stream_list.append(stream.rnnt_decoding_stream)
@ -369,6 +365,9 @@ def decode_one_chunk(
# if T is less than 7 there will be an error in time reduction layer, # if T is less than 7 there will be an error in time reduction layer,
# because we subsample features with ((x_len - 1) // 2 - 1) // 2 # because we subsample features with ((x_len - 1) // 2 - 1) // 2
# we plus 2 here because we will cut off one frame on each size of
# encoder_embed output as they see invalid paddings. so we need extra 2
# frames.
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
if features.size(1) < tail_length: if features.size(1) < tail_length:
feature_lens += tail_length - features.size(1) feature_lens += tail_length - features.size(1)
@ -390,7 +389,7 @@ def decode_one_chunk(
torch.stack([x[0] for x in states], dim=2), torch.stack([x[0] for x in states], dim=2),
torch.stack([x[1] for x in states], dim=2), torch.stack([x[1] for x in states], dim=2),
] ]
processed_feature_lens = torch.tensor(processed_feature_lens, device=device) processed_lens = torch.tensor(processed_lens, device=device)
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward( encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
x=features, x=features,
@ -398,7 +397,7 @@ def decode_one_chunk(
states=states, states=states,
left_context=params.left_context, left_context=params.left_context,
right_context=params.right_context, right_context=params.right_context,
processed_lens=processed_feature_lens, processed_lens=processed_lens,
) )
encoder_out = model.joiner.encoder_proj(encoder_out) encoder_out = model.joiner.encoder_proj(encoder_out)
@ -414,7 +413,7 @@ def decode_one_chunk(
max_states=params.max_states, max_states=params.max_states,
) )
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
processed_lens = processed_feature_lens + encoder_out_lens processed_lens = processed_lens + encoder_out_lens
hyp_tokens = fast_beam_search( hyp_tokens = fast_beam_search(
model, encoder_out, processed_lens, decoding_streams model, encoder_out, processed_lens, decoding_streams
) )
@ -426,7 +425,7 @@ def decode_one_chunk(
finished_streams = [] finished_streams = []
for i in range(len(decode_streams)): for i in range(len(decode_streams)):
decode_streams[i].states = [states[0][i], states[1][i]] decode_streams[i].states = [states[0][i], states[1][i]]
decode_streams[i].feature_len += encoder_out_lens[i] decode_streams[i].done_frames += encoder_out_lens[i]
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
decode_streams[i].hyp = hyp_tokens[i] decode_streams[i].hyp = hyp_tokens[i]
if decode_streams[i].done: if decode_streams[i].done:
@ -561,7 +560,10 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
store_transcripts(filename=recog_path, texts=sorted(results)) # sort results so we can easily compare the difference between two
# recognition results
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned

View File

@ -348,20 +348,16 @@ def decode_one_chunk(
states = [] states = []
rnnt_stream_list = [] rnnt_stream_list = []
processed_feature_lens = [] processed_lens = []
for stream in decode_streams: for stream in decode_streams:
# we plus 2 here because we will cut off one frame on each size of
# encoder_embed output as they see invalid paddings. so we need extra 2
# frames.
feat, feat_len = stream.get_feature_frames( feat, feat_len = stream.get_feature_frames(
(params.decode_chunk_size + 2 + params.right_context) params.decode_chunk_size * params.subsampling_factor
* params.subsampling_factor
) )
features.append(feat) features.append(feat)
feature_lens.append(feat_len) feature_lens.append(feat_len)
states.append(stream.states) states.append(stream.states)
processed_feature_lens.append(stream.feature_len) processed_lens.append(stream.done_frames)
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
rnnt_stream_list.append(stream.rnnt_decoding_stream) rnnt_stream_list.append(stream.rnnt_decoding_stream)
@ -394,7 +390,7 @@ def decode_one_chunk(
torch.stack([x[0] for x in states], dim=2), torch.stack([x[0] for x in states], dim=2),
torch.stack([x[1] for x in states], dim=2), torch.stack([x[1] for x in states], dim=2),
] ]
processed_feature_lens = torch.tensor(processed_feature_lens, device=device) processed_lens = torch.tensor(processed_lens, device=device)
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward( encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
x=features, x=features,
@ -402,7 +398,7 @@ def decode_one_chunk(
states=states, states=states,
left_context=params.left_context, left_context=params.left_context,
right_context=params.right_context, right_context=params.right_context,
processed_lens=processed_feature_lens, processed_lens=processed_lens,
) )
encoder_out = model.joiner.encoder_proj(encoder_out) encoder_out = model.joiner.encoder_proj(encoder_out)
@ -418,7 +414,7 @@ def decode_one_chunk(
max_states=params.max_states, max_states=params.max_states,
) )
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
processed_lens = processed_feature_lens + encoder_out_lens processed_lens = processed_lens + encoder_out_lens
hyp_tokens = fast_beam_search( hyp_tokens = fast_beam_search(
model, encoder_out, processed_lens, decoding_streams model, encoder_out, processed_lens, decoding_streams
) )
@ -430,7 +426,7 @@ def decode_one_chunk(
finished_streams = [] finished_streams = []
for i in range(len(decode_streams)): for i in range(len(decode_streams)):
decode_streams[i].states = [states[0][i], states[1][i]] decode_streams[i].states = [states[0][i], states[1][i]]
decode_streams[i].feature_len += encoder_out_lens[i] decode_streams[i].done_frames += encoder_out_lens[i]
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
decode_streams[i].hyp = hyp_tokens[i] decode_streams[i].hyp = hyp_tokens[i]
if decode_streams[i].done: if decode_streams[i].done:
@ -565,7 +561,8 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
store_transcripts(filename=recog_path, texts=sorted(results)) results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned

View File

@ -359,20 +359,16 @@ def decode_one_chunk(
states = [] states = []
rnnt_stream_list = [] rnnt_stream_list = []
processed_feature_lens = [] processed_lens = []
for stream in decode_streams: for stream in decode_streams:
# we plus 2 here because we will cut off one frame on each size of
# encoder_embed output as they see invalid paddings. so we need extra 2
# frames.
feat, feat_len = stream.get_feature_frames( feat, feat_len = stream.get_feature_frames(
(params.decode_chunk_size + 2 + params.right_context) params.decode_chunk_size * params.subsampling_factor
* params.subsampling_factor
) )
features.append(feat) features.append(feat)
feature_lens.append(feat_len) feature_lens.append(feat_len)
states.append(stream.states) states.append(stream.states)
processed_feature_lens.append(stream.feature_len) processed_lens.append(stream.done_frames)
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
rnnt_stream_list.append(stream.rnnt_decoding_stream) rnnt_stream_list.append(stream.rnnt_decoding_stream)
@ -405,7 +401,7 @@ def decode_one_chunk(
torch.stack([x[0] for x in states], dim=2), torch.stack([x[0] for x in states], dim=2),
torch.stack([x[1] for x in states], dim=2), torch.stack([x[1] for x in states], dim=2),
] ]
processed_feature_lens = torch.tensor(processed_feature_lens, device=device) processed_lens = torch.tensor(processed_lens, device=device)
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward( encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
x=features, x=features,
@ -413,7 +409,7 @@ def decode_one_chunk(
states=states, states=states,
left_context=params.left_context, left_context=params.left_context,
right_context=params.right_context, right_context=params.right_context,
processed_lens=processed_feature_lens, processed_lens=processed_lens,
) )
encoder_out = model.joiner.encoder_proj(encoder_out) encoder_out = model.joiner.encoder_proj(encoder_out)
@ -429,7 +425,7 @@ def decode_one_chunk(
max_states=params.max_states, max_states=params.max_states,
) )
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
processed_lens = processed_feature_lens + encoder_out_lens processed_lens = processed_lens + encoder_out_lens
hyp_tokens = fast_beam_search( hyp_tokens = fast_beam_search(
model, encoder_out, processed_lens, decoding_streams model, encoder_out, processed_lens, decoding_streams
) )
@ -441,7 +437,7 @@ def decode_one_chunk(
finished_streams = [] finished_streams = []
for i in range(len(decode_streams)): for i in range(len(decode_streams)):
decode_streams[i].states = [states[0][i], states[1][i]] decode_streams[i].states = [states[0][i], states[1][i]]
decode_streams[i].feature_len += encoder_out_lens[i] decode_streams[i].done_frames += encoder_out_lens[i]
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
decode_streams[i].hyp = hyp_tokens[i] decode_streams[i].hyp = hyp_tokens[i]
if decode_streams[i].done: if decode_streams[i].done:
@ -576,7 +572,8 @@ def save_results(
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
store_transcripts(filename=recog_path, texts=sorted(results)) results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned