mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Replace SimpleCombiner with BypassModule, for simplicity
Refactor code for simplicity Fix bug
This commit is contained in:
parent
5f790c41f7
commit
fb6a1c1464
@ -255,8 +255,7 @@ class Zipformer2(EncoderInterface):
|
|||||||
logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will "
|
logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will "
|
||||||
f"combine the outputs of layers {j} and {i-1}, with downsampling_factor={z[j]} and {z[i-1]}.")
|
f"combine the outputs of layers {j} and {i-1}, with downsampling_factor={z[j]} and {z[i-1]}.")
|
||||||
skip_layers.append(j)
|
skip_layers.append(j)
|
||||||
skip_modules.append(SimpleCombiner(self.encoder_dim[i-1],
|
skip_modules.append(BypassModule(self.encoder_dim[i]))
|
||||||
min_weight=(0.0, 0.25)))
|
|
||||||
break
|
break
|
||||||
self.skip_layers = skip_layers
|
self.skip_layers = skip_layers
|
||||||
self.skip_modules = nn.ModuleList(skip_modules)
|
self.skip_modules = nn.ModuleList(skip_modules)
|
||||||
@ -389,22 +388,25 @@ class Zipformer2(EncoderInterface):
|
|||||||
|
|
||||||
for i, module in enumerate(self.encoders):
|
for i, module in enumerate(self.encoders):
|
||||||
ds = self.downsampling_factor[i]
|
ds = self.downsampling_factor[i]
|
||||||
|
x = convert_num_channels(x, self.encoder_dim[i])
|
||||||
|
|
||||||
if self.skip_layers[i] is not None:
|
if self.skip_layers[i] is not None:
|
||||||
# this how we implement U-net-like skipping of some series of
|
# this how we implement U-net-like skipping of some series of
|
||||||
# stacks. The layer_skip_dropout_prob is to discourage it from
|
# stacks. The layer_skip_dropout_prob is to discourage it from
|
||||||
# completely ignoring the middle layers, especially early in
|
# completely ignoring the middle layers, especially early in
|
||||||
# training,
|
# training,
|
||||||
batch_size = x.shape[1]
|
skip_output = convert_num_channels(outputs[self.skip_layers[i]],
|
||||||
skip_x = self.skip_modules[i](outputs[self.skip_layers[i]], x)
|
self.encoder_dim[i])
|
||||||
|
skip_x = self.skip_modules[i](skip_output, x)
|
||||||
|
|
||||||
layer_skip_dropout_prob = float(self.layer_skip_dropout_prob)
|
layer_skip_dropout_prob = float(self.layer_skip_dropout_prob)
|
||||||
if self.training and layer_skip_dropout_prob > 0:
|
if self.training and layer_skip_dropout_prob > 0:
|
||||||
|
batch_size = x.shape[1]
|
||||||
mask = (torch.rand((1, batch_size, 1), device=x.device) >
|
mask = (torch.rand((1, batch_size, 1), device=x.device) >
|
||||||
layer_skip_dropout_prob)
|
layer_skip_dropout_prob)
|
||||||
x = torch.where(mask, skip_x, x)
|
x = torch.where(mask, skip_x, x)
|
||||||
else:
|
else:
|
||||||
x = skip_x
|
x = skip_x
|
||||||
x = convert_num_channels(x, self.encoder_dim[i])
|
|
||||||
x = module(x,
|
x = module(x,
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
feature_mask=feature_masks[i],
|
feature_mask=feature_masks[i],
|
||||||
@ -524,14 +526,15 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
|
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
|
||||||
ff2_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)),
|
ff2_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)),
|
||||||
ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)),
|
ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)),
|
||||||
|
bypass_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0),
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Zipformer2EncoderLayer, self).__init__()
|
super(Zipformer2EncoderLayer, self).__init__()
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
# self.bypass implements layer skipping as well as bypass; see its default values.
|
# self.bypass implements layer skipping as well as bypass; see its default values.
|
||||||
self.bypass = BypassModule(embed_dim)
|
self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate)
|
||||||
# bypass_mid is bypass used in the middle of the layer.
|
# bypass_mid is bypass used in the middle of the layer.
|
||||||
self.bypass_mid = BypassModule(embed_dim, skip_rate=0.0)
|
self.bypass_mid = BypassModule(embed_dim)
|
||||||
|
|
||||||
|
|
||||||
# skip probability for dynamic modules (meaning: anything but feedforward).
|
# skip probability for dynamic modules (meaning: anything but feedforward).
|
||||||
@ -872,7 +875,7 @@ class BypassModule(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0),
|
skip_rate: FloatLike = 0.0,
|
||||||
scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
|
scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
|
||||||
scale_max: FloatLike = 1.0):
|
scale_max: FloatLike = 1.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -931,7 +934,7 @@ class DownsampledZipformer2Encoder(nn.Module):
|
|||||||
downsample, dropout)
|
downsample, dropout)
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.upsample = SimpleUpsample(dim, downsample)
|
self.upsample = SimpleUpsample(dim, downsample)
|
||||||
self.out_combiner = SimpleCombiner(dim, min_weight=(0.0, 0.25))
|
self.out_combiner = BypassModule(dim)
|
||||||
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
@ -1048,46 +1051,6 @@ class SimpleUpsample(torch.nn.Module):
|
|||||||
src = src.reshape(seq_len * upsample, batch_size, num_channels)
|
src = src.reshape(seq_len * upsample, batch_size, num_channels)
|
||||||
return src
|
return src
|
||||||
|
|
||||||
class SimpleCombiner(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
A very simple way of combining 2 vectors of 2 different dims, via a
|
|
||||||
learned weighted combination in the shared part of the dim.
|
|
||||||
Args:
|
|
||||||
dim1: the dimension of the first input, e.g. 256
|
|
||||||
dim2: the dimension of the second input, e.g. 384.
|
|
||||||
The output will have the same dimension as dim2.
|
|
||||||
"""
|
|
||||||
def __init__(self,
|
|
||||||
dim: int,
|
|
||||||
min_weight: Tuple[float, float] = (0., 0.)):
|
|
||||||
super(SimpleCombiner, self).__init__()
|
|
||||||
initial_weight1 = 0.1
|
|
||||||
self.weight1 = nn.Parameter(torch.full((dim,), initial_weight1))
|
|
||||||
self.min_weight = min_weight
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
src1: Tensor,
|
|
||||||
src2: Tensor) -> Tensor:
|
|
||||||
"""
|
|
||||||
src1: (*, other_dim)
|
|
||||||
src2: (*, dim)
|
|
||||||
|
|
||||||
Returns: a tensor of shape (*, dim)
|
|
||||||
"""
|
|
||||||
assert src1.shape[:-1] == src2.shape[:-1]
|
|
||||||
num_channels = src2.shape[-1]
|
|
||||||
src1 = convert_num_channels(src1, num_channels)
|
|
||||||
|
|
||||||
|
|
||||||
weight1 = limit_param_value(self.weight1,
|
|
||||||
min=self.min_weight[0],
|
|
||||||
max=1.0-self.min_weight[1],
|
|
||||||
training=self.training)
|
|
||||||
|
|
||||||
return src1 * weight1 + src2 * (1.0 - weight1)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CompactRelPositionalEncoding(torch.nn.Module):
|
class CompactRelPositionalEncoding(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user