mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp747' into scaled_adam_exp748
This commit is contained in:
commit
4d61d39d36
@ -338,8 +338,14 @@ class Zipformer(EncoderInterface):
|
||||
# this how we implement U-net-like skipping of some series of
|
||||
# stacks. The layer_skip_dropout_prob is to discourage it, especially
|
||||
# early in training, from completely ignoring the middle layers.
|
||||
if not (self.training and random.random() < float(self.layer_skip_dropout_prob)):
|
||||
x = self.skip_modules[i](outputs[self.skip_layers[i]], x)
|
||||
batch_size = x.shape[0]
|
||||
skip_x = self.skip_modules[i](outputs[self.skip_layers[i]], x)
|
||||
if self.training:
|
||||
mask = (torch.rand((batch_size, 1, 1), device=x.device) >
|
||||
float(self.layer_skip_dropout_prob))
|
||||
x = torch.where(mask, skip_x, x)
|
||||
else:
|
||||
x = skip_x
|
||||
x = module(x,
|
||||
feature_mask=feature_masks[i],
|
||||
src_key_padding_mask=None if mask is None else mask[...,::ds])
|
||||
@ -463,13 +469,24 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
def remove_attention_weights(self):
|
||||
self.self_attn_weights = None
|
||||
|
||||
def get_bypass_scale(self):
|
||||
def get_bypass_scale(self, batch_size: int):
|
||||
# returns bypass-scale of shape (num_channels,),
|
||||
# or (batch_size, num_channels,). This is actually the
|
||||
# scale on the delta src - src_orig, so 0 correponds to bypassing
|
||||
# this module.
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
return self.bypass_scale
|
||||
else:
|
||||
return limit_param_value(self.bypass_scale,
|
||||
min=float(self.bypass_min),
|
||||
max=float(self.bypass_max))
|
||||
ans = limit_param_value(self.bypass_scale,
|
||||
min=float(self.bypass_min),
|
||||
max=float(self.bypass_max))
|
||||
layer_skip_rate = float(self.layer_skip_rate)
|
||||
if layer_skip_rate != 0.0:
|
||||
mask = torch.rand((batch_size, 1), device=ans.device) > layer_skip_rate
|
||||
ans = ans * mask
|
||||
# now ans is of shape (batch_size, num_channels), and is zero for sequences
|
||||
# on which we have randomly chosen to do layer-skipping.
|
||||
return ans
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -521,10 +538,6 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
# attention weights from another one.
|
||||
head_offset = 0 if self.self_attn_weights is not None else 2
|
||||
|
||||
if self.training and random.random() < float(self.layer_skip_rate):
|
||||
# skip the layer
|
||||
return src, attn_weights
|
||||
|
||||
use_self_attn = (random.random() >= attention_skip_rate)
|
||||
if use_self_attn:
|
||||
selected_attn_weights = attn_weights[head_offset:head_offset+2]
|
||||
@ -564,7 +577,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
|
||||
delta = src - src_orig
|
||||
|
||||
src = src_orig + delta * self.get_bypass_scale()
|
||||
src = src_orig + delta * self.get_bypass_scale(src.shape[1])
|
||||
src = self.whiten(src)
|
||||
|
||||
return src, attn_weights
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user