mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Replace SimpleCombiner with BypassModule, for simplicity
Refactor code for simplicity Fix bug
This commit is contained in:
parent
5f790c41f7
commit
fb6a1c1464
@ -255,8 +255,7 @@ class Zipformer2(EncoderInterface):
|
||||
logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will "
|
||||
f"combine the outputs of layers {j} and {i-1}, with downsampling_factor={z[j]} and {z[i-1]}.")
|
||||
skip_layers.append(j)
|
||||
skip_modules.append(SimpleCombiner(self.encoder_dim[i-1],
|
||||
min_weight=(0.0, 0.25)))
|
||||
skip_modules.append(BypassModule(self.encoder_dim[i]))
|
||||
break
|
||||
self.skip_layers = skip_layers
|
||||
self.skip_modules = nn.ModuleList(skip_modules)
|
||||
@ -389,22 +388,25 @@ class Zipformer2(EncoderInterface):
|
||||
|
||||
for i, module in enumerate(self.encoders):
|
||||
ds = self.downsampling_factor[i]
|
||||
x = convert_num_channels(x, self.encoder_dim[i])
|
||||
|
||||
if self.skip_layers[i] is not None:
|
||||
# this how we implement U-net-like skipping of some series of
|
||||
# stacks. The layer_skip_dropout_prob is to discourage it from
|
||||
# completely ignoring the middle layers, especially early in
|
||||
# training,
|
||||
batch_size = x.shape[1]
|
||||
skip_x = self.skip_modules[i](outputs[self.skip_layers[i]], x)
|
||||
skip_output = convert_num_channels(outputs[self.skip_layers[i]],
|
||||
self.encoder_dim[i])
|
||||
skip_x = self.skip_modules[i](skip_output, x)
|
||||
|
||||
layer_skip_dropout_prob = float(self.layer_skip_dropout_prob)
|
||||
if self.training and layer_skip_dropout_prob > 0:
|
||||
batch_size = x.shape[1]
|
||||
mask = (torch.rand((1, batch_size, 1), device=x.device) >
|
||||
layer_skip_dropout_prob)
|
||||
x = torch.where(mask, skip_x, x)
|
||||
else:
|
||||
x = skip_x
|
||||
x = convert_num_channels(x, self.encoder_dim[i])
|
||||
x = module(x,
|
||||
chunk_size=chunk_size,
|
||||
feature_mask=feature_masks[i],
|
||||
@ -524,14 +526,15 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
|
||||
ff2_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)),
|
||||
ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)),
|
||||
bypass_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0),
|
||||
) -> None:
|
||||
super(Zipformer2EncoderLayer, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
# self.bypass implements layer skipping as well as bypass; see its default values.
|
||||
self.bypass = BypassModule(embed_dim)
|
||||
self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate)
|
||||
# bypass_mid is bypass used in the middle of the layer.
|
||||
self.bypass_mid = BypassModule(embed_dim, skip_rate=0.0)
|
||||
self.bypass_mid = BypassModule(embed_dim)
|
||||
|
||||
|
||||
# skip probability for dynamic modules (meaning: anything but feedforward).
|
||||
@ -872,7 +875,7 @@ class BypassModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0),
|
||||
skip_rate: FloatLike = 0.0,
|
||||
scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
|
||||
scale_max: FloatLike = 1.0):
|
||||
super().__init__()
|
||||
@ -931,7 +934,7 @@ class DownsampledZipformer2Encoder(nn.Module):
|
||||
downsample, dropout)
|
||||
self.encoder = encoder
|
||||
self.upsample = SimpleUpsample(dim, downsample)
|
||||
self.out_combiner = SimpleCombiner(dim, min_weight=(0.0, 0.25))
|
||||
self.out_combiner = BypassModule(dim)
|
||||
|
||||
|
||||
def forward(self,
|
||||
@ -1048,46 +1051,6 @@ class SimpleUpsample(torch.nn.Module):
|
||||
src = src.reshape(seq_len * upsample, batch_size, num_channels)
|
||||
return src
|
||||
|
||||
class SimpleCombiner(torch.nn.Module):
|
||||
"""
|
||||
A very simple way of combining 2 vectors of 2 different dims, via a
|
||||
learned weighted combination in the shared part of the dim.
|
||||
Args:
|
||||
dim1: the dimension of the first input, e.g. 256
|
||||
dim2: the dimension of the second input, e.g. 384.
|
||||
The output will have the same dimension as dim2.
|
||||
"""
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
min_weight: Tuple[float, float] = (0., 0.)):
|
||||
super(SimpleCombiner, self).__init__()
|
||||
initial_weight1 = 0.1
|
||||
self.weight1 = nn.Parameter(torch.full((dim,), initial_weight1))
|
||||
self.min_weight = min_weight
|
||||
|
||||
def forward(self,
|
||||
src1: Tensor,
|
||||
src2: Tensor) -> Tensor:
|
||||
"""
|
||||
src1: (*, other_dim)
|
||||
src2: (*, dim)
|
||||
|
||||
Returns: a tensor of shape (*, dim)
|
||||
"""
|
||||
assert src1.shape[:-1] == src2.shape[:-1]
|
||||
num_channels = src2.shape[-1]
|
||||
src1 = convert_num_channels(src1, num_channels)
|
||||
|
||||
|
||||
weight1 = limit_param_value(self.weight1,
|
||||
min=self.min_weight[0],
|
||||
max=1.0-self.min_weight[1],
|
||||
training=self.training)
|
||||
|
||||
return src1 * weight1 + src2 * (1.0 - weight1)
|
||||
|
||||
|
||||
|
||||
|
||||
class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user