mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Use different heads for nonlin/squeeze on alternate layers
This commit is contained in:
parent
d8185201e9
commit
85bd9859e9
@ -504,32 +504,37 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
)
|
||||
# else rely on the ones passed in
|
||||
|
||||
# use different heads for nonlin_attention_module and attention_squeeze, depending
|
||||
# whether this module has its on self_attn_weights submodule or is borrowing
|
||||
# attention weights from another one.
|
||||
head_offset = 0 if self.self_attn_weights is not None else 2
|
||||
|
||||
if self.training and random.random() < float(self.layer_skip_rate):
|
||||
# skip the layer
|
||||
return src, attn_weights
|
||||
|
||||
use_self_attn = (random.random() >= dynamic_skip_rate)
|
||||
if use_self_attn:
|
||||
first_attn_weights = attn_weights[0:3]
|
||||
selected_attn_weights = attn_weights[head_offset:head_offset+2]
|
||||
if random.random() < float(self.const_attention_rate):
|
||||
# Make attention weights constant. The intention is to
|
||||
# encourage these modules to do something similar to an
|
||||
# averaging-over-time operation.
|
||||
# only need the mask, can just use the 1st one and expand later
|
||||
first_attn_weights = first_attn_weights[0:1]
|
||||
first_attn_weights = (first_attn_weights > 0.0).to(first_attn_weights.dtype)
|
||||
first_attn_weights = first_attn_weights * (1.0 / first_attn_weights.sum(dim=-1, keepdim=True))
|
||||
first_attn_weights = first_attn_weights.expand(3, -1, -1, -1)
|
||||
selected_attn_weights = selected_attn_weights[0:1]
|
||||
selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype)
|
||||
selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True))
|
||||
selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1)
|
||||
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.nonlin_attention_module(src,
|
||||
first_attn_weights[0:1])
|
||||
selected_attn_weights[0:1])
|
||||
|
||||
src = src + self.feed_forward1(src)
|
||||
|
||||
# pooling module
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.attention_squeeze(src, first_attn_weights[1:2])
|
||||
src = src + self.attention_squeeze(src, selected_attn_weights[1:2])
|
||||
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.self_attn(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user