mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
Cutting off invalid frames of encoder_embed output
This commit is contained in:
parent
9629be124d
commit
fc54a99a56
@ -115,6 +115,7 @@ class DecodeStream(object):
|
||||
]
|
||||
self.num_processed_frames += (
|
||||
chunk_size
|
||||
- 2 * self.params.subsampling_factor
|
||||
- self.params.right_context * self.params.subsampling_factor
|
||||
)
|
||||
|
||||
|
@ -272,6 +272,9 @@ class Conformer(EncoderInterface):
|
||||
given {states[1].shape}."""
|
||||
|
||||
# src_key_padding_mask = make_pad_mask(lengths + left_context)
|
||||
|
||||
lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output
|
||||
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
|
||||
assert processed_lens is not None
|
||||
@ -287,6 +290,9 @@ class Conformer(EncoderInterface):
|
||||
)
|
||||
|
||||
embed = self.encoder_embed(x)
|
||||
|
||||
embed = embed[:, 1:-1, :]
|
||||
|
||||
embed, pos_enc = self.encoder_pos(embed, left_context)
|
||||
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
||||
|
||||
|
@ -357,7 +357,7 @@ def decode_one_chunk(
|
||||
|
||||
for stream in decode_streams:
|
||||
feat, feat_len = stream.get_feature_frames(
|
||||
params.decode_chunk_size * params.subsampling_factor
|
||||
(params.decode_chunk_size + 2) * params.subsampling_factor
|
||||
)
|
||||
features.append(feat)
|
||||
feature_lens.append(feat_len)
|
||||
@ -371,7 +371,7 @@ def decode_one_chunk(
|
||||
|
||||
# if T is less than 7 there will be an error in time reduction layer,
|
||||
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
|
||||
tail_length = 7 + params.right_context * params.subsampling_factor
|
||||
tail_length = 15 + params.right_context * params.subsampling_factor
|
||||
if features.size(1) < tail_length:
|
||||
feature_lens += tail_length - features.size(1)
|
||||
features = torch.cat(
|
||||
|
Loading…
x
Reference in New Issue
Block a user