mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Have nonlin_attention and attention_squeeze operate only on every other layer.
This commit is contained in:
parent
87ef4078d3
commit
9cf5d92f39
@ -450,6 +450,12 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
def remove_attention_weights(self):
|
def remove_attention_weights(self):
|
||||||
self.self_attn_weights = None
|
self.self_attn_weights = None
|
||||||
|
|
||||||
|
def remove_nonlin_attention(self):
|
||||||
|
self.nonlin_attention_module = None
|
||||||
|
|
||||||
|
def remove_attention_squeeze(self):
|
||||||
|
self.attention_squeeze = None
|
||||||
|
|
||||||
def get_bypass_scale(self):
|
def get_bypass_scale(self):
|
||||||
if torch.jit.is_scripting() or not self.training:
|
if torch.jit.is_scripting() or not self.training:
|
||||||
return self.bypass_scale
|
return self.bypass_scale
|
||||||
@ -520,14 +526,14 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
first_attn_weights = first_attn_weights * (1.0 / first_attn_weights.sum(dim=-1, keepdim=True))
|
first_attn_weights = first_attn_weights * (1.0 / first_attn_weights.sum(dim=-1, keepdim=True))
|
||||||
first_attn_weights = first_attn_weights.expand(3, -1, -1, -1)
|
first_attn_weights = first_attn_weights.expand(3, -1, -1, -1)
|
||||||
|
|
||||||
if torch.jit.is_scripting() or use_self_attn:
|
if (torch.jit.is_scripting() or use_self_attn) and self.nonlin_attention_module is not None:
|
||||||
src = src + self.nonlin_attention_module(src,
|
src = src + self.nonlin_attention_module(src,
|
||||||
first_attn_weights[0:1])
|
first_attn_weights[0:1])
|
||||||
|
|
||||||
src = src + self.feed_forward1(src)
|
src = src + self.feed_forward1(src)
|
||||||
|
|
||||||
# pooling module
|
# pooling module
|
||||||
if torch.jit.is_scripting() or use_self_attn:
|
if (torch.jit.is_scripting() or use_self_attn) and self.attention_squeeze is not None:
|
||||||
src = src + self.attention_squeeze(src, first_attn_weights[1:2])
|
src = src + self.attention_squeeze(src, first_attn_weights[1:2])
|
||||||
|
|
||||||
if torch.jit.is_scripting() or use_self_attn:
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
@ -598,6 +604,10 @@ class ZipformerEncoder(nn.Module):
|
|||||||
cur_begin = cur_end
|
cur_begin = cur_end
|
||||||
if i % attention_share_layers != 0:
|
if i % attention_share_layers != 0:
|
||||||
self.layers[i].remove_attention_weights()
|
self.layers[i].remove_attention_weights()
|
||||||
|
if i % attention_share_layers == 0:
|
||||||
|
self.layers[i].remove_nonlin_attention()
|
||||||
|
else:
|
||||||
|
self.layers[i].remove_attention_squeeze()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user