Replace SimpleCombiner with BypassModule, for simplicity

Refactor code for simplicity

Fix bug
This commit is contained in:
Daniel Povey 2023-04-09 16:57:59 +08:00
parent 5f790c41f7
commit fb6a1c1464

View File

@ -255,8 +255,7 @@ class Zipformer2(EncoderInterface):
logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will "
f"combine the outputs of layers {j} and {i-1}, with downsampling_factor={z[j]} and {z[i-1]}.")
skip_layers.append(j)
skip_modules.append(SimpleCombiner(self.encoder_dim[i-1],
min_weight=(0.0, 0.25)))
skip_modules.append(BypassModule(self.encoder_dim[i]))
break
self.skip_layers = skip_layers
self.skip_modules = nn.ModuleList(skip_modules)
@ -389,22 +388,25 @@ class Zipformer2(EncoderInterface):
for i, module in enumerate(self.encoders):
ds = self.downsampling_factor[i]
x = convert_num_channels(x, self.encoder_dim[i])
if self.skip_layers[i] is not None:
# this how we implement U-net-like skipping of some series of
# stacks. The layer_skip_dropout_prob is to discourage it from
# completely ignoring the middle layers, especially early in
# training,
batch_size = x.shape[1]
skip_x = self.skip_modules[i](outputs[self.skip_layers[i]], x)
skip_output = convert_num_channels(outputs[self.skip_layers[i]],
self.encoder_dim[i])
skip_x = self.skip_modules[i](skip_output, x)
layer_skip_dropout_prob = float(self.layer_skip_dropout_prob)
if self.training and layer_skip_dropout_prob > 0:
batch_size = x.shape[1]
mask = (torch.rand((1, batch_size, 1), device=x.device) >
layer_skip_dropout_prob)
x = torch.where(mask, skip_x, x)
else:
x = skip_x
x = convert_num_channels(x, self.encoder_dim[i])
x = module(x,
chunk_size=chunk_size,
feature_mask=feature_masks[i],
@ -524,14 +526,15 @@ class Zipformer2EncoderLayer(nn.Module):
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
ff2_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)),
ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)),
bypass_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0),
) -> None:
super(Zipformer2EncoderLayer, self).__init__()
self.embed_dim = embed_dim
# self.bypass implements layer skipping as well as bypass; see its default values.
self.bypass = BypassModule(embed_dim)
self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate)
# bypass_mid is bypass used in the middle of the layer.
self.bypass_mid = BypassModule(embed_dim, skip_rate=0.0)
self.bypass_mid = BypassModule(embed_dim)
# skip probability for dynamic modules (meaning: anything but feedforward).
@ -872,7 +875,7 @@ class BypassModule(nn.Module):
def __init__(
self,
embed_dim: int,
skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0),
skip_rate: FloatLike = 0.0,
scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
scale_max: FloatLike = 1.0):
super().__init__()
@ -931,7 +934,7 @@ class DownsampledZipformer2Encoder(nn.Module):
downsample, dropout)
self.encoder = encoder
self.upsample = SimpleUpsample(dim, downsample)
self.out_combiner = SimpleCombiner(dim, min_weight=(0.0, 0.25))
self.out_combiner = BypassModule(dim)
def forward(self,
@ -1048,46 +1051,6 @@ class SimpleUpsample(torch.nn.Module):
src = src.reshape(seq_len * upsample, batch_size, num_channels)
return src
class SimpleCombiner(torch.nn.Module):
"""
A very simple way of combining 2 vectors of 2 different dims, via a
learned weighted combination in the shared part of the dim.
Args:
dim1: the dimension of the first input, e.g. 256
dim2: the dimension of the second input, e.g. 384.
The output will have the same dimension as dim2.
"""
def __init__(self,
dim: int,
min_weight: Tuple[float, float] = (0., 0.)):
super(SimpleCombiner, self).__init__()
initial_weight1 = 0.1
self.weight1 = nn.Parameter(torch.full((dim,), initial_weight1))
self.min_weight = min_weight
def forward(self,
src1: Tensor,
src2: Tensor) -> Tensor:
"""
src1: (*, other_dim)
src2: (*, dim)
Returns: a tensor of shape (*, dim)
"""
assert src1.shape[:-1] == src2.shape[:-1]
num_channels = src2.shape[-1]
src1 = convert_num_channels(src1, num_channels)
weight1 = limit_param_value(self.weight1,
min=self.min_weight[0],
max=1.0-self.min_weight[1],
training=self.training)
return src1 * weight1 + src2 * (1.0 - weight1)
class CompactRelPositionalEncoding(torch.nn.Module):
"""