remove skip_modules

This commit is contained in:
yaozengwei 2023-04-24 15:50:12 +08:00
parent 2e80841790
commit 0ec31c84da

View File

@ -129,14 +129,6 @@ class Zipformer2(EncoderInterface):
dropout = ScheduledFloat((0.0, 0.3), dropout = ScheduledFloat((0.0, 0.3),
(20000.0, 0.1)) (20000.0, 0.1))
# this is not the probability of skipping a layer. It is the probability of
# dropping out the "skip module" which allows the model to skip groups of
# encoder stacks; when it's dropped out like this, it means we are forced
# to take the "direct" (non-bypass) path.
self.layer_skip_dropout_prob = ScheduledFloat((0.0, 0.5),
(warmup_batches, 0.025),
(20000.0, 0.0))
def _to_tuple(x): def _to_tuple(x):
""" Converts a single int or a 1-tuple of an int to a tuple with the same length """ Converts a single int or a 1-tuple of an int to a tuple with the same length
as downsampling_factor""" as downsampling_factor"""
@ -223,41 +215,10 @@ class Zipformer2(EncoderInterface):
self.encoders = nn.ModuleList(encoders) self.encoders = nn.ModuleList(encoders)
# initializes self.skip_layers and self.skip_modules
self._init_skip_modules()
self.downsample_output = SimpleDownsample(max(encoder_dim), self.downsample_output = SimpleDownsample(max(encoder_dim),
downsample=output_downsampling_factor, downsample=output_downsampling_factor,
dropout=dropout) dropout=dropout)
def _init_skip_modules(self):
"""
If self.downampling_factor = (1, 2, 4, 8, 4, 2), then at the input of layer
indexed 4 (in zero indexing), with has subsapling_factor=4, we combine the output of
layers 2 and 3; and at the input of layer indexed 5, which which has subsampling_factor=2,
we combine the outputs of layers 1 and 5.
"""
skip_layers = []
skip_modules = []
z = self.downsampling_factor
for i in range(len(z)):
if i <= 1 or z[i-1] <= z[i]:
skip_layers.append(None)
skip_modules.append(Identity())
else:
# TEMP
for j in range(i-2, -1, -1):
if z[j] <= z[i] or j == 0:
# TEMP logging statement.
logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will "
f"combine the outputs of layers {j} and {i-1}, with downsampling_factor={z[j]} and {z[i-1]}.")
skip_layers.append(j)
skip_modules.append(BypassModule(self.encoder_dim[i]))
break
self.skip_layers = skip_layers
self.skip_modules = nn.ModuleList(skip_modules)
def get_feature_masks( def get_feature_masks(
self, self,
x: torch.Tensor) -> List[Union[float, Tensor]]: x: torch.Tensor) -> List[Union[float, Tensor]]:
@ -384,23 +345,6 @@ class Zipformer2(EncoderInterface):
ds = self.downsampling_factor[i] ds = self.downsampling_factor[i]
x = convert_num_channels(x, self.encoder_dim[i]) x = convert_num_channels(x, self.encoder_dim[i])
if self.skip_layers[i] is not None:
# this how we implement U-net-like skipping of some series of
# stacks. The layer_skip_dropout_prob is to discourage it from
# completely ignoring the middle layers, especially early in
# training,
skip_output = convert_num_channels(outputs[self.skip_layers[i]],
self.encoder_dim[i])
skip_x = self.skip_modules[i](skip_output, x)
layer_skip_dropout_prob = float(self.layer_skip_dropout_prob)
if self.training and layer_skip_dropout_prob > 0:
batch_size = x.shape[1]
mask = (torch.rand((1, batch_size, 1), device=x.device) >
layer_skip_dropout_prob)
x = torch.where(mask, skip_x, x)
else:
x = skip_x
x = module(x, x = module(x,
chunk_size=chunk_size, chunk_size=chunk_size,
feature_mask=feature_masks[i], feature_mask=feature_masks[i],