mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Have nonlin_attention and attention_squeeze operate only on every other layer.
This commit is contained in:
parent
87ef4078d3
commit
9cf5d92f39
@ -450,6 +450,12 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
def remove_attention_weights(self):
|
||||
self.self_attn_weights = None
|
||||
|
||||
def remove_nonlin_attention(self):
|
||||
self.nonlin_attention_module = None
|
||||
|
||||
def remove_attention_squeeze(self):
|
||||
self.attention_squeeze = None
|
||||
|
||||
def get_bypass_scale(self):
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
return self.bypass_scale
|
||||
@ -520,14 +526,14 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
first_attn_weights = first_attn_weights * (1.0 / first_attn_weights.sum(dim=-1, keepdim=True))
|
||||
first_attn_weights = first_attn_weights.expand(3, -1, -1, -1)
|
||||
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
if (torch.jit.is_scripting() or use_self_attn) and self.nonlin_attention_module is not None:
|
||||
src = src + self.nonlin_attention_module(src,
|
||||
first_attn_weights[0:1])
|
||||
|
||||
src = src + self.feed_forward1(src)
|
||||
|
||||
# pooling module
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
if (torch.jit.is_scripting() or use_self_attn) and self.attention_squeeze is not None:
|
||||
src = src + self.attention_squeeze(src, first_attn_weights[1:2])
|
||||
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
@ -598,6 +604,10 @@ class ZipformerEncoder(nn.Module):
|
||||
cur_begin = cur_end
|
||||
if i % attention_share_layers != 0:
|
||||
self.layers[i].remove_attention_weights()
|
||||
if i % attention_share_layers == 0:
|
||||
self.layers[i].remove_nonlin_attention()
|
||||
else:
|
||||
self.layers[i].remove_attention_squeeze()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user