From fc54a99a569deac50f203af6d5d90bfeb33067df Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 2 Jun 2022 14:19:08 +0800 Subject: [PATCH] Cutting off invalid frames of encoder_embed output --- .../ASR/pruned_transducer_stateless/decode_stream.py | 1 + .../ASR/pruned_transducer_stateless2/conformer.py | 6 ++++++ .../ASR/pruned_transducer_stateless2/streaming_decode.py | 4 ++-- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py index a77d1b141..9263ac449 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py @@ -115,6 +115,7 @@ class DecodeStream(object): ] self.num_processed_frames += ( chunk_size + - 2 * self.params.subsampling_factor - self.params.right_context * self.params.subsampling_factor ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index cb817eac0..f72c63036 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -272,6 +272,9 @@ class Conformer(EncoderInterface): given {states[1].shape}.""" # src_key_padding_mask = make_pad_mask(lengths + left_context) + + lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output + src_key_padding_mask = make_pad_mask(lengths) assert processed_lens is not None @@ -287,6 +290,9 @@ class Conformer(EncoderInterface): ) embed = self.encoder_embed(x) + + embed = embed[:, 1:-1, :] + embed, pos_enc = self.encoder_pos(embed, left_context) embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index c61975ccc..81643e3c4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -357,7 +357,7 @@ def decode_one_chunk( for stream in decode_streams: feat, feat_len = stream.get_feature_frames( - params.decode_chunk_size * params.subsampling_factor + (params.decode_chunk_size + 2) * params.subsampling_factor ) features.append(feat) feature_lens.append(feat_len) @@ -371,7 +371,7 @@ def decode_one_chunk( # if T is less than 7 there will be an error in time reduction layer, # because we subsample features with ((x_len - 1) // 2 - 1) // 2 - tail_length = 7 + params.right_context * params.subsampling_factor + tail_length = 15 + params.right_context * params.subsampling_factor if features.size(1) < tail_length: feature_lens += tail_length - features.size(1) features = torch.cat(