Remove nonlin_skip_rate, introduce conv_skip_rate.
This commit is contained in:
parent
1506b83c7b
commit
864ff96322
@ -398,8 +398,8 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
|
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
|
||||||
# to work correctly.
|
# to work correctly.
|
||||||
layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
|
layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
|
||||||
dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0),
|
attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0),
|
||||||
conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (20000, 0.0), default=0),
|
conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (16000, 0.0), default=0),
|
||||||
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
|
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
|
||||||
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0),
|
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0),
|
||||||
bypass_max: FloatLike = 1.0,
|
bypass_max: FloatLike = 1.0,
|
||||||
@ -410,7 +410,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
# probability of skipping the entire layer.
|
# probability of skipping the entire layer.
|
||||||
self.layer_skip_rate = copy.deepcopy(layer_skip_rate)
|
self.layer_skip_rate = copy.deepcopy(layer_skip_rate)
|
||||||
# skip probability for dynamic modules (meaning: anything but feedforward).
|
# skip probability for dynamic modules (meaning: anything but feedforward).
|
||||||
self.dynamic_skip_rate = copy.deepcopy(dynamic_skip_rate)
|
self.attention_skip_rate = copy.deepcopy(attention_skip_rate)
|
||||||
# an additional skip probability that applies to ConvModule to stop it from
|
# an additional skip probability that applies to ConvModule to stop it from
|
||||||
# contributing too much early on.
|
# contributing too much early on.
|
||||||
self.conv_skip_rate = copy.deepcopy(conv_skip_rate)
|
self.conv_skip_rate = copy.deepcopy(conv_skip_rate)
|
||||||
@ -507,7 +507,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
src_orig = src
|
src_orig = src
|
||||||
|
|
||||||
# dropout rate for non-feedforward submodules
|
# dropout rate for non-feedforward submodules
|
||||||
dynamic_skip_rate = float(self.dynamic_skip_rate) if self.training else 0.0
|
attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0
|
||||||
|
|
||||||
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
||||||
if self.self_attn_weights is not None:
|
if self.self_attn_weights is not None:
|
||||||
@ -528,7 +528,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
# skip the layer
|
# skip the layer
|
||||||
return src, attn_weights
|
return src, attn_weights
|
||||||
|
|
||||||
use_self_attn = (random.random() >= dynamic_skip_rate)
|
use_self_attn = (random.random() >= attention_skip_rate)
|
||||||
if use_self_attn:
|
if use_self_attn:
|
||||||
selected_attn_weights = attn_weights[head_offset:head_offset+2]
|
selected_attn_weights = attn_weights[head_offset:head_offset+2]
|
||||||
if random.random() < float(self.const_attention_rate):
|
if random.random() < float(self.const_attention_rate):
|
||||||
@ -555,7 +555,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
src = src + self.self_attn(
|
src = src + self.self_attn(
|
||||||
src, attn_weights)
|
src, attn_weights)
|
||||||
|
|
||||||
if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate + float(self.conv_skip_rate):
|
if torch.jit.is_scripting() or random.random() >= float(self.conv_skip_rate):
|
||||||
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
||||||
|
|
||||||
src = src + self.feed_forward2(src)
|
src = src + self.feed_forward2(src)
|
||||||
|
|||||||
Reference in New Issue
Block a user