mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
remove skip_modules
This commit is contained in:
parent
2e80841790
commit
0ec31c84da
@ -129,14 +129,6 @@ class Zipformer2(EncoderInterface):
|
|||||||
dropout = ScheduledFloat((0.0, 0.3),
|
dropout = ScheduledFloat((0.0, 0.3),
|
||||||
(20000.0, 0.1))
|
(20000.0, 0.1))
|
||||||
|
|
||||||
# this is not the probability of skipping a layer. It is the probability of
|
|
||||||
# dropping out the "skip module" which allows the model to skip groups of
|
|
||||||
# encoder stacks; when it's dropped out like this, it means we are forced
|
|
||||||
# to take the "direct" (non-bypass) path.
|
|
||||||
self.layer_skip_dropout_prob = ScheduledFloat((0.0, 0.5),
|
|
||||||
(warmup_batches, 0.025),
|
|
||||||
(20000.0, 0.0))
|
|
||||||
|
|
||||||
def _to_tuple(x):
|
def _to_tuple(x):
|
||||||
""" Converts a single int or a 1-tuple of an int to a tuple with the same length
|
""" Converts a single int or a 1-tuple of an int to a tuple with the same length
|
||||||
as downsampling_factor"""
|
as downsampling_factor"""
|
||||||
@ -223,41 +215,10 @@ class Zipformer2(EncoderInterface):
|
|||||||
|
|
||||||
self.encoders = nn.ModuleList(encoders)
|
self.encoders = nn.ModuleList(encoders)
|
||||||
|
|
||||||
# initializes self.skip_layers and self.skip_modules
|
|
||||||
self._init_skip_modules()
|
|
||||||
|
|
||||||
self.downsample_output = SimpleDownsample(max(encoder_dim),
|
self.downsample_output = SimpleDownsample(max(encoder_dim),
|
||||||
downsample=output_downsampling_factor,
|
downsample=output_downsampling_factor,
|
||||||
dropout=dropout)
|
dropout=dropout)
|
||||||
|
|
||||||
|
|
||||||
def _init_skip_modules(self):
|
|
||||||
"""
|
|
||||||
If self.downampling_factor = (1, 2, 4, 8, 4, 2), then at the input of layer
|
|
||||||
indexed 4 (in zero indexing), with has subsapling_factor=4, we combine the output of
|
|
||||||
layers 2 and 3; and at the input of layer indexed 5, which which has subsampling_factor=2,
|
|
||||||
we combine the outputs of layers 1 and 5.
|
|
||||||
"""
|
|
||||||
skip_layers = []
|
|
||||||
skip_modules = []
|
|
||||||
z = self.downsampling_factor
|
|
||||||
for i in range(len(z)):
|
|
||||||
if i <= 1 or z[i-1] <= z[i]:
|
|
||||||
skip_layers.append(None)
|
|
||||||
skip_modules.append(Identity())
|
|
||||||
else:
|
|
||||||
# TEMP
|
|
||||||
for j in range(i-2, -1, -1):
|
|
||||||
if z[j] <= z[i] or j == 0:
|
|
||||||
# TEMP logging statement.
|
|
||||||
logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will "
|
|
||||||
f"combine the outputs of layers {j} and {i-1}, with downsampling_factor={z[j]} and {z[i-1]}.")
|
|
||||||
skip_layers.append(j)
|
|
||||||
skip_modules.append(BypassModule(self.encoder_dim[i]))
|
|
||||||
break
|
|
||||||
self.skip_layers = skip_layers
|
|
||||||
self.skip_modules = nn.ModuleList(skip_modules)
|
|
||||||
|
|
||||||
def get_feature_masks(
|
def get_feature_masks(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor) -> List[Union[float, Tensor]]:
|
x: torch.Tensor) -> List[Union[float, Tensor]]:
|
||||||
@ -384,23 +345,6 @@ class Zipformer2(EncoderInterface):
|
|||||||
ds = self.downsampling_factor[i]
|
ds = self.downsampling_factor[i]
|
||||||
x = convert_num_channels(x, self.encoder_dim[i])
|
x = convert_num_channels(x, self.encoder_dim[i])
|
||||||
|
|
||||||
if self.skip_layers[i] is not None:
|
|
||||||
# this how we implement U-net-like skipping of some series of
|
|
||||||
# stacks. The layer_skip_dropout_prob is to discourage it from
|
|
||||||
# completely ignoring the middle layers, especially early in
|
|
||||||
# training,
|
|
||||||
skip_output = convert_num_channels(outputs[self.skip_layers[i]],
|
|
||||||
self.encoder_dim[i])
|
|
||||||
skip_x = self.skip_modules[i](skip_output, x)
|
|
||||||
|
|
||||||
layer_skip_dropout_prob = float(self.layer_skip_dropout_prob)
|
|
||||||
if self.training and layer_skip_dropout_prob > 0:
|
|
||||||
batch_size = x.shape[1]
|
|
||||||
mask = (torch.rand((1, batch_size, 1), device=x.device) >
|
|
||||||
layer_skip_dropout_prob)
|
|
||||||
x = torch.where(mask, skip_x, x)
|
|
||||||
else:
|
|
||||||
x = skip_x
|
|
||||||
x = module(x,
|
x = module(x,
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
feature_mask=feature_masks[i],
|
feature_mask=feature_masks[i],
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user