Merge branch 'scaled_adam_exp747' into scaled_adam_exp748

This commit is contained in:
Daniel Povey 2022-12-20 23:23:49 +08:00
commit 4d61d39d36

View File

@ -338,8 +338,14 @@ class Zipformer(EncoderInterface):
# this how we implement U-net-like skipping of some series of
# stacks. The layer_skip_dropout_prob is to discourage it, especially
# early in training, from completely ignoring the middle layers.
if not (self.training and random.random() < float(self.layer_skip_dropout_prob)):
x = self.skip_modules[i](outputs[self.skip_layers[i]], x)
batch_size = x.shape[0]
skip_x = self.skip_modules[i](outputs[self.skip_layers[i]], x)
if self.training:
mask = (torch.rand((batch_size, 1, 1), device=x.device) >
float(self.layer_skip_dropout_prob))
x = torch.where(mask, skip_x, x)
else:
x = skip_x
x = module(x,
feature_mask=feature_masks[i],
src_key_padding_mask=None if mask is None else mask[...,::ds])
@ -463,13 +469,24 @@ class ZipformerEncoderLayer(nn.Module):
def remove_attention_weights(self):
self.self_attn_weights = None
def get_bypass_scale(self):
def get_bypass_scale(self, batch_size: int):
# returns bypass-scale of shape (num_channels,),
# or (batch_size, num_channels,). This is actually the
# scale on the delta src - src_orig, so 0 correponds to bypassing
# this module.
if torch.jit.is_scripting() or not self.training:
return self.bypass_scale
else:
return limit_param_value(self.bypass_scale,
min=float(self.bypass_min),
max=float(self.bypass_max))
ans = limit_param_value(self.bypass_scale,
min=float(self.bypass_min),
max=float(self.bypass_max))
layer_skip_rate = float(self.layer_skip_rate)
if layer_skip_rate != 0.0:
mask = torch.rand((batch_size, 1), device=ans.device) > layer_skip_rate
ans = ans * mask
# now ans is of shape (batch_size, num_channels), and is zero for sequences
# on which we have randomly chosen to do layer-skipping.
return ans
def forward(
self,
@ -521,10 +538,6 @@ class ZipformerEncoderLayer(nn.Module):
# attention weights from another one.
head_offset = 0 if self.self_attn_weights is not None else 2
if self.training and random.random() < float(self.layer_skip_rate):
# skip the layer
return src, attn_weights
use_self_attn = (random.random() >= attention_skip_rate)
if use_self_attn:
selected_attn_weights = attn_weights[head_offset:head_offset+2]
@ -564,7 +577,7 @@ class ZipformerEncoderLayer(nn.Module):
delta = src - src_orig
src = src_orig + delta * self.get_bypass_scale()
src = src_orig + delta * self.get_bypass_scale(src.shape[1])
src = self.whiten(src)
return src, attn_weights