mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Change for mem efficiency
This commit is contained in:
parent
1d0252d420
commit
36e49a8d61
@ -480,13 +480,14 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
|
||||
first_attn_weights = attn_weights[0:3]
|
||||
first_attn_weights = attn_weights[0:1]
|
||||
if random.random() < float(self.const_attention_rate):
|
||||
# Make attention weights constant. The intention is to
|
||||
# encourage these modules to do something similar to an
|
||||
# averaging-over-time operation.
|
||||
first_attn_weights = (first_attn_weights > 0.0).to(first_attn_weights.dtype)
|
||||
first_attn_weights = first_attn_weights * (1.0 / first_attn_weights.sum(dim=-1, keepdim=True))
|
||||
first_attn_weights = first_attn_weights.expand(3, -1, -1, -1)
|
||||
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.nonlin_attention_module(src,
|
||||
@ -1300,7 +1301,7 @@ class AttentionSqueeze(nn.Module):
|
||||
bottleneck_dim, channel_dim=-1,
|
||||
min_positive=0.05, max_positive=0.95,
|
||||
min_abs=0.05,
|
||||
max_abs=2.0,
|
||||
max_abs=10.0,
|
||||
max_factor=0.02,
|
||||
min_prob=0.1,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user