mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Mask larger regions
This commit is contained in:
parent
d41b73000e
commit
cd0f48f508
@ -292,7 +292,14 @@ class Zipformer2(EncoderInterface):
|
|||||||
|
|
||||||
assert self.encoder_dim[0] == _encoder_dims0
|
assert self.encoder_dim[0] == _encoder_dims0
|
||||||
|
|
||||||
max_downsampling_factor = max(self.downsampling_factor)
|
|
||||||
|
# this is a heuristic, the regions over which we masked frames are 4 times the
|
||||||
|
# largest of the downsampling_factors. Can be any integer >= 1.
|
||||||
|
downsampling_multiple = 4
|
||||||
|
|
||||||
|
downsampling_factor = [ downsampling_multiple * i for i in self.downsampling_factor ]
|
||||||
|
|
||||||
|
max_downsampling_factor = max(downsampling_factor)
|
||||||
|
|
||||||
num_frames_max = (num_frames0 + max_downsampling_factor - 1)
|
num_frames_max = (num_frames0 + max_downsampling_factor - 1)
|
||||||
|
|
||||||
@ -318,7 +325,7 @@ class Zipformer2(EncoderInterface):
|
|||||||
|
|
||||||
feature_masks = []
|
feature_masks = []
|
||||||
for i in range(num_encoders):
|
for i in range(num_encoders):
|
||||||
ds = self.downsampling_factor[i]
|
ds = downsampling_factor[i]
|
||||||
upsample_factor = (max_downsampling_factor // ds)
|
upsample_factor = (max_downsampling_factor // ds)
|
||||||
|
|
||||||
frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor,
|
frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user