Replace SimpleCombiner with BypassModule, for simplicity

Refactor code for simplicity

Fix bug
This commit is contained in:
Daniel Povey 2023-04-09 16:57:59 +08:00
parent 5f790c41f7
commit fb6a1c1464

View File

@ -255,8 +255,7 @@ class Zipformer2(EncoderInterface):
logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will "
f"combine the outputs of layers {j} and {i-1}, with downsampling_factor={z[j]} and {z[i-1]}.") f"combine the outputs of layers {j} and {i-1}, with downsampling_factor={z[j]} and {z[i-1]}.")
skip_layers.append(j) skip_layers.append(j)
skip_modules.append(SimpleCombiner(self.encoder_dim[i-1], skip_modules.append(BypassModule(self.encoder_dim[i]))
min_weight=(0.0, 0.25)))
break break
self.skip_layers = skip_layers self.skip_layers = skip_layers
self.skip_modules = nn.ModuleList(skip_modules) self.skip_modules = nn.ModuleList(skip_modules)
@ -389,22 +388,25 @@ class Zipformer2(EncoderInterface):
for i, module in enumerate(self.encoders): for i, module in enumerate(self.encoders):
ds = self.downsampling_factor[i] ds = self.downsampling_factor[i]
x = convert_num_channels(x, self.encoder_dim[i])
if self.skip_layers[i] is not None: if self.skip_layers[i] is not None:
# this how we implement U-net-like skipping of some series of # this how we implement U-net-like skipping of some series of
# stacks. The layer_skip_dropout_prob is to discourage it from # stacks. The layer_skip_dropout_prob is to discourage it from
# completely ignoring the middle layers, especially early in # completely ignoring the middle layers, especially early in
# training, # training,
batch_size = x.shape[1] skip_output = convert_num_channels(outputs[self.skip_layers[i]],
skip_x = self.skip_modules[i](outputs[self.skip_layers[i]], x) self.encoder_dim[i])
skip_x = self.skip_modules[i](skip_output, x)
layer_skip_dropout_prob = float(self.layer_skip_dropout_prob) layer_skip_dropout_prob = float(self.layer_skip_dropout_prob)
if self.training and layer_skip_dropout_prob > 0: if self.training and layer_skip_dropout_prob > 0:
batch_size = x.shape[1]
mask = (torch.rand((1, batch_size, 1), device=x.device) > mask = (torch.rand((1, batch_size, 1), device=x.device) >
layer_skip_dropout_prob) layer_skip_dropout_prob)
x = torch.where(mask, skip_x, x) x = torch.where(mask, skip_x, x)
else: else:
x = skip_x x = skip_x
x = convert_num_channels(x, self.encoder_dim[i])
x = module(x, x = module(x,
chunk_size=chunk_size, chunk_size=chunk_size,
feature_mask=feature_masks[i], feature_mask=feature_masks[i],
@ -524,14 +526,15 @@ class Zipformer2EncoderLayer(nn.Module):
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
ff2_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), ff2_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)),
ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)),
bypass_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0),
) -> None: ) -> None:
super(Zipformer2EncoderLayer, self).__init__() super(Zipformer2EncoderLayer, self).__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
# self.bypass implements layer skipping as well as bypass; see its default values. # self.bypass implements layer skipping as well as bypass; see its default values.
self.bypass = BypassModule(embed_dim) self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate)
# bypass_mid is bypass used in the middle of the layer. # bypass_mid is bypass used in the middle of the layer.
self.bypass_mid = BypassModule(embed_dim, skip_rate=0.0) self.bypass_mid = BypassModule(embed_dim)
# skip probability for dynamic modules (meaning: anything but feedforward). # skip probability for dynamic modules (meaning: anything but feedforward).
@ -872,7 +875,7 @@ class BypassModule(nn.Module):
def __init__( def __init__(
self, self,
embed_dim: int, embed_dim: int,
skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0), skip_rate: FloatLike = 0.0,
scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
scale_max: FloatLike = 1.0): scale_max: FloatLike = 1.0):
super().__init__() super().__init__()
@ -931,7 +934,7 @@ class DownsampledZipformer2Encoder(nn.Module):
downsample, dropout) downsample, dropout)
self.encoder = encoder self.encoder = encoder
self.upsample = SimpleUpsample(dim, downsample) self.upsample = SimpleUpsample(dim, downsample)
self.out_combiner = SimpleCombiner(dim, min_weight=(0.0, 0.25)) self.out_combiner = BypassModule(dim)
def forward(self, def forward(self,
@ -1048,46 +1051,6 @@ class SimpleUpsample(torch.nn.Module):
src = src.reshape(seq_len * upsample, batch_size, num_channels) src = src.reshape(seq_len * upsample, batch_size, num_channels)
return src return src
class SimpleCombiner(torch.nn.Module):
"""
A very simple way of combining 2 vectors of 2 different dims, via a
learned weighted combination in the shared part of the dim.
Args:
dim1: the dimension of the first input, e.g. 256
dim2: the dimension of the second input, e.g. 384.
The output will have the same dimension as dim2.
"""
def __init__(self,
dim: int,
min_weight: Tuple[float, float] = (0., 0.)):
super(SimpleCombiner, self).__init__()
initial_weight1 = 0.1
self.weight1 = nn.Parameter(torch.full((dim,), initial_weight1))
self.min_weight = min_weight
def forward(self,
src1: Tensor,
src2: Tensor) -> Tensor:
"""
src1: (*, other_dim)
src2: (*, dim)
Returns: a tensor of shape (*, dim)
"""
assert src1.shape[:-1] == src2.shape[:-1]
num_channels = src2.shape[-1]
src1 = convert_num_channels(src1, num_channels)
weight1 = limit_param_value(self.weight1,
min=self.min_weight[0],
max=1.0-self.min_weight[1],
training=self.training)
return src1 * weight1 + src2 * (1.0 - weight1)
class CompactRelPositionalEncoding(torch.nn.Module): class CompactRelPositionalEncoding(torch.nn.Module):
""" """