Cutting off invalid frames of encoder_embed output

This commit is contained in:
pkufool 2022-06-02 14:19:08 +08:00
parent 9629be124d
commit fc54a99a56
3 changed files with 9 additions and 2 deletions

View File

@ -115,6 +115,7 @@ class DecodeStream(object):
] ]
self.num_processed_frames += ( self.num_processed_frames += (
chunk_size chunk_size
- 2 * self.params.subsampling_factor
- self.params.right_context * self.params.subsampling_factor - self.params.right_context * self.params.subsampling_factor
) )

View File

@ -272,6 +272,9 @@ class Conformer(EncoderInterface):
given {states[1].shape}.""" given {states[1].shape}."""
# src_key_padding_mask = make_pad_mask(lengths + left_context) # src_key_padding_mask = make_pad_mask(lengths + left_context)
lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output
src_key_padding_mask = make_pad_mask(lengths) src_key_padding_mask = make_pad_mask(lengths)
assert processed_lens is not None assert processed_lens is not None
@ -287,6 +290,9 @@ class Conformer(EncoderInterface):
) )
embed = self.encoder_embed(x) embed = self.encoder_embed(x)
embed = embed[:, 1:-1, :]
embed, pos_enc = self.encoder_pos(embed, left_context) embed, pos_enc = self.encoder_pos(embed, left_context)
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F) embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)

View File

@ -357,7 +357,7 @@ def decode_one_chunk(
for stream in decode_streams: for stream in decode_streams:
feat, feat_len = stream.get_feature_frames( feat, feat_len = stream.get_feature_frames(
params.decode_chunk_size * params.subsampling_factor (params.decode_chunk_size + 2) * params.subsampling_factor
) )
features.append(feat) features.append(feat)
feature_lens.append(feat_len) feature_lens.append(feat_len)
@ -371,7 +371,7 @@ def decode_one_chunk(
# if T is less than 7 there will be an error in time reduction layer, # if T is less than 7 there will be an error in time reduction layer,
# because we subsample features with ((x_len - 1) // 2 - 1) // 2 # because we subsample features with ((x_len - 1) // 2 - 1) // 2
tail_length = 7 + params.right_context * params.subsampling_factor tail_length = 15 + params.right_context * params.subsampling_factor
if features.size(1) < tail_length: if features.size(1) < tail_length:
feature_lens += tail_length - features.size(1) feature_lens += tail_length - features.size(1)
features = torch.cat( features = torch.cat(