Cutting off invalid frames of encoder_embed output

This commit is contained in:
pkufool 2022-06-02 14:19:08 +08:00
parent 9629be124d
commit fc54a99a56
3 changed files with 9 additions and 2 deletions

View File

@ -115,6 +115,7 @@ class DecodeStream(object):
]
self.num_processed_frames += (
chunk_size
- 2 * self.params.subsampling_factor
- self.params.right_context * self.params.subsampling_factor
)

View File

@ -272,6 +272,9 @@ class Conformer(EncoderInterface):
given {states[1].shape}."""
# src_key_padding_mask = make_pad_mask(lengths + left_context)
lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output
src_key_padding_mask = make_pad_mask(lengths)
assert processed_lens is not None
@ -287,6 +290,9 @@ class Conformer(EncoderInterface):
)
embed = self.encoder_embed(x)
embed = embed[:, 1:-1, :]
embed, pos_enc = self.encoder_pos(embed, left_context)
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)

View File

@ -357,7 +357,7 @@ def decode_one_chunk(
for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(
params.decode_chunk_size * params.subsampling_factor
(params.decode_chunk_size + 2) * params.subsampling_factor
)
features.append(feat)
feature_lens.append(feat_len)
@ -371,7 +371,7 @@ def decode_one_chunk(
# if T is less than 7 there will be an error in time reduction layer,
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
tail_length = 7 + params.right_context * params.subsampling_factor
tail_length = 15 + params.right_context * params.subsampling_factor
if features.size(1) < tail_length:
feature_lens += tail_length - features.size(1)
features = torch.cat(