mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Introduce nonlin_skip_rate
This commit is contained in:
commit
75a1e05e49
@ -397,6 +397,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
# to work correctly.
|
# to work correctly.
|
||||||
layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
|
layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
|
||||||
dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0),
|
dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0),
|
||||||
|
nonlin_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (40000, 0.0), default=0),
|
||||||
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
|
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
|
||||||
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0),
|
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0),
|
||||||
bypass_max: FloatLike = 1.0,
|
bypass_max: FloatLike = 1.0,
|
||||||
@ -406,8 +407,12 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
# probability of skipping the entire layer.
|
# probability of skipping the entire layer.
|
||||||
self.layer_skip_rate = copy.deepcopy(layer_skip_rate)
|
self.layer_skip_rate = copy.deepcopy(layer_skip_rate)
|
||||||
# skip probability for dynamic modules (meaning: anything but feedforward)
|
# skip probability for dynamic modules (meaning: anything but feedforward).
|
||||||
self.dynamic_skip_rate = copy.deepcopy(dynamic_skip_rate)
|
self.dynamic_skip_rate = copy.deepcopy(dynamic_skip_rate)
|
||||||
|
# an additional skip probability that applies to NoninAttentionModule to stop it from
|
||||||
|
# contributing too much early on.
|
||||||
|
self.nonlin_skip_rate = copy.deepcopy(nonlin_skip_rate)
|
||||||
|
|
||||||
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
|
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
|
||||||
# ever becoming zero.
|
# ever becoming zero.
|
||||||
self.bypass_min = copy.deepcopy(bypass_min)
|
self.bypass_min = copy.deepcopy(bypass_min)
|
||||||
@ -534,7 +539,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True))
|
selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True))
|
||||||
selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1)
|
selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1)
|
||||||
|
|
||||||
if torch.jit.is_scripting() or use_self_attn:
|
if torch.jit.is_scripting() or (use_self_attn and random.random() >= float(self.nonlin_skip_rate)):
|
||||||
src = src + self.nonlin_attention_module(src,
|
src = src + self.nonlin_attention_module(src,
|
||||||
selected_attn_weights[0:1])
|
selected_attn_weights[0:1])
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user