Refactor bypass, and add bypass in the middle of the layer.

This commit is contained in:
Daniel Povey 2023-04-05 14:44:16 +08:00
parent b526f3af00
commit 67fcae95a8

View File

@ -207,7 +207,7 @@ class Zipformer2(EncoderInterface):
dropout=dropout,
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
final_layerdrop_rate=0.02 * (downsampling_factor[i] ** 0.5),
)
if downsampling_factor[i] != 1:
@ -519,23 +519,21 @@ class Zipformer2EncoderLayer(nn.Module):
dropout: FloatLike = 0.1,
cnn_module_kernel: int = 31,
causal: bool = False,
# layer_skip_rate will be overwritten to change warmup begin and end times.
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
# to work correctly.
layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
ff2_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)),
ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)),
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0),
bypass_max: FloatLike = 1.0,
) -> None:
super(Zipformer2EncoderLayer, self).__init__()
self.embed_dim = embed_dim
# probability of skipping the entire layer.
self.layer_skip_rate = copy.deepcopy(layer_skip_rate)
# self.bypass implements layer skipping as well as bypass; see its default values.
self.bypass = BypassModule(embed_dim)
# bypass_mid is bypass used in the middle of the layer.
self.bypass_mid = BypassModule(embed_dim)
# skip probability for dynamic modules (meaning: anything but feedforward).
self.attention_skip_rate = copy.deepcopy(attention_skip_rate)
# an additional skip probability that applies to ConvModule to stop it from
@ -547,10 +545,6 @@ class Zipformer2EncoderLayer(nn.Module):
self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate)
self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate)
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
# ever becoming zero.
self.bypass_min = copy.deepcopy(bypass_min)
self.bypass_max = copy.deepcopy(bypass_max)
self.const_attention_rate = copy.deepcopy(const_attention_rate)
self.self_attn_weights = RelPositionMultiheadAttentionWeights(
@ -751,6 +745,9 @@ class Zipformer2EncoderLayer(nn.Module):
src = src + self.sequence_dropout(self.balancer_ff2(self.feed_forward2(src)),
float(self.ff2_skip_rate))
# bypass in the middle of the layer.
src = self.bypass_mid(src_orig, src)
self_attn = self.self_attn2(
src, attn_weights)
@ -769,10 +766,7 @@ class Zipformer2EncoderLayer(nn.Module):
src = self.balancer1(src)
src = self.norm(src)
bypass_scale = self.get_bypass_scale(src.shape[1])
# the next line equivalent to: src = src * bypass_scale + src_orig *
# (1.0 - bypass_scale), but more memory efficient for backprop.
src = src_orig + (src - src_orig) * bypass_scale
src = self.bypass(src_orig, src)
src = self.balancer2(src)
src = self.whiten(src)
@ -819,8 +813,10 @@ class Zipformer2Encoder(nn.Module):
cur_begin = warmup_begin # interpreted as a training batch index
for i in range(num_layers):
cur_end = cur_begin + delta
# treating batch_index=0.0 specially is just to get scan_pessimistic_batches_for_oom()
self.layers[i].layer_skip_rate = ScheduledFloat((cur_begin, initial_layerdrop_rate),
self.layers[i].bypass.skip_rate = ScheduledFloat((cur_begin, initial_layerdrop_rate),
(cur_end, final_layerdrop_rate),
default=0.0)
self.layers[i].bypass_mid.skip_rate = ScheduledFloat((cur_begin, initial_layerdrop_rate),
(cur_end, final_layerdrop_rate),
default=0.0)
cur_begin = cur_end
@ -869,6 +865,58 @@ class Zipformer2Encoder(nn.Module):
return output
class BypassModule(nn.Module):
"""
An nn.Module that implements a learnable bypass scale, and also randomized per-sequence
layer-skipping. The bypass is limited during early stages of training to be close to
"straight-through", i.e. to not do the bypass operation much initially, in order to
force all the modules to learn something.
"""
def __init__(
self,
embed_dim: int,
skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0),
scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
scale_max: FloatLike = 1.0):
super().__init__()
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
self.skip_rate = copy.deepcopy(skip_rate)
self.scale_min = copy.deepcopy(scale_min)
self.scale_max = copy.deepcopy(scale_max)
def _get_bypass_scale(self, batch_size: int):
# returns bypass-scale of shape (num_channels,),
# or (batch_size, num_channels,). This is actually the
# scale on the non-residual term, so 0 correponds to bypassing
# this module.
if torch.jit.is_scripting() or not self.training:
return self.bypass_scale
else:
ans = limit_param_value(self.bypass_scale,
min=float(self.scale_min),
max=float(self.scale_max))
skip_rate = float(self.skip_rate)
if skip_rate != 0.0:
mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate
ans = ans * mask
# now ans is of shape (batch_size, num_channels), and is zero for sequences
# on which we have randomly chosen to do layer-skipping.
return ans
def forward(self,
src_orig: Tensor,
src: Tensor):
"""
Args: src_orig and src are both of shape (seq_len, batch_size, num_channels)
Returns: something with the same shape as src and src_orig
"""
bypass_scale = self._get_bypass_scale(src.shape[1])
return src_orig + (src - src_orig) * bypass_scale
class DownsampledZipformer2Encoder(nn.Module):
r"""
DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,