mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Refactor bypass, and add bypass in the middle of the layer.
This commit is contained in:
parent
b526f3af00
commit
67fcae95a8
@ -207,7 +207,7 @@ class Zipformer2(EncoderInterface):
|
||||
dropout=dropout,
|
||||
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
||||
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
|
||||
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
|
||||
final_layerdrop_rate=0.02 * (downsampling_factor[i] ** 0.5),
|
||||
)
|
||||
|
||||
if downsampling_factor[i] != 1:
|
||||
@ -519,23 +519,21 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
dropout: FloatLike = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
causal: bool = False,
|
||||
# layer_skip_rate will be overwritten to change warmup begin and end times.
|
||||
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
|
||||
# to work correctly.
|
||||
layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
|
||||
attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
||||
conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
||||
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
|
||||
ff2_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)),
|
||||
ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)),
|
||||
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0),
|
||||
bypass_max: FloatLike = 1.0,
|
||||
) -> None:
|
||||
super(Zipformer2EncoderLayer, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
# probability of skipping the entire layer.
|
||||
self.layer_skip_rate = copy.deepcopy(layer_skip_rate)
|
||||
# self.bypass implements layer skipping as well as bypass; see its default values.
|
||||
self.bypass = BypassModule(embed_dim)
|
||||
# bypass_mid is bypass used in the middle of the layer.
|
||||
self.bypass_mid = BypassModule(embed_dim)
|
||||
|
||||
|
||||
# skip probability for dynamic modules (meaning: anything but feedforward).
|
||||
self.attention_skip_rate = copy.deepcopy(attention_skip_rate)
|
||||
# an additional skip probability that applies to ConvModule to stop it from
|
||||
@ -547,10 +545,6 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate)
|
||||
self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate)
|
||||
|
||||
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
|
||||
# ever becoming zero.
|
||||
self.bypass_min = copy.deepcopy(bypass_min)
|
||||
self.bypass_max = copy.deepcopy(bypass_max)
|
||||
self.const_attention_rate = copy.deepcopy(const_attention_rate)
|
||||
|
||||
self.self_attn_weights = RelPositionMultiheadAttentionWeights(
|
||||
@ -751,6 +745,9 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
src = src + self.sequence_dropout(self.balancer_ff2(self.feed_forward2(src)),
|
||||
float(self.ff2_skip_rate))
|
||||
|
||||
# bypass in the middle of the layer.
|
||||
src = self.bypass_mid(src_orig, src)
|
||||
|
||||
self_attn = self.self_attn2(
|
||||
src, attn_weights)
|
||||
|
||||
@ -769,10 +766,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
src = self.balancer1(src)
|
||||
src = self.norm(src)
|
||||
|
||||
bypass_scale = self.get_bypass_scale(src.shape[1])
|
||||
# the next line equivalent to: src = src * bypass_scale + src_orig *
|
||||
# (1.0 - bypass_scale), but more memory efficient for backprop.
|
||||
src = src_orig + (src - src_orig) * bypass_scale
|
||||
src = self.bypass(src_orig, src)
|
||||
|
||||
src = self.balancer2(src)
|
||||
src = self.whiten(src)
|
||||
@ -819,10 +813,12 @@ class Zipformer2Encoder(nn.Module):
|
||||
cur_begin = warmup_begin # interpreted as a training batch index
|
||||
for i in range(num_layers):
|
||||
cur_end = cur_begin + delta
|
||||
# treating batch_index=0.0 specially is just to get scan_pessimistic_batches_for_oom()
|
||||
self.layers[i].layer_skip_rate = ScheduledFloat((cur_begin, initial_layerdrop_rate),
|
||||
(cur_end, final_layerdrop_rate),
|
||||
default=0.0)
|
||||
self.layers[i].bypass.skip_rate = ScheduledFloat((cur_begin, initial_layerdrop_rate),
|
||||
(cur_end, final_layerdrop_rate),
|
||||
default=0.0)
|
||||
self.layers[i].bypass_mid.skip_rate = ScheduledFloat((cur_begin, initial_layerdrop_rate),
|
||||
(cur_end, final_layerdrop_rate),
|
||||
default=0.0)
|
||||
cur_begin = cur_end
|
||||
|
||||
def forward(
|
||||
@ -869,6 +865,58 @@ class Zipformer2Encoder(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class BypassModule(nn.Module):
|
||||
"""
|
||||
An nn.Module that implements a learnable bypass scale, and also randomized per-sequence
|
||||
layer-skipping. The bypass is limited during early stages of training to be close to
|
||||
"straight-through", i.e. to not do the bypass operation much initially, in order to
|
||||
force all the modules to learn something.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0),
|
||||
scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
|
||||
scale_max: FloatLike = 1.0):
|
||||
super().__init__()
|
||||
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
||||
self.skip_rate = copy.deepcopy(skip_rate)
|
||||
self.scale_min = copy.deepcopy(scale_min)
|
||||
self.scale_max = copy.deepcopy(scale_max)
|
||||
|
||||
|
||||
def _get_bypass_scale(self, batch_size: int):
|
||||
# returns bypass-scale of shape (num_channels,),
|
||||
# or (batch_size, num_channels,). This is actually the
|
||||
# scale on the non-residual term, so 0 correponds to bypassing
|
||||
# this module.
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
return self.bypass_scale
|
||||
else:
|
||||
ans = limit_param_value(self.bypass_scale,
|
||||
min=float(self.scale_min),
|
||||
max=float(self.scale_max))
|
||||
skip_rate = float(self.skip_rate)
|
||||
if skip_rate != 0.0:
|
||||
mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate
|
||||
ans = ans * mask
|
||||
# now ans is of shape (batch_size, num_channels), and is zero for sequences
|
||||
# on which we have randomly chosen to do layer-skipping.
|
||||
return ans
|
||||
|
||||
def forward(self,
|
||||
src_orig: Tensor,
|
||||
src: Tensor):
|
||||
"""
|
||||
Args: src_orig and src are both of shape (seq_len, batch_size, num_channels)
|
||||
Returns: something with the same shape as src and src_orig
|
||||
"""
|
||||
bypass_scale = self._get_bypass_scale(src.shape[1])
|
||||
return src_orig + (src - src_orig) * bypass_scale
|
||||
|
||||
|
||||
|
||||
|
||||
class DownsampledZipformer2Encoder(nn.Module):
|
||||
r"""
|
||||
DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user