Restore the changes from scaled_adam_219 and scaled_adam_exp220, accidentally lost, re layer skipping
This commit is contained in:
parent
e4a22bbe96
commit
efbb1d25c7
@ -85,6 +85,10 @@ class Zipformer(EncoderInterface):
|
|||||||
self.zipformer_downsampling_factors = zipformer_downsampling_factors
|
self.zipformer_downsampling_factors = zipformer_downsampling_factors
|
||||||
self.output_downsampling_factor = output_downsampling_factor
|
self.output_downsampling_factor = output_downsampling_factor
|
||||||
|
|
||||||
|
# will be written to, see set_batch_count()
|
||||||
|
self.batch_count = 0
|
||||||
|
self.warmup_end = warmup_batches
|
||||||
|
|
||||||
for u,d in zip(encoder_unmasked_dims, encoder_dims):
|
for u,d in zip(encoder_unmasked_dims, encoder_dims):
|
||||||
assert u <= d
|
assert u <= d
|
||||||
|
|
||||||
@ -132,11 +136,53 @@ class Zipformer(EncoderInterface):
|
|||||||
encoders.append(encoder)
|
encoders.append(encoder)
|
||||||
self.encoders = nn.ModuleList(encoders)
|
self.encoders = nn.ModuleList(encoders)
|
||||||
|
|
||||||
|
# initializes self.skip_layers and self.skip_modules
|
||||||
|
self._init_skip_modules()
|
||||||
|
|
||||||
self.downsample_output = AttentionDownsample(encoder_dims[-1],
|
self.downsample_output = AttentionDownsample(encoder_dims[-1],
|
||||||
encoder_dims[-1],
|
encoder_dims[-1],
|
||||||
downsample=output_downsampling_factor)
|
downsample=output_downsampling_factor)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_layer_skip_dropout_prob(self):
|
||||||
|
if not self.training:
|
||||||
|
return 0.0
|
||||||
|
batch_count = self.batch_count
|
||||||
|
min_dropout_prob = 0.025
|
||||||
|
|
||||||
|
if batch_count > self.warmup_end:
|
||||||
|
return min_dropout_prob
|
||||||
|
else:
|
||||||
|
return 0.5 - (batch_count / self.warmup_end) * (0.5 - min_dropout_prob)
|
||||||
|
|
||||||
|
def _init_skip_modules(self):
|
||||||
|
"""
|
||||||
|
If self.zipformer_downampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer
|
||||||
|
indexed 4 (in zero indexing), with has subsapling_factor=4, we combine the output of
|
||||||
|
layers 2 and 3; and at the input of layer indexed 5, which which has subsampling_factor=2,
|
||||||
|
we combine the outputs of layers 1 and 5.
|
||||||
|
"""
|
||||||
|
skip_layers = []
|
||||||
|
skip_modules = []
|
||||||
|
z = self.zipformer_downsampling_factors
|
||||||
|
for i in range(len(z)):
|
||||||
|
if i <= 1 or z[i-1] <= z[i]:
|
||||||
|
skip_layers.append(None)
|
||||||
|
skip_modules.append(nn.Identity())
|
||||||
|
else:
|
||||||
|
# TEMP
|
||||||
|
for j in range(i-2, -1, -1):
|
||||||
|
if z[j] <= z[i] or j == 0:
|
||||||
|
# TEMP logging statement.
|
||||||
|
logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will "
|
||||||
|
f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}.")
|
||||||
|
skip_layers.append(j)
|
||||||
|
skip_modules.append(SimpleCombiner(self.encoder_dims[j],
|
||||||
|
self.encoder_dims[i-1]))
|
||||||
|
break
|
||||||
|
self.skip_layers = skip_layers
|
||||||
|
self.skip_modules = nn.ModuleList(skip_modules)
|
||||||
|
|
||||||
def get_feature_masks(
|
def get_feature_masks(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor) -> List[Union[float, Tensor]]:
|
x: torch.Tensor) -> List[Union[float, Tensor]]:
|
||||||
@ -221,20 +267,25 @@ class Zipformer(EncoderInterface):
|
|||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
mask = make_pad_mask(lengths)
|
mask = make_pad_mask(lengths)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
feature_masks = self.get_feature_masks(x)
|
feature_masks = self.get_feature_masks(x)
|
||||||
|
|
||||||
for i, module in enumerate(self.encoders):
|
for i, module in enumerate(self.encoders):
|
||||||
ds = self.zipformer_downsampling_factors[i]
|
ds = self.zipformer_downsampling_factors[i]
|
||||||
|
if self.skip_layers[i] is not None:
|
||||||
|
layer_skip_dropout_prob = self._get_layer_skip_dropout_prob()
|
||||||
|
if (not self.training) or random.random() > layer_skip_dropout_prob:
|
||||||
|
x = self.skip_modules[i](outputs[self.skip_layers[i]], x)
|
||||||
x = module(x,
|
x = module(x,
|
||||||
feature_mask=feature_masks[i],
|
feature_mask=feature_masks[i],
|
||||||
src_key_padding_mask=None if mask is None else mask[...,::ds])
|
src_key_padding_mask=None if mask is None else mask[...,::ds])
|
||||||
|
outputs.append(x)
|
||||||
|
|
||||||
x = self.downsample_output(x)
|
x = self.downsample_output(x)
|
||||||
# class Downsample has this rounding behavior..
|
# class Downsample has this rounding behavior..
|
||||||
assert self.output_downsampling_factor == 2
|
assert self.output_downsampling_factor == 2
|
||||||
lengths = (lengths + 1) // 2
|
lengths = (lengths + 1) // 2
|
||||||
|
|
||||||
|
|
||||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
return x, lengths
|
return x, lengths
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user