Have nonlin_attention and attention_squeeze operate only on every other layer.

This commit is contained in:
Daniel Povey 2022-11-28 16:24:24 +08:00
parent 87ef4078d3
commit 9cf5d92f39

View File

@ -450,6 +450,12 @@ class ZipformerEncoderLayer(nn.Module):
def remove_attention_weights(self): def remove_attention_weights(self):
self.self_attn_weights = None self.self_attn_weights = None
def remove_nonlin_attention(self):
self.nonlin_attention_module = None
def remove_attention_squeeze(self):
self.attention_squeeze = None
def get_bypass_scale(self): def get_bypass_scale(self):
if torch.jit.is_scripting() or not self.training: if torch.jit.is_scripting() or not self.training:
return self.bypass_scale return self.bypass_scale
@ -520,14 +526,14 @@ class ZipformerEncoderLayer(nn.Module):
first_attn_weights = first_attn_weights * (1.0 / first_attn_weights.sum(dim=-1, keepdim=True)) first_attn_weights = first_attn_weights * (1.0 / first_attn_weights.sum(dim=-1, keepdim=True))
first_attn_weights = first_attn_weights.expand(3, -1, -1, -1) first_attn_weights = first_attn_weights.expand(3, -1, -1, -1)
if torch.jit.is_scripting() or use_self_attn: if (torch.jit.is_scripting() or use_self_attn) and self.nonlin_attention_module is not None:
src = src + self.nonlin_attention_module(src, src = src + self.nonlin_attention_module(src,
first_attn_weights[0:1]) first_attn_weights[0:1])
src = src + self.feed_forward1(src) src = src + self.feed_forward1(src)
# pooling module # pooling module
if torch.jit.is_scripting() or use_self_attn: if (torch.jit.is_scripting() or use_self_attn) and self.attention_squeeze is not None:
src = src + self.attention_squeeze(src, first_attn_weights[1:2]) src = src + self.attention_squeeze(src, first_attn_weights[1:2])
if torch.jit.is_scripting() or use_self_attn: if torch.jit.is_scripting() or use_self_attn:
@ -598,6 +604,10 @@ class ZipformerEncoder(nn.Module):
cur_begin = cur_end cur_begin = cur_end
if i % attention_share_layers != 0: if i % attention_share_layers != 0:
self.layers[i].remove_attention_weights() self.layers[i].remove_attention_weights()
if i % attention_share_layers == 0:
self.layers[i].remove_nonlin_attention()
else:
self.layers[i].remove_attention_squeeze()
def forward( def forward(
self, self,